mirror of git://git.psyced.org/git/pypsyc
158 lines
4.4 KiB
Python
158 lines
4.4 KiB
Python
"""
|
|
pypsyc.util
|
|
~~~~~~~~~~~
|
|
|
|
:copyright: 2010 by Manuel Jacob
|
|
:license: MIT
|
|
"""
|
|
import logging
|
|
|
|
from greenlet import greenlet, getcurrent
|
|
from twisted.internet import reactor, error
|
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
|
from twisted.internet.tcp import Connector
|
|
|
|
|
|
scheduler = getcurrent()
|
|
|
|
def schedule(f, *args, **kwds):
|
|
g = greenlet(f, scheduler)
|
|
if getcurrent() is scheduler:
|
|
g.switch(*args, **kwds)
|
|
else:
|
|
reactor.callLater(0, g.switch, *args, **kwds)
|
|
|
|
|
|
class Waiter(object):
|
|
def __init__(self):
|
|
self.greenlet = None
|
|
self.value = None
|
|
self.exception = None
|
|
|
|
def callback(self, *args, **kwds):
|
|
if self.greenlet is None:
|
|
self.value, = args
|
|
else:
|
|
assert getcurrent() is not self.greenlet
|
|
self.greenlet.switch(*args, **kwds)
|
|
|
|
def errback(self, exception):
|
|
if self.greenlet is None:
|
|
self.exception = exception
|
|
else:
|
|
self.greenlet.throw(exception)
|
|
|
|
def get(self):
|
|
if self.value is not None:
|
|
return self.value
|
|
if self.exception is not None:
|
|
raise self.exception
|
|
self.greenlet = getcurrent()
|
|
return self.greenlet.parent.switch()
|
|
|
|
|
|
class DNSError(Exception):
|
|
pass
|
|
|
|
@inlineCallbacks
|
|
def _resolve_hostname(host):
|
|
try:
|
|
name = '_psyc._tcp.%s.' % host
|
|
answers, auth, add = yield lookupService(name)
|
|
|
|
srv_rr = answers[0]
|
|
assert srv_rr.name.name == name
|
|
host = srv_rr.payload.target.name
|
|
port = srv_rr.payload.port
|
|
|
|
a_rr = [rr for rr in add if rr.name.name == host][0]
|
|
ip = a_rr.payload.dottedQuad()
|
|
|
|
except (DNSNameError, IndexError):
|
|
try:
|
|
ip = yield getHostByName(host)
|
|
except DNSNameError:
|
|
raise DNSError("Unknown host %s." % host)
|
|
port = 4404
|
|
|
|
returnValue((ip, port))
|
|
|
|
try:
|
|
from twisted.names.client import lookupService, getHostByName
|
|
from twisted.names.error import DNSNameError
|
|
except ImportError:
|
|
log = logging.getLogger(__name__)
|
|
log.warn("twisted names isn't installed -- DNS SRV is disabled")
|
|
def resolve_hostname(host):
|
|
return host, 4404
|
|
else:
|
|
def resolve_hostname(host):
|
|
waiter = Waiter()
|
|
_resolve_hostname(host).addCallbacks(waiter.callback,
|
|
lambda f: waiter.errback(f.value))
|
|
return waiter.get()
|
|
|
|
|
|
class _PSYCConnector(Connector):
|
|
def __init__(self, *args):
|
|
Connector.__init__(self, *args, reactor=reactor)
|
|
self.waiter = Waiter()
|
|
|
|
def connect(self):
|
|
assert self.state == 'disconnected', "can't connect in this state"
|
|
self.state = 'connecting'
|
|
|
|
self.transport = transport = self._makeTransport()
|
|
self.timeoutID = self.reactor.callLater(self.timeout,
|
|
transport.failIfNotConnected,
|
|
error.TimeoutError())
|
|
self.waiter.get()
|
|
return self.circuit
|
|
|
|
def buildProtocol(self, addr):
|
|
self.circuit = Connector.buildProtocol(self, addr)
|
|
self.circuit.inited = self.waiter.callback
|
|
return self.circuit
|
|
|
|
def connectionFailed(self, reason):
|
|
self.cancelTimeout()
|
|
self.transport = None
|
|
self.state = 'disconnected'
|
|
self.waiter.errback(reason.value)
|
|
|
|
def connectionLost(self, reason):
|
|
self.state = 'disconnected'
|
|
self.factory.connection_lost(self.circuit, reason.value)
|
|
|
|
def connect(host, port, factory, timeout=30, bindAddress=None):
|
|
return _PSYCConnector(host, port, factory, timeout, bindAddress).connect()
|
|
|
|
|
|
class Event(object):
|
|
def __init__(self):
|
|
self.observers = []
|
|
|
|
def add_observer(self, observer, *args, **kwds):
|
|
self.observers.append((observer, args, kwds))
|
|
|
|
def __iadd__(self, observer):
|
|
self.observers.append((observer, (), {}))
|
|
return self
|
|
|
|
def __isub__(self, observer):
|
|
observers = [i for i in self.observers if i[0] == observer]
|
|
assert len(observers) == 1, observers
|
|
self.observers.remove(observers[0])
|
|
return self
|
|
|
|
def __call__(self, *args, **kwds):
|
|
for observer, args2, kwds2 in self.observers:
|
|
kwds.update(kwds2)
|
|
observer(*(args + args2), **kwds)
|
|
|
|
|
|
def key_intersection(a, b):
|
|
for k in b:
|
|
if k in a:
|
|
yield k
|