mirror of
https://github.com/uhIgnacio/EmoteManager.git
synced 2024-08-15 02:23:13 +00:00
add systemd notify support
This commit is contained in:
parent
722c87c75b
commit
eb6e275c8a
4 changed files with 361 additions and 0 deletions
5
bot.py
5
bot.py
|
@ -58,6 +58,11 @@ class Bot(Bot):
|
||||||
return super().activity
|
return super().activity
|
||||||
return super().activity or discord.Game(f'@{self.user.name} help')
|
return super().activity or discord.Game(f'@{self.user.name} help')
|
||||||
|
|
||||||
|
def load_extensions(self):
|
||||||
|
super().load_extensions()
|
||||||
|
if self.config.get('systemd'):
|
||||||
|
self.load_extension('cogs.systemd')
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
33
cogs/systemd.py
Normal file
33
cogs/systemd.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from utils.socket import open_datagram_endpoint
|
||||||
|
|
||||||
|
class SystemdNotifier(commands.Cog):
|
||||||
|
def __init__(self):
|
||||||
|
self.os_sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_DGRAM)
|
||||||
|
self.connect_task = asyncio.create_task(self.connect())
|
||||||
|
self.addr = os.environ['NOTIFY_SOCKET']
|
||||||
|
|
||||||
|
def send(self, msg):
|
||||||
|
self.sock.send(msg, self.addr)
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
self.sock = await open_datagram_endpoint(sock=self.os_sock)
|
||||||
|
|
||||||
|
def cog_unload(self):
|
||||||
|
self.connect_task.cancel()
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_shard_ready(self, shard_id):
|
||||||
|
self.send(b'STATUS=Ready on shard %d' % shard_id)
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_ready(self):
|
||||||
|
self.send(b'READY=1')
|
||||||
|
|
||||||
|
def setup(bot):
|
||||||
|
bot.add_cog(SystemdNotifier())
|
|
@ -4,6 +4,11 @@
|
||||||
'NOTE: Most commands will be unavailable until both you and the bot have the '
|
'NOTE: Most commands will be unavailable until both you and the bot have the '
|
||||||
'"Manage Emojis" permission.',
|
'"Manage Emojis" permission.',
|
||||||
|
|
||||||
|
# Whether to use systemd notify to inform systemd of the bot being ready.
|
||||||
|
# This can be used for fine-grained dependency management, where one cluster doesn't start until the previous
|
||||||
|
# finishes starting up.
|
||||||
|
'systemd': False,
|
||||||
|
|
||||||
# a channel ID to invite people to when they request help with the bot
|
# a channel ID to invite people to when they request help with the bot
|
||||||
# the bot must have Create Instant Invite permissions for this channel
|
# the bot must have Create Instant Invite permissions for this channel
|
||||||
# if set to None, the support command will be disabled
|
# if set to None, the support command will be disabled
|
||||||
|
|
318
utils/socket.py
Normal file
318
utils/socket.py
Normal file
|
@ -0,0 +1,318 @@
|
||||||
|
"""Provide high-level UDP endpoints for asyncio.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
|
||||||
|
# Create a local UDP enpoint
|
||||||
|
local = await open_local_endpoint('localhost', 8888)
|
||||||
|
|
||||||
|
# Create a remote UDP enpoint, pointing to the first one
|
||||||
|
remote = await open_remote_endpoint(*local.address)
|
||||||
|
|
||||||
|
# The remote endpoint sends a datagram
|
||||||
|
remote.send(b'Hey Hey, My My')
|
||||||
|
|
||||||
|
# The local endpoint receives the datagram, along with the address
|
||||||
|
data, address = await local.receive()
|
||||||
|
|
||||||
|
# This prints: Got 'Hey Hey, My My' from 127.0.0.1 port 8888
|
||||||
|
print(f"Got {data!r} from {address[0]} port {address[1]}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
__all__ = ['open_local_endpoint', 'open_remote_endpoint']
|
||||||
|
|
||||||
|
|
||||||
|
# Imports
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
# Datagram protocol
|
||||||
|
|
||||||
|
class DatagramEndpointProtocol(asyncio.DatagramProtocol):
|
||||||
|
"""Datagram protocol for the endpoint high-level interface."""
|
||||||
|
|
||||||
|
def __init__(self, endpoint):
|
||||||
|
self._endpoint = endpoint
|
||||||
|
|
||||||
|
# Protocol methods
|
||||||
|
|
||||||
|
def connection_made(self, transport):
|
||||||
|
self._endpoint._transport = transport
|
||||||
|
|
||||||
|
def connection_lost(self, exc):
|
||||||
|
assert exc is None
|
||||||
|
if self._endpoint._write_ready_future is not None:
|
||||||
|
self._endpoint._write_ready_future.set_result(None)
|
||||||
|
self._endpoint.close()
|
||||||
|
|
||||||
|
# Datagram protocol methods
|
||||||
|
|
||||||
|
def datagram_received(self, data, addr):
|
||||||
|
self._endpoint.feed_datagram(data, addr)
|
||||||
|
|
||||||
|
def error_received(self, exc):
|
||||||
|
msg = 'Endpoint received an error: {!r}'
|
||||||
|
warnings.warn(msg.format(exc))
|
||||||
|
|
||||||
|
# Workflow control
|
||||||
|
|
||||||
|
def pause_writing(self):
|
||||||
|
assert self._endpoint._write_ready_future is None
|
||||||
|
loop = self._endpoint._transport._loop
|
||||||
|
self._endpoint._write_ready_future = loop.create_future()
|
||||||
|
|
||||||
|
def resume_writing(self):
|
||||||
|
assert self._endpoint._write_ready_future is not None
|
||||||
|
self._endpoint._write_ready_future.set_result(None)
|
||||||
|
self._endpoint._write_ready_future = None
|
||||||
|
|
||||||
|
|
||||||
|
# Enpoint classes
|
||||||
|
|
||||||
|
class Endpoint:
|
||||||
|
"""High-level interface for UDP enpoints.
|
||||||
|
|
||||||
|
Can either be local or remote.
|
||||||
|
It is initialized with an optional queue size for the incoming datagrams.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, queue_size=None):
|
||||||
|
if queue_size is None:
|
||||||
|
queue_size = 0
|
||||||
|
self._queue = asyncio.Queue(queue_size)
|
||||||
|
self._closed = False
|
||||||
|
self._transport = None
|
||||||
|
self._write_ready_future = None
|
||||||
|
|
||||||
|
# Protocol callbacks
|
||||||
|
|
||||||
|
def feed_datagram(self, data, addr):
|
||||||
|
try:
|
||||||
|
self._queue.put_nowait((data, addr))
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
warnings.warn('Endpoint queue is full')
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
# Manage flag
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
self._closed = True
|
||||||
|
# Wake up
|
||||||
|
if self._queue.empty():
|
||||||
|
self.feed_datagram(None, None)
|
||||||
|
# Close transport
|
||||||
|
if self._transport:
|
||||||
|
self._transport.close()
|
||||||
|
|
||||||
|
# User methods
|
||||||
|
|
||||||
|
def send(self, data, addr):
|
||||||
|
"""Send a datagram to the given address."""
|
||||||
|
if self._closed:
|
||||||
|
raise IOError("Enpoint is closed")
|
||||||
|
self._transport.sendto(data, addr)
|
||||||
|
|
||||||
|
async def receive(self):
|
||||||
|
"""Wait for an incoming datagram and return it with
|
||||||
|
the corresponding address.
|
||||||
|
|
||||||
|
This method is a coroutine.
|
||||||
|
"""
|
||||||
|
if self._queue.empty() and self._closed:
|
||||||
|
raise IOError("Enpoint is closed")
|
||||||
|
data, addr = await self._queue.get()
|
||||||
|
if data is None:
|
||||||
|
raise IOError("Enpoint is closed")
|
||||||
|
return data, addr
|
||||||
|
|
||||||
|
def abort(self):
|
||||||
|
"""Close the transport immediately."""
|
||||||
|
if self._closed:
|
||||||
|
raise IOError("Enpoint is closed")
|
||||||
|
self._transport.abort()
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
async def drain(self):
|
||||||
|
"""Drain the transport buffer below the low-water mark."""
|
||||||
|
if self._write_ready_future is not None:
|
||||||
|
await self._write_ready_future
|
||||||
|
|
||||||
|
# Properties
|
||||||
|
|
||||||
|
@property
|
||||||
|
def address(self):
|
||||||
|
"""The endpoint address as a (host, port) tuple."""
|
||||||
|
return self._transport.get_extra_info("socket").getsockname()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def closed(self):
|
||||||
|
"""Indicates whether the endpoint is closed or not."""
|
||||||
|
return self._closed
|
||||||
|
|
||||||
|
|
||||||
|
class LocalEndpoint(Endpoint):
|
||||||
|
"""High-level interface for UDP local enpoints.
|
||||||
|
|
||||||
|
It is initialized with an optional queue size for the incoming datagrams.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteEndpoint(Endpoint):
|
||||||
|
"""High-level interface for UDP remote enpoints.
|
||||||
|
|
||||||
|
It is initialized with an optional queue size for the incoming datagrams.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def send(self, data):
|
||||||
|
"""Send a datagram to the remote host."""
|
||||||
|
super().send(data, None)
|
||||||
|
|
||||||
|
async def receive(self):
|
||||||
|
""" Wait for an incoming datagram from the remote host.
|
||||||
|
|
||||||
|
This method is a coroutine.
|
||||||
|
"""
|
||||||
|
data, addr = await super().receive()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# High-level coroutines
|
||||||
|
|
||||||
|
async def open_datagram_endpoint(endpoint_factory=Endpoint, **kwargs):
|
||||||
|
"""Open and return a datagram endpoint.
|
||||||
|
|
||||||
|
The default endpoint factory is the Endpoint class.
|
||||||
|
The endpoint can be made local or remote using the remote argument.
|
||||||
|
Extra keyword arguments are forwarded to `loop.create_datagram_endpoint`.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
endpoint = endpoint_factory()
|
||||||
|
kwargs['protocol_factory'] = lambda: DatagramEndpointProtocol(endpoint)
|
||||||
|
await loop.create_datagram_endpoint(**kwargs)
|
||||||
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
|
async def open_local_endpoint(host='0.0.0.0', port=0, *, queue_size=None, **kwargs):
|
||||||
|
"""Open and return a local datagram endpoint.
|
||||||
|
|
||||||
|
An optional queue size arguement can be provided.
|
||||||
|
Extra keyword arguments are forwarded to `loop.create_datagram_endpoint`.
|
||||||
|
"""
|
||||||
|
return await open_datagram_endpoint(
|
||||||
|
local_addr=(host, port),
|
||||||
|
endpoint_factory=lambda: LocalEndpoint(queue_size),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def open_remote_endpoint(host, port, queue_size=None, **kwargs):
|
||||||
|
"""Open and return a remote datagram endpoint.
|
||||||
|
|
||||||
|
An optional queue size arguement can be provided.
|
||||||
|
Extra keyword arguments are forwarded to `loop.create_datagram_endpoint`.
|
||||||
|
"""
|
||||||
|
return await open_datagram_endpoint(
|
||||||
|
remote_addr=(host, port),
|
||||||
|
endpoint_factory=lambda: RemoteEndpoint(queue_size),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pytest
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
except ImportError: # pragma: no cover
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def test_standard_behavior():
|
||||||
|
local = await open_local_endpoint()
|
||||||
|
remote = await open_remote_endpoint(*local.address)
|
||||||
|
|
||||||
|
remote.send(b'Hey Hey')
|
||||||
|
data, address = await local.receive()
|
||||||
|
|
||||||
|
assert data == b'Hey Hey'
|
||||||
|
assert address == remote.address
|
||||||
|
|
||||||
|
local.send(b'My My', address)
|
||||||
|
data = await remote.receive()
|
||||||
|
assert data == b'My My'
|
||||||
|
|
||||||
|
local.abort()
|
||||||
|
assert local.closed
|
||||||
|
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
await asyncio.sleep(1e-3)
|
||||||
|
remote.send(b'U there?')
|
||||||
|
await asyncio.sleep(1e-3)
|
||||||
|
|
||||||
|
remote.abort()
|
||||||
|
assert remote.closed
|
||||||
|
|
||||||
|
|
||||||
|
async def test_closed_endpoint():
|
||||||
|
local = await open_local_endpoint()
|
||||||
|
future = asyncio.ensure_future(local.receive())
|
||||||
|
local.abort()
|
||||||
|
assert local.closed
|
||||||
|
|
||||||
|
with pytest.raises(IOError):
|
||||||
|
await future
|
||||||
|
|
||||||
|
with pytest.raises(IOError):
|
||||||
|
await local.receive()
|
||||||
|
|
||||||
|
with pytest.raises(IOError):
|
||||||
|
await local.send(b'test', ('localhost', 8888))
|
||||||
|
|
||||||
|
with pytest.raises(IOError):
|
||||||
|
local.abort()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_queue_size():
|
||||||
|
local = await open_local_endpoint(queue_size=1)
|
||||||
|
remote = await open_remote_endpoint(*local.address)
|
||||||
|
|
||||||
|
remote.send(b'1')
|
||||||
|
remote.send(b'2')
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
await asyncio.sleep(1e-3)
|
||||||
|
assert await local.receive() == (b'1', remote.address)
|
||||||
|
remote.send(b'3')
|
||||||
|
assert await local.receive() == (b'3', remote.address)
|
||||||
|
|
||||||
|
remote.send(b'4')
|
||||||
|
await asyncio.sleep(1e-3)
|
||||||
|
local.abort()
|
||||||
|
assert local.closed
|
||||||
|
assert await local.receive() == (b'4', remote.address)
|
||||||
|
|
||||||
|
remote.abort()
|
||||||
|
assert remote.closed
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flow_control():
|
||||||
|
m = n = 1024
|
||||||
|
remote = await open_remote_endpoint("8.8.8.8", 12345)
|
||||||
|
|
||||||
|
for _ in range(m):
|
||||||
|
remote.send(b"a" * n)
|
||||||
|
|
||||||
|
await remote.drain()
|
||||||
|
|
||||||
|
for _ in range(m):
|
||||||
|
remote.send(b"a" * n)
|
||||||
|
|
||||||
|
remote.abort()
|
||||||
|
await remote.drain()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': # pragma: no cover
|
||||||
|
pytest.main([__file__])
|
Loading…
Reference in a new issue