client management tools

This commit is contained in:
Stefan Midjich 2017-03-06 16:03:57 +01:00
parent dbbdb29601
commit da34f9017c
6 changed files with 183 additions and 39 deletions

View file

@ -2,11 +2,13 @@
# Python helper tool to add IPtables rule using the iptc library. This must # Python helper tool to add IPtables rule using the iptc library. This must
# of course run as root for iptc to work. # of course run as root for iptc to work.
from sys import exit
from argparse import ArgumentParser, FileType from argparse import ArgumentParser, FileType
from pprint import pprint as pp from pprint import pprint as pp
from configparser import RawConfigParser from configparser import RawConfigParser
from storage import StorageRedis import errors
from storage import StoragePostgres
from client import Client from client import Client
parser = ArgumentParser() parser = ArgumentParser()
@ -35,10 +37,16 @@ args = parser.parse_args()
config = RawConfigParser() config = RawConfigParser()
config.readfp(args.config) config.readfp(args.config)
sr = StorageRedis(config=config) sr = StoragePostgres(config=config)
client = Client( try:
storage=sr, client = Client(
client_id=args.src_ip, storage=sr,
protocol=args.protocol, ip_address=args.src_ip,
chain=config.get('iptables', 'chain') protocol=args.protocol,
) chain=config.get('iptables', 'chain')
)
except errors.StorageNotFound:
print('Could not find client')
exit(1)
client.commit()

View file

@ -4,10 +4,11 @@ create table if not exists client (
client_id uuid NOT NULL primary key unique, client_id uuid NOT NULL primary key unique,
created timestamp NOT NULL, created timestamp NOT NULL,
ip_address inet NOT NULL, ip_address inet NOT NULL,
"protocol" inet_protocol NOT NULL, protocol inet_protocol NOT NULL,
enabled boolean NOT NULL, enabled boolean NOT NULL,
last_packets bigint default 0, last_packets bigint default 0,
last_activity timestamp last_activity timestamp,
primary key (client_id, ip_address, protocol)
); );
create index if not exists client_ip_address_index on client (ip_address); create index if not exists client_ip_address_index on client (ip_address);

View file

@ -2,36 +2,59 @@
Handles "clients" in IPtables for captive portal. Handles "clients" in IPtables for captive portal.
""" """
from uuid import uuid4
from datetime import datetime from datetime import datetime
#import iptc #import iptc
import errors
class Client(object): class Client(object):
def __init__(self, **kw): def __init__(self, **kw):
# Required parameters
self.storage = kw.pop('storage') self.storage = kw.pop('storage')
self.client_id = kw.pop('client_id')
self.protocol = kw.pop('protocol')
self.chain = kw.pop('chain') self.chain = kw.pop('chain')
# Default values for client data # First try to get an existing client by ID
self.data = { self.client_id = kw.pop('client_id', None)
'client_id': self.client_id, if self.client_id:
'protocol': self.protocol, client_data = self.storage.get_client_by_id(self.client_id)
'created': datetime.now(),
'bytes': 0,
'packets': 0,
'last_activity': None
}
self.client_exists = False # If ID is specified then we raise exception if client isn't
# found.
if client_data is None:
raise StorageNotFound('Client not found')
# Next try to get an existing client by IP and protocol
self.ip_address = kw.pop('ip_address')
self.protocol = kw.pop('protocol')
if self.ip_address and self.protocol:
client_data = self.storage.get_client(
self.ip_address,
self.protocol
)
# Attempt to fetch client from storage
client_data = self.storage.get_client(self.client_id)
if client_data: if client_data:
self.data = client_data self.load_client(client_data)
self.exists = True else:
self.client_id = str(uuid4())
self.created = datetime.now()
self.enabled = False
self.last_packets = 0
self.last_activity = None
def load_client(self, data):
self.client_id = data.get('client_id')
self.created = data.get('created')
self.ip_address = data.get('ip_address')
self.protocol = data.get('protocol')
self.enabled = data.get('enabled')
self.last_packets = data.get('last_packets')
self.last_activity = data.get('last_activity')
def commit(self): def commit(self):
@ -40,16 +63,19 @@ class Client(object):
def commit_client(self): def commit_client(self):
if self.exists: self.storage.write_client(
self.storage.update_client( self
self.client_id, )
**self.data
)
else: def delete(self):
self.storage.add_client( #self.remove_rule()
self.client_id, self.storage.remove_client(self)
**self.data
)
def find_rule(self):
raise NotImplemented
def commit_rule(self): def commit_rule(self):
table = iptc.Table(iptc.Table.MANGLE) table = iptc.Table(iptc.Table.MANGLE)
@ -58,12 +84,13 @@ class Client(object):
# Check if rule exists # Check if rule exists
for rule in chain.rules: for rule in chain.rules:
src_ip = rule.src src_ip = rule.src
if src_ip.startswith(self.client_id) and rule.protocol == self.protocol: if src_ip.startswith(self.ip_address) and rule.protocol == self.protocol:
print('Rule exists') print('Rule exists')
break break
else: else:
rule = iptc.Rule() rule = iptc.Rule()
rule.src = self.client_id rule.src = self.ip_address
rule.protocol = self.protocol rule.protocol = self.protocol
rule.target = iptc.Target(rule, 'RETURN') rule.target = iptc.Target(rule, 'RETURN')
chain.insert_rule(rule) chain.insert_rule(rule)

