From 3131edf1a458ac7cfae154699f5c78157cff79d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Sa=C5=82aban?= Date: Sun, 10 Dec 2017 02:54:09 +0100 Subject: [PATCH] Add more sanity tests on addresses --- monero/address.py | 38 ++++++++++++++++++++++++++++---------- tests/address.py | 3 +++ 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/monero/address.py b/monero/address.py index 68894d7..b04a9a3 100644 --- a/monero/address.py +++ b/monero/address.py @@ -1,17 +1,29 @@ from binascii import hexlify, unhexlify +import re import struct +import sys from sha3 import keccak_256 from . import base58 +_ADDR_REGEX = re.compile(r'^[123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz]{95}$') +_IADDR_REGEX = re.compile(r'^[123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz]{106}$') + +if sys.version_info < (3,): + _integer_types = (int, long,) +else: + _integer_types = (int,) + + class Address(object): _valid_netbytes = (18, 53) # NOTE: _valid_netbytes order is (real, testnet) def __init__(self, address): address = str(address) - if len(address) != 95: - raise ValueError("Address must be 95 characters long, is %d" % len(address)) + if not _ADDR_REGEX.match(address): + raise ValueError("Address must be 95 characters long base58-encoded string, " + "is {addr} ({len} chars length)".format(addr=address, len=len(address))) self._decode(address) def _decode(self, address): @@ -36,8 +48,12 @@ class Address(object): def with_payment_id(self, payment_id=0): if isinstance(payment_id, (bytes, str)): payment_id = int(payment_id, 16) - elif not isinstance(payment_id, int): - raise TypeError("payment_id must be either int or hexadecimal str or bytes") + elif not isinstance(payment_id, _integer_types): + raise TypeError("payment_id must be either int or hexadecimal str or bytes, " + "is %r" % payment_id) + if payment_id.bit_length() > 64: + raise TypeError("Integrated payment_id cannot have more than 64 bits, " + "has %d" % payment_id.bit_length()) prefix = 54 if self.is_testnet() else 19 data = bytearray([prefix]) + self._decoded[1:65] + struct.pack('>Q', payment_id) checksum = bytearray(keccak_256(data).digest()[:4]) @@ -58,7 +74,7 @@ class SubAddress(Address): _valid_netbytes = (42, 63) def with_payment_id(self): - raise TypeError("SubAddress cannot be merged with payment ID into IntegratedAddress") + raise TypeError("SubAddress cannot be integrated with payment ID") class IntegratedAddress(Address): @@ -66,8 +82,9 @@ class IntegratedAddress(Address): def __init__(self, address): address = str(address) - if len(address) != 106: - raise ValueError("Integrated address must be 106 characters long, is %d" % len(address)) + if not _IADDR_REGEX.match(address): + raise ValueError("Integrated address must be 106 characters long base58-encoded string, " + "is {addr} ({len} chars length)".format(addr=address, len=len(address))) self._decode(address) def get_payment_id(self): @@ -82,7 +99,7 @@ class IntegratedAddress(Address): def address(addr): addr = str(addr) - if len(addr) == 95: + if _ADDR_REGEX.match(addr): netbyte = bytearray(unhexlify(base58.decode(addr)))[0] if netbyte in Address._valid_netbytes: return Address(addr) @@ -93,6 +110,7 @@ def address(addr): allowed=", ".join(map( lambda b: '%02x' % b, sorted(Address._valid_netbytes + SubAddress._valid_netbytes))))) - elif len(addr) == 106: + elif _IADDR_REGEX.match(addr): return IntegratedAddress(addr) - raise ValueError("Address must be either 95 or 106 characters long") + raise ValueError("Address must be either 95 or 106 characters long base58-encoded string, " + "is {addr} ({len} chars length)".format(addr=address, len=len(address))) diff --git a/tests/address.py b/tests/address.py index 4103685..4604b30 100644 --- a/tests/address.py +++ b/tests/address.py @@ -75,6 +75,9 @@ class Tests(object): def test_invalid(self): self.assertRaises(ValueError, Address, self.addr_invalid) self.assertRaises(ValueError, Address, self.iaddr_invalid) + a = Address(self.addr) + self.assertRaises(TypeError, a.with_payment_id, 2**64+1) + self.assertRaises(TypeError, a.with_payment_id, "%x" % (2**64+1)) def test_type_mismatch(self): self.assertRaises(ValueError, Address, self.iaddr)