#include #include #include #include static const mpz_class q = (1_mpz<<255)-19; static const size_t modulus_bytes = 32; static const unsigned int a24 = 0x01db41; static void fe_print(const mpz_class &x) { printf("0x"); for (size_t i = modulus_bytes-1; i>(8*i)).get_ui()&0xff); } } static void fe_print_frac(mpz_class x, mpz_class z) { // remainder -> modulo if (z < 0) { z += q; } if (mpz_invert(z.get_mpz_t(), z.get_mpz_t(), q.get_mpz_t())) { // remainder -> modulo if (x < 0) { x += q; } x = x*z % q; fe_print(x); } else { printf("inf "); } } using std::pair; using std::make_pair; static const pair, pair> ladderstep(const mpz_class &x1, const mpz_class &x, const mpz_class &z, const mpz_class &x_p, const mpz_class &z_p) { mpz_class t; { t = x; mpz_class origx = t; { t = (x + z)%q; mpz_class x = t; { t = (origx - z)%q; mpz_class z = t; { t = x_p; mpz_class origx_p = t; { t = (x_p + z_p)%q; mpz_class x_p = t; { t = (origx_p - z_p)%q; mpz_class z_p = t; { t = (x_p * z)%q; mpz_class xx_p = t; { t = (x * z_p)%q; mpz_class zz_p = t; { t = xx_p; mpz_class origx_p = t; { t = (xx_p + zz_p)%q; mpz_class xx_p = t; { t = (origx_p - zz_p)%q; mpz_class zz_p = t; { t = (xx_p*xx_p)%q; mpz_class x3 = t; { t = (zz_p*zz_p)%q; mpz_class zzz_p = t; { t = (zzz_p * x1)%q; mpz_class z3 = t; { t = (x*x)%q; mpz_class xx = t; { t = (z*z)%q; mpz_class zz = t; { t = (xx * zz)%q; mpz_class x2 = t; { t = (xx - zz)%q; mpz_class zz = t; { t = (zz * a24)%q; mpz_class zzz = t; { t = (zzz + xx)%q; mpz_class zzz = t; { t = (zz * zzz)%q; mpz_class z2 = t; return make_pair(make_pair(x2, z2), make_pair(x3, z3)); }}}}}}}}}}}}}}}}}}}}} } static void crypto_scalarmult(uint8_t *out, const uint8_t *secret, size_t secretbits, const uint8_t *point) { mpz_class x1; for (size_t i = 0; i> (i%8))&1; // printf("%d ", bit); fe_print_frac(x, z); printf(" "); fe_print_frac(x_p, z_p); printf("\n"); if (swap ^ bit) { std::swap(x, x_p); std::swap(z, z_p); } swap = bit; auto pp = ladderstep(x1, x, z, x_p, z_p); x = pp.first.first; z = pp.first.second; x_p = pp.second.first; z_p = pp.second.second; } if (swap) { std::swap(x, x_p); std::swap(z, z_p); } // remainder -> modulo if (z < 0) { z += q; } if (mpz_invert(z.get_mpz_t(), z.get_mpz_t(), q.get_mpz_t())) { x = x*z % q; } else { x = 0; } // remainder -> modulo if (x < 0) { x += q; } for (size_t i = 0; i>(8*i)).get_ui()&0xff; } } int main() { { uint8_t out[modulus_bytes] = {0}; uint8_t point[modulus_bytes] = {9}; uint8_t secret[modulus_bytes] = {1}; crypto_scalarmult(out, secret, 256, point); // printf("0x"); for (int i = 31; i>=0; --i) { printf("%02x", out[i]); }; printf("\n"); } { const uint8_t expected[32] = {0x89, 0x16, 0x1f, 0xde, 0x88, 0x7b, 0x2b, 0x53, 0xde, 0x54, 0x9a, 0xf4, 0x83, 0x94, 0x01, 0x06, 0xec, 0xc1, 0x14, 0xd6, 0x98, 0x2d, 0xaa, 0x98, 0x25, 0x6d, 0xe2, 0x3b, 0xdf, 0x77, 0x66, 0x1a}; const uint8_t basepoint[32] = {9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; uint8_t a[32] = {0}, b[32] = {0}; uint8_t* in = a; uint8_t* out = b; a[0] = 1; for (int i = 0; i < 200; i++) { in[0] &= 248; in[31] &= 127; in[31] |= 64; crypto_scalarmult(out, in, 256, basepoint); uint8_t* t = out; out = in; in = t; } for (int i = 0; i < 32; i++) { if (in[i] != expected[i]) { return (i+1); } } return 0; } }