Refactor Payment ID as separate type

This commit is contained in:
Michał Sałaban 2018-01-06 23:12:42 +01:00
parent 1a4e41df4d
commit 666aed038f
6 changed files with 75 additions and 29 deletions

View File

@ -33,15 +33,16 @@ class Account(object):
def get_transactions_out(self): def get_transactions_out(self):
return self._backend.get_transactions_out(account=self.index) return self._backend.get_transactions_out(account=self.index)
def transfer(self, address, amount, priority=prio.NORMAL, mixin=5, unlock_time=0): def transfer(self, address, amount, priority=prio.NORMAL, mixin=5, payment_id=0, unlock_time=0):
return self._backend.transfer( return self._backend.transfer(
[(address, amount)], [(address, amount)],
priority, priority,
mixin, mixin,
payment_id,
unlock_time, unlock_time,
account=self.index) account=self.index)
def transfer_multiple(self, destinations, priority=prio.NORMAL, mixin=5, unlock_time=0): def transfer_multiple(self, destinations, priority=prio.NORMAL, mixin=5, payment_id=0, unlock_time=0):
""" """
destinations = [(address, amount), ...] destinations = [(address, amount), ...]
""" """
@ -49,5 +50,6 @@ class Account(object):
destinations, destinations,
priority, priority,
mixin, mixin,
payment_id,
unlock_time, unlock_time,
account=self.index) account=self.index)

View File

@ -41,12 +41,11 @@ class Address(object):
return hexlify(self._decoded[1:33]).decode() return hexlify(self._decoded[1:33]).decode()
def with_payment_id(self, payment_id=0): def with_payment_id(self, payment_id=0):
payment_id = numbers.payment_id_as_int(payment_id) payment_id = numbers.PaymentID(payment_id)
if payment_id.bit_length() > 64: if not payment_id.is_short():
raise TypeError("Integrated payment_id cannot have more than 64 bits, " raise TypeError("Integrated payment ID {0} has more than 64 bits".format(payment_id))
"has %d" % payment_id.bit_length())
prefix = 54 if self.is_testnet() else 19 prefix = 54 if self.is_testnet() else 19
data = bytearray([prefix]) + self._decoded[1:65] + struct.pack('>Q', payment_id) data = bytearray([prefix]) + self._decoded[1:65] + struct.pack('>Q', int(payment_id))
checksum = bytearray(keccak_256(data).digest()[:4]) checksum = bytearray(keccak_256(data).digest()[:4])
return IntegratedAddress(base58.encode(hexlify(data + checksum))) return IntegratedAddress(base58.encode(hexlify(data + checksum)))
@ -79,7 +78,7 @@ class IntegratedAddress(Address):
self._decode(address) self._decode(address)
def get_payment_id(self): def get_payment_id(self):
return hexlify(self._decoded[65:-4]).decode() return numbers.PaymentID(hexlify(self._decoded[65:-4]).decode())
def get_base_address(self): def get_base_address(self):
prefix = 53 if self.is_testnet() else 18 prefix = 53 if self.is_testnet() else 18

View File

