aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Andres Erbsen <andreser@google.com>2017-10-20 10:07:10 -0400
committerGravatar Andres Erbsen <andreser@google.com>2017-10-20 10:09:41 -0400
commit9bd7d8e4a2ef9ac5944162ae8b3ba46f530980ce (patch)
tree981b45f8dd53ee21bfa7d1c9247aceba97de5484
parentd90bcfa0c4969908d3fa2fb8cb4a2bef74d6f111 (diff)
gmpsec.c: generic constant-time montgomery ladder implementation using mpn_sec_* functions
-rw-r--r--gmpsec.c118
-rw-r--r--src/Specific/X25519/C64/scalarmult.c6
-rw-r--r--src/Specific/X25519/x25519_test.c1
3 files changed, 71 insertions, 54 deletions
diff --git a/gmpsec.c b/gmpsec.c
index 3b536109a..614dcf1e6 100644
--- a/gmpsec.c
+++ b/gmpsec.c
@@ -6,7 +6,7 @@
// modulus, encoded as big-endian bytes
static const unsigned char modulus[] = {0x7f,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xff,0xed};
static const unsigned char a_minus_two_over_four[] = {0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x01,0xdb,0x41};
-static const size_t modulus_bytes = sizeof(modulus);
+#define modulus_bytes (sizeof(modulus))
#define modulus_limbs ((8*sizeof(modulus) + GMP_LIMB_BITS-1)/GMP_LIMB_BITS)
@@ -16,7 +16,7 @@ static void fe_print(mp_limb_t* fe) {
printf("%016lx", fe[0]);
}
-static void crypto_scalarmult(uint8_t *out, const uint8_t *point, const uint8_t *secret, size_t secretbits) {
+static void crypto_scalarmult(uint8_t *out, const uint8_t *secret, size_t secretbits, const uint8_t *point) {
// curve constants
mp_limb_t m[modulus_limbs+1];
mp_limb_t a24[modulus_limbs+1];
@@ -37,22 +37,49 @@ static void crypto_scalarmult(uint8_t *out, const uint8_t *point, const uint8_t
scratch_sz = (modscratch_sz > scratch_sz) ? modscratch_sz : scratch_sz;
scratch_sz = (invscratch_sz > scratch_sz) ? invscratch_sz : scratch_sz;
mp_limb_t scratch[scratch_sz];
+ for (size_t i = 0; i<scratch_sz; ++i) { scratch[i] = 0; }
// allocate scratch space for use by the field operation macros.
mp_limb_t _product_tmp[modulus_limbs+modulus_limbs];
+
+ #define fe_mul(out, x, y) do { \
+ mpn_sec_mul(_product_tmp, x, modulus_limbs, y, modulus_limbs, scratch); \
+ mpn_sec_div_r(_product_tmp, modulus_limbs+modulus_limbs, m, modulus_limbs, scratch); \
+ for (size_t i = 0; i<modulus_limbs; i++) { out[i] = _product_tmp[i]; } \
+ } while (0)
+
+ #define fe_sqr(out, x) do { \
+ mpn_sec_sqr(_product_tmp, x, modulus_limbs, scratch); \
+ mpn_sec_div_r(_product_tmp, modulus_limbs+modulus_limbs, m, modulus_limbs, scratch); \
+ for (size_t i = 0; i<modulus_limbs; i++) { out[i] = _product_tmp[i]; } \
+ } while (0)
+
+ #define fe_add(out, x, y) do { \
+ mpn_cnd_sub_n(mpn_add_n(out, x, y, modulus_limbs), out, out, m, modulus_limbs); \
+ } while (0)
+
+ #define fe_sub(out, x, y) do { \
+ mpn_cnd_add_n(mpn_sub_n(out, x, y, modulus_limbs), out, out, m, modulus_limbs); \
+ } while (0)
+
+ #define fe_inv(out, x) do { \
+ for (size_t i = 0; i<modulus_limbs; i++) { _product_tmp[i] = x[i]; } \
+ mp_size_t invertible = mpn_sec_invert(out, _product_tmp, m, modulus_limbs, 2*modulus_limbs*GMP_NUMB_BITS, scratch); \
+ mpn_cnd_sub_n(1-invertible, out, out, out, modulus_limbs); \
+ } while (0)
- mp_limb_t a[modulus_limbs] = {0}; mp_limb_t *nqpqx = a;
+ mp_limb_t a[modulus_limbs] = {0}; mp_limb_t *nqpqx = a;
mp_limb_t b[modulus_limbs] = {1}; mp_limb_t *nqpqz = b;
mp_limb_t c[modulus_limbs] ={1}; mp_limb_t *nqx = c;
mp_limb_t d[modulus_limbs] = {0}; mp_limb_t *nqz = d;
- mp_limb_t e[modulus_limbs] = {0}; mp_limb_t *nqpqx2 = e;
+ mp_limb_t e[modulus_limbs] = {0}; mp_limb_t *nqpqx2 = e;
mp_limb_t f[modulus_limbs] = {1}; mp_limb_t *nqpqz2 = f;
mp_limb_t g[modulus_limbs] = {0}; mp_limb_t *nqx2 = g;
mp_limb_t h[modulus_limbs] = {1}; mp_limb_t *nqz2 = h;
mp_limb_t *t;
uint8_t revpoint[modulus_bytes];
- for (size_t i = 0; i<secretbits/8; i++) { revpoint[i] = point[secretbits/8-1-i]; }
+ for (size_t i = 0; i<modulus_bytes; i++) { revpoint[i] = point[modulus_bytes-1-i]; }
for (size_t i = 0; i<modulus_limbs; i++) { nqpqx[i] = 0; }
assert(mpn_set_str(nqpqx, revpoint, modulus_bytes, 256) <= (mp_size_t)modulus_limbs);
@@ -61,75 +88,50 @@ static void crypto_scalarmult(uint8_t *out, const uint8_t *point, const uint8_t
for (size_t i = secretbits-1; i < secretbits; --i) {
mp_limb_t bit = (secret[i/8] >> (i%8))&1;
+ // printf("%01d ", bit);
mpn_cnd_swap(bit, nqx, nqpqx, modulus_limbs);
mpn_cnd_swap(bit, nqz, nqpqz, modulus_limbs);
mp_limb_t *x2 = nqx2;
- mp_limb_t *z2 = nqz2;
+ mp_limb_t *z2 = nqz2;
mp_limb_t *x3 = nqpqx2;
mp_limb_t *z3 = nqpqz2;
mp_limb_t *x = nqx;
mp_limb_t *z = nqz;
mp_limb_t *xprime = nqpqx;
- mp_limb_t *zprime = nqpqz;
+ mp_limb_t *zprime = nqpqz;
// fmonty(mp_limb_t *x2, mp_limb_t 0*z2, /* output 2Q */
// mp_limb_t *x3, mp_limb_t *z3, /* output Q + Q' */
// mp_limb_t *x, mp_limb_t *z, /* input Q */
// mp_limb_t *xprime, mp_limb_t *zprime, /* input Q' */
// const mp_limb_t *qmqp /* input Q - Q' */) {
-
- #define fe_mul(out, x, y) do { \
- mpn_sec_mul(_product_tmp, x, modulus_limbs, y, modulus_limbs, scratch); \
- mpn_sec_div_r(_product_tmp, modulus_limbs+modulus_limbs, m, modulus_limbs, scratch); \
- for (size_t i = 0; i<modulus_limbs; i++) { out[i] = _product_tmp[i]; } \
- } while (0)
-
- #define fe_sqr(out, x) do { \
- mpn_sec_sqr(_product_tmp, x, modulus_limbs, scratch); \
- mpn_sec_div_r(_product_tmp, modulus_limbs+modulus_limbs, m, modulus_limbs, scratch); \
- for (size_t i = 0; i<modulus_limbs; i++) { out[i] = _product_tmp[i]; } \
- } while (0)
-
- #define fe_add(out, x, y) do { \
- mpn_cnd_sub_n(mpn_add_n(out, x, y, modulus_limbs), out, out, m, modulus_limbs); \
- } while (0)
-
- #define fe_sub(out, x, y) do { \
- mpn_cnd_add_n(mpn_sub_n(out, x, y, modulus_limbs), out, out, m, modulus_limbs); \
- } while (0)
-
- #define fe_inv(out, x) do { \
- for (size_t i = 0; i<modulus_limbs; i++) { _product_tmp[i] = x[i]; } \
- mp_size_t invertible = mpn_sec_invert(out, _product_tmp, m, modulus_limbs, 2*modulus_limbs*GMP_NUMB_BITS, scratch); \
- mpn_cnd_sub_n(1-invertible, out, out, out, modulus_limbs); \
- } while (0)
-
- mp_limb_t origx[modulus_limbs], origxprime[modulus_limbs], zzz[modulus_limbs], xx[modulus_limbs], zz[modulus_limbs], xxprime[modulus_limbs], zzprime[modulus_limbs], zzzprime[modulus_limbs];
-
+
+ mp_limb_t origx[modulus_limbs], origxprime[modulus_limbs], zzz[modulus_limbs], xx[modulus_limbs], zz[modulus_limbs], xxprime[modulus_limbs], zzprime[modulus_limbs], zzzprime[modulus_limbs];
+
for (size_t i = 0; i<modulus_limbs; i++) { origx[i] = x[i]; }
- fe_add(x, x, z);
+ fe_add(x, x, z);
fe_sub(z, origx, z);
-
+
for (size_t i = 0; i<modulus_limbs; i++) { origxprime[i] = xprime[i]; }
fe_add(xprime, xprime, zprime);
fe_sub(zprime, origxprime, zprime);
- fe_mul(xxprime, xprime, z);
- fe_mul(zzprime, x, zprime);
+ fe_mul(xxprime, xprime, z);
+ fe_mul(zzprime, x, zprime);
for (size_t i = 0; i<modulus_limbs; i++) { origxprime[i] = xxprime[i]; }
fe_add(xxprime, xxprime, zzprime);
fe_sub(zzprime, origxprime, zzprime);
fe_sqr(x3, xxprime);
fe_sqr(zzzprime, zzprime);
- fe_mul(z3, zzzprime, qmqp);
-
- fe_sqr(xx, x);
- fe_sqr(zz, z);
- fe_mul(x2, xx, zz);
- fe_sub(zz, xx, zz);
- fe_mul(zzz, zz, a24);
- fe_add(zzz, zzz, xx);
- fe_mul(z2, zz, zzz);
+ fe_mul(z3, zzzprime, qmqp);
+
+ fe_sqr(xx, x);
+ fe_sqr(zz, z);
+ fe_mul(x2, xx, zz);
+ fe_sub(zz, xx, zz);
+ fe_mul(zzz, zz, a24);
+ fe_add(zzz, zzz, xx);
+ fe_mul(z2, zz, zzz);
// } fmonty
@@ -159,7 +161,8 @@ static void crypto_scalarmult(uint8_t *out, const uint8_t *point, const uint8_t
fe_inv(nqz, nqz);
fe_mul(nqx, nqx, nqz);
- for (size_t i = 0; i < 8*sizeof(modulus); i++) {
+ for (size_t i = 0; i < modulus_bytes; i++) { out[i] = 0; }
+ for (size_t i = 0; i < 8*modulus_bytes; i++) {
mp_limb_t bit = (nqx[i/GMP_LIMB_BITS] >> (i%GMP_LIMB_BITS))&1;
out [i/8] |= bit<<(i%8);
}
@@ -186,14 +189,21 @@ int main() {
a[0] = 1;
for (int i = 0; i < 200; i++) {
- printf("0x"); for (int i = 31; i>=0; --i) { printf("%02x", in[i]); }; printf("\n");
- crypto_scalarmult(out, basepoint, in, 254);
+ in[0] &= 248;
+ in[31] &= 127;
+ in[31] |= 64;
+
+ crypto_scalarmult(out, in, 256, basepoint);
uint8_t* t = out;
out = in;
in = t;
}
- printf("0x"); for (int i = 31; i>=0; --i) { printf("%02x", out[i]); }; printf("\n");
- printf("0x"); for (int i = 31; i>=0; --i) { printf("%02x", expected[i]); }; printf("\n");
+ for (int i = 0; i < 32; i++) {
+ if (in[i] != expected[i]) {
+ return (i+1);
+ }
+ }
+ return 0;
}
}
diff --git a/src/Specific/X25519/C64/scalarmult.c b/src/Specific/X25519/C64/scalarmult.c
index ffe015012..bde9a9b22 100644
--- a/src/Specific/X25519/C64/scalarmult.c
+++ b/src/Specific/X25519/C64/scalarmult.c
@@ -41,6 +41,7 @@ typedef unsigned int uint128_t __attribute__((mode(TI)));
typedef uint8_t u8;
typedef uint64_t limb;
typedef limb felem[5];
+//static void crecip(felem out, const felem z);
static void force_inline
fmul(felem output, const felem in2, const felem in) {
@@ -200,6 +201,7 @@ swap_conditional(limb a[5], limb b[5], limb iswap) {
}
}
+
/* Calculates nQ where Q is the x-coordinate of a point on the curve
*
* resultx/resultz: the x coordinate of the resulting curve point (short form)
@@ -221,6 +223,7 @@ cmult(limb *resultx, limb *resultz, const u8 *n, const limb *q) {
u8 byte = n[31 - i];
for (j = 0; j < 8; ++j) {
const limb bit = byte >> 7;
+ // printf("%01d ", bit);
swap_conditional(nqx, nqpqx, bit);
swap_conditional(nqz, nqpqz, bit);
@@ -246,6 +249,9 @@ cmult(limb *resultx, limb *resultz, const u8 *n, const limb *q) {
nqpqz2 = t;
byte <<= 1;
+
+ // { felem pr; crecip(pr, nqz); fmul(pr, pr, nqx); uint8_t s[32]; fcontract(s, pr); printf("0x"); for (int i = 31; i>=0; --i) { printf("%02x", s[i]); }; printf(" "); }
+ // { felem pr; crecip(pr, nqpqz); fmul(pr, pr, nqpqx); uint8_t s[32]; fcontract(s, pr); printf("0x"); for (int i = 31; i>=0; --i) { printf("%02x", s[i]); }; printf("\n"); }
}
}
diff --git a/src/Specific/X25519/x25519_test.c b/src/Specific/X25519/x25519_test.c
index 1a4334932..11bdb7acb 100644
--- a/src/Specific/X25519/x25519_test.c
+++ b/src/Specific/X25519/x25519_test.c
@@ -13,6 +13,7 @@ int main() {
a[0] = 1;
for (int i = 0; i < 200; i++) {
+ // printf("0x"); for (int i = 31; i>=0; --i) { printf("%02x", in[i]); }; printf("\n");
crypto_scalarmult(out, in, basepoint);
uint8_t* t = out;
out = in;