better handling of IPv4Interface across the board.

This commit is contained in:
Stefan Midjich 2017-03-07 17:36:15 +01:00
parent c1ee2b0813
commit be6b2bef77
2 changed files with 28 additions and 5 deletions

View file

@ -2,6 +2,7 @@
Handles "clients" in IPtables for captive portal.
"""
import ipaddress
from uuid import uuid4
from datetime import datetime
@ -33,7 +34,7 @@ class Client(object):
if self.ip_address and self.protocol:
client_data = self.storage.get_client(
self.ip_address,
self._ip_address,
self.protocol
)
@ -82,7 +83,7 @@ class Client(object):
def remove_rule(self):
rule = self.find_rule(self.ip_address, self.protocol)
rule = self.find_rule(self._ip_address, self.protocol)
if rule:
self.chain.delete_rule(rule)
@ -92,6 +93,9 @@ class Client(object):
Takes an ipaddress.IPv4Interface object as ip_address argument.
"""
if not isinstance(ip_address, ipaddress.IPv4Interface):
raise ValueError('Invalid argument type')
for rule in self.chain.rules:
src_ip = rule.src
@ -108,7 +112,7 @@ class Client(object):
def commit_rule(self):
rule = self.find_rule(self.ip_address, self.protocol)
rule = self.find_rule(self._ip_address, self.protocol)
if not rule:
rule = iptc.Rule()
rule.src = self.ip_address
@ -116,3 +120,18 @@ class Client(object):
rule.target = iptc.Target(rule, 'RETURN')
self.chain.insert_rule(rule)
@property
def ip_address(self):
return str(self._ip_address.ip)
@ip_address.setter
def ip_address(self, value):
if isinstance(value, str):
self._ip_address = ipaddress.IPv4Interface(value)
elif isinstance(value, ipaddress.IPv4Interface):
self._ip_address = value
else:
raise ValueError('Cannot set invalid value')

View file

@ -42,9 +42,13 @@ class StoragePostgres(object):
def get_client(self, ip_address, protocol):
"""
Expects an ipaddress.IPv4Interface as ip_address argument.
"""
self.cur.execute(
'select * from client where ip_address=%s and protocol=%s',
(Inet(ip_address), protocol, )
(ip_address, protocol, )
)
return self.cur.fetchone()
@ -63,7 +67,7 @@ class StoragePostgres(object):
(
client.client_id,
client.created,
client.ip_address,
client._ip_address,
client.protocol,
client.enabled,
client.last_packets,