@ -8,7 +8,7 @@ import requests
from .. import exceptions from .. import exceptions
from ..account import Account from ..account import Account
from ..address import address, Address from ..address import address, Address
from ..numbers import from_atomic, to_atomic, payment_id_as_int from ..numbers import from_atomic, to_atomic, PaymentID
from ..transaction import Transaction, Payment, Transfer from ..transaction import Transaction, Payment, Transfer
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -62,16 +62,12 @@ class JSONRPCWallet(object):
return (from_atomic(_balance['balance']), from_atomic(_balance['unlocked_balance'])) return (from_atomic(_balance['balance']), from_atomic(_balance['unlocked_balance']))
def get_payments(self, account=0, payment_id=0): def get_payments(self, account=0, payment_id=0):
payment_id = payment_id_as_int(payment_id) payment_id = PaymentID(payment_id)
_log.debug("Getting payments for account {acc}, payment_id {pid}".format( _log.debug("Getting payments for account {acc}, payment_id {pid}".format(
acc=account, pid=payment_id)) acc=account, pid=payment_id))
if payment_id.bit_length() > 64:
_pid = '{:064x}'.format(payment_id)
else:
_pid = '{:016x}'.format(payment_id)
_payments = self.raw_request('get_payments', { _payments = self.raw_request('get_payments', {
'account_index': account, 'account_index': account,
'payment_id': _pid}) 'payment_id': str(payment_id)})
pmts = [] pmts = []
for tx in _payments['payments']: for tx in _payments['payments']:
data = self._tx2dict(tx) data = self._tx2dict(tx)
@ -97,7 +93,7 @@ class JSONRPCWallet(object):
'amount': from_atomic(tx['amount']), 'amount': from_atomic(tx['amount']),
'fee': from_atomic(tx['fee']) if 'fee' in tx else None, 'fee': from_atomic(tx['fee']) if 'fee' in tx else None,
'height': tx.get('height', tx.get('block_height')), 'height': tx.get('height', tx.get('block_height')),
'payment_id': tx['payment_id'], 'payment_id': PaymentID(tx.get('payment_id', 0)),
'note': tx.get('note'), 'note': tx.get('note'),
# NOTE: address will be resolved only after PR#3010 has been merged to Monero # NOTE: address will be resolved only after PR#3010 has been merged to Monero
'local_address': address(tx['address']) if 'address' in tx else None, 'local_address': address(tx['address']) if 'address' in tx else None,
@ -105,7 +101,7 @@ class JSONRPCWallet(object):
'blob': tx.get('blob', None), 'blob': tx.get('blob', None),
} }
def transfer(self, destinations, priority, mixin, unlock_time, account=0): def transfer(self, destinations, priority, mixin, payment_id, unlock_time, account=0):
data = { data = {
'account_index': account, 'account_index': account,
'destinations': list(map( 'destinations': list(map(
@ -114,6 +110,7 @@ class JSONRPCWallet(object):
'mixin': mixin, 'mixin': mixin,
'priority': priority, 'priority': priority,
'unlock_time': 0, 'unlock_time': 0,
'payment_id': payment_id,
'get_tx_keys': True, 'get_tx_keys': True,
'get_tx_hex': True, 'get_tx_hex': True,
'new_algorithm': True, 'new_algorithm': True,

View File

@ -5,8 +5,10 @@ PICONERO = Decimal('0.000000000001')
if sys.version_info < (3,): if sys.version_info < (3,):
_integer_types = (int, long,) _integer_types = (int, long,)
_str_types = (str, bytes, unicode)
else: else:
_integer_types = (int,) _integer_types = (int,)
_str_types = (str, bytes)
def to_atomic(amount): def to_atomic(amount):
@ -21,10 +23,38 @@ def as_monero(amount):
"""Return the amount rounded to maximal Monero precision.""" """Return the amount rounded to maximal Monero precision."""
return Decimal(amount).quantize(PICONERO) return Decimal(amount).quantize(PICONERO)
def payment_id_as_int(payment_id):
if isinstance(payment_id, (bytes, str)): class PaymentID(object):
payment_id = int(payment_id, 16) _payment_id = None
elif not isinstance(payment_id, _integer_types):
raise TypeError("payment_id must be either int or hexadecimal str or bytes, " def __init__(self, payment_id):
"is %r" % payment_id) if isinstance(payment_id, PaymentID):
return payment_id payment_id = int(payment_id)
if isinstance(payment_id, _str_types):
payment_id = int(payment_id, 16)
elif not isinstance(payment_id, _integer_types):
raise TypeError("payment_id must be either int or hexadecimal str or bytes, "
"is %r" % payment_id)
self._payment_id = payment_id
def is_short(self):
"""Returns True if payment ID is short enough to be included
in Integrated Address."""
return self._payment_id.bit_length() <= 64
def __repr__(self):
if self.is_short():
return "{:016x}".format(self._payment_id)
return "{:064x}".format(self._payment_id)
def __int__(self):
return self._payment_id
def __eq__(self, other):
if isinstance(other, PaymentID):
return int(self) == int(other)
elif isinstance(other, _integer_types):
return int(self) == other
elif isinstance(other, _str_types):
return str(self) == other
return super()

View File

@ -1,7 +1,7 @@
from decimal import Decimal from decimal import Decimal
import unittest import unittest
from monero.numbers import to_atomic, from_atomic, payment_id_as_int from monero.numbers import to_atomic, from_atomic, PaymentID
class NumbersTestCase(unittest.TestCase): class NumbersTestCase(unittest.TestCase):
def test_simple_numbers(self): def test_simple_numbers(self):
@ -16,5 +16,23 @@ class NumbersTestCase(unittest.TestCase):
self.assertEqual(to_atomic(Decimal('1.0000000000004')), 1000000000000) self.assertEqual(to_atomic(Decimal('1.0000000000004')), 1000000000000)
def test_payment_id(self): def test_payment_id(self):
self.assertEqual(payment_id_as_int('0'), 0) pid = PaymentID('0')
self.assertEqual(payment_id_as_int('abcdef'), 0xabcdef) self.assertTrue(pid.is_short())
self.assertEqual(pid, 0)
self.assertEqual(pid, '0000000000000000')
self.assertEqual(PaymentID(pid), pid)
pid = PaymentID('abcdef')
self.assertTrue(pid.is_short())
self.assertEqual(pid, 0xabcdef)
self.assertEqual(pid, '0000000000abcdef')
self.assertEqual(PaymentID(pid), pid)
pid = PaymentID('1234567812345678')
self.assertTrue(pid.is_short())
self.assertEqual(pid, 0x1234567812345678)
self.assertEqual(pid, '1234567812345678')
self.assertEqual(PaymentID(pid), pid)
pid = PaymentID('a1234567812345678')
self.assertFalse(pid.is_short())
self.assertEqual(pid, 0xa1234567812345678)
self.assertEqual(pid, '00000000000000000000000000000000000000000000000a1234567812345678')
self.assertEqual(PaymentID(pid), pid)

View File

@ -31,10 +31,10 @@ def get_wallet():
return Wallet(JSONRPCWallet(**args.daemon_url)) return Wallet(JSONRPCWallet(**args.daemon_url))
_TXHDR = "timestamp height id/hash " \ _TXHDR = "timestamp height id/hash " \
" amount fee payment_id {dir}" " amount fee {dir:95s} payment_id"
def tx2str(tx): def tx2str(tx):
return "{time} {height} {hash} {amount:17.12f} {fee:13.12f} {payment_id} {addr}".format( return "{time} {height} {hash} {amount:17.12f} {fee:13.12f} {addr} {payment_id}".format(
time=tx.timestamp.strftime("%d-%m-%y %H:%M:%S") if getattr(tx, 'timestamp', None) else None, time=tx.timestamp.strftime("%d-%m-%y %H:%M:%S") if getattr(tx, 'timestamp', None) else None,
height=tx.height, height=tx.height,
hash=tx.hash, hash=tx.hash,