mirror of
https://github.com/uhIgnacio/EmoteManager.git
synced 2024-08-15 02:23:13 +00:00
add systemd notify support
This commit is contained in:
parent
722c87c75b
commit
eb6e275c8a
4 changed files with 361 additions and 0 deletions
5
bot.py
5
bot.py
|
@ -58,6 +58,11 @@ class Bot(Bot):
|
|||
return super().activity
|
||||
return super().activity or discord.Game(f'@{self.user.name} help')
|
||||
|
||||
def load_extensions(self):
|
||||
super().load_extensions()
|
||||
if self.config.get('systemd'):
|
||||
self.load_extension('cogs.systemd')
|
||||
|
||||
def main():
|
||||
import sys
|
||||
|
||||
|
|
33
cogs/systemd.py
Normal file
33
cogs/systemd.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
import asyncio
|
||||
import os
|
||||
import socket
|
||||
|
||||
from discord.ext import commands
|
||||
|
||||
from utils.socket import open_datagram_endpoint
|
||||
|
||||
class SystemdNotifier(commands.Cog):
|
||||
def __init__(self):
|
||||
self.os_sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_DGRAM)
|
||||
self.connect_task = asyncio.create_task(self.connect())
|
||||
self.addr = os.environ['NOTIFY_SOCKET']
|
||||
|
||||
def send(self, msg):
|
||||
self.sock.send(msg, self.addr)
|
||||
|
||||
async def connect(self):
|
||||
self.sock = await open_datagram_endpoint(sock=self.os_sock)
|
||||
|
||||
def cog_unload(self):
|
||||
self.connect_task.cancel()
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_shard_ready(self, shard_id):
|
||||
self.send(b'STATUS=Ready on shard %d' % shard_id)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_ready(self):
|
||||
self.send(b'READY=1')
|
||||
|
||||
def setup(bot):
|
||||
bot.add_cog(SystemdNotifier())
|
|
@ -4,6 +4,11 @@
|
|||
'NOTE: Most commands will be unavailable until both you and the bot have the '
|
||||
'"Manage Emojis" permission.',
|
||||
|
||||
# Whether to use systemd notify to inform systemd of the bot being ready.
|
||||
# This can be used for fine-grained dependency management, where one cluster doesn't start until the previous
|
||||
# finishes starting up.
|
||||
'systemd': False,
|
||||
|
||||
# a channel ID to invite people to when they request help with the bot
|
||||
# the bot must have Create Instant Invite permissions for this channel
|
||||
# if set to None, the support command will be disabled
|
||||
|
|
318
utils/socket.py
Normal file
318
utils/socket.py
Normal file
|
@ -0,0 +1,318 @@
|
|||
"""Provide high-level UDP endpoints for asyncio.
|
||||
|
||||
Example:
|
||||
|
||||
async def main():
|
||||
|
||||
# Create a local UDP enpoint
|
||||
local = await open_local_endpoint('localhost', 8888)
|
||||
|
||||
# Create a remote UDP enpoint, pointing to the first one
|
||||
remote = await open_remote_endpoint(*local.address)
|
||||
|
||||
# The remote endpoint sends a datagram
|
||||
remote.send(b'Hey Hey, My My')
|
||||
|
||||
# The local endpoint receives the datagram, along with the address
|
||||
data, address = await local.receive()
|
||||
|
||||
# This prints: Got 'Hey Hey, My My' from 127.0.0.1 port 8888
|
||||
print(f"Got {data!r} from {address[0]} port {address[1]}")
|
||||
"""
|
||||
|
||||
__all__ = ['open_local_endpoint', 'open_remote_endpoint']
|
||||
|
||||
|
||||
# Imports
|
||||
|
||||
import asyncio
|
||||
import warnings
|
||||
|
||||
|
||||
# Datagram protocol
|
||||
|
||||
class DatagramEndpointProtocol(asyncio.DatagramProtocol):
|
||||
"""Datagram protocol for the endpoint high-level interface."""
|
||||
|
||||
def __init__(self, endpoint):
|
||||
self._endpoint = endpoint
|
||||
|
||||
# Protocol methods
|
||||
|
||||
def connection_made(self, transport):
|
||||
self._endpoint._transport = transport
|
||||
|
||||
def connection_lost(self, exc):
|
||||
assert exc is None
|
||||
if self._endpoint._write_ready_future is not None:
|
||||
self._endpoint._write_ready_future.set_result(None)
|
||||
self._endpoint.close()
|
||||
|
||||
# Datagram protocol methods
|
||||
|
||||
def datagram_received(self, data, addr):
|
||||
self._endpoint.feed_datagram(data, addr)
|
||||
|
||||
def error_received(self, exc):
|
||||
msg = 'Endpoint received an error: {!r}'
|
||||
warnings.warn(msg.format(exc))
|
||||
|
||||
# Workflow control
|
||||
|
||||
def pause_writing(self):
|
||||
assert self._endpoint._write_ready_future is None
|
||||
loop = self._endpoint._transport._loop
|
||||
self._endpoint._write_ready_future = loop.create_future()
|
||||
|
||||
def resume_writing(self):
|
||||
assert self._endpoint._write_ready_future is not None
|
||||
self._endpoint._write_ready_future.set_result(None)
|
||||
self._endpoint._write_ready_future = None
|
||||
|
||||
|
||||
# Enpoint classes
|
||||
|
||||
class Endpoint:
|
||||
"""High-level interface for UDP enpoints.
|
||||
|
||||
Can either be local or remote.
|
||||
It is initialized with an optional queue size for the incoming datagrams.
|
||||
"""
|
||||
|
||||
def __init__(self, queue_size=None):
|
||||
if queue_size is None:
|
||||
queue_size = 0
|
||||
self._queue = asyncio.Queue(queue_size)
|
||||
self._closed = False
|
||||
self._transport = None
|
||||
self._write_ready_future = None
|
||||
|
||||
# Protocol callbacks
|
||||
|
||||
def feed_datagram(self, data, addr):
|
||||
try:
|
||||
self._queue.put_nowait((data, addr))
|
||||
except asyncio.QueueFull:
|
||||
warnings.warn('Endpoint queue is full')
|
||||
|
||||
def close(self):
|
||||
# Manage flag
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
# Wake up
|
||||
if self._queue.empty():
|
||||
self.feed_datagram(None, None)
|
||||
# Close transport
|
||||
if self._transport:
|
||||
self._transport.close()
|
||||
|
||||
# User methods
|
||||
|
||||
def send(self, data, addr):
|
||||
"""Send a datagram to the given address."""
|
||||
if self._closed:
|
||||
raise IOError("Enpoint is closed")
|
||||
self._transport.sendto(data, addr)
|
||||
|
||||
async def receive(self):
|
||||
"""Wait for an incoming datagram and return it with
|
||||
the corresponding address.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
if self._queue.empty() and self._closed:
|
||||
raise IOError("Enpoint is closed")
|
||||
data, addr = await self._queue.get()
|
||||
if data is None:
|
||||
raise IOError("Enpoint is closed")
|
||||
return data, addr
|
||||
|
||||
def abort(self):
|
||||
"""Close the transport immediately."""
|
||||
if self._closed:
|
||||
raise IOError("Enpoint is closed")
|
||||
self._transport.abort()
|
||||
self.close()
|
||||
|
||||
async def drain(self):
|
||||
"""Drain the transport buffer below the low-water mark."""
|
||||
if self._write_ready_future is not None:
|
||||
await self._write_ready_future
|
||||
|
||||
# Properties
|
||||
|
||||
@property
|
||||
def address(self):
|
||||
"""The endpoint address as a (host, port) tuple."""
|
||||
return self._transport.get_extra_info("socket").getsockname()
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
"""Indicates whether the endpoint is closed or not."""
|
||||
return self._closed
|
||||
|
||||
|
||||
class LocalEndpoint(Endpoint):
|
||||
"""High-level interface for UDP local enpoints.
|
||||
|
||||
It is initialized with an optional queue size for the incoming datagrams.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class RemoteEndpoint(Endpoint):
|
||||
"""High-level interface for UDP remote enpoints.
|
||||
|
||||
It is initialized with an optional queue size for the incoming datagrams.
|
||||
"""
|
||||
|
||||
def send(self, data):
|
||||
"""Send a datagram to the remote host."""
|
||||
super().send(data, None)
|
||||
|
||||
async def receive(self):
|
||||
""" Wait for an incoming datagram from the remote host.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
data, addr = await super().receive()
|
||||
return data
|
||||
|
||||
|
||||
# High-level coroutines
|
||||
|
||||
async def open_datagram_endpoint(endpoint_factory=Endpoint, **kwargs):
|
||||
"""Open and return a datagram endpoint.
|
||||
|
||||
The default endpoint factory is the Endpoint class.
|
||||
The endpoint can be made local or remote using the remote argument.
|
||||
Extra keyword arguments are forwarded to `loop.create_datagram_endpoint`.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
endpoint = endpoint_factory()
|
||||
kwargs['protocol_factory'] = lambda: DatagramEndpointProtocol(endpoint)
|
||||
await loop.create_datagram_endpoint(**kwargs)
|
||||
return endpoint
|
||||
|
||||
|
||||
async def open_local_endpoint(host='0.0.0.0', port=0, *, queue_size=None, **kwargs):
|
||||
"""Open and return a local datagram endpoint.
|
||||
|
||||
An optional queue size arguement can be provided.
|
||||
Extra keyword arguments are forwarded to `loop.create_datagram_endpoint`.
|
||||
"""
|
||||
return await open_datagram_endpoint(
|
||||
local_addr=(host, port),
|
||||
endpoint_factory=lambda: LocalEndpoint(queue_size),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def open_remote_endpoint(host, port, queue_size=None, **kwargs):
|
||||
"""Open and return a remote datagram endpoint.
|
||||
|
||||
An optional queue size arguement can be provided.
|
||||
Extra keyword arguments are forwarded to `loop.create_datagram_endpoint`.
|
||||
"""
|
||||
return await open_datagram_endpoint(
|
||||
remote_addr=(host, port),
|
||||
endpoint_factory=lambda: RemoteEndpoint(queue_size),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Testing
|
||||
|
||||
try:
|
||||
import pytest
|
||||
pytestmark = pytest.mark.asyncio
|
||||
except ImportError: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
async def test_standard_behavior():
|
||||
local = await open_local_endpoint()
|
||||
remote = await open_remote_endpoint(*local.address)
|
||||
|
||||
remote.send(b'Hey Hey')
|
||||
data, address = await local.receive()
|
||||
|
||||
assert data == b'Hey Hey'
|
||||
assert address == remote.address
|
||||
|
||||
local.send(b'My My', address)
|
||||
data = await remote.receive()
|
||||
assert data == b'My My'
|
||||
|
||||
local.abort()
|
||||
assert local.closed
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
await asyncio.sleep(1e-3)
|
||||
remote.send(b'U there?')
|
||||
await asyncio.sleep(1e-3)
|
||||
|
||||
remote.abort()
|
||||
assert remote.closed
|
||||
|
||||
|
||||
async def test_closed_endpoint():
|
||||
local = await open_local_endpoint()
|
||||
future = asyncio.ensure_future(local.receive())
|
||||
local.abort()
|
||||
assert local.closed
|
||||
|
||||
with pytest.raises(IOError):
|
||||
await future
|
||||
|
||||
with pytest.raises(IOError):
|
||||
await local.receive()
|
||||
|
||||
with pytest.raises(IOError):
|
||||
await local.send(b'test', ('localhost', 8888))
|
||||
|
||||
with pytest.raises(IOError):
|
||||
local.abort()
|
||||
|
||||
|
||||
async def test_queue_size():
|
||||
local = await open_local_endpoint(queue_size=1)
|
||||
remote = await open_remote_endpoint(*local.address)
|
||||
|
||||
remote.send(b'1')
|
||||
remote.send(b'2')
|
||||
with pytest.warns(UserWarning):
|
||||
await asyncio.sleep(1e-3)
|
||||
assert await local.receive() == (b'1', remote.address)
|
||||
remote.send(b'3')
|
||||
assert await local.receive() == (b'3', remote.address)
|
||||
|
||||
remote.send(b'4')
|
||||
await asyncio.sleep(1e-3)
|
||||
local.abort()
|
||||
assert local.closed
|
||||
assert await local.receive() == (b'4', remote.address)
|
||||
|
||||
remote.abort()
|
||||
assert remote.closed
|
||||
|
||||
|
||||
async def test_flow_control():
|
||||
m = n = 1024
|
||||
remote = await open_remote_endpoint("8.8.8.8", 12345)
|
||||
|
||||
for _ in range(m):
|
||||
remote.send(b"a" * n)
|
||||
|
||||
await remote.drain()
|
||||
|
||||
for _ in range(m):
|
||||
remote.send(b"a" * n)
|
||||
|
||||
remote.abort()
|
||||
await remote.drain()
|
||||
|
||||
|
||||
if __name__ == '__main__': # pragma: no cover
|
||||
pytest.main([__file__])
|
Loading…
Reference in a new issue