Simplify the Seed class interface, raise ValueError instead of assertions

This commit is contained in:
Michał Sałaban 2018-05-26 17:58:39 +02:00
parent 83eca24ce0
commit ce9427180b
2 changed files with 39 additions and 56 deletions

View File

@ -29,6 +29,11 @@
# Copied 17 February 2018 from MoneroPy, originally from Electrum: # Copied 17 February 2018 from MoneroPy, originally from Electrum:
# https://github.com/bigreddmachine/MoneroPy/blob/master/moneropy/mnemonic.py ch: 80cc16c39b16c55a8d052fbf7fae68644f7a5f02 # https://github.com/bigreddmachine/MoneroPy/blob/master/moneropy/mnemonic.py ch: 80cc16c39b16c55a8d052fbf7fae68644f7a5f02
# https://github.com/spesmilo/electrum/blob/master/lib/old_mnemonic.py ch:9a0aa9b4783ea03ea13c6d668e080e0cdf261c5b # https://github.com/spesmilo/electrum/blob/master/lib/old_mnemonic.py ch:9a0aa9b4783ea03ea13c6d668e080e0cdf261c5b
#
# Significantly modified on 26 May 2018 by Michal Salaban:
# + support for 12/13-word seeds
# + simplified interface, changed exceptions (assertions -> explicit raise)
# + optimization
from monero import address from monero import address
from monero import wordlists from monero import wordlists
@ -50,40 +55,40 @@ class Seed(object):
phrase = "" #13 or 25 word mnemonic word string phrase = "" #13 or 25 word mnemonic word string
hex = "" # hexadecimal hex = "" # hexadecimal
def __init__(self, phrase=""): def __init__(self, phrase_or_hex=""):
"""If user supplied a seed string to the class, break it down and determine """If user supplied a seed string to the class, break it down and determine
if it's hexadecimal or mnemonic word string. Gather the values and store them. if it's hexadecimal or mnemonic word string. Gather the values and store them.
If no seed is passed, automatically generate a new one from local system randomness. If no seed is passed, automatically generate a new one from local system randomness.
:rtype: :class:`Seed <monero.seed.Seed>` :rtype: :class:`Seed <monero.seed.Seed>`
""" """
if phrase: if phrase_or_hex:
seed_split = phrase.split(" ") seed_split = phrase_or_hex.split(" ")
if len(seed_split) >= 24: if len(seed_split) >= 24:
# standard mnemonic # standard mnemonic
self.phrase = phrase self.phrase = phrase_or_hex
if len(seed_split) == 25: if len(seed_split) == 25:
# with checksum # with checksum
self.validate_checksum() self._validate_checksum()
self.hex = self.decode_seed() self._decode_seed()
elif len(seed_split) >= 12: elif len(seed_split) >= 12:
# mymonero mnemonic # mymonero mnemonic
self.phrase = phrase self.phrase = phrase_or_hex
if len(seed_split) == 13: if len(seed_split) == 13:
# with checksum # with checksum
self.validate_checksum() self._validate_checksum()
self.hex = self.decode_seed() self._decode_seed()
elif len(seed_split) == 1: elif len(seed_split) == 1:
# single string, probably hex, but confirm # single string, probably hex, but confirm
if not len(phrase) % 8 == 0: if not len(phrase_or_hex) % 8 == 0:
raise ValueError("Not valid hexadecimal: {hex}".format(hex=phrase)) raise ValueError("Not valid hexadecimal: {hex}".format(hex=phrase_or_hex))
self.hex = phrase self.hex = phrase_or_hex
self.phrase = self.encode_seed() self._encode_seed()
else: else:
raise ValueError("Not valid mnemonic phrase: {phrase}".format(phrase=phrase)) raise ValueError("Not valid mnemonic phrase or hex: {arg}".format(arg=phrase_or_hex))
else: else:
self.hex = generate_hex() self.hex = generate_hex()
self.encode_seed() self._encode_seed()
def is_mymonero(self): def is_mymonero(self):
"""Returns True if the seed is MyMonero-style (12/13-word).""" """Returns True if the seed is MyMonero-style (12/13-word)."""
@ -96,13 +101,9 @@ class Seed(object):
""" """
return "".join([word[i:i+2] for i in [6, 4, 2, 0]]) return "".join([word[i:i+2] for i in [6, 4, 2, 0]])
def encode_seed(self): def _encode_seed(self):
"""Given a hexadecimal string, return it's mnemonic word representation with checksum. """Convert hexadecimal string to mnemonic word representation with checksum.
:rtype: str
""" """
assert self.hex, "Seed hex not set"
assert len(self.hex) % 8 == 0, "Not valid hexadecimal"
out = [] out = []
for i in range(len(self.hex) // 8): for i in range(len(self.hex) // 8):
word = self.endian_swap(self.hex[8*i:8*i+8]) word = self.endian_swap(self.hex[8*i:8*i+8])
@ -114,16 +115,11 @@ class Seed(object):
checksum = get_checksum(" ".join(out)) checksum = get_checksum(" ".join(out))
out.append(checksum) out.append(checksum)
self.phrase = " ".join(out) self.phrase = " ".join(out)
return self.phrase
def decode_seed(self): def _decode_seed(self):
"""Given a mnemonic word string, return it's hexadecimal representation. """Calculate hexadecimal representation of the phrase.
:rtype: str
""" """
assert self.phrase, "Seed phrase not set"
phrase = self.phrase.split(" ") phrase = self.phrase.split(" ")
assert len(phrase) >= 12, "Not valid mnemonic phrase"
out = "" out = ""
for i in range(len(phrase) // 3): for i in range(len(phrase) // 3):
word1, word2, word3 = phrase[3*i:3*i+3] word1, word2, word3 = phrase[3*i:3*i+3]
@ -133,19 +129,16 @@ class Seed(object):
x = w1 + self.n *((w2 - w1) % self.n) + self.n * self.n * ((w3 - w2) % self.n) x = w1 + self.n *((w2 - w1) % self.n) + self.n * self.n * ((w3 - w2) % self.n)
out += self.endian_swap("%08x" % x) out += self.endian_swap("%08x" % x)
self.hex = out self.hex = out
return self.hex
def validate_checksum(self): def _validate_checksum(self):
"""Given a mnemonic word string, confirm seed checksum (last word) matches the computed checksum. """Given a mnemonic word string, confirm seed checksum (last word) matches the computed checksum.
:rtype: bool :rtype: bool
""" """
assert self.phrase, "Seed phrase not set"
phrase = self.phrase.split(" ") phrase = self.phrase.split(" ")
assert len(phrase) > 12, "Not valid mnemonic phrase" if get_checksum(self.phrase) == phrase[-1]:
is_match = get_checksum(self.phrase) == phrase[-1] return True
assert is_match, "Not valid checksum" raise ValueError("Invalid checksum")
return is_match
def sc_reduce(self, input): def sc_reduce(self, input):
integer = ed25519.decodeint(unhexlify(input)) integer = ed25519.decodeint(unhexlify(input))
@ -204,7 +197,8 @@ def get_checksum(phrase):
:rtype: str :rtype: str
""" """
phrase_split = phrase.split(" ") phrase_split = phrase.split(" ")
assert len(phrase_split) >= 12, "Not valid mnemonic phrase" if len(phrase_split) < 12:
raise ValueError("Invalid mnemonic phrase")
if len(phrase_split) > 13: if len(phrase_split) > 13:
# Standard format # Standard format
phrase = phrase_split[:24] phrase = phrase_split[:24]

View File

@ -6,46 +6,37 @@ from monero.seed import Seed, get_checksum
class SeedTestCase(unittest.TestCase): class SeedTestCase(unittest.TestCase):
def test_mnemonic_seed(self): def test_mnemonic_seed(self):
# Known good 25 word seed phrases should construct a class, validate checksum, and register valid hex # Known good 25 word seed phrases should construct a class and register valid hex
seed = Seed("wedge going quick racetrack auburn physics lectures light waist axes whipped habitat square awkward together injury niece nugget guarded hive obnoxious waxing faked folding square") seed = Seed("wedge going quick racetrack auburn physics lectures light waist axes whipped habitat square awkward together injury niece nugget guarded hive obnoxious waxing faked folding square")
self.assertTrue(seed.validate_checksum())
self.assertEqual(seed.hex, "8ffa9f586b86d294d93731765d192765311bddc76a4fa60311f8af36bbf6fb06") self.assertEqual(seed.hex, "8ffa9f586b86d294d93731765d192765311bddc76a4fa60311f8af36bbf6fb06")
# Known good 24 word seed phrases should construct a class, store phrase with valid checksum, and register valid hex # Known good 24 word seed phrases should construct a class, store phrase, and register valid hex
seed = Seed("wedge going quick racetrack auburn physics lectures light waist axes whipped habitat square awkward together injury niece nugget guarded hive obnoxious waxing faked folding") seed = Seed("wedge going quick racetrack auburn physics lectures light waist axes whipped habitat square awkward together injury niece nugget guarded hive obnoxious waxing faked folding")
seed.encode_seed()
self.assertTrue(seed.validate_checksum())
self.assertEqual(seed.hex, "8ffa9f586b86d294d93731765d192765311bddc76a4fa60311f8af36bbf6fb06") self.assertEqual(seed.hex, "8ffa9f586b86d294d93731765d192765311bddc76a4fa60311f8af36bbf6fb06")
# Known good 25 word hexadecimal strings should construct a class, store phrase with valid checksum, and register valid hex # Known good 25 word hexadecimal strings should construct a class, store phrase, and register valid hex
seed = Seed("8ffa9f586b86d294d93731765d192765311bddc76a4fa60311f8af36bbf6fb06") seed = Seed("8ffa9f586b86d294d93731765d192765311bddc76a4fa60311f8af36bbf6fb06")
self.assertTrue(seed.validate_checksum())
self.assertEqual(seed.phrase, "wedge going quick racetrack auburn physics lectures light waist axes whipped habitat square awkward together injury niece nugget guarded hive obnoxious waxing faked folding square") self.assertEqual(seed.phrase, "wedge going quick racetrack auburn physics lectures light waist axes whipped habitat square awkward together injury niece nugget guarded hive obnoxious waxing faked folding square")
self.assertTrue(len(seed.hex) % 8 == 0)
# Known good 13 word seed phrases should construct a class, validate checksum, and register valid hex # Known good 13 word seed phrases should construct a class and register valid hex
seed = Seed("ought knowledge upright innocent eldest nerves gopher fowls below exquisite aces basin fowls") seed = Seed("ought knowledge upright innocent eldest nerves gopher fowls below exquisite aces basin fowls")
self.assertTrue(seed.validate_checksum())
self.assertEqual(seed.hex, "932d70711acc2d536ca11fcb79e05516") self.assertEqual(seed.hex, "932d70711acc2d536ca11fcb79e05516")
# Known good 12 word seed phrases should construct a class, store phrase with valid checksum, and register valid hex # Known good 12 word seed phrases should construct a class, store phrase, and register valid hex
seed = Seed("ought knowledge upright innocent eldest nerves gopher fowls below exquisite aces basin") seed = Seed("ought knowledge upright innocent eldest nerves gopher fowls below exquisite aces basin")
seed.encode_seed()
self.assertTrue(seed.validate_checksum())
self.assertEqual(seed.hex, "932d70711acc2d536ca11fcb79e05516") self.assertEqual(seed.hex, "932d70711acc2d536ca11fcb79e05516")
# Known good 13 word hexadecimal strings should construct a class, store phrase with valid checksum, and register valid hex # Known good 13 word hexadecimal strings should construct a class, store phrase, and register valid hex
seed = Seed("932d70711acc2d536ca11fcb79e05516") seed = Seed("932d70711acc2d536ca11fcb79e05516")
self.assertTrue(seed.validate_checksum())
self.assertEqual(seed.phrase, "ought knowledge upright innocent eldest nerves gopher fowls below exquisite aces basin fowls") self.assertEqual(seed.phrase, "ought knowledge upright innocent eldest nerves gopher fowls below exquisite aces basin fowls")
self.assertTrue(len(seed.hex) % 8 == 0)
# Generated seed phrases should be 25 words, validate checksum, register valid hex, and return matching outputs for decode/encode # Generated seed phrases should be 25 words, register valid hex
seed = Seed() seed = Seed()
seed_split = seed.phrase.split(" ") seed_split = seed.phrase.split(" ")
self.assertTrue(len(seed_split) == 25) self.assertTrue(len(seed_split) == 25)
self.assertTrue(seed.validate_checksum())
self.assertTrue(len(seed.hex) % 8 == 0) self.assertTrue(len(seed.hex) % 8 == 0)
self.assertEqual(seed.hex, seed.decode_seed())
self.assertEqual(seed.phrase, seed.encode_seed())
# Invalid phrases should not be allowed # Invalid phrases should not be allowed
with self.assertRaises(ValueError) as ts: with self.assertRaises(ValueError) as ts:
@ -67,7 +58,6 @@ class SeedTestCase(unittest.TestCase):
"framed succeed fuzzy return demonstrate nucleus album noises peculiar virtual "\ "framed succeed fuzzy return demonstrate nucleus album noises peculiar virtual "\
"rowboat inorganic jester fuzzy") "rowboat inorganic jester fuzzy")
self.assertFalse(seed.is_mymonero()) self.assertFalse(seed.is_mymonero())
self.assertTrue(seed.validate_checksum())
self.assertEqual( self.assertEqual(
seed.secret_spend_key(), seed.secret_spend_key(),
'482700617ba810f94035d7f4d7ccc1a29878e165b4867872b705204c85406906') '482700617ba810f94035d7f4d7ccc1a29878e165b4867872b705204c85406906')
@ -89,7 +79,6 @@ class SeedTestCase(unittest.TestCase):
seed = Seed("dwelt idols lopped blender haggled rabbits piloted value swagger taunts toolbox upgrade swagger") seed = Seed("dwelt idols lopped blender haggled rabbits piloted value swagger taunts toolbox upgrade swagger")
self.assertTrue(seed.is_mymonero()) self.assertTrue(seed.is_mymonero())
self.assertTrue(seed.validate_checksum())
# the following fails, #21 addresses that # the following fails, #21 addresses that
self.assertEqual( self.assertEqual(
seed.secret_spend_key(), seed.secret_spend_key(),