""" pypsyc.server.routing ~~~~~~~~~~~~~~~~~~~~~ This module defines the following classes: - `ServerCircuit` - `_TreeNode` - `Routing` :copyright: 2010 by Manuel Jacob :license: MIT """ import logging from twisted.internet import reactor from twisted.internet.protocol import Factory from pypsyc.core.mmp import Circuit, Uni from pypsyc.core.psyc import PSYCPacket, PSYCObject from pypsyc.protocol import Error, check_response from pypsyc.util import schedule, resolve_hostname, DNSError, connect log = logging.getLogger(__name__) class InvalidTargetError(Error): pass class InvalidSourceError(Error): pass class ServerCircuit(Circuit): def __init__(self): Circuit.__init__(self) self.allowed_sources = [] self.psyc = PSYCObject(self.send) self.psyc.add_handler(self) def packet_received(self, header, content): """Handle a packet that was received.""" if header.context: if header.source: raise NotImplementedError if header.target: self.factory.route_singlecast(header, content) else: self.factory.route_multicast(header, content) else: if header.source: if not any((header.source == i or header.source.is_descendant_of(i)) for i in self.allowed_sources): return else: header.source = self if header.target: self.factory.route_singlecast(header, content) else: self.psyc.handle_packet(header, content) def handle_request_verification(self, packet): source_uni = Uni(packet.cvars['_uni_source']) target_uni = packet.cvars['_uni_target'] try: self.factory.verify_address(self, source_uni, target_uni) except InvalidSourceError: return PSYCPacket(mc='_error_invalid_source') except InvalidTargetError: return PSYCPacket(mc='_error_invalid_target') self.allowed_sources.append(source_uni) return PSYCPacket(mc='_echo_verification') def request_verification(self, source_uni, target_uni): r = self.psyc.sendmsg(self, mc='_request_verification', _uni_source=source_uni, _uni_target=target_uni) if r.mc == '_error_invalid_source': raise InvalidSourceError if r.mc == '_error_invalid_target': raise InvalidTargetError check_response(r, '_echo_verification') self.allowed_sources.append(target_uni) def connectionLost(self, reason): for uni in self.allowed_sources: parts = uni.into_parts() if len(parts) == 1: del self.factory.srouting_table[parts[0]] else: entity = self.factory.root.children[parts[1]] entity.packages['person'].unlink(parts[2]) class _TreeNode(object): def __init__(self, parent=None, name=''): self._parent = parent self.children = {} if parent: self._root = parent._root parent.children[name] = self else: self._root = self class Routing(Factory, object): """This class handles routing and circuit managment.""" protocol = ServerCircuit def __init__(self, hostname, interface): self.hostname = hostname self.interface = interface self.circuits = {} # ip -> circuit self.queues = {} # hostname -> list self.srouting_table = {} # hostname -> circuit self.mrouting_table = {} # context -> list of circuits def init(self, root): self.root = root def route_singlecast(self, header, content): """Route the packet to the right target.""" parts = header.target.into_parts() host = parts[0] if host == self.hostname: node = self.root try: for i in parts[1:]: node = node.children[i] except KeyError: self._error(header, '_error_unknown_target') else: node.handle_packet(header, content) elif header.get('_source'): try: self.srouting_table[host].send(header, content) except KeyError: if host not in self.queues: self.queues[host] = [] schedule(self._add_route, host) self.queues[host].append((header, content)) else: client = header.source.transport.client log.error("Dropped packet without _source from %s", client) def _error(self, header, error_mc, message=None): tag = header.get('_tag') self.root.sendmsg(header.source, None, tag and {'_tag_relay': tag}, mc=error_mc, _uni=header.target, data=message) def route_multicast(self, header, content): content = list(content) for circuit in self.mrouting_table[header.context]: circuit.send(header, content) def listen(self, port): reactor.listenTCP(port, self, interface=self.interface) def _add_route(self, host): """ Add a route to host. First check if we have an open connection to the target host. If, do address verification. If not, try to connect to the server and then do address verification. """ try: addr = resolve_hostname(host) except DNSError: self._unsuccessful_delivery(host, "Can't resolve '%s'" % host) return try: circuit = self.circuits[addr] except KeyError: try: ip, p = addr circuit = connect(ip, p, self, bindAddress=(self.interface, 0)) except Exception: self._unsuccessful_delivery(host, "Can't connect '%s'" % host) return self.circuits[addr] = circuit try: target_uni = Uni('psyc://%s/' % host) circuit.request_verification(self.root.uni, target_uni) except Error: self._unsuccessful_delivery(host, "Can't verify %s" % target_uni) else: self.srouting_table[host] = circuit for header, content in self.queues.pop(host): circuit.send(header, content) def _unsuccessful_delivery(self, host, message): log.exception(message) for header, content in self.queues.pop(host): self._error(header, '_failure_unsuccessful_delivery', message) def verify_address(self, circuit, source_uni, target_uni): if target_uni != self.root.uni: raise InvalidTargetError source_host = source_uni.into_parts()[0] if resolve_hostname(source_host)[0] != circuit.transport.client[0]: raise InvalidSourceError self.srouting_table[source_host] = circuit def connection_lost(self, circuit, error): pass