Rearrange address class hierarchy

This commit is contained in:
Michał Sałaban 2018-12-15 03:53:19 +00:00
parent 5ee551fa4a
commit bab0099419
2 changed files with 43 additions and 38 deletions

View File

@ -9,18 +9,8 @@ from . import numbers
_ADDR_REGEX = re.compile(r'^[123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz]{95}$')
_IADDR_REGEX = re.compile(r'^[123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz]{106}$')
class Address(object):
"""Monero address.
Address of this class is the master address for a :class:`Wallet <monero.wallet.Wallet>`.
:param address: a Monero address as string-like object
:param label: a label for the address (defaults to `None`)
"""
class BaseAddress(object):
label = None
_valid_netbytes = (18, 53, 24)
# NOTE: _valid_netbytes order is (mainnet, testnet, stagenet)
def __init__(self, addr, label=None):
addr = str(addr)
@ -30,16 +20,6 @@ class Address(object):
self._decode(addr)
self.label = label or self.label
def _decode(self, address):
self._decoded = bytearray(unhexlify(base58.decode(address)))
checksum = self._decoded[-4:]
if checksum != keccak_256(self._decoded[:-4]).digest()[:4]:
raise ValueError("Invalid checksum in address {}".format(address))
if self._decoded[0] not in self._valid_netbytes:
raise ValueError("Invalid address netbyte {nb}. Allowed values are: {allowed}".format(
nb=self._decoded[0],
allowed=", ".join(map(lambda b: '%02x' % b, self._valid_netbytes))))
def is_mainnet(self):
"""Returns `True` if the address belongs to mainnet.
@ -61,6 +41,41 @@ class Address(object):
"""
return self._decoded[0] == self._valid_netbytes[2]
def _decode(self, address):
self._decoded = bytearray(unhexlify(base58.decode(address)))
checksum = self._decoded[-4:]
if checksum != keccak_256(self._decoded[:-4]).digest()[:4]:
raise ValueError("Invalid checksum in address {}".format(address))
if self._decoded[0] not in self._valid_netbytes:
raise ValueError("Invalid address netbyte {nb}. Allowed values are: {allowed}".format(
nb=self._decoded[0],
allowed=", ".join(map(lambda b: '%02x' % b, self._valid_netbytes))))
def __repr__(self):
return base58.encode(hexlify(self._decoded))
def __eq__(self, other):
if isinstance(other, BaseAddress):
return str(self) == str(other)
if isinstance(other, str):
return str(self) == other
return super(BaseAddress, self).__eq__(other)
def __hash__(self):
return hash(str(self))
class Address(BaseAddress):
"""Monero address.
Address of this class is the master address for a :class:`Wallet <monero.wallet.Wallet>`.
:param address: a Monero address as string-like object
:param label: a label for the address (defaults to `None`)
"""
_valid_netbytes = (18, 53, 24)
# NOTE: _valid_netbytes order is (mainnet, testnet, stagenet)
def view_key(self):
"""Returns public view key.
@ -92,18 +107,8 @@ class Address(object):
checksum = bytearray(keccak_256(data).digest()[:4])
return IntegratedAddress(base58.encode(hexlify(data + checksum)))
def __repr__(self):
return base58.encode(hexlify(self._decoded))
def __eq__(self, other):
if isinstance(other, Address):
return str(self) == str(other)
if isinstance(other, str):
return str(self) == other
return super(Address, self).__eq__(other)
class SubAddress(Address):
class SubAddress(BaseAddress):
"""Monero subaddress.
Any type of address which is not the master one for a wallet.

View File

@ -8,7 +8,7 @@ except ImportError:
import warnings
from monero.wallet import Wallet
from monero.address import Address
from monero.address import BaseAddress, Address
from monero.seed import Seed
from monero.transaction import IncomingPayment, OutgoingPayment, Transaction
from monero.backends.jsonrpc import JSONRPCWallet
@ -258,7 +258,7 @@ class SubaddrWalletTestCase(unittest.TestCase):
self.assertEqual(len(list(pay_in)), 9)
for pmt in pay_in:
self.assertIsInstance(pmt, IncomingPayment)
self.assertIsInstance(pmt.local_address, Address)
self.assertIsInstance(pmt.local_address, BaseAddress)
self.assertIsInstance(pmt.amount, Decimal)
self.assertIsInstance(pmt.transaction, Transaction)
self.assertIsInstance(pmt.transaction.fee, Decimal)
@ -409,7 +409,7 @@ class SubaddrWalletTestCase(unittest.TestCase):
self.assertEqual(len(list(pay_in)), 11)
for pmt in pay_in:
self.assertIsInstance(pmt, IncomingPayment)
self.assertIsInstance(pmt.local_address, Address)
self.assertIsInstance(pmt.local_address, BaseAddress)
self.assertIsInstance(pmt.amount, Decimal)
self.assertIsInstance(pmt.transaction, Transaction)
self.assertIsInstance(pmt.transaction.fee, Decimal)
@ -452,7 +452,7 @@ class SubaddrWalletTestCase(unittest.TestCase):
self.assertEqual(len(list(pay_in)), 2)
for pmt in pay_in:
self.assertIsInstance(pmt, IncomingPayment)
self.assertIsInstance(pmt.local_address, Address)
self.assertIsInstance(pmt.local_address, BaseAddress)
self.assertIsInstance(pmt.amount, Decimal)
self.assertIsInstance(pmt.transaction, Transaction)
self.assertIsInstance(pmt.transaction.fee, Decimal)
@ -493,7 +493,7 @@ class SubaddrWalletTestCase(unittest.TestCase):
self.assertEqual(len(list(pay_in)), 3)
for pmt in pay_in:
self.assertIsInstance(pmt, IncomingPayment)
self.assertIsInstance(pmt.local_address, Address)
self.assertIsInstance(pmt.local_address, BaseAddress)
self.assertIsInstance(pmt.amount, Decimal)
self.assertIsInstance(pmt.transaction, Transaction)
# Fee is not returned by this RPC method!
@ -537,7 +537,7 @@ class SubaddrWalletTestCase(unittest.TestCase):
self.assertEqual(len(list(pay_in)), 4)
for pmt in pay_in:
self.assertIsInstance(pmt, IncomingPayment)
self.assertIsInstance(pmt.local_address, Address)
self.assertIsInstance(pmt.local_address, BaseAddress)
self.assertIsInstance(pmt.amount, Decimal)
self.assertIsInstance(pmt.transaction, Transaction)
# Fee is not returned by this RPC method!