bulletproofs: speedup prover

This commit is contained in:
moneromooo-monero 2018-08-25 18:37:21 +00:00
parent 6f9ae5b6eb
commit 8b4767221c
No known key found for this signature in database
GPG key ID: 686F07454D6CEFC3

View file

@ -146,13 +146,14 @@ static rct::key vector_exponent(const rct::keyV &a, const rct::keyV &b)
} }
/* Compute a custom vector-scalar commitment */ /* Compute a custom vector-scalar commitment */
static rct::key cross_vector_exponent8(size_t size, const std::vector<ge_p3> &A, size_t Ao, const std::vector<ge_p3> &B, size_t Bo, const rct::keyV &a, size_t ao, const rct::keyV &b, size_t bo, const ge_p3 *extra_point, const rct::key *extra_scalar) static rct::key cross_vector_exponent8(size_t size, const std::vector<ge_p3> &A, size_t Ao, const std::vector<ge_p3> &B, size_t Bo, const rct::keyV &a, size_t ao, const rct::keyV &b, size_t bo, const rct::keyV *scale, const ge_p3 *extra_point, const rct::key *extra_scalar)
{ {
CHECK_AND_ASSERT_THROW_MES(size + Ao <= A.size(), "Incompatible size for A"); CHECK_AND_ASSERT_THROW_MES(size + Ao <= A.size(), "Incompatible size for A");
CHECK_AND_ASSERT_THROW_MES(size + Bo <= B.size(), "Incompatible size for B"); CHECK_AND_ASSERT_THROW_MES(size + Bo <= B.size(), "Incompatible size for B");
CHECK_AND_ASSERT_THROW_MES(size + ao <= a.size(), "Incompatible size for a"); CHECK_AND_ASSERT_THROW_MES(size + ao <= a.size(), "Incompatible size for a");
CHECK_AND_ASSERT_THROW_MES(size + bo <= b.size(), "Incompatible size for b"); CHECK_AND_ASSERT_THROW_MES(size + bo <= b.size(), "Incompatible size for b");
CHECK_AND_ASSERT_THROW_MES(size <= maxN*maxM, "size is too large"); CHECK_AND_ASSERT_THROW_MES(size <= maxN*maxM, "size is too large");
CHECK_AND_ASSERT_THROW_MES(!scale || size == scale->size() / 2, "Incompatible size for scale");
CHECK_AND_ASSERT_THROW_MES(!!extra_point == !!extra_scalar, "only one of extra point/scalar present"); CHECK_AND_ASSERT_THROW_MES(!!extra_point == !!extra_scalar, "only one of extra point/scalar present");
std::vector<MultiexpData> multiexp_data; std::vector<MultiexpData> multiexp_data;
@ -162,6 +163,8 @@ static rct::key cross_vector_exponent8(size_t size, const std::vector<ge_p3> &A,
sc_mul(multiexp_data[i*2].scalar.bytes, a[ao+i].bytes, INV_EIGHT.bytes);; sc_mul(multiexp_data[i*2].scalar.bytes, a[ao+i].bytes, INV_EIGHT.bytes);;
multiexp_data[i*2].point = A[Ao+i]; multiexp_data[i*2].point = A[Ao+i];
sc_mul(multiexp_data[i*2+1].scalar.bytes, b[bo+i].bytes, INV_EIGHT.bytes); sc_mul(multiexp_data[i*2+1].scalar.bytes, b[bo+i].bytes, INV_EIGHT.bytes);
if (scale)
sc_mul(multiexp_data[i*2+1].scalar.bytes, multiexp_data[i*2+1].scalar.bytes, (*scale)[Bo+i].bytes);
multiexp_data[i*2+1].point = B[Bo+i]; multiexp_data[i*2+1].point = B[Bo+i];
} }
if (extra_point) if (extra_point)
@ -232,7 +235,7 @@ static rct::keyV hadamard(const rct::keyV &a, const rct::keyV &b)
} }
/* folds a curvepoint array using a two way scaled Hadamard product */ /* folds a curvepoint array using a two way scaled Hadamard product */
static void hadamard_fold(std::vector<ge_p3> &v, const rct::key &a, const rct::key &b) static void hadamard_fold(std::vector<ge_p3> &v, const rct::keyV *scale, const rct::key &a, const rct::key &b)
{ {
CHECK_AND_ASSERT_THROW_MES((v.size() & 1) == 0, "Vector size should be even"); CHECK_AND_ASSERT_THROW_MES((v.size() & 1) == 0, "Vector size should be even");
const size_t sz = v.size() / 2; const size_t sz = v.size() / 2;
@ -241,7 +244,10 @@ static void hadamard_fold(std::vector<ge_p3> &v, const rct::key &a, const rct::k
ge_dsmp c[2]; ge_dsmp c[2];
ge_dsm_precomp(c[0], &v[n]); ge_dsm_precomp(c[0], &v[n]);
ge_dsm_precomp(c[1], &v[sz + n]); ge_dsm_precomp(c[1], &v[sz + n]);
ge_double_scalarmult_precomp_vartime2_p3(&v[n], a.bytes, c[0], b.bytes, c[1]); rct::key sa, sb;
if (scale) sc_mul(sa.bytes, a.bytes, (*scale)[n].bytes); else sa = a;
if (scale) sc_mul(sb.bytes, b.bytes, (*scale)[sz + n].bytes); else sb = b;
ge_double_scalarmult_precomp_vartime2_p3(&v[n], sa.bytes, c[0], sb.bytes, c[1]);
} }
v.resize(sz); v.resize(sz);
} }
@ -258,14 +264,24 @@ static rct::keyV vector_add(const rct::keyV &a, const rct::keyV &b)
return res; return res;
} }
/* Subtract two vectors */ /* Add a scalar to all elements of a vector */
static rct::keyV vector_subtract(const rct::keyV &a, const rct::keyV &b) static rct::keyV vector_add(const rct::keyV &a, const rct::key &b)
{ {
CHECK_AND_ASSERT_THROW_MES(a.size() == b.size(), "Incompatible sizes of a and b");
rct::keyV res(a.size()); rct::keyV res(a.size());
for (size_t i = 0; i < a.size(); ++i) for (size_t i = 0; i < a.size(); ++i)
{ {
sc_sub(res[i].bytes, a[i].bytes, b[i].bytes); sc_add(res[i].bytes, a[i].bytes, b.bytes);
}
return res;
}
/* Subtract a scalar from all elements of a vector */
static rct::keyV vector_subtract(const rct::keyV &a, const rct::key &b)
{
rct::keyV res(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
sc_sub(res[i].bytes, a[i].bytes, b.bytes);
} }
return res; return res;
} }
@ -549,8 +565,7 @@ try_again:
} }
// Polynomial construction by coefficients // Polynomial construction by coefficients
const auto zMN = vector_dup(z, MN); rct::keyV l0 = vector_subtract(aL, z);
rct::keyV l0 = vector_subtract(aL, zMN);
const rct::keyV &l1 = sL; const rct::keyV &l1 = sL;
// This computes the ugly sum/concatenation from PAPER LINE 65 // This computes the ugly sum/concatenation from PAPER LINE 65
@ -570,7 +585,7 @@ try_again:
} }
} }
rct::keyV r0 = vector_add(aR, zMN); rct::keyV r0 = vector_add(aR, z);
const auto yMN = vector_powers(y, MN); const auto yMN = vector_powers(y, MN);
r0 = hadamard(r0, yMN); r0 = hadamard(r0, yMN);
r0 = vector_add(r0, zero_twos); r0 = vector_add(r0, zero_twos);
@ -658,12 +673,15 @@ try_again:
rct::keyV aprime(MN); rct::keyV aprime(MN);
rct::keyV bprime(MN); rct::keyV bprime(MN);
const rct::key yinv = invert(y); const rct::key yinv = invert(y);
rct::key yinvpow = rct::identity(); rct::keyV yinvpow(MN);
yinvpow[0] = rct::identity();
yinvpow[1] = yinv;
for (size_t i = 0; i < MN; ++i) for (size_t i = 0; i < MN; ++i)
{ {
Gprime[i] = Gi_p3[i]; Gprime[i] = Gi_p3[i];
ge_scalarmult_p3(&Hprime[i], yinvpow.bytes, &Hi_p3[i]); Hprime[i] = Hi_p3[i];
sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes); if (i > 1)
sc_mul(yinvpow[i].bytes, yinvpow[i-1].bytes, yinv.bytes);
aprime[i] = l[i]; aprime[i] = l[i];
bprime[i] = r[i]; bprime[i] = r[i];
} }
@ -675,6 +693,7 @@ try_again:
PERF_TIMER_START_BP(PROVE_step4); PERF_TIMER_START_BP(PROVE_step4);
// PAPER LINE 13 // PAPER LINE 13
const rct::keyV *scale = &yinvpow;
while (nprime > 1) while (nprime > 1)
{ {
// PAPER LINE 15 // PAPER LINE 15
@ -689,9 +708,9 @@ try_again:
// PAPER LINES 18-19 // PAPER LINES 18-19
PERF_TIMER_START_BP(PROVE_LR); PERF_TIMER_START_BP(PROVE_LR);
sc_mul(tmp.bytes, cL.bytes, x_ip.bytes); sc_mul(tmp.bytes, cL.bytes, x_ip.bytes);
L[round] = cross_vector_exponent8(nprime, Gprime, nprime, Hprime, 0, aprime, 0, bprime, nprime, &ge_p3_H, &tmp); L[round] = cross_vector_exponent8(nprime, Gprime, nprime, Hprime, 0, aprime, 0, bprime, nprime, scale, &ge_p3_H, &tmp);
sc_mul(tmp.bytes, cR.bytes, x_ip.bytes); sc_mul(tmp.bytes, cR.bytes, x_ip.bytes);
R[round] = cross_vector_exponent8(nprime, Gprime, 0, Hprime, nprime, aprime, nprime, bprime, 0, &ge_p3_H, &tmp); R[round] = cross_vector_exponent8(nprime, Gprime, 0, Hprime, nprime, aprime, nprime, bprime, 0, scale, &ge_p3_H, &tmp);
PERF_TIMER_STOP(PROVE_LR); PERF_TIMER_STOP(PROVE_LR);
// PAPER LINES 21-22 // PAPER LINES 21-22
@ -708,8 +727,8 @@ try_again:
if (nprime > 1) if (nprime > 1)
{ {
PERF_TIMER_START_BP(PROVE_hadamard2); PERF_TIMER_START_BP(PROVE_hadamard2);
hadamard_fold(Gprime, winv, w[round]); hadamard_fold(Gprime, NULL, winv, w[round]);
hadamard_fold(Hprime, w[round], winv); hadamard_fold(Hprime, scale, w[round], winv);
PERF_TIMER_STOP(PROVE_hadamard2); PERF_TIMER_STOP(PROVE_hadamard2);
} }
@ -719,6 +738,7 @@ try_again:
bprime = vector_add(vector_scalar(slice(bprime, 0, nprime), winv), vector_scalar(slice(bprime, nprime, bprime.size()), w[round])); bprime = vector_add(vector_scalar(slice(bprime, 0, nprime), winv), vector_scalar(slice(bprime, nprime, bprime.size()), w[round]));
PERF_TIMER_STOP(PROVE_prime); PERF_TIMER_STOP(PROVE_prime);
scale = NULL;
++round; ++round;
} }
PERF_TIMER_STOP(PROVE_step4); PERF_TIMER_STOP(PROVE_step4);