refactored client lib, so it can be used by the new portalclient module
This commit is contained in:
parent
b4c9c5cdca
commit
e2b160a993
122
tools/client.py
122
tools/client.py
|
@ -1,122 +0,0 @@
|
||||||
"""
|
|
||||||
Handles "clients" in IPtables for captive portal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import ipaddress
|
|
||||||
from uuid import uuid4
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
from errors import StorageNotFound, IPTCRuleNotFound
|
|
||||||
from helpers import run_ipset
|
|
||||||
|
|
||||||
|
|
||||||
class Client(object):
|
|
||||||
|
|
||||||
def __init__(self, **kw):
|
|
||||||
# Required parameters
|
|
||||||
self.storage = kw.pop('storage')
|
|
||||||
self.ipset_name = kw.pop('ipset_name')
|
|
||||||
self.use_sudo = kw.pop('use_sudo', False)
|
|
||||||
|
|
||||||
self.ip_address = kw.pop('ip_address', '127.0.0.1')
|
|
||||||
self.protocol = kw.pop('protocol', 'tcp')
|
|
||||||
|
|
||||||
self.new = False
|
|
||||||
|
|
||||||
# First try to get an existing client by ID
|
|
||||||
self.client_id = kw.pop('client_id', None)
|
|
||||||
if self.client_id:
|
|
||||||
client_data = self.storage.get_client_by_id(self.client_id)
|
|
||||||
|
|
||||||
# If ID is specified then we raise exception if client isn't
|
|
||||||
# found.
|
|
||||||
if client_data is None:
|
|
||||||
raise StorageNotFound('Client not found')
|
|
||||||
else:
|
|
||||||
client_data = self.storage.get_client(
|
|
||||||
self._ip_address,
|
|
||||||
self.protocol
|
|
||||||
)
|
|
||||||
|
|
||||||
# Init iptables
|
|
||||||
#self.table = iptc.Table(iptc.Table.MANGLE)
|
|
||||||
#self.chain = iptc.Chain(self.table, self._chain)
|
|
||||||
|
|
||||||
if client_data:
|
|
||||||
self.load_client(client_data)
|
|
||||||
else:
|
|
||||||
self.client_id = str(uuid4())
|
|
||||||
self.created = datetime.now()
|
|
||||||
self.enabled = False
|
|
||||||
self.last_packets = 0
|
|
||||||
self.last_activity = None
|
|
||||||
self.expires = datetime.now() + timedelta(days=1)
|
|
||||||
self.new = True
|
|
||||||
|
|
||||||
|
|
||||||
def load_client(self, data):
|
|
||||||
self.client_id = data.get('client_id')
|
|
||||||
self.created = data.get('created')
|
|
||||||
self.ip_address = data.get('ip_address')
|
|
||||||
self.protocol = data.get('protocol')
|
|
||||||
self.enabled = data.get('enabled')
|
|
||||||
self.last_packets = data.get('last_packets')
|
|
||||||
self.last_activity = data.get('last_activity')
|
|
||||||
self.expires = data.get('expires')
|
|
||||||
|
|
||||||
|
|
||||||
def commit(self):
|
|
||||||
self.commit_client()
|
|
||||||
|
|
||||||
if self.enabled:
|
|
||||||
self.commit_rule()
|
|
||||||
else:
|
|
||||||
self.remove_rule()
|
|
||||||
|
|
||||||
|
|
||||||
def commit_client(self):
|
|
||||||
self.storage.write_client(
|
|
||||||
self
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def delete(self):
|
|
||||||
self.remove_rule()
|
|
||||||
self.storage.remove_client(self)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_rule(self):
|
|
||||||
run_ipset(
|
|
||||||
'del',
|
|
||||||
'-exist',
|
|
||||||
self.ipset_name,
|
|
||||||
self.ip_address,
|
|
||||||
use_sudo=self.use_sudo
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def commit_rule(self):
|
|
||||||
run_ipset(
|
|
||||||
'add',
|
|
||||||
'-exist',
|
|
||||||
self.ipset_name,
|
|
||||||
self.ip_address,
|
|
||||||
use_sudo=self.use_sudo
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ip_address(self):
|
|
||||||
return str(self._ip_address.ip)
|
|
||||||
|
|
||||||
|
|
||||||
@ip_address.setter
|
|
||||||
def ip_address(self, value):
|
|
||||||
if isinstance(value, str):
|
|
||||||
self._ip_address = ipaddress.IPv4Interface(value)
|
|
||||||
elif isinstance(value, ipaddress.IPv4Interface):
|
|
||||||
self._ip_address = value
|
|
||||||
else:
|
|
||||||
raise ValueError('Cannot set invalid value')
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
class StorageNotFound(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class IPTCRuleExists(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class IPTCRuleNotFound(Exception):
|
|
||||||
pass
|
|
|
@ -1,25 +0,0 @@
|
||||||
|
|
||||||
import subprocess
|
|
||||||
import shlex
|
|
||||||
|
|
||||||
def run_ipset(command, *args, **kw):
|
|
||||||
use_sudo = kw.get('use_sudo', True)
|
|
||||||
timeout = kw.get('timeout', 2)
|
|
||||||
|
|
||||||
if use_sudo:
|
|
||||||
ipset_cmd = 'sudo ipset'
|
|
||||||
else:
|
|
||||||
ipset_cmd = 'ipset'
|
|
||||||
|
|
||||||
full_command = '{ipset} {command} {args}'.format(
|
|
||||||
ipset=ipset_cmd,
|
|
||||||
command=command,
|
|
||||||
args=' '.join(args)
|
|
||||||
)
|
|
||||||
|
|
||||||
output = subprocess.check_output(
|
|
||||||
shlex.split(full_command),
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
|
@ -10,10 +10,9 @@ from configparser import RawConfigParser
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import errors
|
from portalclientlib.client import Client
|
||||||
from helpers import run_ipset
|
from portalclientlib.storage import StoragePostgres
|
||||||
from storage import StoragePostgres
|
from portalclientlib.helpers import run_ipset
|
||||||
from client import Client
|
|
||||||
|
|
||||||
|
|
||||||
# Custom defined argparse types for dates
|
# Custom defined argparse types for dates
|
||||||
|
@ -158,7 +157,7 @@ if args.refresh:
|
||||||
if int(packets_val) != client.last_packets:
|
if int(packets_val) != client.last_packets:
|
||||||
client.last_activity = current_date
|
client.last_activity = current_date
|
||||||
client.last_packets = int(packets_val)
|
client.last_packets = int(packets_val)
|
||||||
if args.verbose:
|
if args.verbose > 1:
|
||||||
print('Updating activity for client:{ip}'.format(
|
print('Updating activity for client:{ip}'.format(
|
||||||
ip=client.ip_address
|
ip=client.ip_address
|
||||||
))
|
))
|
||||||
|
|
|
@ -1,47 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# Python helper tool to purge expired clients from DB and iptables. Requires
|
|
||||||
# root privileges for iptc to work.
|
|
||||||
|
|
||||||
from sys import exit
|
|
||||||
from argparse import ArgumentParser, FileType
|
|
||||||
from pprint import pprint as pp
|
|
||||||
from configparser import RawConfigParser
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
import errors
|
|
||||||
from storage import StoragePostgres
|
|
||||||
from client import Client
|
|
||||||
|
|
||||||
|
|
||||||
parser = ArgumentParser((
|
|
||||||
'Purge expired clients by disabling them.'
|
|
||||||
))
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--config',
|
|
||||||
type=FileType('r'),
|
|
||||||
required=True,
|
|
||||||
help='Configuration file'
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
config = RawConfigParser()
|
|
||||||
config.readfp(args.config)
|
|
||||||
|
|
||||||
sr = StoragePostgres(config=config)
|
|
||||||
|
|
||||||
for client_id in sr.client_ids():
|
|
||||||
client = Client(
|
|
||||||
storage=sr,
|
|
||||||
chain=config.get('iptables', 'chain'),
|
|
||||||
client_id=client_id[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
if datetime.now() > client.expires:
|
|
||||||
client.enabled = False
|
|
||||||
client.commit()
|
|
||||||
else:
|
|
||||||
# Simply commit whatever was loaded during Client.__init__(), like
|
|
||||||
# up-to-date packet count stats for example.
|
|
||||||
client.commit()
|
|
121
tools/storage.py
121
tools/storage.py
|
@ -1,121 +0,0 @@
|
||||||
"""
|
|
||||||
Database storage backends for client.py.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import psycopg2
|
|
||||||
from psycopg2.extras import DictCursor, register_ipaddress, Inet
|
|
||||||
from redis import Redis
|
|
||||||
|
|
||||||
from client import Client
|
|
||||||
|
|
||||||
|
|
||||||
class StoragePostgres(object):
|
|
||||||
"""
|
|
||||||
This requires python 3 for inet data type.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, **kw):
|
|
||||||
config = kw.pop('config')
|
|
||||||
|
|
||||||
self.conn = psycopg2.connect(
|
|
||||||
host=config.get('postgres', 'hostname'),
|
|
||||||
user=config.get('postgres', 'username'),
|
|
||||||
password=config.get('postgres', 'password'),
|
|
||||||
dbname=config.get('postgres', 'database'),
|
|
||||||
port=config.getint('postgres', 'port'),
|
|
||||||
sslmode='disable',
|
|
||||||
cursor_factory=DictCursor
|
|
||||||
)
|
|
||||||
self.cur = self.conn.cursor()
|
|
||||||
register_ipaddress()
|
|
||||||
|
|
||||||
|
|
||||||
def client_ids(self):
|
|
||||||
self.cur.execute(
|
|
||||||
'select client_id from authenticated_clients'
|
|
||||||
)
|
|
||||||
return self.cur.fetchall()
|
|
||||||
|
|
||||||
|
|
||||||
def get_client_by_id(self, client_id):
|
|
||||||
self.cur.execute(
|
|
||||||
'select * from authenticated_clients where client_id=%s',
|
|
||||||
(client_id,)
|
|
||||||
)
|
|
||||||
return self.cur.fetchone()
|
|
||||||
|
|
||||||
|
|
||||||
def get_client(self, ip_address, protocol):
|
|
||||||
"""
|
|
||||||
Expects an ipaddress.IPv4Interface as ip_address argument.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.cur.execute(
|
|
||||||
'select * from authenticated_clients where ip_address=%s and protocol=%s',
|
|
||||||
(ip_address, protocol, )
|
|
||||||
)
|
|
||||||
return self.cur.fetchone()
|
|
||||||
|
|
||||||
|
|
||||||
def write_client(self, client):
|
|
||||||
query = (
|
|
||||||
'insert into authenticated_clients (client_id, created, ip_address, protocol, '
|
|
||||||
'enabled, last_packets, last_activity, expires) values '
|
|
||||||
'(%s, %s, %s, %s, %s, %s, %s, %s) on conflict (client_id, '
|
|
||||||
'ip_address, protocol) do update set (enabled, last_packets, '
|
|
||||||
'last_activity, expires) = (EXCLUDED.enabled, EXCLUDED.last_packets, '
|
|
||||||
'EXCLUDED.last_activity, EXCLUDED.expires)'
|
|
||||||
)
|
|
||||||
self.cur.execute(
|
|
||||||
query,
|
|
||||||
(
|
|
||||||
client.client_id,
|
|
||||||
client.created,
|
|
||||||
client.ip_address,
|
|
||||||
client.protocol,
|
|
||||||
client.enabled,
|
|
||||||
client.last_packets,
|
|
||||||
client.last_activity,
|
|
||||||
client.expires
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def remove_client(self, client):
|
|
||||||
query = 'delete from authenticated_clients where client_id=%s'
|
|
||||||
self.cur.execute(query, (client.client_id,))
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
|
|
||||||
class DateTimeEncoder(json.JSONEncoder):
|
|
||||||
"""
|
|
||||||
json.JSONEncoder sub-class that converts all datetime objects to
|
|
||||||
epoch timestamp integer values.
|
|
||||||
"""
|
|
||||||
def default(self, o):
|
|
||||||
if isinstance(o, datetime):
|
|
||||||
return int(o.strftime('%s'))
|
|
||||||
return json.JSONEncoder.default(self, o)
|
|
||||||
|
|
||||||
|
|
||||||
class StorageRedis(object):
|
|
||||||
"""
|
|
||||||
Note: Abandoned this storage backend for Postgres.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, **kw):
|
|
||||||
config = kw.pop('config')
|
|
||||||
|
|
||||||
self.r = Redis(
|
|
||||||
host=config.get('redis', 'hostname'),
|
|
||||||
port=config.getint('redis', 'port'),
|
|
||||||
db=config.getint('redis', 'db')
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_client(self, client_id, **kw):
|
|
||||||
raise NotImplemented
|
|
Loading…
Reference in New Issue