2017-03-03 00:04:12 +00:00
|
|
|
"""
|
|
|
|
Handles "clients" in IPtables for captive portal.
|
|
|
|
"""
|
|
|
|
|
2017-03-06 15:03:57 +00:00
|
|
|
from uuid import uuid4
|
2017-03-03 00:04:12 +00:00
|
|
|
from datetime import datetime
|
|
|
|
|
2017-03-07 08:36:56 +00:00
|
|
|
import iptc
|
2017-03-03 00:04:12 +00:00
|
|
|
|
2017-03-06 15:03:57 +00:00
|
|
|
import errors
|
|
|
|
|
2017-03-03 00:04:12 +00:00
|
|
|
|
|
|
|
class Client(object):
|
|
|
|
|
|
|
|
def __init__(self, **kw):
|
2017-03-06 15:03:57 +00:00
|
|
|
# Required parameters
|
2017-03-03 00:04:12 +00:00
|
|
|
self.storage = kw.pop('storage')
|
2017-03-07 08:36:56 +00:00
|
|
|
self._chain = kw.pop('chain')
|
2017-03-03 00:04:12 +00:00
|
|
|
|
2017-03-06 15:03:57 +00:00
|
|
|
# First try to get an existing client by ID
|
|
|
|
self.client_id = kw.pop('client_id', None)
|
|
|
|
if self.client_id:
|
|
|
|
client_data = self.storage.get_client_by_id(self.client_id)
|
2017-03-03 00:04:12 +00:00
|
|
|
|
2017-03-06 15:03:57 +00:00
|
|
|
# If ID is specified then we raise exception if client isn't
|
|
|
|
# found.
|
|
|
|
if client_data is None:
|
|
|
|
raise StorageNotFound('Client not found')
|
|
|
|
|
|
|
|
# Next try to get an existing client by IP and protocol
|
|
|
|
self.ip_address = kw.pop('ip_address')
|
|
|
|
self.protocol = kw.pop('protocol')
|
|
|
|
|
|
|
|
if self.ip_address and self.protocol:
|
|
|
|
client_data = self.storage.get_client(
|
|
|
|
self.ip_address,
|
|
|
|
self.protocol
|
|
|
|
)
|
2017-03-03 00:04:12 +00:00
|
|
|
|
|
|
|
if client_data:
|
2017-03-06 15:03:57 +00:00
|
|
|
self.load_client(client_data)
|
|
|
|
else:
|
|
|
|
self.client_id = str(uuid4())
|
|
|
|
self.created = datetime.now()
|
|
|
|
self.enabled = False
|
|
|
|
self.last_packets = 0
|
|
|
|
self.last_activity = None
|
|
|
|
|
2017-03-07 08:36:56 +00:00
|
|
|
# Init iptables
|
|
|
|
self.table = iptc.Table(iptc.Table.MANGLE)
|
2017-03-07 09:39:41 +00:00
|
|
|
self.chain = iptc.Chain(self.table, self._chain)
|
2017-03-07 08:36:56 +00:00
|
|
|
|
2017-03-06 15:03:57 +00:00
|
|
|
|
|
|
|
def load_client(self, data):
|
|
|
|
self.client_id = data.get('client_id')
|
|
|
|
self.created = data.get('created')
|
|
|
|
self.ip_address = data.get('ip_address')
|
|
|
|
self.protocol = data.get('protocol')
|
|
|
|
self.enabled = data.get('enabled')
|
|
|
|
self.last_packets = data.get('last_packets')
|
|
|
|
self.last_activity = data.get('last_activity')
|
2017-03-03 00:04:12 +00:00
|
|
|
|
|
|
|
|
|
|
|
def commit(self):
|
|
|
|
self.commit_client()
|
2017-03-07 16:15:39 +00:00
|
|
|
|
|
|
|
if self.enabled:
|
|
|
|
self.commit_rule()
|
2017-03-03 00:04:12 +00:00
|
|
|
|
|
|
|
|
|
|
|
def commit_client(self):
|
2017-03-06 15:03:57 +00:00
|
|
|
self.storage.write_client(
|
|
|
|
self
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def delete(self):
|
2017-03-07 08:36:56 +00:00
|
|
|
self.remove_rule()
|
2017-03-06 15:03:57 +00:00
|
|
|
self.storage.remove_client(self)
|
|
|
|
|
|
|
|
|
2017-03-07 08:36:56 +00:00
|
|
|
def remove_rule(self):
|
|
|
|
rule = self.find_rule(self.ip_address, self.protocol)
|
|
|
|
if rule:
|
|
|
|
self.chain.delete_rule(rule)
|
2017-03-03 00:04:12 +00:00
|
|
|
|
|
|
|
|
2017-03-07 08:36:56 +00:00
|
|
|
def find_rule(self, ip_address, protocol):
|
2017-03-07 16:01:37 +00:00
|
|
|
"""
|
|
|
|
Takes an ipaddress.IPv4Interface object as ip_address argument.
|
|
|
|
"""
|
|
|
|
|
2017-03-07 08:36:56 +00:00
|
|
|
for rule in self.chain.rules:
|
2017-03-03 00:04:12 +00:00
|
|
|
src_ip = rule.src
|
2017-03-07 16:07:08 +00:00
|
|
|
|
|
|
|
try:
|
|
|
|
_ip = str(ip_address.ip)
|
|
|
|
except:
|
|
|
|
# If we can't understand the argument just return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
if src_ip.startswith(_ip) and rule.protocol == protocol:
|
2017-03-07 08:36:56 +00:00
|
|
|
return rule
|
2017-03-03 00:04:12 +00:00
|
|
|
else:
|
2017-03-07 08:36:56 +00:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def commit_rule(self):
|
2017-03-07 15:40:02 +00:00
|
|
|
rule = self.find_rule(self.ip_address, self.protocol)
|
2017-03-07 08:36:56 +00:00
|
|
|
if not rule:
|
2017-03-03 00:04:12 +00:00
|
|
|
rule = iptc.Rule()
|
2017-03-06 15:03:57 +00:00
|
|
|
rule.src = self.ip_address
|
2017-03-03 00:04:12 +00:00
|
|
|
rule.protocol = self.protocol
|
|
|
|
rule.target = iptc.Target(rule, 'RETURN')
|
2017-03-07 08:36:56 +00:00
|
|
|
self.chain.insert_rule(rule)
|
2017-03-06 15:03:57 +00:00
|
|
|
|