diff --git a/memed.py b/memed.py index b64eed4..18d3256 100644 --- a/memed.py +++ b/memed.py @@ -17,6 +17,8 @@ from bot import schedule_bot logging.basicConfig(level=logging.DEBUG) log = logging.getLogger(__name__) + +# globals are bad, but who cares db = None bot = None @@ -71,28 +73,31 @@ class MemeClient: self.loop = asyncio.get_event_loop() async def read_msg(self) -> str: + """Read one message from the socket.""" header = await self.reader.read(8) log.debug('[recv] %r', header) - length, op = struct.unpack('Ii', header) + length, opcode = struct.unpack('Ii', header) data = await self.reader.read(length) data = data.decode() - log.debug('[recv] %d %d %s', length, op, data) - return op, data + log.debug('[recv] %d %d %s', length, opcode, data) + return opcode, data async def read_payload(self) -> dict: """Read a payload from the socket.""" - op, message = await self.read_msg() - if op > 10: - return op, json.loads(message) - else: - return op, message + opcode, message = await self.read_msg() - async def send_msg(self, op: int, data: str) -> 'None': + # NOTE: this is kinda unused + if opcode > 10: + return opcode, json.loads(message) + + return opcode, message + + async def send_msg(self, op: int, data: str): """Send a message. - + This does not wait for the receiving end - to properly finish their buffers. + to properly flush their buffers. Arguments --------- @@ -101,57 +106,85 @@ class MemeClient: data: str Message to be sent with the op code. """ + # create header, pack message, yadda yadda header = struct.pack('Ii', len(data), op).decode() msg = f'{header}{data}'.encode() log.debug('[send] %d, %s -> %r', op, data, msg) self.writer.write(msg) - # Utils can close this early + # clients can close this early # and make writer.drain kill itself # so we wrap on a task which is isolated asyncio.get_event_loop().create_task(wrap(self.writer.drain)) async def process(self, op: int, message: str) -> 'None': """Process a message given through the socket""" - if op == 1: - uid, cwd, command = parse_logstr(message) - log.info('[process] Logging command uid=%d cwd=%r cmd=%r', - uid, cwd, command) + handler = getattr(self, f'handle_{op}', None) - await db.execute(""" - INSERT INTO logs (uid, cwd, cmd) VALUES ($1, $2, $3) - """, uid, cwd, command) - elif op == 2: - # Handle rsudo without waiting - if not bot: - return await self.send_msg(1, 'no bot up') + if not handler: + # Ignore unknown OP codes. + return - rsudo = bot.get_cog('Rsudo') - if not rsudo: - return await self.send_msg(1, 'no rsudo cog') + await handler(message) - log.info('[process] got rsudo! %r', message) + async def handle_1(self, message: str): + uid, cwd, command = parse_logstr(message) + log.info('[process] Logging command ' + f'uid={uid} cwd={cwd} cmd={command}') - # this doesnt wait for the thing - self.loop.create_task(rsudo.request(message)) - return await self.send_msg(1, "true") - elif op == 3: - # handle rsudo, steroid version - if not bot: - return await self.send_msg(1, 'no bot') + await db.execute(""" + INSERT INTO logs (uid, cwd, cmd) VALUES ($1, $2, $3) + """, uid, cwd, command) - rsudo = bot.get_cog('Rsudo') - if not rsudo: - return await self.send_msg(1, 'no rsudo cog') + async def handle_2(self, message: str): + """Handle an OP 2 packet. - log.info('[process - wait] %r', message) - # this does - ok = await rsudo.request(message, True) - ok_str = 'true' if ok else 'false' - return await self.send_msg(1, ok_str) + This is the RSudo handling, but without waiting for + the approval of a mod (the 'nowait' behavior). + """ + if not bot: + return await self.send_msg(1, 'no bot up') + + rsudo = bot.get_cog('Rsudo') + if not rsudo: + return await self.send_msg(1, 'no rsudo cog') + + log.info('[rsudo:nowait] %r', message) + + # this doesnt wait for the thing + self.loop.create_task(rsudo.request(message)) + return await self.send_msg(1, "true") + + async def handle_3(self, message: str): + """Handle an OP 3 packet. + + This follows the same logic as OP 2, however + the client gets a response back, a string ('true' or 'false'), + depending on the request. + + NO COMMANDS WILL BE EXECUTED SERVER-SIDE FROM THIS OP. + """ + if not bot: + return await self.send_msg(1, 'no bot') + + rsudo = bot.get_cog('Rsudo') + if not rsudo: + return await self.send_msg(1, 'no rsudo cog') + + log.info('[rsudo:wait] %r', message) + success = await rsudo.request(message, True) + + # the original idea was sending an int back + # but we had no idea how to do that since + # we were already working with strings all the time + # so yeah, this is *another* string + ok_str = 'true' if success else 'false' + + return await self.send_msg(1, ok_str) async def client_loop(self): + """Enter a loop waiting for messages from the client.""" try: while True: op, message = await self.read_msg() @@ -164,34 +197,36 @@ class MemeClient: async def handle_client(reader, writer): - """Handle clients""" + """Handle clients coming in the socket, spawn a loop for them.""" client = MemeClient(reader, writer) await client.send_msg(0, 'hello') await client.client_loop() -if __name__ == '__main__': +def main(): loop = asyncio.get_event_loop() coro = asyncio.start_unix_server(handle_client, sys.argv[1], loop=loop) - db = loop.create_task(asyncpg.create_pool(**config.db)) + pool = loop.create_task(asyncpg.create_pool(**config.db)) server = loop.run_until_complete(coro) if config.bot_token: - bot = schedule_bot(loop, config, db) + bot = schedule_bot(loop, config, pool) if bot: loop.create_task(bot.start(config.bot_token)) - log.info(f'Serving on {server.sockets[0].getsockname()}') + log.info(f'memed serving at {server.sockets[0].getsockname()}') try: loop.run_forever() - except KeyboardInterrupt: - pass + finally: + log.info('Closing server') + server.close() + loop.run_until_complete(server.wait_closed()) + loop.close() - log.info('Closing server') - server.close() - loop.run_until_complete(server.wait_closed()) - loop.close() + +if __name__ == '__main__': + main()