diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 166 |
1 files changed, 105 insertions, 61 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index d83263ee9..2630c099a 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -831,20 +831,21 @@ Module Columns. (w_multiples : forall i, w (S i) mod w i = 0) (w_divides : forall i : nat, w (S i) / w i > 0). - (* take in inputs in base w. Converts to w', multiplies in that format, converts to w again, then flattens. *) + (* takes in inputs in base w, converts to w', multiplies in that + format, converts to w again, then flattens. *) Definition mul_converted n1 n2 (* lengths in original format *) m1 m2 (* lengths in converted format *) (n3 : nat) (* final length *) + (idxs : list nat) (* carries to do -- this helps preemptively line up weights *) (p1 p2 : list Z) := let p1' := BaseConversion.convert_bases w w' n1 m1 p1 in let p2' := BaseConversion.convert_bases w w' n2 m2 p2 in let p1_a := Positional.to_associational w' m1 p1' in let p2_a := Positional.to_associational w' m2 p2' in - (* - let p3_a := Associational.carry (w' 1%nat) (w 1) (Associational.mul p1_a p2_a) in - *) let p3_a := Associational.mul p1_a p2_a in + (* important not to use Positional.carry here; we don't want to accumulate yet *) + let p3'_a := fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p3_a (rev idxs) in fst (flatten w (from_associational w n3 p3_a)). Hint Rewrite @@ -854,10 +855,10 @@ Module Columns. @Positional.eval_to_associational @BaseConversion.eval_convert_bases using solve [auto] : push_eval. - Lemma eval_mul_converted n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): + Lemma eval_mul_converted n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): length p1 = n1 -> length p2 = n2 -> 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> - Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2). + Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2). Proof. cbv [mul_converted]; intros. rewrite Columns.flatten_mod by auto using Columns.length_from_associational. @@ -867,10 +868,10 @@ Module Columns. Hint Rewrite @length_from_associational : distr_length. - Lemma mul_converted_mod n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): + Lemma mul_converted_mod n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): length p1 = n1 -> length p2 = n2 -> 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> - nth_default 0 (mul_converted n1 n2 m1 m2 n3 p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1). + nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1). Proof. intros; cbv [mul_converted]. erewrite flatten_partitions by (auto; distr_length). @@ -879,13 +880,13 @@ Module Columns. reflexivity. Qed. - Lemma mul_converted_div n1 n2 m1 m2 n3 p1 p2: + Lemma mul_converted_div n1 n2 m1 m2 n3 idxs p1 p2: m1 <> 0%nat -> m2 <> 0%nat -> n3 = 2%nat -> length p1 = n1 -> length p2 = n2 -> 0 <= Positional.eval w n1 p1 -> 0 <= Positional.eval w n2 p2 -> 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> - nth_default 0 (mul_converted n1 n2 m1 m2 n3 p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1). + nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1). Proof. intros; subst n3; cbv [mul_converted]. erewrite flatten_partitions by (auto; distr_length). @@ -898,6 +899,12 @@ Module Columns. rewrite w_multiples. omega. Qed. + (* shortcut definition for convert-mul-convert for cases when we are halving the bitwidth before multiplying. *) + (* the most important feature here is the carries--we carry from all the odd indices after multiplying, + thus pre-aligning everything with the double-size bitwidth *) + Definition mul_converted_halve n n2 := + mul_converted n n n2 n2 n2 (map (fun x => 2*x + 1)%nat (seq 0 n)). + End mul_converted. End Columns. @@ -1287,6 +1294,8 @@ Module Compilers. | primitive {t:type.primitive} (v : interp t) : ident () t | Let_In {tx tC} : ident (tx * (tx -> tC)) tC | Nat_succ : ident nat nat + | Nat_mul : ident (nat * nat) nat + | Nat_add : ident (nat * nat) nat | nil {t} : ident () (list t) | cons {t} : ident (t * list t) (list t) | fst {A B} : ident (A * B) A @@ -1348,6 +1357,8 @@ Module Compilers. | primitive _ v => curry0 v | Let_In tx tC => curry2 (@LetIn.Let_In (type.interp tx) (fun _ => type.interp tC)) | Nat_succ => Nat.succ + | Nat_add => curry2 Nat.add + | Nat_mul => curry2 Nat.mul | nil t => curry0 (@Datatypes.nil (type.interp t)) | cons t => curry2 (@Datatypes.cons (type.interp t)) | fst A B => @Datatypes.fst (type.interp A) (type.interp B) @@ -1393,6 +1404,8 @@ Module Compilers. (*let dummy := match goal with _ => idtac "attempting to reify_op" term end in*) lazymatch term with | Nat.succ ?x => mkAppIdent Nat_succ x + | Nat.add ?x ?y => mkAppIdent Nat_add (x, y) + | Nat.mul ?x ?y => mkAppIdent Nat_mul (x, y) | S ?x => mkAppIdent Nat_succ x | @Datatypes.nil ?T => let rT := type.reify T in @@ -1546,6 +1559,8 @@ Module Compilers. Module Nat. Notation succ := Nat_succ. + Notation add := Nat_add. + Notation mul := Nat_mul. End Nat. Module Export Notations. @@ -1587,6 +1602,8 @@ Module Compilers. | primitive {t : type.primitive} (v : interp t) : ident () t | Let_In {tx tC} : ident (tx * (tx -> tC)) tC | Nat_succ : ident nat nat + | Nat_add : ident (nat * nat) nat + | Nat_mul : ident (nat * nat) nat | nil {t} : ident () (list t) | cons {t} : ident (t * list t) (list t) | fst {A B} : ident (A * B) A @@ -1643,6 +1660,8 @@ Module Compilers. | primitive _ v => curry0 v | Let_In tx tC => curry2 (@LetIn.Let_In (type.interp tx) (fun _ => type.interp tC)) | Nat_succ => Nat.succ + | Nat_add => curry2 Nat.add + | Nat_mul => curry2 Nat.mul | nil t => curry0 (@Datatypes.nil (type.interp t)) | cons t => curry2 (@Datatypes.cons (type.interp t)) | fst A B => @Datatypes.fst (type.interp A) (type.interp B) @@ -1685,6 +1704,8 @@ Module Compilers. (*let dummy := match goal with _ => idtac "attempting to reify_op" term end in*) lazymatch term with | Nat.succ ?x => mkAppIdent Nat_succ x + | Nat.add ?x ?y => mkAppIdent Nat_add (x, y) + | Nat.mul ?x ?y => mkAppIdent Nat_mul (x, y) | S ?x => mkAppIdent Nat_succ x | @Datatypes.nil ?T => let rT := type.reify T in @@ -1800,6 +1821,8 @@ Module Compilers. Module Nat. Notation succ := Nat_succ. + Notation add := Nat_add. + Notation mul := Nat_mul. End Nat. Module Export Notations. @@ -1859,6 +1882,10 @@ Module Compilers. => AppIdent ident.Let_In | for_reification.ident.Nat_succ => AppIdent ident.Nat_succ + | for_reification.ident.Nat_add + => AppIdent ident.Nat_add + | for_reification.ident.Nat_mul + => AppIdent ident.Nat_mul | for_reification.ident.nil t => AppIdent ident.nil | for_reification.ident.cons t @@ -2668,6 +2695,8 @@ Module Compilers. match idc in Uncurried.expr.default.ident s d return type.interp R (type.translate s) -> (type.interp R (type.translate d) -> R) -> R with | ident.primitive _ _ as idc | ident.Nat_succ as idc + | ident.Nat_add as idc + | ident.Nat_mul as idc | ident.pred as idc | ident.Z_shiftr _ as idc | ident.Z_shiftl _ as idc @@ -2852,6 +2881,13 @@ Module Compilers. (ident.snd @@ (Var xyk)) @ ((idc : default.ident _ type.nat) @@ (ident.fst @@ (Var xyk))) + | ident.Nat_add as idc + | ident.Nat_mul as idc + => λ (xyk : + (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.nat * type.nat * (type.nat -> R))%ctype) , + (ident.snd @@ (Var xyk)) + @ ((idc : default.ident _ type.nat) + @@ (ident.fst @@ (Var xyk))) | ident.Z_shiftr _ as idc | ident.Z_shiftl _ as idc | ident.Z_land _ as idc @@ -3596,6 +3632,8 @@ Module Compilers. | inr x => inr (ident.interp idc x) | inl x => expr.reflect (AppIdent idc x) end + | ident.Nat_add as idc + | ident.Nat_mul as idc | ident.Z_pow as idc | ident.Z_eqb as idc | ident.Z_leb as idc @@ -4176,6 +4214,8 @@ Module Compilers. | default.ident.primitive _ _ => None | ident.Let_In tx tC => None | ident.Nat_succ => None + | ident.Nat_add => None + | ident.Nat_mul => None | default.ident.nil (Compilers.type.type_primitive t) => Some (@nil (type.primitive.compile t)) | default.ident.nil _ @@ -5643,20 +5683,22 @@ Module MontgomeryReduction. (w_multiples : forall i, w (S i) mod w i = 0) (w_divides : forall i : nat, w (S i) / w i > 0). Context (w_1_gt1 : w 1 > 1) (w_half_1_gt1 : w_half 1 > 1). - Context (n:nat) (Hn : n = 2%nat). + Context (n:nat) (Hn: n = 2%nat). Definition montred' (lo_hi : (Z * Z)) := - dlet_nd y := nth_default 0 (Columns.mul_converted w w_half 1 1 n n n [fst lo_hi] [N']) 0 in - dlet_nd t1_t2 := Columns.mul_converted w w_half 1 1 n n n [y] [N] in + dlet_nd y := nth_default 0 (Columns.mul_converted_halve w w_half 1%nat n [fst lo_hi] [N']) 0 in + dlet_nd t1_t2 := Columns.mul_converted_halve w w_half 1%nat n [y] [N] in dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in dlet_nd lo'' := fst (Z.sub_get_borrow_full R (fst hi'_carry) y') in Z.add_modulo lo'' 0 N. - Local Ltac solve_range H := + Context (Hw : forall i, w i = R ^ Z.of_nat i). + + Local Ltac solve_range := repeat match goal with - | _ => rewrite H, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r + | _ => rewrite Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r | |- context [?a mod ?b] => unique pose proof (Z.mod_pos_bound a b ltac:(omega)) | |- 0 <= _ => progress Z.zero_bounds | |- 0 <= _ * _ < _ * _ => @@ -5666,17 +5708,20 @@ Module MontgomeryReduction. | _ => nia end. + Hint Rewrite + Columns.mul_converted_mod Columns.mul_converted_div using (solve [auto; autorewrite with mul_conv; solve_range]) + : mul_conv. + Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N) - (Hw : forall i, w i = R ^ Z.of_nat i) (Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R): montred' lo_hi = reduce_via_partial N R N' T. Proof. rewrite <-reduce_via_partial_alt_eq by nia. cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In]. rewrite Hlo, Hhi. subst n. - assert (0 <= T mod R * N' < w 2) by (solve_range Hw). - rewrite !Columns.mul_converted_mod by (auto; rewrite ?Columns.mul_converted_mod; solve_range Hw). - rewrite !Columns.mul_converted_div by (auto; solve_range Hw). + assert (0 <= T mod R * N' < w 2) by (solve_range). + cbv [Columns.mul_converted_halve]. cbn. + autorewrite with mul_conv. rewrite Hw, ?Z.pow_1_r. autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct. @@ -5706,7 +5751,6 @@ Module MontgomeryReduction. Qed. Lemma montred'_correct lo_hi T (HT_range: 0 <= T < R * N) - (Hw : forall i, w i = R ^ Z.of_nat i) (Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R): montred' lo_hi = (T * R') mod N. Proof. erewrite montred'_eq by eauto. @@ -5719,7 +5763,7 @@ Module MontgomeryReduction. Derive montred_gen SuchThat (forall (N R N' : Z) (w w_half : nat -> Z) - (n : nat) + (n: nat) (lo_hi : Z * Z), Interp (t:=type.reify_type_of montred') montred_gen N R N' w w_half n lo_hi @@ -5879,26 +5923,26 @@ Module Montgomery256. (* expr_let 3 := (uint128)(fst @@ x_1 >> 128) in expr_let 4 := ((uint128)fst @@ x_1 & 340282366920938463463374607431768211455) in - expr_let 5 := (uint128)(MUL_256 @@ (x_4, (340282366841710300986003757985643364352)) << 128) in - expr_let 6 := (uint128)(MUL_256 @@ (x_3, (79228162514264337593543950337)) << 128) in - expr_let 11 := MUL_256 @@ (x_4, (79228162514264337593543950337)) in - expr_let 12 := ADD_256 @@ (x_5, x_6) in - expr_let 13 := ADD_256 @@ (x_11, fst @@ x_12) in - expr_let 23 := (uint128)(fst @@ x_13 >> 128) in - expr_let 24 := ((uint128)fst @@ x_13 & 340282366920938463463374607431768211455) in - expr_let 25 := (uint128)(MUL_256 @@ (x_24, (340282366841710300967557013911933812736)) << 128) in - expr_let 26 := (uint128)(MUL_256 @@ (x_23, (79228162514264337593543950335)) << 128) in - expr_let 31 := MUL_256 @@ (x_24, (79228162514264337593543950335)) in - expr_let 32 := ADD_128 @@ (x_25, x_26) in - expr_let 33 := ADD_256 @@ (x_31, fst @@ x_32) in - expr_let 34 := snd @@ x_33 +₁₂₈ snd @@ x_32 in - expr_let 41 := MUL_256 @@ (x_23, (340282366841710300967557013911933812736)) in - expr_let 42 := ADD_256 @@ (x_34, x_41) in - expr_let 43 := ADD_256 @@ (fst @@ x_1, fst @@ x_33) in - expr_let 44 := ADDC_256 @@ (snd @@ x_43, snd @@ x_1, fst @@ x_42) in - expr_let 45 := SELC @@ (snd @@ x_44, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in - expr_let 46 := fst @@ (SUB_256 @@ (fst @@ x_44, x_45)) in - ADDM @@ (x_46, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) + expr_let 11 := (uint128)(MUL_256 @@ (x_4, (340282366841710300986003757985643364352)) << 128) in + expr_let 12 := (uint128)(MUL_256 @@ (x_3, (79228162514264337593543950337)) << 128) in + expr_let 17 := MUL_256 @@ (x_4, (79228162514264337593543950337)) in + expr_let 18 := ADD_256 @@ (x_11, x_12) in + expr_let 19 := ADD_256 @@ (x_17, fst @@ x_18) in + expr_let 29 := (uint128)(fst @@ x_19 >> 128) in + expr_let 30 := ((uint128)fst @@ x_19 & 340282366920938463463374607431768211455) in + expr_let 37 := (uint128)(MUL_256 @@ (x_30, (340282366841710300967557013911933812736)) << 128) in + expr_let 38 := (uint128)(MUL_256 @@ (x_29, (79228162514264337593543950335)) << 128) in + expr_let 43 := MUL_256 @@ (x_30, (79228162514264337593543950335)) in + expr_let 44 := ADD_128 @@ (x_37, x_38) in + expr_let 45 := ADD_256 @@ (x_43, fst @@ x_44) in + expr_let 46 := snd @@ x_45 +₁₂₈ snd @@ x_44 in + expr_let 53 := MUL_256 @@ (x_29, (340282366841710300967557013911933812736)) in + expr_let 54 := ADD_256 @@ (x_46, x_53) in + expr_let 55 := ADD_256 @@ (fst @@ x_1, fst @@ x_45) in + expr_let 56 := ADDC_256 @@ (snd @@ x_55, snd @@ x_1, fst @@ x_54) in + expr_let 57 := SELC @@ (snd @@ x_56, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in + expr_let 58 := fst @@ (SUB_256 @@ (fst @@ x_56, x_57)) in + ADDM @@ (x_58, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) : expr uint256 *) End Montgomery256. @@ -6001,25 +6045,25 @@ Print Montgomery256.montred256. (* c.ShiftR($r3,$r1_lo, 128); c.Lower128($r4,$r1_lo); -c.Mul128x128($r5, $r4, RegPinv >> 128) << 128; -c.Mul128x128($r6, $r3, Lower128{RegPinv}) << 128; -c.Mul128x128($r11, $r4, Lower128{RegPinv}); -c.Add256($r12, $r5, $r6); -c.Add256($r13, $r11, $r12_lo); -c.ShiftR($r23,$r13_lo, 128); -c.Lower128($r24,$r13_lo); -c.Mul128x128($r25, $r24, RegMod << 128) << 128; -c.Mul128x128($r26, $r23, Lower128{RegMod}) << 128; -c.Mul128x128($r31, $r24, Lower128{RegMod}); -c.Add128($r32, $r25, $r26); -c.Add256($r33, $r31, $r32_lo); -c.Add64($r34, $r33_hi, $r32_hi); -c.Mul128x128($r41, $r23, RegMod << 128); -c.Add256($r42, $r34, $r41); -c.Add256($r43, $r1_lo, $r33_lo); -c.Addc($r44, $r1_hi, $r42_lo); -c.Selc($r45,RegZero, RegMod); -c.Sub($r46, $r44_lo, $r45); -c.AddM($ret, $r46, RegZero, RegMod); +c.Mul128x128($r11, $r4, RegPinv >> 128) << 128; +c.Mul128x128($r12, $r3, Lower128{RegPinv}) << 128; +c.Mul128x128($r17, $r4, Lower128{RegPinv}); +c.Add256($r18, $r11, $r12); +c.Add256($r19, $r17, $r18_lo); +c.ShiftR($r29,$r19_lo, 128); +c.Lower128($r30,$r19_lo); +c.Mul128x128($r37, $r30, RegMod << 128) << 128; +c.Mul128x128($r38, $r29, Lower128{RegMod}) << 128; +c.Mul128x128($r43, $r30, Lower128{RegMod}); +c.Add128($r44, $r37, $r38); +c.Add256($r45, $r43, $r44_lo); +c.Add64($r46, $r45_hi, $r44_hi); +c.Mul128x128($r53, $r29, RegMod << 128); +c.Add256($r54, $r46, $r53); +c.Add256($r55, $r1_lo, $r45_lo); +c.Addc($r56, $r1_hi, $r54_lo); +c.Selc($r57,RegZero, RegMod); +c.Sub($r58, $r56_lo, $r57); +c.AddM($ret, $r58, RegZero, RegMod); : expr uint256 *) |