Add methods for account and address creation; Retrieve labels

This commit is contained in:
Michał Sałaban 2018-01-11 23:17:34 +01:00
parent 088b7aac9e
commit a1849cab8d
6 changed files with 39 additions and 8 deletions

View file

@ -24,6 +24,9 @@ class Account(object):
def get_addresses(self):
return self._backend.get_addresses(account=self.index)
def new_address(self, label=None):
return self._backend.new_address(account=self.index, label=label)
def get_payments(self, payment_id=None):
return self._backend.get_payments(account=self.index, payment_id=payment_id)

View file

@ -11,15 +11,17 @@ _IADDR_REGEX = re.compile(r'^[123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqr
class Address(object):
label = None
_valid_netbytes = (18, 53)
# NOTE: _valid_netbytes order is (real, testnet)
def __init__(self, address):
def __init__(self, address, label=None):
address = str(address)
if not _ADDR_REGEX.match(address):
raise ValueError("Address must be 95 characters long base58-encoded string, "
"is {addr} ({len} chars length)".format(addr=address, len=len(address)))
self._decode(address)
self.label = label or self.label
def _decode(self, address):
self._decoded = bytearray(unhexlify(base58.decode(address)))
@ -87,14 +89,14 @@ class IntegratedAddress(Address):
return Address(base58.encode(hexlify(data + checksum)))
def address(addr):
def address(addr, label=None):
addr = str(addr)
if _ADDR_REGEX.match(addr):
netbyte = bytearray(unhexlify(base58.decode(addr)))[0]
if netbyte in Address._valid_netbytes:
return Address(addr)
return Address(addr, label=label)
elif netbyte in SubAddress._valid_netbytes:
return SubAddress(addr)
return SubAddress(addr, label=label)
raise ValueError("Invalid address netbyte {nb}. Allowed values are: {allowed}".format(
nb=hexlify(chr(netbyte)),
allowed=", ".join(map(

View file

@ -7,7 +7,7 @@ import requests
from .. import exceptions
from ..account import Account
from ..address import address, Address
from ..address import address, Address, SubAddress
from ..numbers import from_atomic, to_atomic, PaymentID
from ..transaction import Transaction, Payment, Transfer
@ -46,6 +46,10 @@ class JSONRPCWallet(object):
idx += 1
return accounts
def new_account(self, label=None):
_account = self.raw_request('create_account', {'label': label})
return Account(self, _account['account_index']), SubAddress(_account['address'])
def get_addresses(self, account=0):
_addresses = self.raw_request('getaddress', {'account_index': account})
if 'addresses' not in _addresses:
@ -54,9 +58,16 @@ class JSONRPCWallet(object):
return [Address(_addresses['address'])]
addresses = [None] * (max(map(operator.itemgetter('address_index'), _addresses['addresses'])) + 1)
for _addr in _addresses['addresses']:
addresses[_addr['address_index']] = address(_addr['address'])
addresses[_addr['address_index']] = address(
_addr['address'],
label=_addr.get('label', None))
return addresses
def new_address(self, account=0, label=None):
_address = self.raw_request(
'create_address', {'account_index': account, 'label': label})
return SubAddress(_address['address'])
def get_balances(self, account=0):
_balance = self.raw_request('getbalance', {'account_index': account})
return (from_atomic(_balance['balance']), from_atomic(_balance['unlocked_balance']))

View file

@ -21,6 +21,12 @@ class Wallet(object):
self.accounts.append(_acc)
idx += 1
def new_account(self, label=None):
acc, addr = self._backend.new_account(label=label)
assert acc.index == len(self.accounts)
self.accounts.append(acc)
return acc
# Following methods operate on default account (index=0)
def get_balances(self):
return self.accounts[0].get_balances()
@ -31,6 +37,9 @@ class Wallet(object):
def get_address(self, index=0):
return self.accounts[0].get_addresses()[0]
def new_address(self, label=None):
return self.accounts[0].new_address(label=label)
def get_payments(self, payment_id=None):
return self.accounts[0].get_payments(payment_id=payment_id)

View file

@ -116,6 +116,7 @@ class SubaddrWalletTestCase(unittest.TestCase):
self.assertEqual(
waddr,
'9vgV48wWAPTWik5QSUSoGYicdvvsbSNHrT9Arsx1XBTz6VrWPSgfmnUKSPZDMyX4Ms8R9TkhB4uFqK9s5LUBbV6YQN2Q9ag')
self.assertEqual(a0addr.label, 'Primary account')
self.assertEqual(len(self.wallet.accounts[0].get_addresses()), 8)
@patch('monero.backends.jsonrpc.requests.post')

View file

@ -43,11 +43,16 @@ def tx2str(tx):
payment_id=tx.payment_id,
addr=getattr(tx, 'local_address', None) or '')
def a2str(a):
return "{addr} {label}".format(
addr=a,
label=a.label or "")
w = get_wallet()
print(
"Master address: {addr}\n" \
"Balance: {total:16.12f} ({unlocked:16.12f} unlocked)".format(
addr=w.get_address(),
addr=a2str(w.get_address()),
total=w.get_balance(),
unlocked=w.get_balance(unlocked=True)))
@ -60,7 +65,7 @@ if len(w.accounts) > 1:
unlocked=acc.get_balance(unlocked=True)))
addresses = acc.get_addresses()
print("{num:2d} address(es):".format(num=len(addresses)))
print("\n".join(map(str, addresses)))
print("\n".join(map(a2str, addresses)))
ins = acc.get_transactions_in()
if ins:
print("\nIncoming transactions:")