diff options
author | 2018-02-28 11:54:20 +0100 | |
---|---|---|
committer | 2018-03-07 12:36:29 -0500 | |
commit | 497cda884a5816fc0a955e637ce666768f28417f (patch) | |
tree | 3675d0e7d60e884c5e4bd543765788734d66a3bc | |
parent | cddfdafe2fb7187c5a124927ff1c44eeb0b1211d (diff) |
remove special-case convert-mul-convert implementation and use generalized one in Montgomery example
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 230 |
1 files changed, 100 insertions, 130 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 9a5d12be5..d83263ee9 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -854,7 +854,7 @@ Module Columns. @Positional.eval_to_associational @BaseConversion.eval_convert_bases using solve [auto] : push_eval. - Lemma mul_converted_correct n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): + Lemma eval_mul_converted n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): length p1 = n1 -> length p2 = n2 -> 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2). @@ -863,85 +863,41 @@ Module Columns. rewrite Columns.flatten_mod by auto using Columns.length_from_associational. autorewrite with push_eval. auto using Z.mod_small. Qed. + Hint Rewrite eval_mul_converted : push_eval. + + Hint Rewrite @length_from_associational : distr_length. + + Lemma mul_converted_mod n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): + length p1 = n1 -> length p2 = n2 -> + 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> + nth_default 0 (mul_converted n1 n2 m1 m2 n3 p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1). + Proof. + intros; cbv [mul_converted]. + erewrite flatten_partitions by (auto; distr_length). + autorewrite with distr_length push_eval natsimplify. + rewrite w_0; autorewrite with zsimplify. + reflexivity. + Qed. + + Lemma mul_converted_div n1 n2 m1 m2 n3 p1 p2: + m1 <> 0%nat -> m2 <> 0%nat -> n3 = 2%nat -> + length p1 = n1 -> length p2 = n2 -> + 0 <= Positional.eval w n1 p1 -> + 0 <= Positional.eval w n2 p2 -> + 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> + nth_default 0 (mul_converted n1 n2 m1 m2 n3 p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1). + Proof. + intros; subst n3; cbv [mul_converted]. + erewrite flatten_partitions by (auto; distr_length). + autorewrite with distr_length push_eval. + pose proof (w_positive 1). + apply Z.mod_small. + split; [ solve[Z.zero_bounds] | ]. + apply Z.div_lt_upper_bound; [omega|]. + rewrite Z.mul_div_eq_full by auto. + rewrite w_multiples. omega. + Qed. - (* TODO: this section specializes to one-element lists in which - the intermediate weight is the square root of the old. It would - be better to specialize just to the relationship between - weights, rather than the size of the input. However, partial - reduction/CPS transform seems to take forever when dynamic list - allocation is happening. *) - Section single. - Context (w'_sq : forall i, (w' i) * (w' i) = w i). - Context (w_1_gt1 : w 1 > 1) (w'_1_gt1 : w' 1 > 1). - - Derive convert_single - SuchThat (forall p, convert_single p = BaseConversion.convert_bases w w' 1 2 [p]) - As convert_single_correct. - Proof. - intros. - cbv - [Z.add Z.div Z.mul Z.eqb Z.modulo]. - assert (w 0 mod w' 1 = 1) as P0 by (rewrite w_0, Z.mod_1_l; omega). - assert (w' 1 =? 1 = false) as P1 by (apply Z.eqb_neq; omega). - assert (1 =? 0 = false) as P2 by reflexivity. - repeat match goal with - | _ => progress rewrite ?w_0, ?w'_0 - | _ => progress rewrite ?P0, ?P1, ?P2 - | _ => progress rewrite ?Z.mod_1_l, ?Z.eqb_refl by omega - | _ => progress autorewrite with zsimplify_fast - end. - autorewrite with zsimplify. - reflexivity. - Qed. - - Derive mul_converted_single - SuchThat (forall n (p1 p2 : Z), (0 <= p1 < w 1) -> (0 <= p2 < w 1) -> - mul_converted_single n p1 p2 = mul_converted 1 1 2 2 n [p1] [p2]) - As mul_converted_single_eq. - Proof. - intros. - cbv [mul_converted]. - rewrite <-!convert_single_correct. - cbv [convert_single]. - subst mul_converted_single. - reflexivity. - Qed. - - Lemma eval_mul_converted_single n p1 p2 (_: n <> 0%nat) (_: 0 <= p1 < w 1) (_:0 <= p2 < w 1) (_: 0 <= p1 * p2 < w n) : - Positional.eval w n (mul_converted_single n p1 p2) = (Positional.eval w 1 [p1]) * (Positional.eval w 1 [p2]). - Proof. rewrite mul_converted_single_eq by auto. apply mul_converted_correct; cbn; nia. Qed. - - Hint Rewrite @length_from_associational : distr_length. - - Lemma mul_converted_single_mod n x y : - n = 2%nat -> 0 <= x < w 1 -> 0 <= y < w 1 -> - nth_default 0 (mul_converted_single n x y) 0 = (x * y) mod (w 1). - Proof. - intros; subst n; rewrite mul_converted_single_eq by auto. cbv [mul_converted]. - erewrite flatten_partitions by (auto; distr_length). - autorewrite with distr_length push_eval. cbn. - rewrite w_0; autorewrite with zsimplify. - reflexivity. - Qed. - - Lemma mul_converted_single_div n x y : - n = 2%nat -> - 0 <= x < w 1 -> 0 <= y < w 1 -> - 0 <= x * y < w 2 -> - nth_default 0 (mul_converted_single n x y) 1 = (x * y) / (w 1). - Proof. - intros; subst n; rewrite mul_converted_single_eq by auto. cbv [mul_converted]. - erewrite flatten_partitions by (auto; distr_length). - autorewrite with distr_length push_eval. cbn. - rewrite w_0; autorewrite with zsimplify. - apply Z.mod_small. - split. - { apply Z.div_nonneg; auto; omega. } - { apply Z.div_lt_upper_bound. omega. - rewrite Z.mul_div_eq_full by auto. - rewrite w_multiples. omega. } - Qed. - - End single. End mul_converted. End Columns. @@ -5640,7 +5596,7 @@ Module RemoveDeadLets. | Let_In s T n x f => Let_In n (inline_let idx _ new _ x) (inline_let idx _ new _ f) end. - (* inlines lets that just re-bind a variable or half a variable with type prod *) + (* inlines lets that just re-bind a variable or the output of a specified operation on a single variable *) Fixpoint inline_silly_lets t (e : @expr ident t) : @expr ident t := match e in (expr t') return expr t' with | Var T n => Var T n @@ -5652,7 +5608,7 @@ Module RemoveDeadLets. match x with | Var T' m => inline_let n _ (Var T' m) _ f | AppIdent _ _ (@BoundsAnalysis.ident.fst A B) (Var _ m) => - inline_let n _ (@AppIdent _ _ _ (@BoundsAnalysis.ident.fst A B) (Var _ m)) _ (inline_silly_lets _ f) + inline_let n _ (@AppIdent _ _ _ (@BoundsAnalysis.ident.fst A B) (Var _ m)) _ (inline_silly_lets _ f) | _ => Let_In n (inline_silly_lets _ x) (inline_silly_lets _ f) end end. @@ -5690,8 +5646,8 @@ Module MontgomeryReduction. Context (n:nat) (Hn : n = 2%nat). Definition montred' (lo_hi : (Z * Z)) := - dlet_nd y := nth_default 0 (Columns.mul_converted_single w w_half n (fst lo_hi) N') 0 in - dlet_nd t1_t2 := Columns.mul_converted_single w w_half n y N in + dlet_nd y := nth_default 0 (Columns.mul_converted w w_half 1 1 n n n [fst lo_hi] [N']) 0 in + dlet_nd t1_t2 := Columns.mul_converted w w_half 1 1 n n n [y] [N] in dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in @@ -5702,9 +5658,12 @@ Module MontgomeryReduction. repeat match goal with | _ => rewrite H, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r | |- context [?a mod ?b] => unique pose proof (Z.mod_pos_bound a b ltac:(omega)) + | |- 0 <= _ => progress Z.zero_bounds | |- 0 <= _ * _ < _ * _ => split; [ solve [Z.zero_bounds] | apply Z.mul_lt_mono_nonneg; omega ] | _ => solve [auto] + | _ => cbn + | _ => nia end. Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N) @@ -5714,11 +5673,10 @@ Module MontgomeryReduction. Proof. rewrite <-reduce_via_partial_alt_eq by nia. cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In]. - rewrite Hlo, Hhi. + rewrite Hlo, Hhi. subst n. assert (0 <= T mod R * N' < w 2) by (solve_range Hw). - rewrite !Columns.mul_converted_single_mod; - (auto; rewrite ?Columns.mul_converted_single_mod; solve_range Hw). - rewrite !Columns.mul_converted_single_div by (auto; solve_range Hw). + rewrite !Columns.mul_converted_mod by (auto; rewrite ?Columns.mul_converted_mod; solve_range Hw). + rewrite !Columns.mul_converted_div by (auto; solve_range Hw). rewrite Hw, ?Z.pow_1_r. autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct. @@ -5730,7 +5688,9 @@ Module MontgomeryReduction. |- context [if R * R <=? ?x then _ else _] => match goal with |- context [if dec (?xHigh / R = 0) then _ else _] => assert (x / R = xHigh) as cond_equiv end end. - { apply Z.mul_cancel_r with (p:=R); [omega|]. autorewrite with push_Zmul zdiv_to_mod push_Zmod; ring. } + { apply Z.mul_cancel_r with (p:=R); [omega|]. cbn. + rewrite w_0. autorewrite with zsimplify_fast. + autorewrite with push_Zmul zdiv_to_mod push_Zmod; ring. } rewrite <-cond_equiv. rewrite ?Z.mod_pull_div, ?Z.div_div by omega. assert (0 < R * R)%Z by Z.zero_bounds. @@ -5917,26 +5877,28 @@ Module Montgomery256. Open Scope nexpr_scope. Print montred256. (* - expr_let 2 := (uint128)(MUL_256 @@ - (((uint128)fst @@ x_1 & 340282366920938463463374607431768211455), (340282366841710300986003757985643364352)) << 128) in - expr_let 3 := (uint128)(MUL_256 @@ ((uint128)(fst @@ x_1 >> 128), (79228162514264337593543950337)) << 128) in - expr_let 8 := MUL_256 @@ (((uint128)fst @@ x_1 & 340282366920938463463374607431768211455), (79228162514264337593543950337)) in - expr_let 9 := ADD_256 @@ (x_2, x_3) in - expr_let 10 := ADD_256 @@ (x_8, fst @@ x_9) in - expr_let 20 := (uint128)(MUL_256 @@ - (((uint128)fst @@ x_10 & 340282366920938463463374607431768211455), (340282366841710300967557013911933812736)) << 128) in - expr_let 21 := (uint128)(MUL_256 @@ ((uint128)(fst @@ x_10 >> 128), (79228162514264337593543950335)) << 128) in - expr_let 26 := MUL_256 @@ (((uint128)fst @@ x_10 & 340282366920938463463374607431768211455), (79228162514264337593543950335)) in - expr_let 27 := ADD_128 @@ (x_20, x_21) in - expr_let 28 := ADD_256 @@ (x_26, fst @@ x_27) in - expr_let 29 := snd @@ x_28 +₁₂₈ snd @@ x_27 in - expr_let 36 := MUL_256 @@ ((uint128)(fst @@ x_10 >> 128), (340282366841710300967557013911933812736)) in - expr_let 37 := ADD_256 @@ (x_29, x_36) in - expr_let 38 := ADD_256 @@ (fst @@ x_1, fst @@ x_28) in - expr_let 39 := ADDC_256 @@ (snd @@ x_38, snd @@ x_1, fst @@ x_37) in - expr_let 40 := SELC @@ (snd @@ x_39, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in - expr_let 41 := fst @@ (SUB_256 @@ (fst @@ x_39, x_40)) in - ADDM @@ (x_41, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) + expr_let 3 := (uint128)(fst @@ x_1 >> 128) in + expr_let 4 := ((uint128)fst @@ x_1 & 340282366920938463463374607431768211455) in + expr_let 5 := (uint128)(MUL_256 @@ (x_4, (340282366841710300986003757985643364352)) << 128) in + expr_let 6 := (uint128)(MUL_256 @@ (x_3, (79228162514264337593543950337)) << 128) in + expr_let 11 := MUL_256 @@ (x_4, (79228162514264337593543950337)) in + expr_let 12 := ADD_256 @@ (x_5, x_6) in + expr_let 13 := ADD_256 @@ (x_11, fst @@ x_12) in + expr_let 23 := (uint128)(fst @@ x_13 >> 128) in + expr_let 24 := ((uint128)fst @@ x_13 & 340282366920938463463374607431768211455) in + expr_let 25 := (uint128)(MUL_256 @@ (x_24, (340282366841710300967557013911933812736)) << 128) in + expr_let 26 := (uint128)(MUL_256 @@ (x_23, (79228162514264337593543950335)) << 128) in + expr_let 31 := MUL_256 @@ (x_24, (79228162514264337593543950335)) in + expr_let 32 := ADD_128 @@ (x_25, x_26) in + expr_let 33 := ADD_256 @@ (x_31, fst @@ x_32) in + expr_let 34 := snd @@ x_33 +₁₂₈ snd @@ x_32 in + expr_let 41 := MUL_256 @@ (x_23, (340282366841710300967557013911933812736)) in + expr_let 42 := ADD_256 @@ (x_34, x_41) in + expr_let 43 := ADD_256 @@ (fst @@ x_1, fst @@ x_33) in + expr_let 44 := ADDC_256 @@ (snd @@ x_43, snd @@ x_1, fst @@ x_42) in + expr_let 45 := SELC @@ (snd @@ x_44, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in + expr_let 46 := fst @@ (SUB_256 @@ (fst @@ x_44, x_45)) in + ADDM @@ (x_46, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) : expr uint256 *) End Montgomery256. @@ -6014,18 +5976,22 @@ Module Montgomery256PrintingNotations. f)%nexpr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$r' n ',' x ',' y ');' '//' f") : nexpr_scope. Notation "'c.AddM(' '$ret' ',' x ',' y ',' z ');'" := (add_modulo _ _ _ uint256 @@ (x, y, z))%nexpr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : nexpr_scope. + Notation "'c.ShiftR(' '$r' n ',' x ',' y ');' f" := + (expr_let n := (shiftr _ _ y @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftR(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. + Notation "'c.Lower128(' '$r' n ',' x ');' f" := + (expr_let n := (land _ _ 340282366920938463463374607431768211455 @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Lower128(' '$r' n ',' x ');' ']' '//' f") : nexpr_scope. Notation "'Lower128'" := ((land uint256 uint128 340282366920938463463374607431768211455)) (at level 10, only printing, format "Lower128") : nexpr_scope. - Notation "( v >> count )" - := ((shiftr _ _ count @@ v)%nexpr) - (format "( v >> count )") - : nexpr_scope. Notation "( v << count )" := ((shiftl _ _ count @@ v)%nexpr) (format "( v << count )") : nexpr_scope. + Notation "( x >> count )" + := ((shiftr _ _ count @@ x)%nexpr) + (format "( x >> count )") + : nexpr_scope. End Montgomery256PrintingNotations. Import Montgomery256PrintingNotations. @@ -6033,23 +5999,27 @@ Local Open Scope nexpr_scope. Print Montgomery256.montred256. (* -c.Mul128x128($r2, Lower128 @@ $r1_lo, RegPinv >> 128) << 128; -c.Mul128x128($r3, ($r1_lo >> 128), Lower128{RegPinv}) << 128; -c.Mul128x128($r8, Lower128 @@ $r1_lo, Lower128{RegPinv}); -c.Add256($r9, $r2, $r3); -c.Add256($r10, $r8, $r9_lo); -c.Mul128x128($r20, Lower128 @@ $r10_lo, RegMod << 128) << 128; -c.Mul128x128($r21, ($r10_lo >> 128), Lower128{RegMod}) << 128; -c.Mul128x128($r26, Lower128 @@ $r10_lo, Lower128{RegMod}); -c.Add128($r27, $r20, $r21); -c.Add256($r28, $r26, $r27_lo); -c.Add64($r29, $r28_hi, $r27_hi); -c.Mul128x128($r36, ($r10_lo >> 128), RegMod << 128); -c.Add256($r37, $r29, $r36); -c.Add256($r38, $r1_lo, $r28_lo); -c.Addc($r39, $r1_hi, $r37_lo); -c.Selc($r40,RegZero, RegMod); -c.Sub($r41, $r39_lo, $r40); -c.AddM($ret, $r41, RegZero, RegMod); +c.ShiftR($r3,$r1_lo, 128); +c.Lower128($r4,$r1_lo); +c.Mul128x128($r5, $r4, RegPinv >> 128) << 128; +c.Mul128x128($r6, $r3, Lower128{RegPinv}) << 128; +c.Mul128x128($r11, $r4, Lower128{RegPinv}); +c.Add256($r12, $r5, $r6); +c.Add256($r13, $r11, $r12_lo); +c.ShiftR($r23,$r13_lo, 128); +c.Lower128($r24,$r13_lo); +c.Mul128x128($r25, $r24, RegMod << 128) << 128; +c.Mul128x128($r26, $r23, Lower128{RegMod}) << 128; +c.Mul128x128($r31, $r24, Lower128{RegMod}); +c.Add128($r32, $r25, $r26); +c.Add256($r33, $r31, $r32_lo); +c.Add64($r34, $r33_hi, $r32_hi); +c.Mul128x128($r41, $r23, RegMod << 128); +c.Add256($r42, $r34, $r41); +c.Add256($r43, $r1_lo, $r33_lo); +c.Addc($r44, $r1_hi, $r42_lo); +c.Selc($r45,RegZero, RegMod); +c.Sub($r46, $r44_lo, $r45); +c.AddM($ret, $r46, RegZero, RegMod); : expr uint256 *) |