bulletproofs: speedup PROVE

This commit is contained in:
moneromooo-monero 2018-08-07 08:02:42 +00:00
parent 2287fb9fb4
commit 4564a5d17b
No known key found for this signature in database
GPG key ID: 686F07454D6CEFC3

View file

@ -127,15 +127,6 @@ static void sub_acc_p3(ge_p3 *acc_p3, const rct::key &point)
ge_p1p1_to_p3(acc_p3, &p1); ge_p1p1_to_p3(acc_p3, &p1);
} }
static rct::key scalarmultKey(const ge_p3 &P, const rct::key &a)
{
ge_p2 R;
ge_scalarmult(&R, a.bytes, &P);
rct::key aP;
ge_tobytes(aP.bytes, &R);
return aP;
}
static rct::key get_exponent(const rct::key &base, size_t idx) static rct::key get_exponent(const rct::key &base, size_t idx)
{ {
static const std::string salt("bulletproof"); static const std::string salt("bulletproof");
@ -193,23 +184,28 @@ static rct::key vector_exponent(const rct::keyV &a, const rct::keyV &b)
} }
/* Compute a custom vector-scalar commitment */ /* Compute a custom vector-scalar commitment */
static rct::key vector_exponent_custom(const rct::keyV &A, const rct::keyV &B, const rct::keyV &a, const rct::keyV &b) static rct::key cross_vector_exponent8(size_t size, const std::vector<ge_p3> &A, size_t Ao, const std::vector<ge_p3> &B, size_t Bo, const rct::keyV &a, size_t ao, const rct::keyV &b, size_t bo, const ge_p3 *extra_point, const rct::key *extra_scalar)
{ {
CHECK_AND_ASSERT_THROW_MES(A.size() == B.size(), "Incompatible sizes of A and B"); CHECK_AND_ASSERT_THROW_MES(size + Ao <= A.size(), "Incompatible size for A");
CHECK_AND_ASSERT_THROW_MES(a.size() == b.size(), "Incompatible sizes of a and b"); CHECK_AND_ASSERT_THROW_MES(size + Bo <= B.size(), "Incompatible size for B");
CHECK_AND_ASSERT_THROW_MES(a.size() == A.size(), "Incompatible sizes of a and A"); CHECK_AND_ASSERT_THROW_MES(size + ao <= a.size(), "Incompatible size for a");
CHECK_AND_ASSERT_THROW_MES(a.size() <= maxN*maxM, "Incompatible sizes of a and maxN"); CHECK_AND_ASSERT_THROW_MES(size + bo <= b.size(), "Incompatible size for b");
CHECK_AND_ASSERT_THROW_MES(size <= maxN*maxM, "size is too large");
CHECK_AND_ASSERT_THROW_MES(!!extra_point == !!extra_scalar, "only one of extra point/scalar present");
std::vector<MultiexpData> multiexp_data; std::vector<MultiexpData> multiexp_data;
multiexp_data.reserve(a.size()*2); multiexp_data.resize(size*2 + (!!extra_point));
for (size_t i = 0; i < a.size(); ++i) for (size_t i = 0; i < size; ++i)
{ {
multiexp_data.resize(multiexp_data.size() + 1); sc_mul(multiexp_data[i*2].scalar.bytes, a[ao+i].bytes, INV_EIGHT.bytes);;
multiexp_data.back().scalar = a[i]; multiexp_data[i*2].point = A[Ao+i];
CHECK_AND_ASSERT_THROW_MES(ge_frombytes_vartime(&multiexp_data.back().point, A[i].bytes) == 0, "ge_frombytes_vartime failed"); sc_mul(multiexp_data[i*2+1].scalar.bytes, b[bo+i].bytes, INV_EIGHT.bytes);
multiexp_data.resize(multiexp_data.size() + 1); multiexp_data[i*2+1].point = B[Bo+i];
multiexp_data.back().scalar = b[i]; }
CHECK_AND_ASSERT_THROW_MES(ge_frombytes_vartime(&multiexp_data.back().point, B[i].bytes) == 0, "ge_frombytes_vartime failed"); if (extra_point)
{
sc_mul(multiexp_data.back().scalar.bytes, extra_scalar->bytes, INV_EIGHT.bytes);
multiexp_data.back().point = *extra_point;
} }
return multiexp(multiexp_data, false); return multiexp(multiexp_data, false);
} }
@ -273,16 +269,19 @@ static rct::keyV hadamard(const rct::keyV &a, const rct::keyV &b)
return res; return res;
} }
/* Given two curvepoint arrays, construct the Hadamard product */ /* folds a curvepoint array using a two way scaled Hadamard product */
static rct::keyV hadamard2(const rct::keyV &a, const rct::keyV &b) static void hadamard_fold(std::vector<ge_p3> &v, const rct::key &a, const rct::key &b)
{ {
CHECK_AND_ASSERT_THROW_MES(a.size() == b.size(), "Incompatible sizes of a and b"); CHECK_AND_ASSERT_THROW_MES((v.size() & 1) == 0, "Vector size should be even");
rct::keyV res(a.size()); const size_t sz = v.size() / 2;
for (size_t i = 0; i < a.size(); ++i) for (size_t n = 0; n < sz; ++n)
{ {
rct::addKeys(res[i], a[i], b[i]); ge_dsmp c[2];
ge_dsm_precomp(c[0], &v[n]);
ge_dsm_precomp(c[1], &v[sz + n]);
ge_double_scalarmult_precomp_vartime2_p3(&v[n], a.bytes, c[0], b.bytes, c[1]);
} }
return res; v.resize(sz);
} }
/* Add two vectors */ /* Add two vectors */
@ -326,17 +325,6 @@ static rct::keyV vector_dup(const rct::key &x, size_t N)
return rct::keyV(N, x); return rct::keyV(N, x);
} }
/* Exponentiate a curve vector by a scalar */
static rct::keyV vector_scalar2(const rct::keyV &a, const rct::key &x)
{
rct::keyV res(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
rct::scalarmultKey(res[i], a[i], x);
}
return res;
}
/* Get the sum of a vector's elements */ /* Get the sum of a vector's elements */
static rct::key vector_sum(const rct::keyV &a) static rct::key vector_sum(const rct::keyV &a)
{ {
@ -620,16 +608,16 @@ try_again:
// These are used in the inner product rounds // These are used in the inner product rounds
size_t nprime = N; size_t nprime = N;
rct::keyV Gprime(N); std::vector<ge_p3> Gprime(N);
rct::keyV Hprime(N); std::vector<ge_p3> Hprime(N);
rct::keyV aprime(N); rct::keyV aprime(N);
rct::keyV bprime(N); rct::keyV bprime(N);
const rct::key yinv = invert(y); const rct::key yinv = invert(y);
rct::key yinvpow = rct::identity(); rct::key yinvpow = rct::identity();
for (size_t i = 0; i < N; ++i) for (size_t i = 0; i < N; ++i)
{ {
Gprime[i] = Gi[i]; Gprime[i] = Gi_p3[i];
Hprime[i] = scalarmultKey(Hi_p3[i], yinvpow); ge_scalarmult_p3(&Hprime[i], yinvpow.bytes, &Hi_p3[i]);
sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes); sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes);
aprime[i] = l[i]; aprime[i] = l[i];
bprime[i] = r[i]; bprime[i] = r[i];
@ -652,14 +640,10 @@ try_again:
rct::key cR = inner_product(slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime)); rct::key cR = inner_product(slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime));
// PAPER LINES 18-19 // PAPER LINES 18-19
L[round] = vector_exponent_custom(slice(Gprime, nprime, Gprime.size()), slice(Hprime, 0, nprime), slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size()));
sc_mul(tmp.bytes, cL.bytes, x_ip.bytes); sc_mul(tmp.bytes, cL.bytes, x_ip.bytes);
rct::addKeys(L[round], L[round], rct::scalarmultH(tmp)); L[round] = cross_vector_exponent8(nprime, Gprime, nprime, Hprime, 0, aprime, 0, bprime, nprime, &ge_p3_H, &tmp);
L[round] = rct::scalarmultKey(L[round], INV_EIGHT);
R[round] = vector_exponent_custom(slice(Gprime, 0, nprime), slice(Hprime, nprime, Hprime.size()), slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime));
sc_mul(tmp.bytes, cR.bytes, x_ip.bytes); sc_mul(tmp.bytes, cR.bytes, x_ip.bytes);
rct::addKeys(R[round], R[round], rct::scalarmultH(tmp)); R[round] = cross_vector_exponent8(nprime, Gprime, 0, Hprime, nprime, aprime, nprime, bprime, 0, &ge_p3_H, &tmp);
R[round] = rct::scalarmultKey(R[round], INV_EIGHT);
// PAPER LINES 21-22 // PAPER LINES 21-22
w[round] = hash_cache_mash(hash_cache, L[round], R[round]); w[round] = hash_cache_mash(hash_cache, L[round], R[round]);
@ -672,8 +656,11 @@ try_again:
// PAPER LINES 24-25 // PAPER LINES 24-25
const rct::key winv = invert(w[round]); const rct::key winv = invert(w[round]);
Gprime = hadamard2(vector_scalar2(slice(Gprime, 0, nprime), winv), vector_scalar2(slice(Gprime, nprime, Gprime.size()), w[round])); if (nprime > 1)
Hprime = hadamard2(vector_scalar2(slice(Hprime, 0, nprime), w[round]), vector_scalar2(slice(Hprime, nprime, Hprime.size()), winv)); {
hadamard_fold(Gprime, winv, w[round]);
hadamard_fold(Hprime, w[round], winv);
}
// PAPER LINES 28-29 // PAPER LINES 28-29
aprime = vector_add(vector_scalar(slice(aprime, 0, nprime), w[round]), vector_scalar(slice(aprime, nprime, aprime.size()), winv)); aprime = vector_add(vector_scalar(slice(aprime, 0, nprime), w[round]), vector_scalar(slice(aprime, nprime, aprime.size()), winv));
@ -914,16 +901,16 @@ try_again:
// These are used in the inner product rounds // These are used in the inner product rounds
size_t nprime = MN; size_t nprime = MN;
rct::keyV Gprime(MN); std::vector<ge_p3> Gprime(MN);
rct::keyV Hprime(MN); std::vector<ge_p3> Hprime(MN);
rct::keyV aprime(MN); rct::keyV aprime(MN);
rct::keyV bprime(MN); rct::keyV bprime(MN);
const rct::key yinv = invert(y); const rct::key yinv = invert(y);
rct::key yinvpow = rct::identity(); rct::key yinvpow = rct::identity();
for (size_t i = 0; i < MN; ++i) for (size_t i = 0; i < MN; ++i)
{ {
Gprime[i] = Gi[i]; Gprime[i] = Gi_p3[i];
Hprime[i] = scalarmultKey(Hi_p3[i], yinvpow); ge_scalarmult_p3(&Hprime[i], yinvpow.bytes, &Hi_p3[i]);
sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes); sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes);
aprime[i] = l[i]; aprime[i] = l[i];
bprime[i] = r[i]; bprime[i] = r[i];
@ -942,18 +929,18 @@ try_again:
nprime /= 2; nprime /= 2;
// PAPER LINES 16-17 // PAPER LINES 16-17
PERF_TIMER_START_BP(PROVE_inner_product);
rct::key cL = inner_product(slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size())); rct::key cL = inner_product(slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size()));
rct::key cR = inner_product(slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime)); rct::key cR = inner_product(slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime));
PERF_TIMER_STOP(PROVE_inner_product);
// PAPER LINES 18-19 // PAPER LINES 18-19
L[round] = vector_exponent_custom(slice(Gprime, nprime, Gprime.size()), slice(Hprime, 0, nprime), slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size())); PERF_TIMER_START_BP(PROVE_LR);
sc_mul(tmp.bytes, cL.bytes, x_ip.bytes); sc_mul(tmp.bytes, cL.bytes, x_ip.bytes);
rct::addKeys(L[round], L[round], rct::scalarmultH(tmp)); L[round] = cross_vector_exponent8(nprime, Gprime, nprime, Hprime, 0, aprime, 0, bprime, nprime, &ge_p3_H, &tmp);
L[round] = rct::scalarmultKey(L[round], INV_EIGHT);
R[round] = vector_exponent_custom(slice(Gprime, 0, nprime), slice(Hprime, nprime, Hprime.size()), slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime));
sc_mul(tmp.bytes, cR.bytes, x_ip.bytes); sc_mul(tmp.bytes, cR.bytes, x_ip.bytes);
rct::addKeys(R[round], R[round], rct::scalarmultH(tmp)); R[round] = cross_vector_exponent8(nprime, Gprime, 0, Hprime, nprime, aprime, nprime, bprime, 0, &ge_p3_H, &tmp);
R[round] = rct::scalarmultKey(R[round], INV_EIGHT); PERF_TIMER_STOP(PROVE_LR);
// PAPER LINES 21-22 // PAPER LINES 21-22
w[round] = hash_cache_mash(hash_cache, L[round], R[round]); w[round] = hash_cache_mash(hash_cache, L[round], R[round]);
@ -966,12 +953,19 @@ try_again:
// PAPER LINES 24-25 // PAPER LINES 24-25
const rct::key winv = invert(w[round]); const rct::key winv = invert(w[round]);
Gprime = hadamard2(vector_scalar2(slice(Gprime, 0, nprime), winv), vector_scalar2(slice(Gprime, nprime, Gprime.size()), w[round])); if (nprime > 1)
Hprime = hadamard2(vector_scalar2(slice(Hprime, 0, nprime), w[round]), vector_scalar2(slice(Hprime, nprime, Hprime.size()), winv)); {
PERF_TIMER_START_BP(PROVE_hadamard2);
hadamard_fold(Gprime, winv, w[round]);
hadamard_fold(Hprime, w[round], winv);
PERF_TIMER_STOP(PROVE_hadamard2);
}
// PAPER LINES 28-29 // PAPER LINES 28-29
PERF_TIMER_START_BP(PROVE_prime);
aprime = vector_add(vector_scalar(slice(aprime, 0, nprime), w[round]), vector_scalar(slice(aprime, nprime, aprime.size()), winv)); aprime = vector_add(vector_scalar(slice(aprime, 0, nprime), w[round]), vector_scalar(slice(aprime, nprime, aprime.size()), winv));
bprime = vector_add(vector_scalar(slice(bprime, 0, nprime), winv), vector_scalar(slice(bprime, nprime, bprime.size()), w[round])); bprime = vector_add(vector_scalar(slice(bprime, 0, nprime), winv), vector_scalar(slice(bprime, nprime, bprime.size()), w[round]));
PERF_TIMER_STOP(PROVE_prime);
++round; ++round;
} }