diff --git a/monero/transaction.py b/monero/transaction.py index 4457c8a..20e9bcf 100644 --- a/monero/transaction.py +++ b/monero/transaction.py @@ -112,6 +112,38 @@ class PaymentManager(object): return fetch(self.account_idx, PaymentFilter(**filterparams)) +class _ByHeight(object): + """A helper class used as key in sorting of payments by height. + Mempool goes on top, blockchain payments are ordered with descending block numbers. + + **WARNING:** Integer sorting is reversed here. + """ + def __init__(self, pmt): + self.pmt = pmt + def _cmp(self, other): + sh = self.pmt.transaction.height + oh = other.pmt.transaction.height + if sh is oh is None: + return 0 + if sh is None: + return 1 + if oh is None: + return -1 + return (sh > oh) - (sh < oh) + def __lt__(self, other): + return self._cmp(other) > 0 + def __le__(self, other): + return self._cmp(other) >= 0 + def __eq__(self, other): + return self._cmp(other) == 0 + def __ge__(self, other): + return self._cmp(other) <= 0 + def __gt__(self, other): + return self._cmp(other) < 0 + def __ne__(self, other): + return self._cmp(other) != 0 + + class PaymentFilter(object): """ A helper class that filters payments retrieved by the backend. @@ -176,4 +208,6 @@ class PaymentFilter(object): return True def filter(self, payments): - return filter(self.check, payments) + return sorted( + filter(self.check, payments), + key=_ByHeight) diff --git a/tests/test_transaction.py b/tests/test_transaction.py index d83c2d6..35cc2d9 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -1,10 +1,12 @@ from datetime import datetime from decimal import Decimal +from operator import attrgetter +import random import unittest from monero.address import address from monero.numbers import PaymentID -from monero.transaction import IncomingPayment, OutgoingPayment, Transaction +from monero.transaction import IncomingPayment, OutgoingPayment, Transaction, _ByHeight class FiltersTestCase(unittest.TestCase): def setUp(self): @@ -26,3 +28,23 @@ class FiltersTestCase(unittest.TestCase): self.assertIn( 'a0b876ebcf7c1d499712d84cedec836f9d50b608bb22d6cb49fd2feae3ffed14', repr(self.pm1)) + + +class SortingTestCase(unittest.TestCase): + def test_sorting(self): + pmts = [ + IncomingPayment(transaction=Transaction(height=10)), + IncomingPayment(transaction=Transaction(height=12)), + IncomingPayment(transaction=Transaction(height=13)), + IncomingPayment(transaction=Transaction(height=None)), + IncomingPayment(transaction=Transaction(height=100)), + IncomingPayment(transaction=Transaction(height=None)), + IncomingPayment(transaction=Transaction(height=1)) + ] + for i in range(1680): # 1/3 of possible permutations + sorted_pmts = sorted(pmts, key=_ByHeight) + self.assertEqual( + list(map(attrgetter('height'), map(attrgetter('transaction'), sorted_pmts))), + [None, None, 100, 13, 12, 10, 1]) + random.shuffle(pmts) +