mirror of git://git.psyced.org/git/pypsyc
210 lines
7.0 KiB
Python
210 lines
7.0 KiB
Python
"""
|
|
pypsyc.server.routing
|
|
~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
This module defines the following classes:
|
|
- `ServerCircuit`
|
|
- `_TreeNode`
|
|
- `Routing`
|
|
|
|
:copyright: 2010 by Manuel Jacob
|
|
:license: MIT
|
|
"""
|
|
import logging
|
|
|
|
from twisted.internet import reactor
|
|
from twisted.internet.protocol import Factory
|
|
|
|
from pypsyc.core.mmp import Circuit, Uni
|
|
from pypsyc.core.psyc import PSYCPacket, PSYCObject
|
|
from pypsyc.protocol import Error, check_response
|
|
from pypsyc.util import schedule, resolve_hostname, DNSError, connect
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class InvalidTargetError(Error):
|
|
pass
|
|
|
|
class InvalidSourceError(Error):
|
|
pass
|
|
|
|
class ServerCircuit(Circuit):
|
|
def __init__(self):
|
|
Circuit.__init__(self)
|
|
self.allowed_sources = []
|
|
self.psyc = PSYCObject(self.send)
|
|
self.psyc.add_handler(self)
|
|
|
|
def packet_received(self, header, content):
|
|
"""Handle a packet that was received."""
|
|
if header.context:
|
|
if header.source:
|
|
raise NotImplementedError
|
|
if header.target:
|
|
self.factory.route_singlecast(header, content)
|
|
else:
|
|
self.factory.route_multicast(header, content)
|
|
|
|
else:
|
|
if header.source:
|
|
if not any((header.source == i or
|
|
header.source.is_descendant_of(i))
|
|
for i in self.allowed_sources):
|
|
return
|
|
else:
|
|
header.source = self
|
|
|
|
if header.target:
|
|
self.factory.route_singlecast(header, content)
|
|
else:
|
|
self.psyc.handle_packet(header, content)
|
|
|
|
def handle_request_verification(self, packet):
|
|
source_uni = Uni(packet.cvars['_uni_source'])
|
|
target_uni = packet.cvars['_uni_target']
|
|
try:
|
|
self.factory.verify_address(self, source_uni, target_uni)
|
|
except InvalidSourceError:
|
|
return PSYCPacket(mc='_error_invalid_source')
|
|
except InvalidTargetError:
|
|
return PSYCPacket(mc='_error_invalid_target')
|
|
self.allowed_sources.append(source_uni)
|
|
return PSYCPacket(mc='_echo_verification')
|
|
|
|
def request_verification(self, source_uni, target_uni):
|
|
r = self.psyc.sendmsg(self, mc='_request_verification',
|
|
_uni_source=source_uni, _uni_target=target_uni)
|
|
if r.mc == '_error_invalid_source':
|
|
raise InvalidSourceError
|
|
if r.mc == '_error_invalid_target':
|
|
raise InvalidTargetError
|
|
check_response(r, '_echo_verification')
|
|
self.allowed_sources.append(target_uni)
|
|
|
|
def connectionLost(self, reason):
|
|
for uni in self.allowed_sources:
|
|
parts = uni.into_parts()
|
|
if len(parts) == 1:
|
|
del self.factory.srouting_table[parts[0]]
|
|
else:
|
|
entity = self.factory.root.children[parts[1]]
|
|
entity.packages['person'].unlink(parts[2])
|
|
|
|
|
|
class _TreeNode(object):
|
|
def __init__(self, parent=None, name=''):
|
|
self._parent = parent
|
|
self.children = {}
|
|
if parent:
|
|
self._root = parent._root
|
|
parent.children[name] = self
|
|
else:
|
|
self._root = self
|
|
|
|
|
|
class Routing(Factory, object):
|
|
"""This class handles routing and circuit managment."""
|
|
protocol = ServerCircuit
|
|
|
|
def __init__(self, hostname, interface):
|
|
self.hostname = hostname
|
|
self.interface = interface
|
|
|
|
self.circuits = {} # ip -> circuit
|
|
self.queues = {} # hostname -> list
|
|
self.srouting_table = {} # hostname -> circuit
|
|
self.mrouting_table = {} # context -> list of circuits
|
|
|
|
def init(self, root):
|
|
self.root = root
|
|
|
|
def route_singlecast(self, header, content):
|
|
"""Route the packet to the right target."""
|
|
parts = header.target.into_parts()
|
|
host = parts[0]
|
|
if host == self.hostname:
|
|
node = self.root
|
|
try:
|
|
for i in parts[1:]:
|
|
node = node.children[i]
|
|
except KeyError:
|
|
self._error(header, '_error_unknown_target')
|
|
else:
|
|
node.handle_packet(header, content)
|
|
elif header.get('_source'):
|
|
try:
|
|
self.srouting_table[host].send(header, content)
|
|
except KeyError:
|
|
if host not in self.queues:
|
|
self.queues[host] = []
|
|
schedule(self._add_route, host)
|
|
self.queues[host].append((header, content))
|
|
else:
|
|
client = header.source.transport.client
|
|
log.error("Dropped packet without _source from %s", client)
|
|
|
|
def _error(self, header, error_mc, message=None):
|
|
tag = header.get('_tag')
|
|
self.root.sendmsg(header.source, None, tag and {'_tag_relay': tag},
|
|
mc=error_mc, _uni=header.target, data=message)
|
|
|
|
def route_multicast(self, header, content):
|
|
content = list(content)
|
|
for circuit in self.mrouting_table[header.context]:
|
|
circuit.send(header, content)
|
|
|
|
def listen(self, port):
|
|
reactor.listenTCP(port, self, interface=self.interface)
|
|
|
|
def _add_route(self, host):
|
|
"""
|
|
Add a route to host.
|
|
|
|
First check if we have an open connection to the target host.
|
|
If, do address verification.
|
|
If not, try to connect to the server and then do address verification.
|
|
"""
|
|
try:
|
|
addr = resolve_hostname(host)
|
|
except DNSError:
|
|
self._unsuccessful_delivery(host, "Can't resolve '%s'" % host)
|
|
return
|
|
try:
|
|
circuit = self.circuits[addr]
|
|
except KeyError:
|
|
try:
|
|
ip, p = addr
|
|
circuit = connect(ip, p, self, bindAddress=(self.interface, 0))
|
|
except Exception:
|
|
self._unsuccessful_delivery(host, "Can't connect '%s'" % host)
|
|
return
|
|
self.circuits[addr] = circuit
|
|
|
|
try:
|
|
target_uni = Uni('psyc://%s/' % host)
|
|
circuit.request_verification(self.root.uni, target_uni)
|
|
except Error:
|
|
self._unsuccessful_delivery(host, "Can't verify %s" % target_uni)
|
|
else:
|
|
self.srouting_table[host] = circuit
|
|
for header, content in self.queues.pop(host):
|
|
circuit.send(header, content)
|
|
|
|
def _unsuccessful_delivery(self, host, message):
|
|
log.exception(message)
|
|
for header, content in self.queues.pop(host):
|
|
self._error(header, '_failure_unsuccessful_delivery', message)
|
|
|
|
def verify_address(self, circuit, source_uni, target_uni):
|
|
if target_uni != self.root.uni:
|
|
raise InvalidTargetError
|
|
source_host = source_uni.into_parts()[0]
|
|
if resolve_hostname(source_host)[0] != circuit.transport.client[0]:
|
|
raise InvalidSourceError
|
|
self.srouting_table[source_host] = circuit
|
|
|
|
def connection_lost(self, circuit, error):
|
|
pass
|