pypsyc/mjacob2/pypsyc/util.py

158 lines
4.4 KiB
Python

"""
pypsyc.util
~~~~~~~~~~~
:copyright: 2010 by Manuel Jacob
:license: MIT
"""
import logging
from greenlet import greenlet, getcurrent
from twisted.internet import reactor, error
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.tcp import Connector
scheduler = getcurrent()
def schedule(f, *args, **kwds):
g = greenlet(f, scheduler)
if getcurrent() is scheduler:
g.switch(*args, **kwds)
else:
reactor.callLater(0, g.switch, *args, **kwds)
class Waiter(object):
def __init__(self):
self.greenlet = None
self.value = None
self.exception = None
def callback(self, *args, **kwds):
if self.greenlet is None:
self.value, = args
else:
assert getcurrent() is not self.greenlet
self.greenlet.switch(*args, **kwds)
def errback(self, exception):
if self.greenlet is None:
self.exception = exception
else:
self.greenlet.throw(exception)
def get(self):
if self.value is not None:
return self.value
if self.exception is not None:
raise self.exception
self.greenlet = getcurrent()
return self.greenlet.parent.switch()
class DNSError(Exception):
pass
@inlineCallbacks
def _resolve_hostname(host):
try:
name = '_psyc._tcp.%s.' % host
answers, auth, add = yield lookupService(name)
srv_rr = answers[0]
assert srv_rr.name.name == name
host = srv_rr.payload.target.name
port = srv_rr.payload.port
a_rr = [rr for rr in add if rr.name.name == host][0]
ip = a_rr.payload.dottedQuad()
except (DNSNameError, IndexError):
try:
ip = yield getHostByName(host)
except DNSNameError:
raise DNSError("Unknown host %s." % host)
port = 4404
returnValue((ip, port))
try:
from twisted.names.client import lookupService, getHostByName
from twisted.names.error import DNSNameError
except ImportError:
log = logging.getLogger(__name__)
log.warn("twisted names isn't installed -- DNS SRV is disabled")
def resolve_hostname(host):
return host, 4404
else:
def resolve_hostname(host):
waiter = Waiter()
_resolve_hostname(host).addCallbacks(waiter.callback,
lambda f: waiter.errback(f.value))
return waiter.get()
class _PSYCConnector(Connector):
def __init__(self, *args):
Connector.__init__(self, *args, reactor=reactor)
self.waiter = Waiter()
def connect(self):
assert self.state == 'disconnected', "can't connect in this state"
self.state = 'connecting'
self.transport = transport = self._makeTransport()
self.timeoutID = self.reactor.callLater(self.timeout,
transport.failIfNotConnected,
error.TimeoutError())
self.waiter.get()
return self.circuit
def buildProtocol(self, addr):
self.circuit = Connector.buildProtocol(self, addr)
self.circuit.inited = self.waiter.callback
return self.circuit
def connectionFailed(self, reason):
self.cancelTimeout()
self.transport = None
self.state = 'disconnected'
self.waiter.errback(reason.value)
def connectionLost(self, reason):
self.state = 'disconnected'
self.factory.connection_lost(self.circuit, reason.value)
def connect(host, port, factory, timeout=30, bindAddress=None):
return _PSYCConnector(host, port, factory, timeout, bindAddress).connect()
class Event(object):
def __init__(self):
self.observers = []
def add_observer(self, observer, *args, **kwds):
self.observers.append((observer, args, kwds))
def __iadd__(self, observer):
self.observers.append((observer, (), {}))
return self
def __isub__(self, observer):
observers = [i for i in self.observers if i[0] == observer]
assert len(observers) == 1, observers
self.observers.remove(observers[0])
return self
def __call__(self, *args, **kwds):
for observer, args2, kwds2 in self.observers:
kwds.update(kwds2)
observer(*(args + args2), **kwds)
def key_intersection(a, b):
for k in b:
if k in a:
yield k