diff --git a/tools/add_client.py b/tools/add_client.py index 4f9620c..131d54a 100644 --- a/tools/add_client.py +++ b/tools/add_client.py @@ -2,11 +2,13 @@ # Python helper tool to add IPtables rule using the iptc library. This must # of course run as root for iptc to work. +from sys import exit from argparse import ArgumentParser, FileType from pprint import pprint as pp from configparser import RawConfigParser -from storage import StorageRedis +import errors +from storage import StoragePostgres from client import Client parser = ArgumentParser() @@ -35,10 +37,16 @@ args = parser.parse_args() config = RawConfigParser() config.readfp(args.config) -sr = StorageRedis(config=config) -client = Client( - storage=sr, - client_id=args.src_ip, - protocol=args.protocol, - chain=config.get('iptables', 'chain') -) +sr = StoragePostgres(config=config) +try: + client = Client( + storage=sr, + ip_address=args.src_ip, + protocol=args.protocol, + chain=config.get('iptables', 'chain') + ) +except errors.StorageNotFound: + print('Could not find client') + exit(1) + +client.commit() diff --git a/tools/captiveportal.pgsql b/tools/captiveportal.pgsql index ec1710a..9cc4714 100644 --- a/tools/captiveportal.pgsql +++ b/tools/captiveportal.pgsql @@ -4,10 +4,11 @@ create table if not exists client ( client_id uuid NOT NULL primary key unique, created timestamp NOT NULL, ip_address inet NOT NULL, - "protocol" inet_protocol NOT NULL, + protocol inet_protocol NOT NULL, enabled boolean NOT NULL, last_packets bigint default 0, - last_activity timestamp + last_activity timestamp, + primary key (client_id, ip_address, protocol) ); create index if not exists client_ip_address_index on client (ip_address); diff --git a/tools/client.py b/tools/client.py index 1073144..86077ee 100644 --- a/tools/client.py +++ b/tools/client.py @@ -2,36 +2,59 @@ Handles "clients" in IPtables for captive portal. """ +from uuid import uuid4 from datetime import datetime #import iptc +import errors + class Client(object): def __init__(self, **kw): + # Required parameters self.storage = kw.pop('storage') - self.client_id = kw.pop('client_id') - self.protocol = kw.pop('protocol') self.chain = kw.pop('chain') - # Default values for client data - self.data = { - 'client_id': self.client_id, - 'protocol': self.protocol, - 'created': datetime.now(), - 'bytes': 0, - 'packets': 0, - 'last_activity': None - } + # First try to get an existing client by ID + self.client_id = kw.pop('client_id', None) + if self.client_id: + client_data = self.storage.get_client_by_id(self.client_id) - self.client_exists = False + # If ID is specified then we raise exception if client isn't + # found. + if client_data is None: + raise StorageNotFound('Client not found') + + # Next try to get an existing client by IP and protocol + self.ip_address = kw.pop('ip_address') + self.protocol = kw.pop('protocol') + + if self.ip_address and self.protocol: + client_data = self.storage.get_client( + self.ip_address, + self.protocol + ) - # Attempt to fetch client from storage - client_data = self.storage.get_client(self.client_id) if client_data: - self.data = client_data - self.exists = True + self.load_client(client_data) + else: + self.client_id = str(uuid4()) + self.created = datetime.now() + self.enabled = False + self.last_packets = 0 + self.last_activity = None + + + def load_client(self, data): + self.client_id = data.get('client_id') + self.created = data.get('created') + self.ip_address = data.get('ip_address') + self.protocol = data.get('protocol') + self.enabled = data.get('enabled') + self.last_packets = data.get('last_packets') + self.last_activity = data.get('last_activity') def commit(self): @@ -40,16 +63,19 @@ class Client(object): def commit_client(self): - if self.exists: - self.storage.update_client( - self.client_id, - **self.data - ) - else: - self.storage.add_client( - self.client_id, - **self.data - ) + self.storage.write_client( + self + ) + + + def delete(self): + #self.remove_rule() + self.storage.remove_client(self) + + + def find_rule(self): + raise NotImplemented + def commit_rule(self): table = iptc.Table(iptc.Table.MANGLE) @@ -58,12 +84,13 @@ class Client(object): # Check if rule exists for rule in chain.rules: src_ip = rule.src - if src_ip.startswith(self.client_id) and rule.protocol == self.protocol: + if src_ip.startswith(self.ip_address) and rule.protocol == self.protocol: print('Rule exists') break else: rule = iptc.Rule() - rule.src = self.client_id + rule.src = self.ip_address rule.protocol = self.protocol rule.target = iptc.Target(rule, 'RETURN') chain.insert_rule(rule) + diff --git a/tools/errors.py b/tools/errors.py new file mode 100644 index 0000000..5b0d732 --- /dev/null +++ b/tools/errors.py @@ -0,0 +1,2 @@ +class StorageNotFound(Exception): + pass diff --git a/tools/remove_client.py b/tools/remove_client.py new file mode 100644 index 0000000..9f1e815 --- /dev/null +++ b/tools/remove_client.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python + +from sys import exit +from argparse import ArgumentParser, FileType +from pprint import pprint as pp +from configparser import RawConfigParser + +import errors +from storage import StoragePostgres +from client import Client + +parser = ArgumentParser() + +parser.add_argument( + '--protocol', + required=True, + choices=['tcp', 'udp'], + help='Protocol for client' +) + +parser.add_argument( + '--config', + type=FileType('r'), + required=True, + help='Configuration file' +) + +parser.add_argument( + 'src_ip', + help='Client source IP to add' +) + +args = parser.parse_args() + +config = RawConfigParser() +config.readfp(args.config) + +sr = StoragePostgres(config=config) +try: + client = Client( + storage=sr, + ip_address=args.src_ip, + protocol=args.protocol, + chain=config.get('iptables', 'chain') + ) +except errors.StorageNotFound: + print('Could not find client') + exit(1) + +client.delete() diff --git a/tools/storage.py b/tools/storage.py index 1acfa0c..7cb975f 100644 --- a/tools/storage.py +++ b/tools/storage.py @@ -6,10 +6,16 @@ import json from datetime import datetime import psycopg2 +from psycopg2.extras import DictCursor, register_ipaddress, Inet from redis import Redis +from client import Client + class StoragePostgres(object): + """ + This requires python 3 for inet data type. + """ def __init__(self, **kw): config = kw.pop('config') @@ -19,8 +25,58 @@ class StoragePostgres(object): user=config.get('postgres', 'username'), password=config.get('postgres', 'password'), dbname=config.get('postgres', 'database'), - port=config.getint('postgres', 'port') + port=config.getint('postgres', 'port'), + sslmode='disable', + cursor_factory=DictCursor ) + self.cur = self.conn.cursor() + register_ipaddress() + + + def get_client_by_id(self, client_id): + self.cur.execute( + 'select * from client where client_id=%s', + (client_id,) + ) + return self.cur.fetchone() + + + def get_client(self, ip_address, protocol): + self.cur.execute( + 'select * from client where ip_address=%s and protocol=%s', + (Inet(ip_address), protocol, ) + ) + return self.cur.fetchone() + + + def write_client(self, client): + query = ( + 'insert into client (client_id, created, ip_address, protocol, ' + 'enabled, last_packets, last_activity) values (%s, %s, %s, %s, ' + '%s, %s, %s) on conflict (client_id, ip_address, protocol) do ' + 'update set (enabled, last_packets, last_activity) = ' + '(EXCLUDED.enabled, EXCLUDED.last_packets, ' + 'EXCLUDED.last_activity)' + ) + self.cur.execute( + query, + ( + client.client_id, + client.created, + client.ip_address, + client.protocol, + client.enabled, + client.last_packets, + client.last_activity + ) + ) + self.conn.commit() + + + def remove_client(self, client): + query = 'delete from client where client_id=%s' + self.cur.execute(query, (client.client_id,)) + self.conn.commit() class DateTimeEncoder(json.JSONEncoder):