diff --git a/portalclientlib/__init__.py b/portalclientlib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/portalclientlib/client.py b/portalclientlib/client.py new file mode 100644 index 0000000..2f1c4d4 --- /dev/null +++ b/portalclientlib/client.py @@ -0,0 +1,122 @@ +""" +Handles "clients" in IPtables for captive portal. +""" + +import ipaddress +from uuid import uuid4 +from datetime import datetime, timedelta + +from portalclientlib import errors +from portalclientlib.helpers import run_ipset + + +class Client(object): + + def __init__(self, **kw): + # Required parameters + self.storage = kw.pop('storage') + self.ipset_name = kw.pop('ipset_name') + self.use_sudo = kw.pop('use_sudo', False) + + self.ip_address = kw.pop('ip_address', '127.0.0.1') + self.protocol = kw.pop('protocol', 'tcp') + + self.new = False + + # First try to get an existing client by ID + self.client_id = kw.pop('client_id', None) + if self.client_id: + client_data = self.storage.get_client_by_id(self.client_id) + + # If ID is specified then we raise exception if client isn't + # found. + if client_data is None: + raise errors.StorageNotFound('Client not found') + else: + client_data = self.storage.get_client( + self._ip_address, + self.protocol + ) + + # Init iptables + #self.table = iptc.Table(iptc.Table.MANGLE) + #self.chain = iptc.Chain(self.table, self._chain) + + if client_data: + self.load_client(client_data) + else: + self.client_id = str(uuid4()) + self.created = datetime.now() + self.enabled = False + self.last_packets = 0 + self.last_activity = None + self.expires = datetime.now() + timedelta(days=1) + self.new = True + + + def load_client(self, data): + self.client_id = data.get('client_id') + self.created = data.get('created') + self.ip_address = data.get('ip_address') + self.protocol = data.get('protocol') + self.enabled = data.get('enabled') + self.last_packets = data.get('last_packets') + self.last_activity = data.get('last_activity') + self.expires = data.get('expires') + + + def commit(self): + self.commit_client() + + if self.enabled: + self.commit_rule() + else: + self.remove_rule() + + + def commit_client(self): + self.storage.write_client( + self + ) + + + def delete(self): + self.remove_rule() + self.storage.remove_client(self) + + + def remove_rule(self): + run_ipset( + 'del', + '-exist', + self.ipset_name, + self.ip_address, + use_sudo=self.use_sudo + ) + + + def commit_rule(self): + run_ipset( + 'add', + '-exist', + self.ipset_name, + self.ip_address, + use_sudo=self.use_sudo + ) + + + @property + def ip_address(self): + return str(self._ip_address.ip) + + + @ip_address.setter + def ip_address(self, value): + if isinstance(value, str): + self._ip_address = ipaddress.IPv4Interface(value) + elif isinstance(value, ipaddress.IPv4Interface): + self._ip_address = value + else: + raise ValueError('Cannot set invalid value') + + diff --git a/portalclientlib/errors.py b/portalclientlib/errors.py new file mode 100644 index 0000000..2609dd7 --- /dev/null +++ b/portalclientlib/errors.py @@ -0,0 +1,8 @@ +class StorageNotFound(Exception): + pass + +class IPTCRuleExists(Exception): + pass + +class IPTCRuleNotFound(Exception): + pass \ No newline at end of file diff --git a/portalclientlib/helpers.py b/portalclientlib/helpers.py new file mode 100644 index 0000000..4f5af3d --- /dev/null +++ b/portalclientlib/helpers.py @@ -0,0 +1,25 @@ + +import subprocess +import shlex + +def run_ipset(command, *args, **kw): + use_sudo = kw.get('use_sudo', True) + timeout = kw.get('timeout', 2) + + if use_sudo: + ipset_cmd = 'sudo ipset' + else: + ipset_cmd = 'ipset' + + full_command = '{ipset} {command} {args}'.format( + ipset=ipset_cmd, + command=command, + args=' '.join(args) + ) + + output = subprocess.check_output( + shlex.split(full_command), + timeout=timeout + ) + + return output \ No newline at end of file diff --git a/portalclientlib/storage.py b/portalclientlib/storage.py new file mode 100644 index 0000000..3132c21 --- /dev/null +++ b/portalclientlib/storage.py @@ -0,0 +1,121 @@ +""" +Database storage backends for client.py. +""" + +import json +from datetime import datetime + +import psycopg2 +from psycopg2.extras import DictCursor, register_ipaddress, Inet +from redis import Redis + +from portalclientlib.client import Client + + +class StoragePostgres(object): + """ + This requires python 3 for inet data type. + """ + + def __init__(self, **kw): + config = kw.pop('config') + + self.conn = psycopg2.connect( + host=config.get('postgres', 'hostname'), + user=config.get('postgres', 'username'), + password=config.get('postgres', 'password'), + dbname=config.get('postgres', 'database'), + port=config.getint('postgres', 'port'), + sslmode='disable', + cursor_factory=DictCursor + ) + self.cur = self.conn.cursor() + register_ipaddress() + + + def client_ids(self): + self.cur.execute( + 'select client_id from authenticated_clients' + ) + return self.cur.fetchall() + + + def get_client_by_id(self, client_id): + self.cur.execute( + 'select * from authenticated_clients where client_id=%s', + (client_id,) + ) + return self.cur.fetchone() + + + def get_client(self, ip_address, protocol): + """ + Expects an ipaddress.IPv4Interface as ip_address argument. + """ + + self.cur.execute( + 'select * from authenticated_clients where ip_address=%s and protocol=%s', + (ip_address, protocol, ) + ) + return self.cur.fetchone() + + + def write_client(self, client): + query = ( + 'insert into authenticated_clients (client_id, created, ip_address, protocol, ' + 'enabled, last_packets, last_activity, expires) values ' + '(%s, %s, %s, %s, %s, %s, %s, %s) on conflict (client_id, ' + 'ip_address, protocol) do update set (enabled, last_packets, ' + 'last_activity, expires) = (EXCLUDED.enabled, EXCLUDED.last_packets, ' + 'EXCLUDED.last_activity, EXCLUDED.expires)' + ) + self.cur.execute( + query, + ( + client.client_id, + client.created, + client.ip_address, + client.protocol, + client.enabled, + client.last_packets, + client.last_activity, + client.expires + ) + ) + self.conn.commit() + + + def remove_client(self, client): + query = 'delete from authenticated_clients where client_id=%s' + self.cur.execute(query, (client.client_id,)) + self.conn.commit() + + +class DateTimeEncoder(json.JSONEncoder): + """ + json.JSONEncoder sub-class that converts all datetime objects to + epoch timestamp integer values. + """ + def default(self, o): + if isinstance(o, datetime): + return int(o.strftime('%s')) + return json.JSONEncoder.default(self, o) + + +class StorageRedis(object): + """ + Note: Abandoned this storage backend for Postgres. + """ + + def __init__(self, **kw): + config = kw.pop('config') + + self.r = Redis( + host=config.get('redis', 'hostname'), + port=config.getint('redis', 'port'), + db=config.getint('redis', 'db') + ) + + + def add_client(self, client_id, **kw): + raise NotImplemented