diff --git a/bot.py b/bot.py index 7d002d5..6d9fa5f 100755 --- a/bot.py +++ b/bot.py @@ -35,6 +35,7 @@ class Bot(Bot): 'cogs.stats', 'bot_bin.debug', 'bot_bin.misc', + 'bot_bin.systemd', 'jishaku', ) @@ -57,11 +58,6 @@ class Bot(Bot): utils.SUCCESS_EMOJIS = utils.misc.SUCCESS_EMOJIS = ( self.config.get('response_emojis', {}).get('success', default)) - def load_extensions(self): - super().load_extensions() - if self.config.get('systemd'): - self.load_extension('cogs.systemd') - def main(): import sys diff --git a/cogs/systemd.py b/cogs/systemd.py deleted file mode 100644 index 5030e72..0000000 --- a/cogs/systemd.py +++ /dev/null @@ -1,33 +0,0 @@ -import asyncio -import os -import socket - -from discord.ext import commands - -from utils.socket import open_datagram_endpoint - -class SystemdNotifier(commands.Cog): - def __init__(self): - self.os_sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_DGRAM) - self.connect_task = asyncio.create_task(self.connect()) - self.addr = os.environ['NOTIFY_SOCKET'] - - def send(self, msg): - self.sock.send(msg, self.addr) - - async def connect(self): - self.sock = await open_datagram_endpoint(sock=self.os_sock) - - def cog_unload(self): - self.connect_task.cancel() - - @commands.Cog.listener() - async def on_shard_ready(self, shard_id): - self.send(b'STATUS=Ready on shard %d' % shard_id) - - @commands.Cog.listener() - async def on_ready(self): - self.send(b'READY=1') - -def setup(bot): - bot.add_cog(SystemdNotifier()) diff --git a/data/config.example.py b/data/config.example.py index e3c3e75..3d41a97 100644 --- a/data/config.example.py +++ b/data/config.example.py @@ -4,11 +4,6 @@ 'NOTE: Most commands will be unavailable until both you and the bot have the ' '"Manage Emojis" permission.', - # Whether to use systemd notify to inform systemd of the bot being ready. - # This can be used for fine-grained dependency management, where one cluster doesn't start until the previous - # finishes starting up. - 'systemd': False, - # a channel ID to invite people to when they request help with the bot # the bot must have Create Instant Invite permissions for this channel # if set to None, the support command will be disabled diff --git a/requirements.txt b/requirements.txt index e944e41..9a1faae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ aioec>=0.6.0 aiohttp_socks -bot_bin>=1.0.0,<2.0.0 +bot_bin>=1.5.0,<2.0.0 discord.py>=1.0.1,<2.0.0 jishaku wand diff --git a/utils/socket.py b/utils/socket.py deleted file mode 100644 index e26ebf2..0000000 --- a/utils/socket.py +++ /dev/null @@ -1,318 +0,0 @@ -"""Provide high-level UDP endpoints for asyncio. - -Example: - -async def main(): - - # Create a local UDP enpoint - local = await open_local_endpoint('localhost', 8888) - - # Create a remote UDP enpoint, pointing to the first one - remote = await open_remote_endpoint(*local.address) - - # The remote endpoint sends a datagram - remote.send(b'Hey Hey, My My') - - # The local endpoint receives the datagram, along with the address - data, address = await local.receive() - - # This prints: Got 'Hey Hey, My My' from 127.0.0.1 port 8888 - print(f"Got {data!r} from {address[0]} port {address[1]}") -""" - -__all__ = ['open_local_endpoint', 'open_remote_endpoint'] - - -# Imports - -import asyncio -import warnings - - -# Datagram protocol - -class DatagramEndpointProtocol(asyncio.DatagramProtocol): - """Datagram protocol for the endpoint high-level interface.""" - - def __init__(self, endpoint): - self._endpoint = endpoint - - # Protocol methods - - def connection_made(self, transport): - self._endpoint._transport = transport - - def connection_lost(self, exc): - assert exc is None - if self._endpoint._write_ready_future is not None: - self._endpoint._write_ready_future.set_result(None) - self._endpoint.close() - - # Datagram protocol methods - - def datagram_received(self, data, addr): - self._endpoint.feed_datagram(data, addr) - - def error_received(self, exc): - msg = 'Endpoint received an error: {!r}' - warnings.warn(msg.format(exc)) - - # Workflow control - - def pause_writing(self): - assert self._endpoint._write_ready_future is None - loop = self._endpoint._transport._loop - self._endpoint._write_ready_future = loop.create_future() - - def resume_writing(self): - assert self._endpoint._write_ready_future is not None - self._endpoint._write_ready_future.set_result(None) - self._endpoint._write_ready_future = None - - -# Enpoint classes - -class Endpoint: - """High-level interface for UDP enpoints. - - Can either be local or remote. - It is initialized with an optional queue size for the incoming datagrams. - """ - - def __init__(self, queue_size=None): - if queue_size is None: - queue_size = 0 - self._queue = asyncio.Queue(queue_size) - self._closed = False - self._transport = None - self._write_ready_future = None - - # Protocol callbacks - - def feed_datagram(self, data, addr): - try: - self._queue.put_nowait((data, addr)) - except asyncio.QueueFull: - warnings.warn('Endpoint queue is full') - - def close(self): - # Manage flag - if self._closed: - return - self._closed = True - # Wake up - if self._queue.empty(): - self.feed_datagram(None, None) - # Close transport - if self._transport: - self._transport.close() - - # User methods - - def send(self, data, addr): - """Send a datagram to the given address.""" - if self._closed: - raise IOError("Enpoint is closed") - self._transport.sendto(data, addr) - - async def receive(self): - """Wait for an incoming datagram and return it with - the corresponding address. - - This method is a coroutine. - """ - if self._queue.empty() and self._closed: - raise IOError("Enpoint is closed") - data, addr = await self._queue.get() - if data is None: - raise IOError("Enpoint is closed") - return data, addr - - def abort(self): - """Close the transport immediately.""" - if self._closed: - raise IOError("Enpoint is closed") - self._transport.abort() - self.close() - - async def drain(self): - """Drain the transport buffer below the low-water mark.""" - if self._write_ready_future is not None: - await self._write_ready_future - - # Properties - - @property - def address(self): - """The endpoint address as a (host, port) tuple.""" - return self._transport.get_extra_info("socket").getsockname() - - @property - def closed(self): - """Indicates whether the endpoint is closed or not.""" - return self._closed - - -class LocalEndpoint(Endpoint): - """High-level interface for UDP local enpoints. - - It is initialized with an optional queue size for the incoming datagrams. - """ - pass - - -class RemoteEndpoint(Endpoint): - """High-level interface for UDP remote enpoints. - - It is initialized with an optional queue size for the incoming datagrams. - """ - - def send(self, data): - """Send a datagram to the remote host.""" - super().send(data, None) - - async def receive(self): - """ Wait for an incoming datagram from the remote host. - - This method is a coroutine. - """ - data, addr = await super().receive() - return data - - -# High-level coroutines - -async def open_datagram_endpoint(endpoint_factory=Endpoint, **kwargs): - """Open and return a datagram endpoint. - - The default endpoint factory is the Endpoint class. - The endpoint can be made local or remote using the remote argument. - Extra keyword arguments are forwarded to `loop.create_datagram_endpoint`. - """ - loop = asyncio.get_event_loop() - endpoint = endpoint_factory() - kwargs['protocol_factory'] = lambda: DatagramEndpointProtocol(endpoint) - await loop.create_datagram_endpoint(**kwargs) - return endpoint - - -async def open_local_endpoint(host='0.0.0.0', port=0, *, queue_size=None, **kwargs): - """Open and return a local datagram endpoint. - - An optional queue size arguement can be provided. - Extra keyword arguments are forwarded to `loop.create_datagram_endpoint`. - """ - return await open_datagram_endpoint( - local_addr=(host, port), - endpoint_factory=lambda: LocalEndpoint(queue_size), - **kwargs, - ) - - -async def open_remote_endpoint(host, port, queue_size=None, **kwargs): - """Open and return a remote datagram endpoint. - - An optional queue size arguement can be provided. - Extra keyword arguments are forwarded to `loop.create_datagram_endpoint`. - """ - return await open_datagram_endpoint( - remote_addr=(host, port), - endpoint_factory=lambda: RemoteEndpoint(queue_size), - **kwargs, - ) - - -# Testing - -try: - import pytest - pytestmark = pytest.mark.asyncio -except ImportError: # pragma: no cover - pass - - -async def test_standard_behavior(): - local = await open_local_endpoint() - remote = await open_remote_endpoint(*local.address) - - remote.send(b'Hey Hey') - data, address = await local.receive() - - assert data == b'Hey Hey' - assert address == remote.address - - local.send(b'My My', address) - data = await remote.receive() - assert data == b'My My' - - local.abort() - assert local.closed - - with pytest.warns(UserWarning): - await asyncio.sleep(1e-3) - remote.send(b'U there?') - await asyncio.sleep(1e-3) - - remote.abort() - assert remote.closed - - -async def test_closed_endpoint(): - local = await open_local_endpoint() - future = asyncio.ensure_future(local.receive()) - local.abort() - assert local.closed - - with pytest.raises(IOError): - await future - - with pytest.raises(IOError): - await local.receive() - - with pytest.raises(IOError): - await local.send(b'test', ('localhost', 8888)) - - with pytest.raises(IOError): - local.abort() - - -async def test_queue_size(): - local = await open_local_endpoint(queue_size=1) - remote = await open_remote_endpoint(*local.address) - - remote.send(b'1') - remote.send(b'2') - with pytest.warns(UserWarning): - await asyncio.sleep(1e-3) - assert await local.receive() == (b'1', remote.address) - remote.send(b'3') - assert await local.receive() == (b'3', remote.address) - - remote.send(b'4') - await asyncio.sleep(1e-3) - local.abort() - assert local.closed - assert await local.receive() == (b'4', remote.address) - - remote.abort() - assert remote.closed - - -async def test_flow_control(): - m = n = 1024 - remote = await open_remote_endpoint("8.8.8.8", 12345) - - for _ in range(m): - remote.send(b"a" * n) - - await remote.drain() - - for _ in range(m): - remote.send(b"a" * n) - - remote.abort() - await remote.drain() - - -if __name__ == '__main__': # pragma: no cover - pytest.main([__file__])