Rearrange address class hierarchy

This commit is contained in:
Michał Sałaban 2018-12-15 03:53:19 +00:00
parent 5ee551fa4a
commit bab0099419
2 changed files with 43 additions and 38 deletions

View file

@ -9,18 +9,8 @@ from . import numbers
_ADDR_REGEX = re.compile(r'^[123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz]{95}$') _ADDR_REGEX = re.compile(r'^[123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz]{95}$')
_IADDR_REGEX = re.compile(r'^[123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz]{106}$') _IADDR_REGEX = re.compile(r'^[123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz]{106}$')
class BaseAddress(object):
class Address(object):
"""Monero address.
Address of this class is the master address for a :class:`Wallet <monero.wallet.Wallet>`.
:param address: a Monero address as string-like object
:param label: a label for the address (defaults to `None`)
"""
label = None label = None
_valid_netbytes = (18, 53, 24)
# NOTE: _valid_netbytes order is (mainnet, testnet, stagenet)
def __init__(self, addr, label=None): def __init__(self, addr, label=None):
addr = str(addr) addr = str(addr)
@ -30,16 +20,6 @@ class Address(object):
self._decode(addr) self._decode(addr)
self.label = label or self.label self.label = label or self.label
def _decode(self, address):
self._decoded = bytearray(unhexlify(base58.decode(address)))
checksum = self._decoded[-4:]
if checksum != keccak_256(self._decoded[:-4]).digest()[:4]:
raise ValueError("Invalid checksum in address {}".format(address))
if self._decoded[0] not in self._valid_netbytes:
raise ValueError("Invalid address netbyte {nb}. Allowed values are: {allowed}".format(
nb=self._decoded[0],
allowed=", ".join(map(lambda b: '%02x' % b, self._valid_netbytes))))
def is_mainnet(self): def is_mainnet(self):
"""Returns `True` if the address belongs to mainnet. """Returns `True` if the address belongs to mainnet.
@ -61,6 +41,41 @@ class Address(object):
""" """
return self._decoded[0] == self._valid_netbytes[2] return self._decoded[0] == self._valid_netbytes[2]
def _decode(self, address):
self._decoded = bytearray(unhexlify(base58.decode(address)))
checksum = self._decoded[-4:]
if checksum != keccak_256(self._decoded[:-4]).digest()[:4]:
raise ValueError("Invalid checksum in address {}".format(address))
if self._decoded[0] not in self._valid_netbytes:
raise ValueError("Invalid address netbyte {nb}. Allowed values are: {allowed}".format(
nb=self._decoded[0],
allowed=", ".join(map(lambda b: '%02x' % b, self._valid_netbytes))))
def __repr__(self):
return base58.encode(hexlify(self._decoded))
def __eq__(self, other):
if isinstance(other, BaseAddress):
return str(self) == str(other)
if isinstance(other, str):
return str(self) == other
return super(BaseAddress, self).__eq__(other)
def __hash__(self):
return hash(str(self))
class Address(BaseAddress):
"""Monero address.
Address of this class is the master address for a :class:`Wallet <monero.wallet.Wallet>`.
:param address: a Monero address as string-like object
:param label: a label for the address (defaults to `None`)
"""
_valid_netbytes = (18, 53, 24)
# NOTE: _valid_netbytes order is (mainnet, testnet, stagenet)
def view_key(self): def view_key(self):
"""Returns public view key. """Returns public view key.
@ -92,18 +107,8 @@ class Address(object):
checksum = bytearray(keccak_256(data).digest()[:4]) checksum = bytearray(keccak_256(data).digest()[:4])
return IntegratedAddress(base58.encode(hexlify(data + checksum))) return IntegratedAddress(base58.encode(hexlify(data + checksum)))
def __repr__(self):
return base58.encode(hexlify(self._decoded))
def __eq__(self, other): class SubAddress(BaseAddress):
if isinstance(other, Address):
return str(self) == str(other)
if isinstance(other, str):
return str(self) == other
return super(Address, self).__eq__(other)
class SubAddress(Address):
"""Monero subaddress. """Monero subaddress.
Any type of address which is not the master one for a wallet. Any type of address which is not the master one for a wallet.

View file

