Simplify the Seed class interface, raise ValueError instead of assertions

This commit is contained in:
Michał Sałaban 2018-05-26 17:58:39 +02:00
parent 83eca24ce0
commit ce9427180b
2 changed files with 39 additions and 56 deletions

View file

@ -29,6 +29,11 @@
# Copied 17 February 2018 from MoneroPy, originally from Electrum:
# https://github.com/bigreddmachine/MoneroPy/blob/master/moneropy/mnemonic.py ch: 80cc16c39b16c55a8d052fbf7fae68644f7a5f02
# https://github.com/spesmilo/electrum/blob/master/lib/old_mnemonic.py ch:9a0aa9b4783ea03ea13c6d668e080e0cdf261c5b
#
# Significantly modified on 26 May 2018 by Michal Salaban:
# + support for 12/13-word seeds
# + simplified interface, changed exceptions (assertions -> explicit raise)
# + optimization
from monero import address
from monero import wordlists
@ -50,40 +55,40 @@ class Seed(object):
phrase = "" #13 or 25 word mnemonic word string
hex = "" # hexadecimal
def __init__(self, phrase=""):
def __init__(self, phrase_or_hex=""):
"""If user supplied a seed string to the class, break it down and determine
if it's hexadecimal or mnemonic word string. Gather the values and store them.
If no seed is passed, automatically generate a new one from local system randomness.
:rtype: :class:`Seed <monero.seed.Seed>`
"""
if phrase:
seed_split = phrase.split(" ")
if phrase_or_hex:
seed_split = phrase_or_hex.split(" ")
if len(seed_split) >= 24:
# standard mnemonic
self.phrase = phrase
self.phrase = phrase_or_hex
if len(seed_split) == 25:
# with checksum
self.validate_checksum()
self.hex = self.decode_seed()
self._validate_checksum()
self._decode_seed()
elif len(seed_split) >= 12:
# mymonero mnemonic
self.phrase = phrase
self.phrase = phrase_or_hex
if len(seed_split) == 13:
# with checksum
self.validate_checksum()
self.hex = self.decode_seed()
self._validate_checksum()
self._decode_seed()
elif len(seed_split) == 1:
# single string, probably hex, but confirm
if not len(phrase) % 8 == 0:
raise ValueError("Not valid hexadecimal: {hex}".format(hex=phrase))
self.hex = phrase
self.phrase = self.encode_seed()
if not len(phrase_or_hex) % 8 == 0:
raise ValueError("Not valid hexadecimal: {hex}".format(hex=phrase_or_hex))
self.hex = phrase_or_hex
self._encode_seed()
else:
raise ValueError("Not valid mnemonic phrase: {phrase}".format(phrase=phrase))
raise ValueError("Not valid mnemonic phrase or hex: {arg}".format(arg=phrase_or_hex))
else:
self.hex = generate_hex()
self.encode_seed()
self._encode_seed()
def is_mymonero(self):
"""Returns True if the seed is MyMonero-style (12/13-word)."""
@ -96,13 +101,9 @@ class Seed(object):
"""
return "".join([word[i:i+2] for i in [6, 4, 2, 0]])
def encode_seed(self):
"""Given a hexadecimal string, return it's mnemonic word representation with checksum.
:rtype: str
def _encode_seed(self):
"""Convert hexadecimal string to mnemonic word representation with checksum.
"""
assert self.hex, "Seed hex not set"
assert len(self.hex) % 8 == 0, "Not valid hexadecimal"
out = []
for i in range(len(self.hex) // 8):
word = self.endian_swap(self.hex[8*i:8*i+8])
@ -114,16 +115,11 @@ class Seed(object):
checksum = get_checksum(" ".join(out))
out.append(checksum)
self.phrase = " ".join(out)
return self.phrase
def decode_seed(self):
"""Given a mnemonic word string, return it's hexadecimal representation.
:rtype: str
def _decode_seed(self):
"""Calculate hexadecimal representation of the phrase.
"""
assert self.phrase, "Seed phrase not set"
phrase = self.phrase.split(" ")
assert len(phrase) >= 12, "Not valid mnemonic phrase"
out = ""
for i in range(len(phrase) // 3):
word1, word2, word3 = phrase[3*i:3*i+3]
@ -133,19 +129,16 @@ class Seed(object):
x = w1 + self.n *((w2 - w1) % self.n) + self.n * self.n * ((w3 - w2) % self.n)
out += self.endian_swap("%08x" % x)
self.hex = out
return self.hex
def validate_checksum(self):
def _validate_checksum(self):
"""Given a mnemonic word string, confirm seed checksum (last word) matches the computed checksum.
:rtype: bool
"""
assert self.phrase, "Seed phrase not set"
phrase = self.phrase.split(" ")
assert len(phrase) > 12, "Not valid mnemonic phrase"
is_match = get_checksum(self.phrase) == phrase[-1]
assert is_match, "Not valid checksum"
return is_match
if get_checksum(self.phrase) == phrase[-1]:
return True
raise ValueError("Invalid checksum")
def sc_reduce(self, input):
integer = ed25519.decodeint(unhexlify(input))
@ -204,7 +197,8 @@ def get_checksum(phrase):
:rtype: str
"""
phrase_split = phrase.split(" ")
assert len(phrase_split) >= 12, "Not valid mnemonic phrase"
if len(phrase_split) < 12:
raise ValueError("Invalid mnemonic phrase")
if len(phrase_split) > 13:
# Standard format
phrase = phrase_split[:24]