mirror of
https://github.com/uhIgnacio/EmoteManager.git
synced 2024-08-15 02:23:13 +00:00
318 lines
8.3 KiB
Python
318 lines
8.3 KiB
Python
"""Provide high-level UDP endpoints for asyncio.
|
|
|
|
Example:
|
|
|
|
async def main():
|
|
|
|
# Create a local UDP enpoint
|
|
local = await open_local_endpoint('localhost', 8888)
|
|
|
|
# Create a remote UDP enpoint, pointing to the first one
|
|
remote = await open_remote_endpoint(*local.address)
|
|
|
|
# The remote endpoint sends a datagram
|
|
remote.send(b'Hey Hey, My My')
|
|
|
|
# The local endpoint receives the datagram, along with the address
|
|
data, address = await local.receive()
|
|
|
|
# This prints: Got 'Hey Hey, My My' from 127.0.0.1 port 8888
|
|
print(f"Got {data!r} from {address[0]} port {address[1]}")
|
|
"""
|
|
|
|
__all__ = ['open_local_endpoint', 'open_remote_endpoint']
|
|
|
|
|
|
# Imports
|
|
|
|
import asyncio
|
|
import warnings
|
|
|
|
|
|
# Datagram protocol
|
|
|
|
class DatagramEndpointProtocol(asyncio.DatagramProtocol):
|
|
"""Datagram protocol for the endpoint high-level interface."""
|
|
|
|
def __init__(self, endpoint):
|
|
self._endpoint = endpoint
|
|
|
|
# Protocol methods
|
|
|
|
def connection_made(self, transport):
|
|
self._endpoint._transport = transport
|
|
|
|
def connection_lost(self, exc):
|
|
assert exc is None
|
|
if self._endpoint._write_ready_future is not None:
|
|
self._endpoint._write_ready_future.set_result(None)
|
|
self._endpoint.close()
|
|
|
|
# Datagram protocol methods
|
|
|
|
def datagram_received(self, data, addr):
|
|
self._endpoint.feed_datagram(data, addr)
|
|
|
|
def error_received(self, exc):
|
|
msg = 'Endpoint received an error: {!r}'
|
|
warnings.warn(msg.format(exc))
|
|
|
|
# Workflow control
|
|
|
|
def pause_writing(self):
|
|
assert self._endpoint._write_ready_future is None
|
|
loop = self._endpoint._transport._loop
|
|
self._endpoint._write_ready_future = loop.create_future()
|
|
|
|
def resume_writing(self):
|
|
assert self._endpoint._write_ready_future is not None
|
|
self._endpoint._write_ready_future.set_result(None)
|
|
self._endpoint._write_ready_future = None
|
|
|
|
|
|
# Enpoint classes
|
|
|
|
class Endpoint:
|
|
"""High-level interface for UDP enpoints.
|
|
|
|
Can either be local or remote.
|
|
It is initialized with an optional queue size for the incoming datagrams.
|
|
"""
|
|
|
|
def __init__(self, queue_size=None):
|
|
if queue_size is None:
|
|
queue_size = 0
|
|
self._queue = asyncio.Queue(queue_size)
|
|
self._closed = False
|
|
self._transport = None
|
|
self._write_ready_future = None
|
|
|
|
# Protocol callbacks
|
|
|
|
def feed_datagram(self, data, addr):
|
|
try:
|
|
self._queue.put_nowait((data, addr))
|
|
except asyncio.QueueFull:
|
|
warnings.warn('Endpoint queue is full')
|
|
|
|
def close(self):
|
|
# Manage flag
|
|
if self._closed:
|
|
return
|
|
self._closed = True
|
|
# Wake up
|
|
if self._queue.empty():
|
|
self.feed_datagram(None, None)
|
|
# Close transport
|
|
if self._transport:
|
|
self._transport.close()
|
|
|
|
# User methods
|
|
|
|
def send(self, data, addr):
|
|
"""Send a datagram to the given address."""
|
|
if self._closed:
|
|
raise IOError("Enpoint is closed")
|
|
self._transport.sendto(data, addr)
|
|
|
|
async def receive(self):
|
|
"""Wait for an incoming datagram and return it with
|
|
the corresponding address.
|
|
|
|
This method is a coroutine.
|
|
"""
|
|
if self._queue.empty() and self._closed:
|
|
raise IOError("Enpoint is closed")
|
|
data, addr = await self._queue.get()
|
|
if data is None:
|
|
raise IOError("Enpoint is closed")
|
|
return data, addr
|
|
|
|
def abort(self):
|
|
"""Close the transport immediately."""
|
|
if self._closed:
|
|
raise IOError("Enpoint is closed")
|
|
self._transport.abort()
|
|
self.close()
|
|
|
|
async def drain(self):
|
|
"""Drain the transport buffer below the low-water mark."""
|
|
if self._write_ready_future is not None:
|
|
await self._write_ready_future
|
|
|
|
# Properties
|
|
|
|
@property
|
|
def address(self):
|
|
"""The endpoint address as a (host, port) tuple."""
|
|
return self._transport.get_extra_info("socket").getsockname()
|
|
|
|
@property
|
|
def closed(self):
|
|
"""Indicates whether the endpoint is closed or not."""
|
|
return self._closed
|
|
|
|
|
|
class LocalEndpoint(Endpoint):
|
|
"""High-level interface for UDP local enpoints.
|
|
|
|
It is initialized with an optional queue size for the incoming datagrams.
|
|
"""
|
|
pass
|
|
|
|
|
|
class RemoteEndpoint(Endpoint):
|
|
"""High-level interface for UDP remote enpoints.
|
|
|
|
It is initialized with an optional queue size for the incoming datagrams.
|
|
"""
|
|
|
|
def send(self, data):
|
|
"""Send a datagram to the remote host."""
|
|
super().send(data, None)
|
|
|
|
async def receive(self):
|
|
""" Wait for an incoming datagram from the remote host.
|
|
|
|
This method is a coroutine.
|
|
"""
|
|
data, addr = await super().receive()
|
|
return data
|
|
|
|
|
|
# High-level coroutines
|
|
|
|
async def open_datagram_endpoint(endpoint_factory=Endpoint, **kwargs):
|
|
"""Open and return a datagram endpoint.
|
|
|
|
The default endpoint factory is the Endpoint class.
|
|
The endpoint can be made local or remote using the remote argument.
|
|
Extra keyword arguments are forwarded to `loop.create_datagram_endpoint`.
|
|
"""
|
|
loop = asyncio.get_event_loop()
|
|
endpoint = endpoint_factory()
|
|
kwargs['protocol_factory'] = lambda: DatagramEndpointProtocol(endpoint)
|
|
await loop.create_datagram_endpoint(**kwargs)
|
|
return endpoint
|
|
|
|
|
|
async def open_local_endpoint(host='0.0.0.0', port=0, *, queue_size=None, **kwargs):
|
|
"""Open and return a local datagram endpoint.
|
|
|
|
An optional queue size arguement can be provided.
|
|
Extra keyword arguments are forwarded to `loop.create_datagram_endpoint`.
|
|
"""
|
|
return await open_datagram_endpoint(
|
|
local_addr=(host, port),
|
|
endpoint_factory=lambda: LocalEndpoint(queue_size),
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
async def open_remote_endpoint(host, port, queue_size=None, **kwargs):
|
|
"""Open and return a remote datagram endpoint.
|
|
|
|
An optional queue size arguement can be provided.
|
|
Extra keyword arguments are forwarded to `loop.create_datagram_endpoint`.
|
|
"""
|
|
return await open_datagram_endpoint(
|
|
remote_addr=(host, port),
|
|
endpoint_factory=lambda: RemoteEndpoint(queue_size),
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
# Testing
|
|
|
|
try:
|
|
import pytest
|
|
pytestmark = pytest.mark.asyncio
|
|
except ImportError: # pragma: no cover
|
|
pass
|
|
|
|
|
|
async def test_standard_behavior():
|
|
local = await open_local_endpoint()
|
|
remote = await open_remote_endpoint(*local.address)
|
|
|
|
remote.send(b'Hey Hey')
|
|
data, address = await local.receive()
|
|
|
|
assert data == b'Hey Hey'
|
|
assert address == remote.address
|
|
|
|
local.send(b'My My', address)
|
|
data = await remote.receive()
|
|
assert data == b'My My'
|
|
|
|
local.abort()
|
|
assert local.closed
|
|
|
|
with pytest.warns(UserWarning):
|
|
await asyncio.sleep(1e-3)
|
|
remote.send(b'U there?')
|
|
await asyncio.sleep(1e-3)
|
|
|
|
remote.abort()
|
|
assert remote.closed
|
|
|
|
|
|
async def test_closed_endpoint():
|
|
local = await open_local_endpoint()
|
|
future = asyncio.ensure_future(local.receive())
|
|
local.abort()
|
|
assert local.closed
|
|
|
|
with pytest.raises(IOError):
|
|
await future
|
|
|
|
with pytest.raises(IOError):
|
|
await local.receive()
|
|
|
|
with pytest.raises(IOError):
|
|
await local.send(b'test', ('localhost', 8888))
|
|
|
|
with pytest.raises(IOError):
|
|
local.abort()
|
|
|
|
|
|
async def test_queue_size():
|
|
local = await open_local_endpoint(queue_size=1)
|
|
remote = await open_remote_endpoint(*local.address)
|
|
|
|
remote.send(b'1')
|
|
remote.send(b'2')
|
|
with pytest.warns(UserWarning):
|
|
await asyncio.sleep(1e-3)
|
|
assert await local.receive() == (b'1', remote.address)
|
|
remote.send(b'3')
|
|
assert await local.receive() == (b'3', remote.address)
|
|
|
|
remote.send(b'4')
|
|
await asyncio.sleep(1e-3)
|
|
local.abort()
|
|
assert local.closed
|
|
assert await local.receive() == (b'4', remote.address)
|
|
|
|
remote.abort()
|
|
assert remote.closed
|
|
|
|
|
|
async def test_flow_control():
|
|
m = n = 1024
|
|
remote = await open_remote_endpoint("8.8.8.8", 12345)
|
|
|
|
for _ in range(m):
|
|
remote.send(b"a" * n)
|
|
|
|
await remote.drain()
|
|
|
|
for _ in range(m):
|
|
remote.send(b"a" * n)
|
|
|
|
remote.abort()
|
|
await remote.drain()
|
|
|
|
|
|
if __name__ == '__main__': # pragma: no cover
|
|
pytest.main([__file__])
|