Add methods for account and address creation; Retrieve labels

This commit is contained in:
Michał Sałaban 2018-01-11 23:17:34 +01:00
parent 088b7aac9e
commit a1849cab8d
6 changed files with 39 additions and 8 deletions

View File

@ -24,6 +24,9 @@ class Account(object):
def get_addresses(self): def get_addresses(self):
return self._backend.get_addresses(account=self.index) return self._backend.get_addresses(account=self.index)
def new_address(self, label=None):
return self._backend.new_address(account=self.index, label=label)
def get_payments(self, payment_id=None): def get_payments(self, payment_id=None):
return self._backend.get_payments(account=self.index, payment_id=payment_id) return self._backend.get_payments(account=self.index, payment_id=payment_id)

View File

@ -11,15 +11,17 @@ _IADDR_REGEX = re.compile(r'^[123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqr
class Address(object): class Address(object):
label = None
_valid_netbytes = (18, 53) _valid_netbytes = (18, 53)
# NOTE: _valid_netbytes order is (real, testnet) # NOTE: _valid_netbytes order is (real, testnet)
def __init__(self, address): def __init__(self, address, label=None):
address = str(address) address = str(address)
if not _ADDR_REGEX.match(address): if not _ADDR_REGEX.match(address):
raise ValueError("Address must be 95 characters long base58-encoded string, " raise ValueError("Address must be 95 characters long base58-encoded string, "
"is {addr} ({len} chars length)".format(addr=address, len=len(address))) "is {addr} ({len} chars length)".format(addr=address, len=len(address)))
self._decode(address) self._decode(address)
self.label = label or self.label
def _decode(self, address): def _decode(self, address):
self._decoded = bytearray(unhexlify(base58.decode(address))) self._decoded = bytearray(unhexlify(base58.decode(address)))
@ -87,14 +89,14 @@ class IntegratedAddress(Address):
return Address(base58.encode(hexlify(data + checksum))) return Address(base58.encode(hexlify(data + checksum)))
def address(addr): def address(addr, label=None):
addr = str(addr) addr = str(addr)
if _ADDR_REGEX.match(addr): if _ADDR_REGEX.match(addr):
netbyte = bytearray(unhexlify(base58.decode(addr)))[0] netbyte = bytearray(unhexlify(base58.decode(addr)))[0]
if netbyte in Address._valid_netbytes: if netbyte in Address._valid_netbytes:
return Address(addr) return Address(addr, label=label)
elif netbyte in SubAddress._valid_netbytes: elif netbyte in SubAddress._valid_netbytes:
return SubAddress(addr) return SubAddress(addr, label=label)
raise ValueError("Invalid address netbyte {nb}. Allowed values are: {allowed}".format( raise ValueError("Invalid address netbyte {nb}. Allowed values are: {allowed}".format(
nb=hexlify(chr(netbyte)), nb=hexlify(chr(netbyte)),
allowed=", ".join(map( allowed=", ".join(map(

View File

@ -7,7 +7,7 @@ import requests
from .. import exceptions from .. import exceptions
from ..account import Account from ..account import Account
from ..address import address, Address from ..address import address, Address, SubAddress
from ..numbers import from_atomic, to_atomic, PaymentID from ..numbers import from_atomic, to_atomic, PaymentID
from ..transaction import Transaction, Payment, Transfer from ..transaction import Transaction, Payment, Transfer
@ -46,6 +46,10 @@ class JSONRPCWallet(object):
idx += 1 idx += 1
return accounts return accounts
def new_account(self, label=None):
_account = self.raw_request('create_account', {'label': label})
return Account(self, _account['account_index']), SubAddress(_account['address'])
def get_addresses(self, account=0): def get_addresses(self, account=0):
_addresses = self.raw_request('getaddress', {'account_index': account}) _addresses = self.raw_request('getaddress', {'account_index': account})
if 'addresses' not in _addresses: if 'addresses' not in _addresses:
@ -54,9 +58,16 @@ class JSONRPCWallet(object):
return [Address(_addresses['address'])] return [Address(_addresses['address'])]
addresses = [None] * (max(map(operator.itemgetter('address_index'), _addresses['addresses'])) + 1) addresses = [None] * (max(map(operator.itemgetter('address_index'), _addresses['addresses'])) + 1)
for _addr in _addresses['addresses']: for _addr in _addresses['addresses']:
addresses[_addr['address_index']] = address(_addr['address']) addresses[_addr['address_index']] = address(
_addr['address'],
label=_addr.get('label', None))
return addresses return addresses
def new_address(self, account=0, label=None):
_address = self.raw_request(
'create_address', {'account_index': account, 'label': label})
return SubAddress(_address['address'])
def get_balances(self, account=0): def get_balances(self, account=0):
_balance = self.raw_request('getbalance', {'account_index': account}) _balance = self.raw_request('getbalance', {'account_index': account})
return (from_atomic(_balance['balance']), from_atomic(_balance['unlocked_balance'])) return (from_atomic(_balance['balance']), from_atomic(_balance['unlocked_balance']))

View File

@ -21,6 +21,12 @@ class Wallet(object):
self.accounts.append(_acc) self.accounts.append(_acc)
idx += 1 idx += 1
def new_account(self, label=None):
acc, addr = self._backend.new_account(label=label)
assert acc.index == len(self.accounts)
self.accounts.append(acc)
return acc
# Following methods operate on default account (index=0) # Following methods operate on default account (index=0)
def get_balances(self): def get_balances(self):
return self.accounts[0].get_balances() return self.accounts[0].get_balances()
@ -31,6 +37,9 @@ class Wallet(object):
def get_address(self, index=0): def get_address(self, index=0):
return self.accounts[0].get_addresses()[0] return self.accounts[0].get_addresses()[0]
def new_address(self, label=None):
return self.accounts[0].new_address(label=label)
def get_payments(self, payment_id=None): def get_payments(self, payment_id=None):
return self.accounts[0].get_payments(payment_id=payment_id) return self.accounts[0].get_payments(payment_id=payment_id)

View File

@ -116,6 +116,7 @@ class SubaddrWalletTestCase(unittest.TestCase):
self.assertEqual( self.assertEqual(
waddr, waddr,
'9vgV48wWAPTWik5QSUSoGYicdvvsbSNHrT9Arsx1XBTz6VrWPSgfmnUKSPZDMyX4Ms8R9TkhB4uFqK9s5LUBbV6YQN2Q9ag') '9vgV48wWAPTWik5QSUSoGYicdvvsbSNHrT9Arsx1XBTz6VrWPSgfmnUKSPZDMyX4Ms8R9TkhB4uFqK9s5LUBbV6YQN2Q9ag')
self.assertEqual(a0addr.label, 'Primary account')
self.assertEqual(len(self.wallet.accounts[0].get_addresses()), 8) self.assertEqual(len(self.wallet.accounts[0].get_addresses()), 8)
@patch('monero.backends.jsonrpc.requests.post') @patch('monero.backends.jsonrpc.requests.post')

View File

@ -43,11 +43,16 @@ def tx2str(tx):
payment_id=tx.payment_id, payment_id=tx.payment_id,
addr=getattr(tx, 'local_address', None) or '') addr=getattr(tx, 'local_address', None) or '')
def a2str(a):
return "{addr} {label}".format(
addr=a,
label=a.label or "")
w = get_wallet() w = get_wallet()
print( print(
"Master address: {addr}\n" \ "Master address: {addr}\n" \
"Balance: {total:16.12f} ({unlocked:16.12f} unlocked)".format( "Balance: {total:16.12f} ({unlocked:16.12f} unlocked)".format(
addr=w.get_address(), addr=a2str(w.get_address()),
total=w.get_balance(), total=w.get_balance(),
unlocked=w.get_balance(unlocked=True))) unlocked=w.get_balance(unlocked=True)))
@ -60,7 +65,7 @@ if len(w.accounts) > 1:
unlocked=acc.get_balance(unlocked=True))) unlocked=acc.get_balance(unlocked=True)))
addresses = acc.get_addresses() addresses = acc.get_addresses()
print("{num:2d} address(es):".format(num=len(addresses))) print("{num:2d} address(es):".format(num=len(addresses)))
print("\n".join(map(str, addresses))) print("\n".join(map(a2str, addresses)))
ins = acc.get_transactions_in() ins = acc.get_transactions_in()
if ins: if ins:
print("\nIncoming transactions:") print("\nIncoming transactions:")