aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-02-28 11:54:20 +0100
committerGravatar Jason Gross <jasongross9@gmail.com>2018-03-07 12:36:29 -0500
commit497cda884a5816fc0a955e637ce666768f28417f (patch)
tree3675d0e7d60e884c5e4bd543765788734d66a3bc
parentcddfdafe2fb7187c5a124927ff1c44eeb0b1211d (diff)
remove special-case convert-mul-convert implementation and use generalized one in Montgomery example
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v230
1 files changed, 100 insertions, 130 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 9a5d12be5..d83263ee9 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -854,7 +854,7 @@ Module Columns.
@Positional.eval_to_associational
@BaseConversion.eval_convert_bases using solve [auto] : push_eval.
- Lemma mul_converted_correct n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
+ Lemma eval_mul_converted n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
length p1 = n1 -> length p2 = n2 ->
0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2).
@@ -863,85 +863,41 @@ Module Columns.
rewrite Columns.flatten_mod by auto using Columns.length_from_associational.
autorewrite with push_eval. auto using Z.mod_small.
Qed.
+ Hint Rewrite eval_mul_converted : push_eval.
+
+ Hint Rewrite @length_from_associational : distr_length.
+
+ Lemma mul_converted_mod n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
+ length p1 = n1 -> length p2 = n2 ->
+ 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
+ nth_default 0 (mul_converted n1 n2 m1 m2 n3 p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1).
+ Proof.
+ intros; cbv [mul_converted].
+ erewrite flatten_partitions by (auto; distr_length).
+ autorewrite with distr_length push_eval natsimplify.
+ rewrite w_0; autorewrite with zsimplify.
+ reflexivity.
+ Qed.
+
+ Lemma mul_converted_div n1 n2 m1 m2 n3 p1 p2:
+ m1 <> 0%nat -> m2 <> 0%nat -> n3 = 2%nat ->
+ length p1 = n1 -> length p2 = n2 ->
+ 0 <= Positional.eval w n1 p1 ->
+ 0 <= Positional.eval w n2 p2 ->
+ 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
+ nth_default 0 (mul_converted n1 n2 m1 m2 n3 p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1).
+ Proof.
+ intros; subst n3; cbv [mul_converted].
+ erewrite flatten_partitions by (auto; distr_length).
+ autorewrite with distr_length push_eval.
+ pose proof (w_positive 1).
+ apply Z.mod_small.
+ split; [ solve[Z.zero_bounds] | ].
+ apply Z.div_lt_upper_bound; [omega|].
+ rewrite Z.mul_div_eq_full by auto.
+ rewrite w_multiples. omega.
+ Qed.
- (* TODO: this section specializes to one-element lists in which
- the intermediate weight is the square root of the old. It would
- be better to specialize just to the relationship between
- weights, rather than the size of the input. However, partial
- reduction/CPS transform seems to take forever when dynamic list
- allocation is happening. *)
- Section single.
- Context (w'_sq : forall i, (w' i) * (w' i) = w i).
- Context (w_1_gt1 : w 1 > 1) (w'_1_gt1 : w' 1 > 1).
-
- Derive convert_single
- SuchThat (forall p, convert_single p = BaseConversion.convert_bases w w' 1 2 [p])
- As convert_single_correct.
- Proof.
- intros.
- cbv - [Z.add Z.div Z.mul Z.eqb Z.modulo].
- assert (w 0 mod w' 1 = 1) as P0 by (rewrite w_0, Z.mod_1_l; omega).
- assert (w' 1 =? 1 = false) as P1 by (apply Z.eqb_neq; omega).
- assert (1 =? 0 = false) as P2 by reflexivity.
- repeat match goal with
- | _ => progress rewrite ?w_0, ?w'_0
- | _ => progress rewrite ?P0, ?P1, ?P2
- | _ => progress rewrite ?Z.mod_1_l, ?Z.eqb_refl by omega
- | _ => progress autorewrite with zsimplify_fast
- end.
- autorewrite with zsimplify.
- reflexivity.
- Qed.
-
- Derive mul_converted_single
- SuchThat (forall n (p1 p2 : Z), (0 <= p1 < w 1) -> (0 <= p2 < w 1) ->
- mul_converted_single n p1 p2 = mul_converted 1 1 2 2 n [p1] [p2])
- As mul_converted_single_eq.
- Proof.
- intros.
- cbv [mul_converted].
- rewrite <-!convert_single_correct.
- cbv [convert_single].
- subst mul_converted_single.
- reflexivity.
- Qed.
-
- Lemma eval_mul_converted_single n p1 p2 (_: n <> 0%nat) (_: 0 <= p1 < w 1) (_:0 <= p2 < w 1) (_: 0 <= p1 * p2 < w n) :
- Positional.eval w n (mul_converted_single n p1 p2) = (Positional.eval w 1 [p1]) * (Positional.eval w 1 [p2]).
- Proof. rewrite mul_converted_single_eq by auto. apply mul_converted_correct; cbn; nia. Qed.
-
- Hint Rewrite @length_from_associational : distr_length.
-
- Lemma mul_converted_single_mod n x y :
- n = 2%nat -> 0 <= x < w 1 -> 0 <= y < w 1 ->
- nth_default 0 (mul_converted_single n x y) 0 = (x * y) mod (w 1).
- Proof.
- intros; subst n; rewrite mul_converted_single_eq by auto. cbv [mul_converted].
- erewrite flatten_partitions by (auto; distr_length).
- autorewrite with distr_length push_eval. cbn.
- rewrite w_0; autorewrite with zsimplify.
- reflexivity.
- Qed.
-
- Lemma mul_converted_single_div n x y :
- n = 2%nat ->
- 0 <= x < w 1 -> 0 <= y < w 1 ->
- 0 <= x * y < w 2 ->
- nth_default 0 (mul_converted_single n x y) 1 = (x * y) / (w 1).
- Proof.
- intros; subst n; rewrite mul_converted_single_eq by auto. cbv [mul_converted].
- erewrite flatten_partitions by (auto; distr_length).
- autorewrite with distr_length push_eval. cbn.
- rewrite w_0; autorewrite with zsimplify.
- apply Z.mod_small.
- split.
- { apply Z.div_nonneg; auto; omega. }
- { apply Z.div_lt_upper_bound. omega.
- rewrite Z.mul_div_eq_full by auto.
- rewrite w_multiples. omega. }
- Qed.
-
- End single.
End mul_converted.
End Columns.
@@ -5640,7 +5596,7 @@ Module RemoveDeadLets.
| Let_In s T n x f => Let_In n (inline_let idx _ new _ x) (inline_let idx _ new _ f)
end.
- (* inlines lets that just re-bind a variable or half a variable with type prod *)
+ (* inlines lets that just re-bind a variable or the output of a specified operation on a single variable *)
Fixpoint inline_silly_lets t (e : @expr ident t) : @expr ident t :=
match e in (expr t') return expr t' with
| Var T n => Var T n
@@ -5652,7 +5608,7 @@ Module RemoveDeadLets.
match x with
| Var T' m => inline_let n _ (Var T' m) _ f
| AppIdent _ _ (@BoundsAnalysis.ident.fst A B) (Var _ m) =>
- inline_let n _ (@AppIdent _ _ _ (@BoundsAnalysis.ident.fst A B) (Var _ m)) _ (inline_silly_lets _ f)
+ inline_let n _ (@AppIdent _ _ _ (@BoundsAnalysis.ident.fst A B) (Var _ m)) _ (inline_silly_lets _ f)
| _ => Let_In n (inline_silly_lets _ x) (inline_silly_lets _ f)
end
end.
@@ -5690,8 +5646,8 @@ Module MontgomeryReduction.
Context (n:nat) (Hn : n = 2%nat).
Definition montred' (lo_hi : (Z * Z)) :=
- dlet_nd y := nth_default 0 (Columns.mul_converted_single w w_half n (fst lo_hi) N') 0 in
- dlet_nd t1_t2 := Columns.mul_converted_single w w_half n y N in
+ dlet_nd y := nth_default 0 (Columns.mul_converted w w_half 1 1 n n n [fst lo_hi] [N']) 0 in
+ dlet_nd t1_t2 := Columns.mul_converted w w_half 1 1 n n n [y] [N] in
dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in
dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in
dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in
@@ -5702,9 +5658,12 @@ Module MontgomeryReduction.
repeat match goal with
| _ => rewrite H, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r
| |- context [?a mod ?b] => unique pose proof (Z.mod_pos_bound a b ltac:(omega))
+ | |- 0 <= _ => progress Z.zero_bounds
| |- 0 <= _ * _ < _ * _ =>
split; [ solve [Z.zero_bounds] | apply Z.mul_lt_mono_nonneg; omega ]
| _ => solve [auto]
+ | _ => cbn
+ | _ => nia
end.
Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N)
@@ -5714,11 +5673,10 @@ Module MontgomeryReduction.
Proof.
rewrite <-reduce_via_partial_alt_eq by nia.
cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In].
- rewrite Hlo, Hhi.
+ rewrite Hlo, Hhi. subst n.
assert (0 <= T mod R * N' < w 2) by (solve_range Hw).
- rewrite !Columns.mul_converted_single_mod;
- (auto; rewrite ?Columns.mul_converted_single_mod; solve_range Hw).
- rewrite !Columns.mul_converted_single_div by (auto; solve_range Hw).
+ rewrite !Columns.mul_converted_mod by (auto; rewrite ?Columns.mul_converted_mod; solve_range Hw).
+ rewrite !Columns.mul_converted_div by (auto; solve_range Hw).
rewrite Hw, ?Z.pow_1_r.
autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct.
@@ -5730,7 +5688,9 @@ Module MontgomeryReduction.
|- context [if R * R <=? ?x then _ else _] =>
match goal with |- context [if dec (?xHigh / R = 0) then _ else _] =>
assert (x / R = xHigh) as cond_equiv end end.
- { apply Z.mul_cancel_r with (p:=R); [omega|]. autorewrite with push_Zmul zdiv_to_mod push_Zmod; ring. }
+ { apply Z.mul_cancel_r with (p:=R); [omega|]. cbn.
+ rewrite w_0. autorewrite with zsimplify_fast.
+ autorewrite with push_Zmul zdiv_to_mod push_Zmod; ring. }
rewrite <-cond_equiv. rewrite ?Z.mod_pull_div, ?Z.div_div by omega.
assert (0 < R * R)%Z by Z.zero_bounds.
@@ -5917,26 +5877,28 @@ Module Montgomery256.
Open Scope nexpr_scope.
Print montred256.
(*
- expr_let 2 := (uint128)(MUL_256 @@
- (((uint128)fst @@ x_1 & 340282366920938463463374607431768211455), (340282366841710300986003757985643364352)) << 128) in
- expr_let 3 := (uint128)(MUL_256 @@ ((uint128)(fst @@ x_1 >> 128), (79228162514264337593543950337)) << 128) in
- expr_let 8 := MUL_256 @@ (((uint128)fst @@ x_1 & 340282366920938463463374607431768211455), (79228162514264337593543950337)) in
- expr_let 9 := ADD_256 @@ (x_2, x_3) in
- expr_let 10 := ADD_256 @@ (x_8, fst @@ x_9) in
- expr_let 20 := (uint128)(MUL_256 @@
- (((uint128)fst @@ x_10 & 340282366920938463463374607431768211455), (340282366841710300967557013911933812736)) << 128) in
- expr_let 21 := (uint128)(MUL_256 @@ ((uint128)(fst @@ x_10 >> 128), (79228162514264337593543950335)) << 128) in
- expr_let 26 := MUL_256 @@ (((uint128)fst @@ x_10 & 340282366920938463463374607431768211455), (79228162514264337593543950335)) in
- expr_let 27 := ADD_128 @@ (x_20, x_21) in
- expr_let 28 := ADD_256 @@ (x_26, fst @@ x_27) in
- expr_let 29 := snd @@ x_28 +₁₂₈ snd @@ x_27 in
- expr_let 36 := MUL_256 @@ ((uint128)(fst @@ x_10 >> 128), (340282366841710300967557013911933812736)) in
- expr_let 37 := ADD_256 @@ (x_29, x_36) in
- expr_let 38 := ADD_256 @@ (fst @@ x_1, fst @@ x_28) in
- expr_let 39 := ADDC_256 @@ (snd @@ x_38, snd @@ x_1, fst @@ x_37) in
- expr_let 40 := SELC @@ (snd @@ x_39, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in
- expr_let 41 := fst @@ (SUB_256 @@ (fst @@ x_39, x_40)) in
- ADDM @@ (x_41, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951))
+ expr_let 3 := (uint128)(fst @@ x_1 >> 128) in
+ expr_let 4 := ((uint128)fst @@ x_1 & 340282366920938463463374607431768211455) in
+ expr_let 5 := (uint128)(MUL_256 @@ (x_4, (340282366841710300986003757985643364352)) << 128) in
+ expr_let 6 := (uint128)(MUL_256 @@ (x_3, (79228162514264337593543950337)) << 128) in
+ expr_let 11 := MUL_256 @@ (x_4, (79228162514264337593543950337)) in
+ expr_let 12 := ADD_256 @@ (x_5, x_6) in
+ expr_let 13 := ADD_256 @@ (x_11, fst @@ x_12) in
+ expr_let 23 := (uint128)(fst @@ x_13 >> 128) in
+ expr_let 24 := ((uint128)fst @@ x_13 & 340282366920938463463374607431768211455) in
+ expr_let 25 := (uint128)(MUL_256 @@ (x_24, (340282366841710300967557013911933812736)) << 128) in
+ expr_let 26 := (uint128)(MUL_256 @@ (x_23, (79228162514264337593543950335)) << 128) in
+ expr_let 31 := MUL_256 @@ (x_24, (79228162514264337593543950335)) in
+ expr_let 32 := ADD_128 @@ (x_25, x_26) in
+ expr_let 33 := ADD_256 @@ (x_31, fst @@ x_32) in
+ expr_let 34 := snd @@ x_33 +₁₂₈ snd @@ x_32 in
+ expr_let 41 := MUL_256 @@ (x_23, (340282366841710300967557013911933812736)) in
+ expr_let 42 := ADD_256 @@ (x_34, x_41) in
+ expr_let 43 := ADD_256 @@ (fst @@ x_1, fst @@ x_33) in
+ expr_let 44 := ADDC_256 @@ (snd @@ x_43, snd @@ x_1, fst @@ x_42) in
+ expr_let 45 := SELC @@ (snd @@ x_44, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in
+ expr_let 46 := fst @@ (SUB_256 @@ (fst @@ x_44, x_45)) in
+ ADDM @@ (x_46, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951))
: expr uint256
*)
End Montgomery256.
@@ -6014,18 +5976,22 @@ Module Montgomery256PrintingNotations.
f)%nexpr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$r' n ',' x ',' y ');' '//' f") : nexpr_scope.
Notation "'c.AddM(' '$ret' ',' x ',' y ',' z ');'" :=
(add_modulo _ _ _ uint256 @@ (x, y, z))%nexpr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : nexpr_scope.
+ Notation "'c.ShiftR(' '$r' n ',' x ',' y ');' f" :=
+ (expr_let n := (shiftr _ _ y @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftR(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope.
+ Notation "'c.Lower128(' '$r' n ',' x ');' f" :=
+ (expr_let n := (land _ _ 340282366920938463463374607431768211455 @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Lower128(' '$r' n ',' x ');' ']' '//' f") : nexpr_scope.
Notation "'Lower128'"
:= ((land uint256 uint128 340282366920938463463374607431768211455))
(at level 10, only printing, format "Lower128")
: nexpr_scope.
- Notation "( v >> count )"
- := ((shiftr _ _ count @@ v)%nexpr)
- (format "( v >> count )")
- : nexpr_scope.
Notation "( v << count )"
:= ((shiftl _ _ count @@ v)%nexpr)
(format "( v << count )")
: nexpr_scope.
+ Notation "( x >> count )"
+ := ((shiftr _ _ count @@ x)%nexpr)
+ (format "( x >> count )")
+ : nexpr_scope.
End Montgomery256PrintingNotations.
Import Montgomery256PrintingNotations.
@@ -6033,23 +5999,27 @@ Local Open Scope nexpr_scope.
Print Montgomery256.montred256.
(*
-c.Mul128x128($r2, Lower128 @@ $r1_lo, RegPinv >> 128) << 128;
-c.Mul128x128($r3, ($r1_lo >> 128), Lower128{RegPinv}) << 128;
-c.Mul128x128($r8, Lower128 @@ $r1_lo, Lower128{RegPinv});
-c.Add256($r9, $r2, $r3);
-c.Add256($r10, $r8, $r9_lo);
-c.Mul128x128($r20, Lower128 @@ $r10_lo, RegMod << 128) << 128;
-c.Mul128x128($r21, ($r10_lo >> 128), Lower128{RegMod}) << 128;
-c.Mul128x128($r26, Lower128 @@ $r10_lo, Lower128{RegMod});
-c.Add128($r27, $r20, $r21);
-c.Add256($r28, $r26, $r27_lo);
-c.Add64($r29, $r28_hi, $r27_hi);
-c.Mul128x128($r36, ($r10_lo >> 128), RegMod << 128);
-c.Add256($r37, $r29, $r36);
-c.Add256($r38, $r1_lo, $r28_lo);
-c.Addc($r39, $r1_hi, $r37_lo);
-c.Selc($r40,RegZero, RegMod);
-c.Sub($r41, $r39_lo, $r40);
-c.AddM($ret, $r41, RegZero, RegMod);
+c.ShiftR($r3,$r1_lo, 128);
+c.Lower128($r4,$r1_lo);
+c.Mul128x128($r5, $r4, RegPinv >> 128) << 128;
+c.Mul128x128($r6, $r3, Lower128{RegPinv}) << 128;
+c.Mul128x128($r11, $r4, Lower128{RegPinv});
+c.Add256($r12, $r5, $r6);
+c.Add256($r13, $r11, $r12_lo);
+c.ShiftR($r23,$r13_lo, 128);
+c.Lower128($r24,$r13_lo);
+c.Mul128x128($r25, $r24, RegMod << 128) << 128;
+c.Mul128x128($r26, $r23, Lower128{RegMod}) << 128;
+c.Mul128x128($r31, $r24, Lower128{RegMod});
+c.Add128($r32, $r25, $r26);
+c.Add256($r33, $r31, $r32_lo);
+c.Add64($r34, $r33_hi, $r32_hi);
+c.Mul128x128($r41, $r23, RegMod << 128);
+c.Add256($r42, $r34, $r41);
+c.Add256($r43, $r1_lo, $r33_lo);
+c.Addc($r44, $r1_hi, $r42_lo);
+c.Selc($r45,RegZero, RegMod);
+c.Sub($r46, $r44_lo, $r45);
+c.AddM($ret, $r46, RegZero, RegMod);
: expr uint256
*)