Add more sanity tests on addresses
This commit is contained in:
parent
864b8f237a
commit
3131edf1a4
|
@ -1,17 +1,29 @@
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
|
import re
|
||||||
import struct
|
import struct
|
||||||
|
import sys
|
||||||
from sha3 import keccak_256
|
from sha3 import keccak_256
|
||||||
|
|
||||||
from . import base58
|
from . import base58
|
||||||
|
|
||||||
|
_ADDR_REGEX = re.compile(r'^[123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz]{95}$')
|
||||||
|
_IADDR_REGEX = re.compile(r'^[123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz]{106}$')
|
||||||
|
|
||||||
|
if sys.version_info < (3,):
|
||||||
|
_integer_types = (int, long,)
|
||||||
|
else:
|
||||||
|
_integer_types = (int,)
|
||||||
|
|
||||||
|
|
||||||
class Address(object):
|
class Address(object):
|
||||||
_valid_netbytes = (18, 53)
|
_valid_netbytes = (18, 53)
|
||||||
# NOTE: _valid_netbytes order is (real, testnet)
|
# NOTE: _valid_netbytes order is (real, testnet)
|
||||||
|
|
||||||
def __init__(self, address):
|
def __init__(self, address):
|
||||||
address = str(address)
|
address = str(address)
|
||||||
if len(address) != 95:
|
if not _ADDR_REGEX.match(address):
|
||||||
raise ValueError("Address must be 95 characters long, is %d" % len(address))
|
raise ValueError("Address must be 95 characters long base58-encoded string, "
|
||||||
|
"is {addr} ({len} chars length)".format(addr=address, len=len(address)))
|
||||||
self._decode(address)
|
self._decode(address)
|
||||||
|
|
||||||
def _decode(self, address):
|
def _decode(self, address):
|
||||||
|
@ -36,8 +48,12 @@ class Address(object):
|
||||||
def with_payment_id(self, payment_id=0):
|
def with_payment_id(self, payment_id=0):
|
||||||
if isinstance(payment_id, (bytes, str)):
|
if isinstance(payment_id, (bytes, str)):
|
||||||
payment_id = int(payment_id, 16)
|
payment_id = int(payment_id, 16)
|
||||||
elif not isinstance(payment_id, int):
|
elif not isinstance(payment_id, _integer_types):
|
||||||
raise TypeError("payment_id must be either int or hexadecimal str or bytes")
|
raise TypeError("payment_id must be either int or hexadecimal str or bytes, "
|
||||||
|
"is %r" % payment_id)
|
||||||
|
if payment_id.bit_length() > 64:
|
||||||
|
raise TypeError("Integrated payment_id cannot have more than 64 bits, "
|
||||||
|
"has %d" % payment_id.bit_length())
|
||||||
prefix = 54 if self.is_testnet() else 19
|
prefix = 54 if self.is_testnet() else 19
|
||||||
data = bytearray([prefix]) + self._decoded[1:65] + struct.pack('>Q', payment_id)
|
data = bytearray([prefix]) + self._decoded[1:65] + struct.pack('>Q', payment_id)
|
||||||
checksum = bytearray(keccak_256(data).digest()[:4])
|
checksum = bytearray(keccak_256(data).digest()[:4])
|
||||||
|
@ -58,7 +74,7 @@ class SubAddress(Address):
|
||||||
_valid_netbytes = (42, 63)
|
_valid_netbytes = (42, 63)
|
||||||
|
|
||||||
def with_payment_id(self):
|
def with_payment_id(self):
|
||||||
raise TypeError("SubAddress cannot be merged with payment ID into IntegratedAddress")
|
raise TypeError("SubAddress cannot be integrated with payment ID")
|
||||||
|
|
||||||
|
|
||||||
class IntegratedAddress(Address):
|
class IntegratedAddress(Address):
|
||||||
|
@ -66,8 +82,9 @@ class IntegratedAddress(Address):
|
||||||
|
|
||||||
def __init__(self, address):
|
def __init__(self, address):
|
||||||
address = str(address)
|
address = str(address)
|
||||||
if len(address) != 106:
|
if not _IADDR_REGEX.match(address):
|
||||||
raise ValueError("Integrated address must be 106 characters long, is %d" % len(address))
|
raise ValueError("Integrated address must be 106 characters long base58-encoded string, "
|
||||||
|
"is {addr} ({len} chars length)".format(addr=address, len=len(address)))
|
||||||
self._decode(address)
|
self._decode(address)
|
||||||
|
|
||||||
def get_payment_id(self):
|
def get_payment_id(self):
|
||||||
|
@ -82,7 +99,7 @@ class IntegratedAddress(Address):
|
||||||
|
|
||||||
def address(addr):
|
def address(addr):
|
||||||
addr = str(addr)
|
addr = str(addr)
|
||||||
if len(addr) == 95:
|
if _ADDR_REGEX.match(addr):
|
||||||
netbyte = bytearray(unhexlify(base58.decode(addr)))[0]
|
netbyte = bytearray(unhexlify(base58.decode(addr)))[0]
|
||||||
if netbyte in Address._valid_netbytes:
|
if netbyte in Address._valid_netbytes:
|
||||||
return Address(addr)
|
return Address(addr)
|
||||||
|
@ -93,6 +110,7 @@ def address(addr):
|
||||||
allowed=", ".join(map(
|
allowed=", ".join(map(
|
||||||
lambda b: '%02x' % b,
|
lambda b: '%02x' % b,
|
||||||
sorted(Address._valid_netbytes + SubAddress._valid_netbytes)))))
|
sorted(Address._valid_netbytes + SubAddress._valid_netbytes)))))
|
||||||
elif len(addr) == 106:
|
elif _IADDR_REGEX.match(addr):
|
||||||
return IntegratedAddress(addr)
|
return IntegratedAddress(addr)
|
||||||
raise ValueError("Address must be either 95 or 106 characters long")
|
raise ValueError("Address must be either 95 or 106 characters long base58-encoded string, "
|
||||||
|
"is {addr} ({len} chars length)".format(addr=address, len=len(address)))
|
||||||
|
|
|
@ -75,6 +75,9 @@ class Tests(object):
|
||||||
def test_invalid(self):
|
def test_invalid(self):
|
||||||
self.assertRaises(ValueError, Address, self.addr_invalid)
|
self.assertRaises(ValueError, Address, self.addr_invalid)
|
||||||
self.assertRaises(ValueError, Address, self.iaddr_invalid)
|
self.assertRaises(ValueError, Address, self.iaddr_invalid)
|
||||||
|
a = Address(self.addr)
|
||||||
|
self.assertRaises(TypeError, a.with_payment_id, 2**64+1)
|
||||||
|
self.assertRaises(TypeError, a.with_payment_id, "%x" % (2**64+1))
|
||||||
|
|
||||||
def test_type_mismatch(self):
|
def test_type_mismatch(self):
|
||||||
self.assertRaises(ValueError, Address, self.iaddr)
|
self.assertRaises(ValueError, Address, self.iaddr)
|
||||||
|
|
Loading…
Reference in New Issue