From a184c3bea23a60a7993c31fee8205ef1a282df86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Sa=C5=82aban?= Date: Mon, 7 Oct 2019 15:10:54 +0200 Subject: [PATCH] Make Address accept bytes/str too --- monero/address.py | 6 +++--- tests/test_address.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/monero/address.py b/monero/address.py index 51521aa..e277b1a 100644 --- a/monero/address.py +++ b/monero/address.py @@ -20,7 +20,7 @@ class BaseAddress(object): label = None def __init__(self, addr, label=None): - addr = str(addr) + addr = addr.decode() if isinstance(addr, bytes) else str(addr) if not _ADDR_REGEX.match(addr): raise ValueError("Address must be 95 characters long base58-encoded string, " "is {addr} ({len} chars length)".format(addr=addr, len=len(addr))) @@ -155,7 +155,7 @@ class IntegratedAddress(Address): # NOTE: _valid_netbytes order is (mainnet, testnet, stagenet) def __init__(self, address): - address = str(address) + address = address.decode() if isinstance(address, bytes) else str(address) if not _IADDR_REGEX.match(address): raise ValueError("Integrated address must be 106 characters long base58-encoded string, " "is {addr} ({len} chars length)".format(addr=address, len=len(address))) @@ -186,7 +186,7 @@ def address(addr, label=None): :rtype: :class:`Address`, :class:`SubAddress` or :class:`IntegratedAddress` """ - addr = str(addr) + addr = addr.decode() if isinstance(addr, bytes) else str(addr) if _ADDR_REGEX.match(addr): netbyte = bytearray(unhexlify(base58.decode(addr)))[0] if netbyte in Address._valid_netbytes: diff --git a/tests/test_address.py b/tests/test_address.py index 12d0cc3..75ae76f 100644 --- a/tests/test_address.py +++ b/tests/test_address.py @@ -18,6 +18,10 @@ class Tests(object): self.assertEqual(a.spend_key(), self.psk) self.assertEqual(a.view_key(), self.pvk) self.assertEqual(hash(a), hash(self.addr)) + ba = Address(self.addr.encode()) + self.assertEqual(ba, a) + ba = address(self.addr.encode()) + self.assertEqual(ba, a) ia = IntegratedAddress(self.iaddr) self.assertEqual(ia.payment_id(), self.pid) @@ -26,10 +30,18 @@ class Tests(object): self.assertEqual(ia.spend_key(), self.psk) self.assertEqual(ia.view_key(), self.pvk) self.assertEqual(ia.base_address(), a) + ba = IntegratedAddress(self.iaddr.encode()) + self.assertEqual(ba, ia) + ba = address(self.iaddr.encode()) + self.assertEqual(ba, ia) sa = SubAddress(self.subaddr) self.assertEqual(str(sa), self.subaddr) self.assertEqual("{:s}".format(sa), self.subaddr) + ba = SubAddress(self.subaddr.encode()) + self.assertEqual(ba, sa) + ba = address(self.subaddr.encode()) + self.assertEqual(ba, sa) def test_payment_id(self): a = Address(self.addr)