mirror of
https://codeberg.org/prof_x_pvt_ltd/captive.whump.shanti-portal
synced 2024-08-14 22:46:42 +00:00
client management tools
This commit is contained in:
parent
dbbdb29601
commit
da34f9017c
6 changed files with 183 additions and 39 deletions
|
@ -2,11 +2,13 @@
|
|||
# Python helper tool to add IPtables rule using the iptc library. This must
|
||||
# of course run as root for iptc to work.
|
||||
|
||||
from sys import exit
|
||||
from argparse import ArgumentParser, FileType
|
||||
from pprint import pprint as pp
|
||||
from configparser import RawConfigParser
|
||||
|
||||
from storage import StorageRedis
|
||||
import errors
|
||||
from storage import StoragePostgres
|
||||
from client import Client
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
@ -35,10 +37,16 @@ args = parser.parse_args()
|
|||
config = RawConfigParser()
|
||||
config.readfp(args.config)
|
||||
|
||||
sr = StorageRedis(config=config)
|
||||
client = Client(
|
||||
sr = StoragePostgres(config=config)
|
||||
try:
|
||||
client = Client(
|
||||
storage=sr,
|
||||
client_id=args.src_ip,
|
||||
ip_address=args.src_ip,
|
||||
protocol=args.protocol,
|
||||
chain=config.get('iptables', 'chain')
|
||||
)
|
||||
)
|
||||
except errors.StorageNotFound:
|
||||
print('Could not find client')
|
||||
exit(1)
|
||||
|
||||
client.commit()
|
||||
|
|
|
@ -4,10 +4,11 @@ create table if not exists client (
|
|||
client_id uuid NOT NULL primary key unique,
|
||||
created timestamp NOT NULL,
|
||||
ip_address inet NOT NULL,
|
||||
"protocol" inet_protocol NOT NULL,
|
||||
protocol inet_protocol NOT NULL,
|
||||
enabled boolean NOT NULL,
|
||||
last_packets bigint default 0,
|
||||
last_activity timestamp
|
||||
last_activity timestamp,
|
||||
primary key (client_id, ip_address, protocol)
|
||||
);
|
||||
|
||||
create index if not exists client_ip_address_index on client (ip_address);
|
||||
|
|
|
@ -2,36 +2,59 @@
|
|||
Handles "clients" in IPtables for captive portal.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
from datetime import datetime
|
||||
|
||||
#import iptc
|
||||
|
||||
import errors
|
||||
|
||||
|
||||
class Client(object):
|
||||
|
||||
def __init__(self, **kw):
|
||||
# Required parameters
|
||||
self.storage = kw.pop('storage')
|
||||
self.client_id = kw.pop('client_id')
|
||||
self.protocol = kw.pop('protocol')
|
||||
self.chain = kw.pop('chain')
|
||||
|
||||
# Default values for client data
|
||||
self.data = {
|
||||
'client_id': self.client_id,
|
||||
'protocol': self.protocol,
|
||||
'created': datetime.now(),
|
||||
'bytes': 0,
|
||||
'packets': 0,
|
||||
'last_activity': None
|
||||
}
|
||||
# First try to get an existing client by ID
|
||||
self.client_id = kw.pop('client_id', None)
|
||||
if self.client_id:
|
||||
client_data = self.storage.get_client_by_id(self.client_id)
|
||||
|
||||
self.client_exists = False
|
||||
# If ID is specified then we raise exception if client isn't
|
||||
# found.
|
||||
if client_data is None:
|
||||
raise StorageNotFound('Client not found')
|
||||
|
||||
# Next try to get an existing client by IP and protocol
|
||||
self.ip_address = kw.pop('ip_address')
|
||||
self.protocol = kw.pop('protocol')
|
||||
|
||||
if self.ip_address and self.protocol:
|
||||
client_data = self.storage.get_client(
|
||||
self.ip_address,
|
||||
self.protocol
|
||||
)
|
||||
|
||||
# Attempt to fetch client from storage
|
||||
client_data = self.storage.get_client(self.client_id)
|
||||
if client_data:
|
||||
self.data = client_data
|
||||
self.exists = True
|
||||
self.load_client(client_data)
|
||||
else:
|
||||
self.client_id = str(uuid4())
|
||||
self.created = datetime.now()
|
||||
self.enabled = False
|
||||
self.last_packets = 0
|
||||
self.last_activity = None
|
||||
|
||||
|
||||
def load_client(self, data):
|
||||
self.client_id = data.get('client_id')
|
||||
self.created = data.get('created')
|
||||
self.ip_address = data.get('ip_address')
|
||||
self.protocol = data.get('protocol')
|
||||
self.enabled = data.get('enabled')
|
||||
self.last_packets = data.get('last_packets')
|
||||
self.last_activity = data.get('last_activity')
|
||||
|
||||
|
||||
def commit(self):
|
||||
|
@ -40,17 +63,20 @@ class Client(object):
|
|||
|
||||
|
||||
def commit_client(self):
|
||||
if self.exists:
|
||||
self.storage.update_client(
|
||||
self.client_id,
|
||||
**self.data
|
||||
)
|
||||
else:
|
||||
self.storage.add_client(
|
||||
self.client_id,
|
||||
**self.data
|
||||
self.storage.write_client(
|
||||
self
|
||||
)
|
||||
|
||||
|
||||
def delete(self):
|
||||
#self.remove_rule()
|
||||
self.storage.remove_client(self)
|
||||
|
||||
|
||||
def find_rule(self):
|
||||
raise NotImplemented
|
||||
|
||||
|
||||
def commit_rule(self):
|
||||
table = iptc.Table(iptc.Table.MANGLE)
|
||||
chain = iptc.Chain(table, self.chain)
|
||||
|
@ -58,12 +84,13 @@ class Client(object):
|
|||
# Check if rule exists
|
||||
for rule in chain.rules:
|
||||
src_ip = rule.src
|
||||
if src_ip.startswith(self.client_id) and rule.protocol == self.protocol:
|
||||
if src_ip.startswith(self.ip_address) and rule.protocol == self.protocol:
|
||||
print('Rule exists')
|
||||
break
|
||||
else:
|
||||
rule = iptc.Rule()
|
||||
rule.src = self.client_id
|
||||
rule.src = self.ip_address
|
||||
rule.protocol = self.protocol
|
||||
rule.target = iptc.Target(rule, 'RETURN')
|
||||
chain.insert_rule(rule)
|
||||
|
||||
|
|
2
tools/errors.py
Normal file
2
tools/errors.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
class StorageNotFound(Exception):
|
||||
pass
|
50
tools/remove_client.py
Normal file
50
tools/remove_client.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from sys import exit
|
||||
from argparse import ArgumentParser, FileType
|
||||
from pprint import pprint as pp
|
||||
from configparser import RawConfigParser
|
||||
|
||||
import errors
|
||||
from storage import StoragePostgres
|
||||
from client import Client
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
'--protocol',
|
||||
required=True,
|
||||
choices=['tcp', 'udp'],
|
||||
help='Protocol for client'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--config',
|
||||
type=FileType('r'),
|
||||
required=True,
|
||||
help='Configuration file'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'src_ip',
|
||||
help='Client source IP to add'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = RawConfigParser()
|
||||
config.readfp(args.config)
|
||||
|
||||
sr = StoragePostgres(config=config)
|
||||
try:
|
||||
client = Client(
|
||||
storage=sr,
|
||||
ip_address=args.src_ip,
|
||||
protocol=args.protocol,
|
||||
chain=config.get('iptables', 'chain')
|
||||
)
|
||||
except errors.StorageNotFound:
|
||||
print('Could not find client')
|
||||
exit(1)
|
||||
|
||||
client.delete()
|
|
@ -6,10 +6,16 @@ import json
|
|||
from datetime import datetime
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import DictCursor, register_ipaddress, Inet
|
||||
from redis import Redis
|
||||
|
||||
from client import Client
|
||||
|
||||
|
||||
class StoragePostgres(object):
|
||||
"""
|
||||
This requires python 3 for inet data type.
|
||||
"""
|
||||
|
||||
def __init__(self, **kw):
|
||||
config = kw.pop('config')
|
||||
|
@ -19,8 +25,58 @@ class StoragePostgres(object):
|
|||
user=config.get('postgres', 'username'),
|
||||
password=config.get('postgres', 'password'),
|
||||
dbname=config.get('postgres', 'database'),
|
||||
port=config.getint('postgres', 'port')
|
||||
port=config.getint('postgres', 'port'),
|
||||
sslmode='disable',
|
||||
cursor_factory=DictCursor
|
||||
)
|
||||
self.cur = self.conn.cursor()
|
||||
register_ipaddress()
|
||||
|
||||
|
||||
def get_client_by_id(self, client_id):
|
||||
self.cur.execute(
|
||||
'select * from client where client_id=%s',
|
||||
(client_id,)
|
||||
)
|
||||
return self.cur.fetchone()
|
||||
|
||||
|
||||
def get_client(self, ip_address, protocol):
|
||||
self.cur.execute(
|
||||
'select * from client where ip_address=%s and protocol=%s',
|
||||
(Inet(ip_address), protocol, )
|
||||
)
|
||||
return self.cur.fetchone()
|
||||
|
||||
|
||||
def write_client(self, client):
|
||||
query = (
|
||||
'insert into client (client_id, created, ip_address, protocol, '
|
||||
'enabled, last_packets, last_activity) values (%s, %s, %s, %s, '
|
||||
'%s, %s, %s) on conflict (client_id, ip_address, protocol) do '
|
||||
'update set (enabled, last_packets, last_activity) = '
|
||||
'(EXCLUDED.enabled, EXCLUDED.last_packets, '
|
||||
'EXCLUDED.last_activity)'
|
||||
)
|
||||
self.cur.execute(
|
||||
query,
|
||||
(
|
||||
client.client_id,
|
||||
client.created,
|
||||
client.ip_address,
|
||||
client.protocol,
|
||||
client.enabled,
|
||||
client.last_packets,
|
||||
client.last_activity
|
||||
)
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
|
||||
def remove_client(self, client):
|
||||
query = 'delete from client where client_id=%s'
|
||||
self.cur.execute(query, (client.client_id,))
|
||||
self.conn.commit()
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
|
|
Loading…
Reference in a new issue