aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-02-28 16:45:34 +0100
committerGravatar Jason Gross <jasongross9@gmail.com>2018-03-07 12:36:29 -0500
commitdf5b34e2b9ea79f897a7a7b3d78e83edd6806cdd (patch)
treee3db4660c459cac7d3933358a7e401cc8f75d3b1
parent497cda884a5816fc0a955e637ce666768f28417f (diff)
make Montgomery do associational carries in a generalized way
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v166
1 files changed, 105 insertions, 61 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index d83263ee9..2630c099a 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -831,20 +831,21 @@ Module Columns.
(w_multiples : forall i, w (S i) mod w i = 0)
(w_divides : forall i : nat, w (S i) / w i > 0).
- (* take in inputs in base w. Converts to w', multiplies in that format, converts to w again, then flattens. *)
+ (* takes in inputs in base w, converts to w', multiplies in that
+ format, converts to w again, then flattens. *)
Definition mul_converted
n1 n2 (* lengths in original format *)
m1 m2 (* lengths in converted format *)
(n3 : nat) (* final length *)
+ (idxs : list nat) (* carries to do -- this helps preemptively line up weights *)
(p1 p2 : list Z) :=
let p1' := BaseConversion.convert_bases w w' n1 m1 p1 in
let p2' := BaseConversion.convert_bases w w' n2 m2 p2 in
let p1_a := Positional.to_associational w' m1 p1' in
let p2_a := Positional.to_associational w' m2 p2' in
- (*
- let p3_a := Associational.carry (w' 1%nat) (w 1) (Associational.mul p1_a p2_a) in
- *)
let p3_a := Associational.mul p1_a p2_a in
+ (* important not to use Positional.carry here; we don't want to accumulate yet *)
+ let p3'_a := fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p3_a (rev idxs) in
fst (flatten w (from_associational w n3 p3_a)).
Hint Rewrite
@@ -854,10 +855,10 @@ Module Columns.
@Positional.eval_to_associational
@BaseConversion.eval_convert_bases using solve [auto] : push_eval.
- Lemma eval_mul_converted n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
+ Lemma eval_mul_converted n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
length p1 = n1 -> length p2 = n2 ->
0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
- Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2).
+ Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2).
Proof.
cbv [mul_converted]; intros.
rewrite Columns.flatten_mod by auto using Columns.length_from_associational.
@@ -867,10 +868,10 @@ Module Columns.
Hint Rewrite @length_from_associational : distr_length.
- Lemma mul_converted_mod n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
+ Lemma mul_converted_mod n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
length p1 = n1 -> length p2 = n2 ->
0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
- nth_default 0 (mul_converted n1 n2 m1 m2 n3 p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1).
+ nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1).
Proof.
intros; cbv [mul_converted].
erewrite flatten_partitions by (auto; distr_length).
@@ -879,13 +880,13 @@ Module Columns.
reflexivity.
Qed.
- Lemma mul_converted_div n1 n2 m1 m2 n3 p1 p2:
+ Lemma mul_converted_div n1 n2 m1 m2 n3 idxs p1 p2:
m1 <> 0%nat -> m2 <> 0%nat -> n3 = 2%nat ->
length p1 = n1 -> length p2 = n2 ->
0 <= Positional.eval w n1 p1 ->
0 <= Positional.eval w n2 p2 ->
0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
- nth_default 0 (mul_converted n1 n2 m1 m2 n3 p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1).
+ nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1).
Proof.
intros; subst n3; cbv [mul_converted].
erewrite flatten_partitions by (auto; distr_length).
@@ -898,6 +899,12 @@ Module Columns.
rewrite w_multiples. omega.
Qed.
+ (* shortcut definition for convert-mul-convert for cases when we are halving the bitwidth before multiplying. *)
+ (* the most important feature here is the carries--we carry from all the odd indices after multiplying,
+ thus pre-aligning everything with the double-size bitwidth *)
+ Definition mul_converted_halve n n2 :=
+ mul_converted n n n2 n2 n2 (map (fun x => 2*x + 1)%nat (seq 0 n)).
+
End mul_converted.
End Columns.
@@ -1287,6 +1294,8 @@ Module Compilers.
| primitive {t:type.primitive} (v : interp t) : ident () t
| Let_In {tx tC} : ident (tx * (tx -> tC)) tC
| Nat_succ : ident nat nat
+ | Nat_mul : ident (nat * nat) nat
+ | Nat_add : ident (nat * nat) nat
| nil {t} : ident () (list t)
| cons {t} : ident (t * list t) (list t)
| fst {A B} : ident (A * B) A
@@ -1348,6 +1357,8 @@ Module Compilers.
| primitive _ v => curry0 v
| Let_In tx tC => curry2 (@LetIn.Let_In (type.interp tx) (fun _ => type.interp tC))
| Nat_succ => Nat.succ
+ | Nat_add => curry2 Nat.add
+ | Nat_mul => curry2 Nat.mul
| nil t => curry0 (@Datatypes.nil (type.interp t))
| cons t => curry2 (@Datatypes.cons (type.interp t))
| fst A B => @Datatypes.fst (type.interp A) (type.interp B)
@@ -1393,6 +1404,8 @@ Module Compilers.
(*let dummy := match goal with _ => idtac "attempting to reify_op" term end in*)
lazymatch term with
| Nat.succ ?x => mkAppIdent Nat_succ x
+ | Nat.add ?x ?y => mkAppIdent Nat_add (x, y)
+ | Nat.mul ?x ?y => mkAppIdent Nat_mul (x, y)
| S ?x => mkAppIdent Nat_succ x
| @Datatypes.nil ?T
=> let rT := type.reify T in
@@ -1546,6 +1559,8 @@ Module Compilers.
Module Nat.
Notation succ := Nat_succ.
+ Notation add := Nat_add.
+ Notation mul := Nat_mul.
End Nat.
Module Export Notations.
@@ -1587,6 +1602,8 @@ Module Compilers.
| primitive {t : type.primitive} (v : interp t) : ident () t
| Let_In {tx tC} : ident (tx * (tx -> tC)) tC
| Nat_succ : ident nat nat
+ | Nat_add : ident (nat * nat) nat
+ | Nat_mul : ident (nat * nat) nat
| nil {t} : ident () (list t)
| cons {t} : ident (t * list t) (list t)
| fst {A B} : ident (A * B) A
@@ -1643,6 +1660,8 @@ Module Compilers.
| primitive _ v => curry0 v
| Let_In tx tC => curry2 (@LetIn.Let_In (type.interp tx) (fun _ => type.interp tC))
| Nat_succ => Nat.succ
+ | Nat_add => curry2 Nat.add
+ | Nat_mul => curry2 Nat.mul
| nil t => curry0 (@Datatypes.nil (type.interp t))
| cons t => curry2 (@Datatypes.cons (type.interp t))
| fst A B => @Datatypes.fst (type.interp A) (type.interp B)
@@ -1685,6 +1704,8 @@ Module Compilers.
(*let dummy := match goal with _ => idtac "attempting to reify_op" term end in*)
lazymatch term with
| Nat.succ ?x => mkAppIdent Nat_succ x
+ | Nat.add ?x ?y => mkAppIdent Nat_add (x, y)
+ | Nat.mul ?x ?y => mkAppIdent Nat_mul (x, y)
| S ?x => mkAppIdent Nat_succ x
| @Datatypes.nil ?T
=> let rT := type.reify T in
@@ -1800,6 +1821,8 @@ Module Compilers.
Module Nat.
Notation succ := Nat_succ.
+ Notation add := Nat_add.
+ Notation mul := Nat_mul.
End Nat.
Module Export Notations.
@@ -1859,6 +1882,10 @@ Module Compilers.
=> AppIdent ident.Let_In
| for_reification.ident.Nat_succ
=> AppIdent ident.Nat_succ
+ | for_reification.ident.Nat_add
+ => AppIdent ident.Nat_add
+ | for_reification.ident.Nat_mul
+ => AppIdent ident.Nat_mul
| for_reification.ident.nil t
=> AppIdent ident.nil
| for_reification.ident.cons t
@@ -2668,6 +2695,8 @@ Module Compilers.
match idc in Uncurried.expr.default.ident s d return type.interp R (type.translate s) -> (type.interp R (type.translate d) -> R) -> R with
| ident.primitive _ _ as idc
| ident.Nat_succ as idc
+ | ident.Nat_add as idc
+ | ident.Nat_mul as idc
| ident.pred as idc
| ident.Z_shiftr _ as idc
| ident.Z_shiftl _ as idc
@@ -2852,6 +2881,13 @@ Module Compilers.
(ident.snd @@ (Var xyk))
@ ((idc : default.ident _ type.nat)
@@ (ident.fst @@ (Var xyk)))
+ | ident.Nat_add as idc
+ | ident.Nat_mul as idc
+ => λ (xyk :
+ (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.nat * type.nat * (type.nat -> R))%ctype) ,
+ (ident.snd @@ (Var xyk))
+ @ ((idc : default.ident _ type.nat)
+ @@ (ident.fst @@ (Var xyk)))
| ident.Z_shiftr _ as idc
| ident.Z_shiftl _ as idc
| ident.Z_land _ as idc
@@ -3596,6 +3632,8 @@ Module Compilers.
| inr x => inr (ident.interp idc x)
| inl x => expr.reflect (AppIdent idc x)
end
+ | ident.Nat_add as idc
+ | ident.Nat_mul as idc
| ident.Z_pow as idc
| ident.Z_eqb as idc
| ident.Z_leb as idc
@@ -4176,6 +4214,8 @@ Module Compilers.
| default.ident.primitive _ _ => None
| ident.Let_In tx tC => None
| ident.Nat_succ => None
+ | ident.Nat_add => None
+ | ident.Nat_mul => None
| default.ident.nil (Compilers.type.type_primitive t)
=> Some (@nil (type.primitive.compile t))
| default.ident.nil _
@@ -5643,20 +5683,22 @@ Module MontgomeryReduction.
(w_multiples : forall i, w (S i) mod w i = 0)
(w_divides : forall i : nat, w (S i) / w i > 0).
Context (w_1_gt1 : w 1 > 1) (w_half_1_gt1 : w_half 1 > 1).
- Context (n:nat) (Hn : n = 2%nat).
+ Context (n:nat) (Hn: n = 2%nat).
Definition montred' (lo_hi : (Z * Z)) :=
- dlet_nd y := nth_default 0 (Columns.mul_converted w w_half 1 1 n n n [fst lo_hi] [N']) 0 in
- dlet_nd t1_t2 := Columns.mul_converted w w_half 1 1 n n n [y] [N] in
+ dlet_nd y := nth_default 0 (Columns.mul_converted_halve w w_half 1%nat n [fst lo_hi] [N']) 0 in
+ dlet_nd t1_t2 := Columns.mul_converted_halve w w_half 1%nat n [y] [N] in
dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in
dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in
dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in
dlet_nd lo'' := fst (Z.sub_get_borrow_full R (fst hi'_carry) y') in
Z.add_modulo lo'' 0 N.
- Local Ltac solve_range H :=
+ Context (Hw : forall i, w i = R ^ Z.of_nat i).
+
+ Local Ltac solve_range :=
repeat match goal with
- | _ => rewrite H, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r
+ | _ => rewrite Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r
| |- context [?a mod ?b] => unique pose proof (Z.mod_pos_bound a b ltac:(omega))
| |- 0 <= _ => progress Z.zero_bounds
| |- 0 <= _ * _ < _ * _ =>
@@ -5666,17 +5708,20 @@ Module MontgomeryReduction.
| _ => nia
end.
+ Hint Rewrite
+ Columns.mul_converted_mod Columns.mul_converted_div using (solve [auto; autorewrite with mul_conv; solve_range])
+ : mul_conv.
+
Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N)
- (Hw : forall i, w i = R ^ Z.of_nat i)
(Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R):
montred' lo_hi = reduce_via_partial N R N' T.
Proof.
rewrite <-reduce_via_partial_alt_eq by nia.
cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In].
rewrite Hlo, Hhi. subst n.
- assert (0 <= T mod R * N' < w 2) by (solve_range Hw).
- rewrite !Columns.mul_converted_mod by (auto; rewrite ?Columns.mul_converted_mod; solve_range Hw).
- rewrite !Columns.mul_converted_div by (auto; solve_range Hw).
+ assert (0 <= T mod R * N' < w 2) by (solve_range).
+ cbv [Columns.mul_converted_halve]. cbn.
+ autorewrite with mul_conv.
rewrite Hw, ?Z.pow_1_r.
autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct.
@@ -5706,7 +5751,6 @@ Module MontgomeryReduction.
Qed.
Lemma montred'_correct lo_hi T (HT_range: 0 <= T < R * N)
- (Hw : forall i, w i = R ^ Z.of_nat i)
(Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R): montred' lo_hi = (T * R') mod N.
Proof.
erewrite montred'_eq by eauto.
@@ -5719,7 +5763,7 @@ Module MontgomeryReduction.
Derive montred_gen
SuchThat (forall (N R N' : Z)
(w w_half : nat -> Z)
- (n : nat)
+ (n: nat)
(lo_hi : Z * Z),
Interp (t:=type.reify_type_of montred')
montred_gen N R N' w w_half n lo_hi
@@ -5879,26 +5923,26 @@ Module Montgomery256.
(*
expr_let 3 := (uint128)(fst @@ x_1 >> 128) in
expr_let 4 := ((uint128)fst @@ x_1 & 340282366920938463463374607431768211455) in
- expr_let 5 := (uint128)(MUL_256 @@ (x_4, (340282366841710300986003757985643364352)) << 128) in
- expr_let 6 := (uint128)(MUL_256 @@ (x_3, (79228162514264337593543950337)) << 128) in
- expr_let 11 := MUL_256 @@ (x_4, (79228162514264337593543950337)) in
- expr_let 12 := ADD_256 @@ (x_5, x_6) in
- expr_let 13 := ADD_256 @@ (x_11, fst @@ x_12) in
- expr_let 23 := (uint128)(fst @@ x_13 >> 128) in
- expr_let 24 := ((uint128)fst @@ x_13 & 340282366920938463463374607431768211455) in
- expr_let 25 := (uint128)(MUL_256 @@ (x_24, (340282366841710300967557013911933812736)) << 128) in
- expr_let 26 := (uint128)(MUL_256 @@ (x_23, (79228162514264337593543950335)) << 128) in
- expr_let 31 := MUL_256 @@ (x_24, (79228162514264337593543950335)) in
- expr_let 32 := ADD_128 @@ (x_25, x_26) in
- expr_let 33 := ADD_256 @@ (x_31, fst @@ x_32) in
- expr_let 34 := snd @@ x_33 +₁₂₈ snd @@ x_32 in
- expr_let 41 := MUL_256 @@ (x_23, (340282366841710300967557013911933812736)) in
- expr_let 42 := ADD_256 @@ (x_34, x_41) in
- expr_let 43 := ADD_256 @@ (fst @@ x_1, fst @@ x_33) in
- expr_let 44 := ADDC_256 @@ (snd @@ x_43, snd @@ x_1, fst @@ x_42) in
- expr_let 45 := SELC @@ (snd @@ x_44, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in
- expr_let 46 := fst @@ (SUB_256 @@ (fst @@ x_44, x_45)) in
- ADDM @@ (x_46, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951))
+ expr_let 11 := (uint128)(MUL_256 @@ (x_4, (340282366841710300986003757985643364352)) << 128) in
+ expr_let 12 := (uint128)(MUL_256 @@ (x_3, (79228162514264337593543950337)) << 128) in
+ expr_let 17 := MUL_256 @@ (x_4, (79228162514264337593543950337)) in
+ expr_let 18 := ADD_256 @@ (x_11, x_12) in
+ expr_let 19 := ADD_256 @@ (x_17, fst @@ x_18) in
+ expr_let 29 := (uint128)(fst @@ x_19 >> 128) in
+ expr_let 30 := ((uint128)fst @@ x_19 & 340282366920938463463374607431768211455) in
+ expr_let 37 := (uint128)(MUL_256 @@ (x_30, (340282366841710300967557013911933812736)) << 128) in
+ expr_let 38 := (uint128)(MUL_256 @@ (x_29, (79228162514264337593543950335)) << 128) in
+ expr_let 43 := MUL_256 @@ (x_30, (79228162514264337593543950335)) in
+ expr_let 44 := ADD_128 @@ (x_37, x_38) in
+ expr_let 45 := ADD_256 @@ (x_43, fst @@ x_44) in
+ expr_let 46 := snd @@ x_45 +₁₂₈ snd @@ x_44 in
+ expr_let 53 := MUL_256 @@ (x_29, (340282366841710300967557013911933812736)) in
+ expr_let 54 := ADD_256 @@ (x_46, x_53) in
+ expr_let 55 := ADD_256 @@ (fst @@ x_1, fst @@ x_45) in
+ expr_let 56 := ADDC_256 @@ (snd @@ x_55, snd @@ x_1, fst @@ x_54) in
+ expr_let 57 := SELC @@ (snd @@ x_56, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in
+ expr_let 58 := fst @@ (SUB_256 @@ (fst @@ x_56, x_57)) in
+ ADDM @@ (x_58, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951))
: expr uint256
*)
End Montgomery256.
@@ -6001,25 +6045,25 @@ Print Montgomery256.montred256.
(*
c.ShiftR($r3,$r1_lo, 128);
c.Lower128($r4,$r1_lo);
-c.Mul128x128($r5, $r4, RegPinv >> 128) << 128;
-c.Mul128x128($r6, $r3, Lower128{RegPinv}) << 128;
-c.Mul128x128($r11, $r4, Lower128{RegPinv});
-c.Add256($r12, $r5, $r6);
-c.Add256($r13, $r11, $r12_lo);
-c.ShiftR($r23,$r13_lo, 128);
-c.Lower128($r24,$r13_lo);
-c.Mul128x128($r25, $r24, RegMod << 128) << 128;
-c.Mul128x128($r26, $r23, Lower128{RegMod}) << 128;
-c.Mul128x128($r31, $r24, Lower128{RegMod});
-c.Add128($r32, $r25, $r26);
-c.Add256($r33, $r31, $r32_lo);
-c.Add64($r34, $r33_hi, $r32_hi);
-c.Mul128x128($r41, $r23, RegMod << 128);
-c.Add256($r42, $r34, $r41);
-c.Add256($r43, $r1_lo, $r33_lo);
-c.Addc($r44, $r1_hi, $r42_lo);
-c.Selc($r45,RegZero, RegMod);
-c.Sub($r46, $r44_lo, $r45);
-c.AddM($ret, $r46, RegZero, RegMod);
+c.Mul128x128($r11, $r4, RegPinv >> 128) << 128;
+c.Mul128x128($r12, $r3, Lower128{RegPinv}) << 128;
+c.Mul128x128($r17, $r4, Lower128{RegPinv});
+c.Add256($r18, $r11, $r12);
+c.Add256($r19, $r17, $r18_lo);
+c.ShiftR($r29,$r19_lo, 128);
+c.Lower128($r30,$r19_lo);
+c.Mul128x128($r37, $r30, RegMod << 128) << 128;
+c.Mul128x128($r38, $r29, Lower128{RegMod}) << 128;
+c.Mul128x128($r43, $r30, Lower128{RegMod});
+c.Add128($r44, $r37, $r38);
+c.Add256($r45, $r43, $r44_lo);
+c.Add64($r46, $r45_hi, $r44_hi);
+c.Mul128x128($r53, $r29, RegMod << 128);
+c.Add256($r54, $r46, $r53);
+c.Add256($r55, $r1_lo, $r45_lo);
+c.Addc($r56, $r1_hi, $r54_lo);
+c.Selc($r57,RegZero, RegMod);
+c.Sub($r58, $r56_lo, $r57);
+c.AddM($ret, $r58, RegZero, RegMod);
: expr uint256
*)