captive.whump.shanti-portal/tools/client.py

160 lines
4.5 KiB
Python
Raw Normal View History

2017-03-03 00:04:12 +00:00
"""
Handles "clients" in IPtables for captive portal.
"""
import ipaddress
2017-03-06 15:03:57 +00:00
from uuid import uuid4
2017-09-29 15:44:40 +00:00
from datetime import datetime, timedelta
2017-03-03 00:04:12 +00:00
2017-03-07 08:36:56 +00:00
import iptc
2017-03-03 00:04:12 +00:00
2017-09-29 17:19:12 +00:00
from errors import StorageNotFound, IPTCRuleNotFound
2017-03-06 15:03:57 +00:00
2017-03-03 00:04:12 +00:00
class Client(object):
def __init__(self, **kw):
2017-03-06 15:03:57 +00:00
# Required parameters
2017-03-03 00:04:12 +00:00
self.storage = kw.pop('storage')
2017-03-07 08:36:56 +00:00
self._chain = kw.pop('chain')
2017-09-29 17:25:53 +00:00
2017-09-29 17:30:27 +00:00
self.ip_address = kw.pop('ip_address', '127.0.0.1')
self.protocol = kw.pop('protocol', 'tcp')
2017-03-03 00:04:12 +00:00
2017-03-06 15:03:57 +00:00
# First try to get an existing client by ID
self.client_id = kw.pop('client_id', None)
if self.client_id:
client_data = self.storage.get_client_by_id(self.client_id)
2017-03-03 00:04:12 +00:00
2017-03-06 15:03:57 +00:00
# If ID is specified then we raise exception if client isn't
# found.
if client_data is None:
raise StorageNotFound('Client not found')
2017-09-29 17:01:27 +00:00
else:
2017-09-29 17:19:12 +00:00
client_data = self.storage.get_client(
self._ip_address,
self.protocol
)
# Init iptables
self.table = iptc.Table(iptc.Table.MANGLE)
self.chain = iptc.Chain(self.table, self._chain)
2017-03-03 00:04:12 +00:00
if client_data:
2017-03-06 15:03:57 +00:00
self.load_client(client_data)
else:
self.client_id = str(uuid4())
self.created = datetime.now()
self.enabled = False
self.last_packets = 0
self.last_activity = None
2017-09-29 15:44:40 +00:00
self.expires = datetime.now() + timedelta(days=1)
2017-03-06 15:03:57 +00:00
def load_client(self, data):
self.client_id = data.get('client_id')
self.created = data.get('created')
self.ip_address = data.get('ip_address')
self.protocol = data.get('protocol')
self.enabled = data.get('enabled')
self.last_packets = data.get('last_packets')
self.last_activity = data.get('last_activity')
2017-09-29 16:42:45 +00:00
self.expires = data.get('expires')
2017-03-03 00:04:12 +00:00
2017-09-29 16:37:30 +00:00
# Try and find a rule for this client and with that rule also packet
# count. Don't rely on it existing though.
rule = None
try:
2017-09-29 17:11:35 +00:00
rule = self.find_rule(self._ip_address, self.protocol)
2017-09-29 16:37:30 +00:00
except Exception as e:
# TODO: This should raise an exception and be handled further up
# the stack by logging the error.
2017-09-29 17:13:56 +00:00
raise
#raise IPTCRuleNotFound(
# 'Could not find the iptables rule for {client_ip}'.format(
# client_ip=self.ip_address
# )
#)
2017-09-29 16:37:30 +00:00
if rule:
(packet_count, byte_count) = rule.get_counters()
if self.last_packets < packet_count:
self.last_activity = datetime.now()
self.last_packets = packet_count
2017-03-03 00:04:12 +00:00
def commit(self):
self.commit_client()
if self.enabled:
self.commit_rule()
2017-03-07 16:24:52 +00:00
else:
self.remove_rule()
2017-03-03 00:04:12 +00:00
def commit_client(self):
2017-03-06 15:03:57 +00:00
self.storage.write_client(
self
)
def delete(self):
2017-03-07 08:36:56 +00:00
self.remove_rule()
2017-03-06 15:03:57 +00:00
self.storage.remove_client(self)
2017-03-07 08:36:56 +00:00
def remove_rule(self):
rule = self.find_rule(self._ip_address, self.protocol)
2017-03-07 08:36:56 +00:00
if rule:
self.chain.delete_rule(rule)
2017-03-03 00:04:12 +00:00
2017-03-07 08:36:56 +00:00
def find_rule(self, ip_address, protocol):
2017-03-07 16:01:37 +00:00
"""
Takes an ipaddress.IPv4Interface object as ip_address argument.
"""
if not isinstance(ip_address, ipaddress.IPv4Interface):
raise ValueError('Invalid argument type')
2017-03-07 08:36:56 +00:00
for rule in self.chain.rules:
2017-03-03 00:04:12 +00:00
src_ip = rule.src
2017-03-07 16:07:08 +00:00
try:
_ip = str(ip_address.ip)
except:
# If we can't understand the argument just return None
return None
if src_ip.startswith(_ip) and rule.protocol == protocol:
2017-03-07 08:36:56 +00:00
return rule
2017-03-03 00:04:12 +00:00
else:
2017-03-07 08:36:56 +00:00
return None
def commit_rule(self):
2017-09-29 16:55:00 +00:00
rule = self.find_rule(self._ip_address, self.protocol)
2017-03-07 08:36:56 +00:00
if not rule:
2017-03-03 00:04:12 +00:00
rule = iptc.Rule()
2017-03-06 15:03:57 +00:00
rule.src = self.ip_address
2017-03-03 00:04:12 +00:00
rule.protocol = self.protocol
rule.target = iptc.Target(rule, 'RETURN')
2017-03-07 08:36:56 +00:00
self.chain.insert_rule(rule)
2017-03-06 15:03:57 +00:00
@property
def ip_address(self):
return str(self._ip_address.ip)
@ip_address.setter
def ip_address(self, value):
if isinstance(value, str):
self._ip_address = ipaddress.IPv4Interface(value)
elif isinstance(value, ipaddress.IPv4Interface):
self._ip_address = value
else:
raise ValueError('Cannot set invalid value')