plugin.audio.librespot/resources/lib/deps/zeroconf/_protocol/outgoing.py

499 lines
18 KiB
Python
Raw Normal View History

2024-02-21 06:17:59 +00:00
""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
This module provides a framework for the use of DNS Service Discovery
using IP multicast.
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
USA
"""
import enum
import logging
from struct import Struct
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from .._dns import DNSPointer, DNSQuestion, DNSRecord
from .._exceptions import NamePartTooLongException
from .._logger import log
from ..const import (
_CLASS_UNIQUE,
_DNS_HOST_TTL,
_DNS_OTHER_TTL,
_DNS_PACKET_HEADER_LEN,
_FLAGS_QR_MASK,
_FLAGS_QR_QUERY,
_FLAGS_QR_RESPONSE,
_FLAGS_TC,
_MAX_MSG_ABSOLUTE,
_MAX_MSG_TYPICAL,
)
from .incoming import DNSIncoming
str_ = str
float_ = float
int_ = int
bytes_ = bytes
DNSQuestion_ = DNSQuestion
DNSRecord_ = DNSRecord
PACK_BYTE = Struct('>B').pack
PACK_SHORT = Struct('>H').pack
PACK_LONG = Struct('>L').pack
SHORT_CACHE_MAX = 128
BYTE_TABLE = tuple(PACK_BYTE(i) for i in range(256))
SHORT_LOOKUP = tuple(PACK_SHORT(i) for i in range(SHORT_CACHE_MAX))
LONG_LOOKUP = {i: PACK_LONG(i) for i in (_DNS_OTHER_TTL, _DNS_HOST_TTL, 0)}
class State(enum.Enum):
init = 0
finished = 1
STATE_INIT = State.init.value
STATE_FINISHED = State.finished.value
LOGGING_IS_ENABLED_FOR = log.isEnabledFor
LOGGING_DEBUG = logging.DEBUG
class DNSOutgoing:
"""Object representation of an outgoing packet"""
__slots__ = (
'flags',
'finished',
'id',
'multicast',
'packets_data',
'names',
'data',
'size',
'allow_long',
'state',
'questions',
'answers',
'authorities',
'additionals',
)
def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None:
self.flags = flags
self.finished = False
self.id = id_
self.multicast = multicast
self.packets_data: List[bytes] = []
# these 3 are per-packet -- see also _reset_for_next_packet()
self.names: Dict[str, int] = {}
self.data: List[bytes] = []
self.size: int = _DNS_PACKET_HEADER_LEN
self.allow_long: bool = True
self.state = STATE_INIT
self.questions: List[DNSQuestion] = []
self.answers: List[Tuple[DNSRecord, float]] = []
self.authorities: List[DNSPointer] = []
self.additionals: List[DNSRecord] = []
def is_query(self) -> bool:
"""Returns true if this is a query."""
return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY
def is_response(self) -> bool:
"""Returns true if this is a response."""
return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE
def _reset_for_next_packet(self) -> None:
self.names = {}
self.data = []
self.size = _DNS_PACKET_HEADER_LEN
self.allow_long = True
def __repr__(self) -> str:
return '<DNSOutgoing:{%s}>' % ', '.join(
[
'multicast=%s' % self.multicast,
'flags=%s' % self.flags,
'questions=%s' % self.questions,
'answers=%s' % self.answers,
'authorities=%s' % self.authorities,
'additionals=%s' % self.additionals,
]
)
def add_question(self, record: DNSQuestion) -> None:
"""Adds a question"""
self.questions.append(record)
def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None:
"""Adds an answer"""
if not record.suppressed_by(inp):
self.add_answer_at_time(record, 0.0)
def add_answer_at_time(self, record: Optional[DNSRecord], now: float_) -> None:
"""Adds an answer if it does not expire by a certain time"""
now_double = now
if record is not None and (now_double == 0 or not record.is_expired(now_double)):
self.answers.append((record, now))
def add_authorative_answer(self, record: DNSPointer) -> None:
"""Adds an authoritative answer"""
self.authorities.append(record)
def add_additional_answer(self, record: DNSRecord) -> None:
"""Adds an additional answer
From: RFC 6763, DNS-Based Service Discovery, February 2013
12. DNS Additional Record Generation
DNS has an efficiency feature whereby a DNS server may place
additional records in the additional section of the DNS message.
These additional records are records that the client did not
explicitly request, but the server has reasonable grounds to expect
that the client might request them shortly, so including them can
save the client from having to issue additional queries.
This section recommends which additional records SHOULD be generated
to improve network efficiency, for both Unicast and Multicast DNS-SD
responses.
12.1. PTR Records
When including a DNS-SD Service Instance Enumeration or Selective
Instance Enumeration (subtype) PTR record in a response packet, the
server/responder SHOULD include the following additional records:
o The SRV record(s) named in the PTR rdata.
o The TXT record(s) named in the PTR rdata.
o All address records (type "A" and "AAAA") named in the SRV rdata.
12.2. SRV Records
When including an SRV record in a response packet, the
server/responder SHOULD include the following additional records:
o All address records (type "A" and "AAAA") named in the SRV rdata.
"""
self.additionals.append(record)
def _write_byte(self, value: int_) -> None:
"""Writes a single byte to the packet"""
self.data.append(BYTE_TABLE[value])
self.size += 1
def _get_short(self, value: int_) -> bytes:
"""Convert an unsigned short to 2 bytes."""
return SHORT_LOOKUP[value] if value < SHORT_CACHE_MAX else PACK_SHORT(value)
def _insert_short_at_start(self, value: int_) -> None:
"""Inserts an unsigned short at the start of the packet"""
self.data.insert(0, self._get_short(value))
def _replace_short(self, index: int_, value: int_) -> None:
"""Replaces an unsigned short in a certain position in the packet"""
self.data[index] = self._get_short(value)
def write_short(self, value: int_) -> None:
"""Writes an unsigned short to the packet"""
self.data.append(self._get_short(value))
self.size += 2
def _write_int(self, value: Union[float, int]) -> None:
"""Writes an unsigned integer to the packet"""
value_as_int = int(value)
long_bytes = LONG_LOOKUP.get(value_as_int)
if long_bytes is not None:
self.data.append(long_bytes)
else:
self.data.append(PACK_LONG(value_as_int))
self.size += 4
def write_string(self, value: bytes_) -> None:
"""Writes a string to the packet"""
if TYPE_CHECKING:
assert isinstance(value, bytes)
self.data.append(value)
self.size += len(value)
def _write_utf(self, s: str_) -> None:
"""Writes a UTF-8 string of a given length to the packet"""
utfstr = s.encode('utf-8')
length = len(utfstr)
if length > 64:
raise NamePartTooLongException
self._write_byte(length)
self.write_string(utfstr)
def write_character_string(self, value: bytes) -> None:
if TYPE_CHECKING:
assert isinstance(value, bytes)
length = len(value)
if length > 256:
raise NamePartTooLongException
self._write_byte(length)
self.write_string(value)
def write_name(self, name: str_) -> None:
"""
Write names to packet
18.14. Name Compression
When generating Multicast DNS messages, implementations SHOULD use
name compression wherever possible to compress the names of resource
records, by replacing some or all of the resource record name with a
compact two-byte reference to an appearance of that data somewhere
earlier in the message [RFC1035].
"""
# split name into each label
if name.endswith('.'):
name = name[:-1]
index = self.names.get(name, 0)
if index:
self._write_link_to_name(index)
return
start_size = self.size
labels = name.split('.')
# Write each new label or a pointer to the existing one in the packet
self.names[name] = start_size
self._write_utf(labels[0])
name_length = 0
for count in range(1, len(labels)):
partial_name = '.'.join(labels[count:])
index = self.names.get(partial_name, 0)
if index:
self._write_link_to_name(index)
return
if name_length == 0:
name_length = len(name.encode('utf-8'))
self.names[partial_name] = start_size + name_length - len(partial_name.encode('utf-8'))
self._write_utf(labels[count])
# this is the end of a name
self._write_byte(0)
def _write_link_to_name(self, index: int_) -> None:
# If part of the name already exists in the packet,
# create a pointer to it
self._write_byte((index >> 8) | 0xC0)
self._write_byte(index & 0xFF)
def _write_question(self, question: DNSQuestion_) -> bool:
"""Writes a question to the packet"""
start_data_length = len(self.data)
start_size = self.size
self.write_name(question.name)
self.write_short(question.type)
self._write_record_class(question)
return self._check_data_limit_or_rollback(start_data_length, start_size)
def _write_record_class(self, record: Union[DNSQuestion_, DNSRecord_]) -> None:
"""Write out the record class including the unique/unicast (QU) bit."""
class_ = record.class_
if record.unique is True and self.multicast:
self.write_short(class_ | _CLASS_UNIQUE)
else:
self.write_short(class_)
def _write_ttl(self, record: DNSRecord_, now: float_) -> None:
"""Write out the record ttl."""
self._write_int(record.ttl if now == 0 else record.get_remaining_ttl(now))
def _write_record(self, record: DNSRecord_, now: float_) -> bool:
"""Writes a record (answer, authoritative answer, additional) to
the packet. Returns True on success, or False if we did not
because the packet because the record does not fit."""
start_data_length = len(self.data)
start_size = self.size
self.write_name(record.name)
self.write_short(record.type)
self._write_record_class(record)
self._write_ttl(record, now)
index = len(self.data)
self.write_short(0) # Will get replaced with the actual size
record.write(self)
# Adjust size for the short we will write before this record
length = 0
for d in self.data[index + 1 :]:
length += len(d)
# Here we replace the 0 length short we wrote
# before with the actual length
self._replace_short(index, length)
return self._check_data_limit_or_rollback(start_data_length, start_size)
def _check_data_limit_or_rollback(self, start_data_length: int_, start_size: int_) -> bool:
"""Check data limit, if we go over, then rollback and return False."""
len_limit = _MAX_MSG_ABSOLUTE if self.allow_long else _MAX_MSG_TYPICAL
self.allow_long = False
if self.size <= len_limit:
return True
if LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG): # pragma: no branch
log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit)
del self.data[start_data_length:]
self.size = start_size
start_size_int = start_size
rollback_names = [name for name, idx in self.names.items() if idx >= start_size_int]
for name in rollback_names:
del self.names[name]
return False
def _write_questions_from_offset(self, questions_offset: int_) -> int:
questions_written = 0
for question in self.questions[questions_offset:]:
if not self._write_question(question):
break
questions_written += 1
return questions_written
def _write_answers_from_offset(self, answer_offset: int_) -> int:
answers_written = 0
for answer, time_ in self.answers[answer_offset:]:
if not self._write_record(answer, time_):
break
answers_written += 1
return answers_written
def _write_records_from_offset(self, records: Sequence[DNSRecord], offset: int_) -> int:
records_written = 0
for record in records[offset:]:
if not self._write_record(record, 0):
break
records_written += 1
return records_written
def _has_more_to_add(
self, questions_offset: int_, answer_offset: int_, authority_offset: int_, additional_offset: int_
) -> bool:
"""Check if all questions, answers, authority, and additionals have been written to the packet."""
return (
questions_offset < len(self.questions)
or answer_offset < len(self.answers)
or authority_offset < len(self.authorities)
or additional_offset < len(self.additionals)
)
def packets(self) -> List[bytes]:
"""Returns a list of bytestrings containing the packets' bytes
No further parts should be added to the packet once this
is done. The packets are each restricted to _MAX_MSG_TYPICAL
or less in length, except for the case of a single answer which
will be written out to a single oversized packet no more than
_MAX_MSG_ABSOLUTE in length (and hence will be subject to IP
fragmentation potentially)."""
packets_data = self.packets_data
if self.state == STATE_FINISHED:
return packets_data
questions_offset = 0
answer_offset = 0
authority_offset = 0
additional_offset = 0
# we have to at least write out the question
debug_enable = LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG) is True
has_more_to_add = True
while has_more_to_add:
if debug_enable:
log.debug(
"offsets = questions=%d, answers=%d, authorities=%d, additionals=%d",
questions_offset,
answer_offset,
authority_offset,
additional_offset,
)
log.debug(
"lengths = questions=%d, answers=%d, authorities=%d, additionals=%d",
len(self.questions),
len(self.answers),
len(self.authorities),
len(self.additionals),
)
questions_written = self._write_questions_from_offset(questions_offset)
answers_written = self._write_answers_from_offset(answer_offset)
authorities_written = self._write_records_from_offset(self.authorities, authority_offset)
additionals_written = self._write_records_from_offset(self.additionals, additional_offset)
made_progress = bool(self.data)
self._insert_short_at_start(additionals_written)
self._insert_short_at_start(authorities_written)
self._insert_short_at_start(answers_written)
self._insert_short_at_start(questions_written)
questions_offset += questions_written
answer_offset += answers_written
authority_offset += authorities_written
additional_offset += additionals_written
if debug_enable:
log.debug(
"now offsets = questions=%d, answers=%d, authorities=%d, additionals=%d",
questions_offset,
answer_offset,
authority_offset,
additional_offset,
)
has_more_to_add = self._has_more_to_add(
questions_offset, answer_offset, authority_offset, additional_offset
)
if has_more_to_add and self.is_query():
# https://datatracker.ietf.org/doc/html/rfc6762#section-7.2
if debug_enable: # pragma: no branch
log.debug("Setting TC flag")
self._insert_short_at_start(self.flags | _FLAGS_TC)
else:
self._insert_short_at_start(self.flags)
if self.multicast:
self._insert_short_at_start(0)
else:
self._insert_short_at_start(self.id)
packets_data.append(b''.join(self.data))
if not made_progress:
# Generating an empty packet is not a desirable outcome, but currently
# too many internals rely on this behavior. So, we'll just return an
# empty packet and log a warning until this can be refactored at a later
# date.
log.warning("packets() made no progress adding records; returning")
break
if has_more_to_add:
self._reset_for_next_packet()
self.state = STATE_FINISHED
return packets_data