2
tools/errors.py Normal file
View file

@ -0,0 +1,2 @@
class StorageNotFound(Exception):
pass

50
tools/remove_client.py Normal file
View file

@ -0,0 +1,50 @@
#!/usr/bin/env python
from sys import exit
from argparse import ArgumentParser, FileType
from pprint import pprint as pp
from configparser import RawConfigParser
import errors
from storage import StoragePostgres
from client import Client
parser = ArgumentParser()
parser.add_argument(
'--protocol',
required=True,
choices=['tcp', 'udp'],
help='Protocol for client'
)
parser.add_argument(
'--config',
type=FileType('r'),
required=True,
help='Configuration file'
)
parser.add_argument(
'src_ip',
help='Client source IP to add'
)
args = parser.parse_args()
config = RawConfigParser()
config.readfp(args.config)
sr = StoragePostgres(config=config)
try:
client = Client(
storage=sr,
ip_address=args.src_ip,
protocol=args.protocol,
chain=config.get('iptables', 'chain')
)
except errors.StorageNotFound:
print('Could not find client')
exit(1)
client.delete()

View file

@ -6,10 +6,16 @@ import json
from datetime import datetime from datetime import datetime
import psycopg2 import psycopg2
from psycopg2.extras import DictCursor, register_ipaddress, Inet
from redis import Redis from redis import Redis
from client import Client
class StoragePostgres(object): class StoragePostgres(object):
"""
This requires python 3 for inet data type.
"""
def __init__(self, **kw): def __init__(self, **kw):
config = kw.pop('config') config = kw.pop('config')
@ -19,8 +25,58 @@ class StoragePostgres(object):
user=config.get('postgres', 'username'), user=config.get('postgres', 'username'),
password=config.get('postgres', 'password'), password=config.get('postgres', 'password'),
dbname=config.get('postgres', 'database'), dbname=config.get('postgres', 'database'),
port=config.getint('postgres', 'port') port=config.getint('postgres', 'port'),
sslmode='disable',
cursor_factory=DictCursor
) )
self.cur = self.conn.cursor()
register_ipaddress()
def get_client_by_id(self, client_id):
self.cur.execute(
'select * from client where client_id=%s',
(client_id,)
)
return self.cur.fetchone()
def get_client(self, ip_address, protocol):
self.cur.execute(
'select * from client where ip_address=%s and protocol=%s',
(Inet(ip_address), protocol, )
)
return self.cur.fetchone()
def write_client(self, client):
query = (
'insert into client (client_id, created, ip_address, protocol, '
'enabled, last_packets, last_activity) values (%s, %s, %s, %s, '
'%s, %s, %s) on conflict (client_id, ip_address, protocol) do '
'update set (enabled, last_packets, last_activity) = '
'(EXCLUDED.enabled, EXCLUDED.last_packets, '
'EXCLUDED.last_activity)'
)
self.cur.execute(
query,
(
client.client_id,
client.created,
client.ip_address,
client.protocol,
client.enabled,
client.last_packets,
client.last_activity
)
)
self.conn.commit()
def remove_client(self, client):
query = 'delete from client where client_id=%s'
self.cur.execute(query, (client.client_id,))
self.conn.commit()
class DateTimeEncoder(json.JSONEncoder): class DateTimeEncoder(json.JSONEncoder):