""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine Copyright 2003 Paul Scott-Murphy, 2014 William McBrine This module provides a framework for the use of DNS Service Discovery using IP multicast. This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """ import struct import sys from typing import Any, Dict, List, Optional, Set, Tuple, Union from .._dns import ( DNSAddress, DNSHinfo, DNSNsec, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText, ) from .._exceptions import IncomingDecodeError from .._logger import log from .._utils.time import current_time_millis from ..const import ( _FLAGS_QR_MASK, _FLAGS_QR_QUERY, _FLAGS_QR_RESPONSE, _FLAGS_TC, _TYPE_A, _TYPE_AAAA, _TYPE_CNAME, _TYPE_HINFO, _TYPE_NSEC, _TYPE_PTR, _TYPE_SRV, _TYPE_TXT, _TYPES, ) DNS_COMPRESSION_HEADER_LEN = 1 DNS_COMPRESSION_POINTER_LEN = 2 MAX_DNS_LABELS = 128 MAX_NAME_LENGTH = 253 DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError) _seen_logs: Dict[str, Union[int, tuple]] = {} _str = str _int = int class DNSIncoming: """Object representation of an incoming DNS packet""" __slots__ = ( "_did_read_others", 'flags', 'offset', 'data', 'view', '_data_len', '_name_cache', '_questions', '_answers', 'id', '_num_questions', '_num_answers', '_num_authorities', '_num_additionals', 'valid', 'now', 'scope_id', 'source', '_has_qu_question', ) def __init__( self, data: bytes, source: Optional[Tuple[str, int]] = None, scope_id: Optional[int] = None, now: Optional[float] = None, ) -> None: """Constructor from string holding bytes of packet""" self.flags = 0 self.offset = 0 self.data = data self.view = data self._data_len = len(data) self._name_cache: Dict[int, List[str]] = {} self._questions: List[DNSQuestion] = [] self._answers: List[DNSRecord] = [] self.id = 0 self._num_questions = 0 self._num_answers = 0 self._num_authorities = 0 self._num_additionals = 0 self.valid = False self._did_read_others = False self.now = now or current_time_millis() self.source = source self.scope_id = scope_id self._has_qu_question = False try: self._initial_parse() except DECODE_EXCEPTIONS: self._log_exception_debug( 'Received invalid packet from %s at offset %d while unpacking %r', self.source, self.offset, self.data, ) def is_query(self) -> bool: """Returns true if this is a query.""" return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY def is_response(self) -> bool: """Returns true if this is a response.""" return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE def has_qu_question(self) -> bool: """Returns true if any question is a QU question.""" return self._has_qu_question @property def truncated(self) -> bool: """Returns true if this is a truncated.""" return (self.flags & _FLAGS_TC) == _FLAGS_TC @property def questions(self) -> List[DNSQuestion]: """Questions in the packet.""" return self._questions @property def num_questions(self) -> int: """Number of questions in the packet.""" return self._num_questions @property def num_answers(self) -> int: """Number of answers in the packet.""" return self._num_answers @property def num_authorities(self) -> int: """Number of authorities in the packet.""" return self._num_authorities @property def num_additionals(self) -> int: """Number of additionals in the packet.""" return self._num_additionals def _initial_parse(self) -> None: """Parse the data needed to initalize the packet object.""" self._read_header() self._read_questions() if not self._num_questions: self._read_others() self.valid = True @classmethod def _log_exception_debug(cls, *logger_data: Any) -> None: log_exc_info = False exc_info = sys.exc_info() exc_str = str(exc_info[1]) if exc_str not in _seen_logs: # log the trace only on the first time _seen_logs[exc_str] = exc_info log_exc_info = True log.debug(*(logger_data or ['Exception occurred']), exc_info=log_exc_info) def answers(self) -> List[DNSRecord]: """Answers in the packet.""" if not self._did_read_others: try: self._read_others() except DECODE_EXCEPTIONS: self._log_exception_debug( 'Received invalid packet from %s at offset %d while unpacking %r', self.source, self.offset, self.data, ) return self._answers def is_probe(self) -> bool: """Returns true if this is a probe.""" return self._num_authorities > 0 def __repr__(self) -> str: return '' % ', '.join( [ 'id=%s' % self.id, 'flags=%s' % self.flags, 'truncated=%s' % self.truncated, 'n_q=%s' % self._num_questions, 'n_ans=%s' % self._num_answers, 'n_auth=%s' % self._num_authorities, 'n_add=%s' % self._num_additionals, 'questions=%s' % self._questions, 'answers=%s' % self.answers(), ] ) def _read_header(self) -> None: """Reads header portion of packet""" view = self.view offset = self.offset self.offset += 12 # The header has 6 unsigned shorts in network order self.id = view[offset] << 8 | view[offset + 1] self.flags = view[offset + 2] << 8 | view[offset + 3] self._num_questions = view[offset + 4] << 8 | view[offset + 5] self._num_answers = view[offset + 6] << 8 | view[offset + 7] self._num_authorities = view[offset + 8] << 8 | view[offset + 9] self._num_additionals = view[offset + 10] << 8 | view[offset + 11] def _read_questions(self) -> None: """Reads questions section of packet""" view = self.view questions = self._questions for _ in range(self._num_questions): name = self._read_name() offset = self.offset self.offset += 4 # The question has 2 unsigned shorts in network order type_ = view[offset] << 8 | view[offset + 1] class_ = view[offset + 2] << 8 | view[offset + 3] question = DNSQuestion(name, type_, class_) if question.unique: # QU questions use the same bit as unique self._has_qu_question = True questions.append(question) def _read_character_string(self) -> str: """Reads a character string from the packet""" length = self.view[self.offset] self.offset += 1 info = self.data[self.offset : self.offset + length].decode('utf-8', 'replace') self.offset += length return info def _read_string(self, length: _int) -> bytes: """Reads a string of a given length from the packet""" info = self.data[self.offset : self.offset + length] self.offset += length return info def _read_others(self) -> None: """Reads the answers, authorities and additionals section of the packet""" self._did_read_others = True view = self.view n = self._num_answers + self._num_authorities + self._num_additionals for _ in range(n): domain = self._read_name() offset = self.offset self.offset += 10 # type_, class_ and length are unsigned shorts in network order # ttl is an unsigned long in network order https://www.rfc-editor.org/errata/eid2130 type_ = view[offset] << 8 | view[offset + 1] class_ = view[offset + 2] << 8 | view[offset + 3] ttl = view[offset + 4] << 24 | view[offset + 5] << 16 | view[offset + 6] << 8 | view[offset + 7] length = view[offset + 8] << 8 | view[offset + 9] end = self.offset + length rec = None try: rec = self._read_record(domain, type_, class_, ttl, length) except DECODE_EXCEPTIONS: # Skip records that fail to decode if we know the length # If the packet is really corrupt read_name and the unpack # above would fail and hit the exception catch in read_others self.offset = end log.debug( 'Unable to parse; skipping record for %s with type %s at offset %d while unpacking %r', domain, _TYPES.get(type_, type_), self.offset, self.data, exc_info=True, ) if rec is not None: self._answers.append(rec) def _read_record( self, domain: _str, type_: _int, class_: _int, ttl: _int, length: _int ) -> Optional[DNSRecord]: """Read known records types and skip unknown ones.""" if type_ == _TYPE_A: return DNSAddress(domain, type_, class_, ttl, self._read_string(4), None, self.now) if type_ in (_TYPE_CNAME, _TYPE_PTR): return DNSPointer(domain, type_, class_, ttl, self._read_name(), self.now) if type_ == _TYPE_TXT: return DNSText(domain, type_, class_, ttl, self._read_string(length), self.now) if type_ == _TYPE_SRV: view = self.view offset = self.offset self.offset += 6 # The SRV record has 3 unsigned shorts in network order priority = view[offset] << 8 | view[offset + 1] weight = view[offset + 2] << 8 | view[offset + 3] port = view[offset + 4] << 8 | view[offset + 5] return DNSService( domain, type_, class_, ttl, priority, weight, port, self._read_name(), self.now, ) if type_ == _TYPE_HINFO: return DNSHinfo( domain, type_, class_, ttl, self._read_character_string(), self._read_character_string(), self.now, ) if type_ == _TYPE_AAAA: return DNSAddress(domain, type_, class_, ttl, self._read_string(16), self.scope_id, self.now) if type_ == _TYPE_NSEC: name_start = self.offset return DNSNsec( domain, type_, class_, ttl, self._read_name(), self._read_bitmap(name_start + length), self.now, ) # Try to ignore types we don't know about # Skip the payload for the resource record so the next # records can be parsed correctly self.offset += length return None def _read_bitmap(self, end: _int) -> List[int]: """Reads an NSEC bitmap from the packet.""" rdtypes = [] view = self.view while self.offset < end: offset = self.offset offset_plus_one = offset + 1 offset_plus_two = offset + 2 window = view[offset] bitmap_length = view[offset_plus_one] bitmap_end = offset_plus_two + bitmap_length for i, byte in enumerate(self.data[offset_plus_two:bitmap_end]): for bit in range(0, 8): if byte & (0x80 >> bit): rdtypes.append(bit + window * 256 + i * 8) self.offset += 2 + bitmap_length return rdtypes def _read_name(self) -> str: """Reads a domain name from the packet.""" labels: List[str] = [] seen_pointers: Set[int] = set() original_offset = self.offset self.offset = self._decode_labels_at_offset(original_offset, labels, seen_pointers) self._name_cache[original_offset] = labels name = ".".join(labels) + "." if len(name) > MAX_NAME_LENGTH: raise IncomingDecodeError( f"DNS name {name} exceeds maximum length of {MAX_NAME_LENGTH} from {self.source}" ) return name def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers: Set[int]) -> int: # This is a tight loop that is called frequently, small optimizations can make a difference. view = self.view while off < self._data_len: length = view[off] if length == 0: return off + DNS_COMPRESSION_HEADER_LEN if length < 0x40: label_idx = off + DNS_COMPRESSION_HEADER_LEN labels.append(self.data[label_idx : label_idx + length].decode('utf-8', 'replace')) off += DNS_COMPRESSION_HEADER_LEN + length continue if length < 0xC0: raise IncomingDecodeError( f"DNS compression type {length} is unknown at {off} from {self.source}" ) # We have a DNS compression pointer link_data = view[off + 1] link = (length & 0x3F) * 256 + link_data link_py_int = link if link > self._data_len: raise IncomingDecodeError( f"DNS compression pointer at {off} points to {link} beyond packet from {self.source}" ) if link == off: raise IncomingDecodeError( f"DNS compression pointer at {off} points to itself from {self.source}" ) if link_py_int in seen_pointers: raise IncomingDecodeError( f"DNS compression pointer at {off} was seen again from {self.source}" ) linked_labels = self._name_cache.get(link_py_int) if not linked_labels: linked_labels = [] seen_pointers.add(link_py_int) self._decode_labels_at_offset(link, linked_labels, seen_pointers) self._name_cache[link_py_int] = linked_labels labels.extend(linked_labels) if len(labels) > MAX_DNS_LABELS: raise IncomingDecodeError( f"Maximum dns labels reached while processing pointer at {off} from {self.source}" ) return off + DNS_COMPRESSION_POINTER_LEN raise IncomingDecodeError(f"Corrupt packet received while decoding name from {self.source}")