import logging import json from typing import List, Dict, Any from .consts import ADAPTERS from .worker import ServiceWorker from .blueprints.streaming import OP log = logging.getLogger(__name__) _COLUMNS = { 'timestamp': 'timestamp bigint', 'status': 'status bool', 'latency': 'latency bigint', } class ServiceManager: def __init__(self, app): self.app = app self.cfg = app.cfg self.conn = app.conn self.loop = app.loop self.workers = {} self.state = {} self.subscribers = {} self._websockets = {} self._start() def _make_db_table(self, name: str, service: dict): adapter = ADAPTERS[service['adapter']] columnstr = map(_COLUMNS.get, adapter.spec['db']) columnstr = ',\n'.join(columnstr) log.info(f'Making table for {name}') self.conn.executescript(f""" CREATE TABLE IF NOT EXISTS {name} ( {columnstr} ); """) self.conn.executescript(""" CREATE TABLE IF NOT EXISTS incidents ( id bigint PRIMARY KEY, incident_type text, title text, content text, ongoing bool, start_timestamp bigint, end_timestamp bigint ); CREATE TABLE IF NOT EXISTS incident_stages ( parent_id bigint REFERENCES incidents (id) NOT NULL, timestamp bigint, title text, content text, PRIMARY KEY (parent_id) ); """) def _check(self, columns: tuple, field: str, worker_name: str): chan_name = f'{field}:{worker_name}' if field in columns and chan_name not in self.subscribers: self.subscribers[chan_name] = [] log.info(f'Created channel {chan_name}') def _create_channels(self, worker): columns = worker.adapter.spec['db'] self._check(columns, 'status', worker.name) self._check(columns, 'latency', worker.name) def _start(self): for name, service in self.cfg.SERVICES.items(): self._make_db_table(name, service) # spawn a service worker serv_worker = ServiceWorker(self, name, service) self.workers[name] = serv_worker self.state[name] = None self._create_channels(serv_worker) def close(self): for worker in self.workers.values(): worker.stop() def subscribe(self, channels: List[str], websocket) -> List[str]: """Subscribe to a list of channels.""" wid = websocket.client_id subscribed = [] self._websockets[websocket.client_id] = websocket for chan in channels: try: self.subscribers[chan].append(wid) subscribed.append(chan) log.info(f'Subscribed {wid} to {chan}') except KeyError: pass return subscribed def unsubscribe(self, channels: List[str], websocket) -> List[str]: wid = websocket.client_id unsub = [] for chan in channels: try: self.subscribers[chan].remove(wid) unsub.append(chan) log.info(f'Unsubscribed {wid} from {chan}') except (KeyError, ValueError): pass return unsub def unsub_all(self, websocket): """Unsubscribe a websocket from all known channels.""" unsub = [] for chan, subs in self.subscribers.items(): try: subs.remove(websocket.client_id) unsub.append(chan) except ValueError: pass log.info(f'unsubscribed {websocket.client_id} from {unsub}') try: self._websockets.pop(websocket.client_id) except KeyError: pass return unsub def _raw_send(self, websocket, channel: str, data: Any): if websocket is None: return loop = self.app.loop return loop.create_task(websocket.send(json.dumps({ 'op': OP.DATA, 'c': channel, 'd': data, }))) def publish(self, channel: str, data: Any): ws_ids = self.subscribers[channel] websockets = map(self._websockets.get, ws_ids) def _send(websocket): return self._raw_send(websocket, channel, data) tasks = map(_send, websockets) return list(tasks)