From 666aed038f40590f35579959f7d00fbc4ba380b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Sa=C5=82aban?= Date: Sat, 6 Jan 2018 23:12:42 +0100 Subject: [PATCH] Refactor Payment ID as separate type --- monero/account.py | 6 ++++-- monero/address.py | 11 +++++----- monero/backends/jsonrpc.py | 15 ++++++------- monero/numbers.py | 44 ++++++++++++++++++++++++++++++++------ tests/numbers.py | 24 ++++++++++++++++++--- utils/walletdump.py | 4 ++-- 6 files changed, 75 insertions(+), 29 deletions(-) diff --git a/monero/account.py b/monero/account.py index d48d7c5..0f3ffdd 100644 --- a/monero/account.py +++ b/monero/account.py @@ -33,15 +33,16 @@ class Account(object): def get_transactions_out(self): return self._backend.get_transactions_out(account=self.index) - def transfer(self, address, amount, priority=prio.NORMAL, mixin=5, unlock_time=0): + def transfer(self, address, amount, priority=prio.NORMAL, mixin=5, payment_id=0, unlock_time=0): return self._backend.transfer( [(address, amount)], priority, mixin, + payment_id, unlock_time, account=self.index) - def transfer_multiple(self, destinations, priority=prio.NORMAL, mixin=5, unlock_time=0): + def transfer_multiple(self, destinations, priority=prio.NORMAL, mixin=5, payment_id=0, unlock_time=0): """ destinations = [(address, amount), ...] """ @@ -49,5 +50,6 @@ class Account(object): destinations, priority, mixin, + payment_id, unlock_time, account=self.index) diff --git a/monero/address.py b/monero/address.py index ec4c931..df78b2c 100644 --- a/monero/address.py +++ b/monero/address.py @@ -41,12 +41,11 @@ class Address(object): return hexlify(self._decoded[1:33]).decode() def with_payment_id(self, payment_id=0): - payment_id = numbers.payment_id_as_int(payment_id) - if payment_id.bit_length() > 64: - raise TypeError("Integrated payment_id cannot have more than 64 bits, " - "has %d" % payment_id.bit_length()) + payment_id = numbers.PaymentID(payment_id) + if not payment_id.is_short(): + raise TypeError("Integrated payment ID {0} has more than 64 bits".format(payment_id)) prefix = 54 if self.is_testnet() else 19 - data = bytearray([prefix]) + self._decoded[1:65] + struct.pack('>Q', payment_id) + data = bytearray([prefix]) + self._decoded[1:65] + struct.pack('>Q', int(payment_id)) checksum = bytearray(keccak_256(data).digest()[:4]) return IntegratedAddress(base58.encode(hexlify(data + checksum))) @@ -79,7 +78,7 @@ class IntegratedAddress(Address): self._decode(address) def get_payment_id(self): - return hexlify(self._decoded[65:-4]).decode() + return numbers.PaymentID(hexlify(self._decoded[65:-4]).decode()) def get_base_address(self): prefix = 53 if self.is_testnet() else 18 diff --git a/monero/backends/jsonrpc.py b/monero/backends/jsonrpc.py index 3bc4255..9c7368c 100644 --- a/monero/backends/jsonrpc.py +++ b/monero/backends/jsonrpc.py @@ -8,7 +8,7 @@ import requests from .. import exceptions from ..account import Account from ..address import address, Address -from ..numbers import from_atomic, to_atomic, payment_id_as_int +from ..numbers import from_atomic, to_atomic, PaymentID from ..transaction import Transaction, Payment, Transfer _log = logging.getLogger(__name__) @@ -62,16 +62,12 @@ class JSONRPCWallet(object): return (from_atomic(_balance['balance']), from_atomic(_balance['unlocked_balance'])) def get_payments(self, account=0, payment_id=0): - payment_id = payment_id_as_int(payment_id) + payment_id = PaymentID(payment_id) _log.debug("Getting payments for account {acc}, payment_id {pid}".format( acc=account, pid=payment_id)) - if payment_id.bit_length() > 64: - _pid = '{:064x}'.format(payment_id) - else: - _pid = '{:016x}'.format(payment_id) _payments = self.raw_request('get_payments', { 'account_index': account, - 'payment_id': _pid}) + 'payment_id': str(payment_id)}) pmts = [] for tx in _payments['payments']: data = self._tx2dict(tx) @@ -97,7 +93,7 @@ class JSONRPCWallet(object): 'amount': from_atomic(tx['amount']), 'fee': from_atomic(tx['fee']) if 'fee' in tx else None, 'height': tx.get('height', tx.get('block_height')), - 'payment_id': tx['payment_id'], + 'payment_id': PaymentID(tx.get('payment_id', 0)), 'note': tx.get('note'), # NOTE: address will be resolved only after PR#3010 has been merged to Monero 'local_address': address(tx['address']) if 'address' in tx else None, @@ -105,7 +101,7 @@ class JSONRPCWallet(object): 'blob': tx.get('blob', None), } - def transfer(self, destinations, priority, mixin, unlock_time, account=0): + def transfer(self, destinations, priority, mixin, payment_id, unlock_time, account=0): data = { 'account_index': account, 'destinations': list(map( @@ -114,6 +110,7 @@ class JSONRPCWallet(object): 'mixin': mixin, 'priority': priority, 'unlock_time': 0, + 'payment_id': payment_id, 'get_tx_keys': True, 'get_tx_hex': True, 'new_algorithm': True, diff --git a/monero/numbers.py b/monero/numbers.py index 6a66d4b..f7b7bc2 100644 --- a/monero/numbers.py +++ b/monero/numbers.py @@ -5,8 +5,10 @@ PICONERO = Decimal('0.000000000001') if sys.version_info < (3,): _integer_types = (int, long,) + _str_types = (str, bytes, unicode) else: _integer_types = (int,) + _str_types = (str, bytes) def to_atomic(amount): @@ -21,10 +23,38 @@ def as_monero(amount): """Return the amount rounded to maximal Monero precision.""" return Decimal(amount).quantize(PICONERO) -def payment_id_as_int(payment_id): - if isinstance(payment_id, (bytes, str)): - payment_id = int(payment_id, 16) - elif not isinstance(payment_id, _integer_types): - raise TypeError("payment_id must be either int or hexadecimal str or bytes, " - "is %r" % payment_id) - return payment_id + +class PaymentID(object): + _payment_id = None + + def __init__(self, payment_id): + if isinstance(payment_id, PaymentID): + payment_id = int(payment_id) + if isinstance(payment_id, _str_types): + payment_id = int(payment_id, 16) + elif not isinstance(payment_id, _integer_types): + raise TypeError("payment_id must be either int or hexadecimal str or bytes, " + "is %r" % payment_id) + self._payment_id = payment_id + + def is_short(self): + """Returns True if payment ID is short enough to be included + in Integrated Address.""" + return self._payment_id.bit_length() <= 64 + + def __repr__(self): + if self.is_short(): + return "{:016x}".format(self._payment_id) + return "{:064x}".format(self._payment_id) + + def __int__(self): + return self._payment_id + + def __eq__(self, other): + if isinstance(other, PaymentID): + return int(self) == int(other) + elif isinstance(other, _integer_types): + return int(self) == other + elif isinstance(other, _str_types): + return str(self) == other + return super() diff --git a/tests/numbers.py b/tests/numbers.py index 409111c..8505e60 100644 --- a/tests/numbers.py +++ b/tests/numbers.py @@ -1,7 +1,7 @@ from decimal import Decimal import unittest -from monero.numbers import to_atomic, from_atomic, payment_id_as_int +from monero.numbers import to_atomic, from_atomic, PaymentID class NumbersTestCase(unittest.TestCase): def test_simple_numbers(self): @@ -16,5 +16,23 @@ class NumbersTestCase(unittest.TestCase): self.assertEqual(to_atomic(Decimal('1.0000000000004')), 1000000000000) def test_payment_id(self): - self.assertEqual(payment_id_as_int('0'), 0) - self.assertEqual(payment_id_as_int('abcdef'), 0xabcdef) + pid = PaymentID('0') + self.assertTrue(pid.is_short()) + self.assertEqual(pid, 0) + self.assertEqual(pid, '0000000000000000') + self.assertEqual(PaymentID(pid), pid) + pid = PaymentID('abcdef') + self.assertTrue(pid.is_short()) + self.assertEqual(pid, 0xabcdef) + self.assertEqual(pid, '0000000000abcdef') + self.assertEqual(PaymentID(pid), pid) + pid = PaymentID('1234567812345678') + self.assertTrue(pid.is_short()) + self.assertEqual(pid, 0x1234567812345678) + self.assertEqual(pid, '1234567812345678') + self.assertEqual(PaymentID(pid), pid) + pid = PaymentID('a1234567812345678') + self.assertFalse(pid.is_short()) + self.assertEqual(pid, 0xa1234567812345678) + self.assertEqual(pid, '00000000000000000000000000000000000000000000000a1234567812345678') + self.assertEqual(PaymentID(pid), pid) diff --git a/utils/walletdump.py b/utils/walletdump.py index e761628..3e3cf21 100644 --- a/utils/walletdump.py +++ b/utils/walletdump.py @@ -31,10 +31,10 @@ def get_wallet(): return Wallet(JSONRPCWallet(**args.daemon_url)) _TXHDR = "timestamp height id/hash " \ - " amount fee payment_id {dir}" + " amount fee {dir:95s} payment_id" def tx2str(tx): - return "{time} {height} {hash} {amount:17.12f} {fee:13.12f} {payment_id} {addr}".format( + return "{time} {height} {hash} {amount:17.12f} {fee:13.12f} {addr} {payment_id}".format( time=tx.timestamp.strftime("%d-%m-%y %H:%M:%S") if getattr(tx, 'timestamp', None) else None, height=tx.height, hash=tx.hash,