better handling of IPv4Interface across the board.

This commit is contained in:
Stefan Midjich 2017-03-07 17:36:15 +01:00
parent c1ee2b0813
commit be6b2bef77
2 changed files with 28 additions and 5 deletions

View file

@ -2,6 +2,7 @@
Handles "clients" in IPtables for captive portal. Handles "clients" in IPtables for captive portal.
""" """
import ipaddress
from uuid import uuid4 from uuid import uuid4
from datetime import datetime from datetime import datetime
@ -33,7 +34,7 @@ class Client(object):
if self.ip_address and self.protocol: if self.ip_address and self.protocol:
client_data = self.storage.get_client( client_data = self.storage.get_client(
self.ip_address, self._ip_address,
self.protocol self.protocol
) )
@ -82,7 +83,7 @@ class Client(object):
def remove_rule(self): def remove_rule(self):
rule = self.find_rule(self.ip_address, self.protocol) rule = self.find_rule(self._ip_address, self.protocol)
if rule: if rule:
self.chain.delete_rule(rule) self.chain.delete_rule(rule)
@ -92,6 +93,9 @@ class Client(object):
Takes an ipaddress.IPv4Interface object as ip_address argument. Takes an ipaddress.IPv4Interface object as ip_address argument.
""" """
if not isinstance(ip_address, ipaddress.IPv4Interface):
raise ValueError('Invalid argument type')
for rule in self.chain.rules: for rule in self.chain.rules:
src_ip = rule.src src_ip = rule.src
@ -108,7 +112,7 @@ class Client(object):
def commit_rule(self): def commit_rule(self):
rule = self.find_rule(self.ip_address, self.protocol) rule = self.find_rule(self._ip_address, self.protocol)
if not rule: if not rule:
rule = iptc.Rule() rule = iptc.Rule()
rule.src = self.ip_address rule.src = self.ip_address
@ -116,3 +120,18 @@ class Client(object):
rule.target = iptc.Target(rule, 'RETURN') rule.target = iptc.Target(rule, 'RETURN')
self.chain.insert_rule(rule) self.chain.insert_rule(rule)
@property
def ip_address(self):
return str(self._ip_address.ip)
@ip_address.setter
def ip_address(self, value):
if isinstance(value, str):
self._ip_address = ipaddress.IPv4Interface(value)
elif isinstance(value, ipaddress.IPv4Interface):
self._ip_address = value
else:
raise ValueError('Cannot set invalid value')

View file

@ -42,9 +42,13 @@ class StoragePostgres(object):
def get_client(self, ip_address, protocol): def get_client(self, ip_address, protocol):
"""
Expects an ipaddress.IPv4Interface as ip_address argument.
"""
self.cur.execute( self.cur.execute(
'select * from client where ip_address=%s and protocol=%s', 'select * from client where ip_address=%s and protocol=%s',
(Inet(ip_address), protocol, ) (ip_address, protocol, )
) )
return self.cur.fetchone() return self.cur.fetchone()
@ -63,7 +67,7 @@ class StoragePostgres(object):
( (
client.client_id, client.client_id,
client.created, client.created,
client.ip_address, client._ip_address,
client.protocol, client.protocol,
client.enabled, client.enabled,
client.last_packets, client.last_packets,