@ -8,7 +8,7 @@ except ImportError:
import warnings import warnings
from monero.wallet import Wallet from monero.wallet import Wallet
from monero.address import Address from monero.address import BaseAddress, Address
from monero.seed import Seed from monero.seed import Seed
from monero.transaction import IncomingPayment, OutgoingPayment, Transaction from monero.transaction import IncomingPayment, OutgoingPayment, Transaction
from monero.backends.jsonrpc import JSONRPCWallet from monero.backends.jsonrpc import JSONRPCWallet
@ -258,7 +258,7 @@ class SubaddrWalletTestCase(unittest.TestCase):
self.assertEqual(len(list(pay_in)), 9) self.assertEqual(len(list(pay_in)), 9)
for pmt in pay_in: for pmt in pay_in:
self.assertIsInstance(pmt, IncomingPayment) self.assertIsInstance(pmt, IncomingPayment)
self.assertIsInstance(pmt.local_address, Address) self.assertIsInstance(pmt.local_address, BaseAddress)
self.assertIsInstance(pmt.amount, Decimal) self.assertIsInstance(pmt.amount, Decimal)
self.assertIsInstance(pmt.transaction, Transaction) self.assertIsInstance(pmt.transaction, Transaction)
self.assertIsInstance(pmt.transaction.fee, Decimal) self.assertIsInstance(pmt.transaction.fee, Decimal)
@ -409,7 +409,7 @@ class SubaddrWalletTestCase(unittest.TestCase):
self.assertEqual(len(list(pay_in)), 11) self.assertEqual(len(list(pay_in)), 11)
for pmt in pay_in: for pmt in pay_in:
self.assertIsInstance(pmt, IncomingPayment) self.assertIsInstance(pmt, IncomingPayment)
self.assertIsInstance(pmt.local_address, Address) self.assertIsInstance(pmt.local_address, BaseAddress)
self.assertIsInstance(pmt.amount, Decimal) self.assertIsInstance(pmt.amount, Decimal)
self.assertIsInstance(pmt.transaction, Transaction) self.assertIsInstance(pmt.transaction, Transaction)
self.assertIsInstance(pmt.transaction.fee, Decimal) self.assertIsInstance(pmt.transaction.fee, Decimal)
@ -452,7 +452,7 @@ class SubaddrWalletTestCase(unittest.TestCase):
self.assertEqual(len(list(pay_in)), 2) self.assertEqual(len(list(pay_in)), 2)
for pmt in pay_in: for pmt in pay_in:
self.assertIsInstance(pmt, IncomingPayment) self.assertIsInstance(pmt, IncomingPayment)
self.assertIsInstance(pmt.local_address, Address) self.assertIsInstance(pmt.local_address, BaseAddress)
self.assertIsInstance(pmt.amount, Decimal) self.assertIsInstance(pmt.amount, Decimal)
self.assertIsInstance(pmt.transaction, Transaction) self.assertIsInstance(pmt.transaction, Transaction)
self.assertIsInstance(pmt.transaction.fee, Decimal) self.assertIsInstance(pmt.transaction.fee, Decimal)
@ -493,7 +493,7 @@ class SubaddrWalletTestCase(unittest.TestCase):
self.assertEqual(len(list(pay_in)), 3) self.assertEqual(len(list(pay_in)), 3)
for pmt in pay_in: for pmt in pay_in:
self.assertIsInstance(pmt, IncomingPayment) self.assertIsInstance(pmt, IncomingPayment)
self.assertIsInstance(pmt.local_address, Address) self.assertIsInstance(pmt.local_address, BaseAddress)
self.assertIsInstance(pmt.amount, Decimal) self.assertIsInstance(pmt.amount, Decimal)
self.assertIsInstance(pmt.transaction, Transaction) self.assertIsInstance(pmt.transaction, Transaction)
# Fee is not returned by this RPC method! # Fee is not returned by this RPC method!
@ -537,7 +537,7 @@ class SubaddrWalletTestCase(unittest.TestCase):
self.assertEqual(len(list(pay_in)), 4) self.assertEqual(len(list(pay_in)), 4)
for pmt in pay_in: for pmt in pay_in:
self.assertIsInstance(pmt, IncomingPayment) self.assertIsInstance(pmt, IncomingPayment)
self.assertIsInstance(pmt.local_address, Address) self.assertIsInstance(pmt.local_address, BaseAddress)
self.assertIsInstance(pmt.amount, Decimal) self.assertIsInstance(pmt.amount, Decimal)
self.assertIsInstance(pmt.transaction, Transaction) self.assertIsInstance(pmt.transaction, Transaction)
# Fee is not returned by this RPC method! # Fee is not returned by this RPC method!