aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/SimplyTypedArithmetic.v
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-02-27 13:15:34 +0100
committerGravatar Jason Gross <jasongross9@gmail.com>2018-03-07 12:36:29 -0500
commit9e27a38fccfb19ae7a04f3e7826d39eabe861662 (patch)
treed6fbf9b2dd115b2649fb39ced8278490b8ffa220 /src/Experiments/SimplyTypedArithmetic.v
parent503cdadf97e5c436390905adfb8ce6824a58cdbf (diff)
factor out convert-mul-convert and prove correctness
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v407
1 files changed, 326 insertions, 81 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 92cf35eee..c6c12827d 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -137,6 +137,9 @@ Module Positional. Section Positional.
Lemma eval_nil n : eval n [] = 0.
Proof. cbv [eval to_associational]. rewrite combine_nil_r. reflexivity. Qed.
Hint Rewrite eval_nil : push_eval.
+ Lemma eval0 p : eval 0 p = 0.
+ Proof. cbv [eval to_associational]. reflexivity. Qed.
+ Hint Rewrite eval0 : push_eval.
Lemma eval_snoc n m x y : n = length x -> m = S n -> eval m (x ++ [y]) = eval n x + weight n * y.
Proof.
@@ -262,6 +265,8 @@ Module Positional. Section Positional.
(weight (S index) / weight index)
(to_associational n p)).
+ Lemma length_carry n m index p : length (carry n m index p) = m.
+ Proof. cbv [carry]; distr_length. Qed.
Lemma eval_carry n m i p: (n <> 0%nat) -> (m <> 0%nat) ->
weight (S i) / weight i <> 0 ->
eval m (carry n m i p) = eval n p.
@@ -316,6 +321,19 @@ Module Positional. Section Positional.
cbn [fold_right]; distr_length.
Qed. Hint Rewrite @length_chained_carries : distr_length.
+ (* carries without modular reduction; useful for converting between bases *)
+ Definition chained_carries_no_reduce n p (idxs : list nat) :=
+ fold_right (fun a b => carry n n a b) p (rev idxs).
+ Lemma eval_chained_carries_no_reduce n p idxs:
+ (forall i, In i idxs -> weight (S i) / weight i <> 0) ->
+ eval n (chained_carries_no_reduce n p idxs) = eval n p.
+ Proof.
+ cbv [chained_carries_no_reduce]; intros.
+ destruct n; [push;reflexivity|].
+ apply fold_right_invariant; [|intro; rewrite <-in_rev];
+ intros; push; auto.
+ Qed. Hint Rewrite @eval_chained_carries_no_reduce : push_eval.
+
(* Reverse of [eval]; translate from Z to basesystem by putting
everything in first digit and then carrying. *)
Definition encode n s c (x : Z) : list Z :=
@@ -329,6 +347,7 @@ Module Positional. Section Positional.
Lemma length_encode n s c x
: length (encode n s c x) = n.
Proof. cbv [encode]; repeat distr_length. Qed.
+
End Carries.
Hint Rewrite @eval_encode : push_eval.
Hint Rewrite @length_encode : distr_length.
@@ -428,6 +447,30 @@ End Positional. End Positional.
Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit.
+Module BaseConversion.
+ Import Positional.
+ Section BaseConversion.
+ Context (sw dw : nat -> Z) (* source/destination weight functions *)
+ (dw_0 : dw 0%nat = 1)
+ (dw_nz : forall i, dw i <> 0).
+ Context (dw_divides : forall i : nat, dw (S i) / dw i > 0).
+
+ Definition convert_bases (sn dn : nat) (p : list Z) : list Z :=
+ let p' := Positional.from_associational dw dn (Positional.to_associational sw sn p) in
+ chained_carries_no_reduce dw dn p' (seq 0 (pred dn)).
+
+ Lemma eval_convert_bases sn dn p :
+ (dn <> 0%nat) -> length p = sn ->
+ eval dw dn (convert_bases sn dn p) = eval sw sn p.
+ Proof.
+ cbv [convert_bases]; intros.
+ rewrite eval_chained_carries_no_reduce; auto using ZUtil.Z.positive_is_nonzero.
+ rewrite eval_from_associational; auto.
+ Qed.
+
+ End BaseConversion.
+End BaseConversion.
+
(* Non-CPS version of Arithmetic/Saturated/MulSplit.v *)
Module MulSplit.
Module Associational.
@@ -702,6 +745,71 @@ Module Columns.
rewrite IHp by tauto. ring. }
Qed.
+ Lemma flatten_snoc x inp : flatten (inp ++ [x]) = flatten_step x (flatten inp).
+ Proof. cbv [flatten]. rewrite rev_unit. reflexivity. Qed.
+
+ Lemma weight_multiples_full j : forall i, (i <= j)%nat -> weight j mod weight i = 0.
+ Proof.
+ induction j; intros; [replace i with 0%nat by omega
+ | destruct (dec (i <= j)%nat); [ rewrite (Z.div_mod (weight (S j)) (weight j)) by auto
+ | replace i with (S j) by omega ] ];
+ repeat match goal with
+ | _ => rewrite weight_0
+ | _ => rewrite weight_multiples
+ | _ => rewrite IHj by omega
+ | _ => progress autorewrite with push_Zmod zsimplify
+ | _ => reflexivity
+ end.
+ Qed.
+
+ (* TODO: move to ZUtil *)
+ Lemma Z_divide_div_mul_exact' a b c : b <> 0 -> (b | a) -> a * c / b = c * (a / b).
+ Proof. intros. rewrite Z.mul_comm. auto using Z.divide_div_mul_exact. Qed.
+
+ Lemma flatten_partitions inp:
+ forall n i, length inp = n -> (i < n)%nat ->
+ nth_default 0 (fst (flatten inp)) i = (((eval n inp) / weight i)) mod (weight (S i) / weight i).
+ Proof.
+ induction inp using rev_ind; distr_length; intros.
+ { cbn.
+ autorewrite with push_eval push_nth_default zsimplify.
+ reflexivity. }
+ {
+ destruct n as [| n]; [omega|].
+ rewrite flatten_snoc, eval_snoc by omega.
+ cbv [flatten_step Let_In]. cbn [fst].
+ rewrite nth_default_app.
+ break_match; distr_length.
+ { rewrite IHinp with (n:=n) by omega.
+ rewrite (Z.div_mod (weight n) (weight i)) by auto.
+ rewrite weight_multiples_full by omega.
+ rewrite (Z.div_mod (weight n) (weight (S i))) by auto.
+ rewrite weight_multiples_full by omega.
+ autorewrite with zsimplify.
+ repeat match goal with
+ | _ => rewrite Z_divide_div_mul_exact' by (try apply Z.mod_divide; auto)
+ | |- context [ (_ + ?a * ?b * ?c) / ?a ] =>
+ replace (a * b * c) with (a * (b * c)) by ring;
+ rewrite Z.div_add' by auto
+ | |- context [ (_ + ?a * ?b * ?c) mod ?b ] =>
+ replace (a * b * c) with (a * c * b) by ring;
+ rewrite Z.mod_add by auto using ZUtil.Z.positive_is_nonzero
+ | _ => reflexivity
+ end.
+ }
+ { repeat match goal with
+ | _ => progress replace (Datatypes.length inp) with n by omega
+ | _ => progress replace i with n by omega
+ | _ => rewrite nth_default_cons
+ | _ => rewrite sum_cons
+ | _ => rewrite flatten_column_mod
+ | _ => erewrite flatten_div by eauto
+ | _ => progress autorewrite with natsimplify
+ end.
+ rewrite Z.div_add' by auto.
+ reflexivity. } }
+ Qed.
+
Section mul.
Definition mul s n m (p q : list Z) : list Z :=
let p_a := Positional.to_associational weight n p in
@@ -710,6 +818,207 @@ Module Columns.
fst (flatten (from_associational m pq_a)).
End mul.
End Columns.
+
+ Section mul_converted.
+ Context (w w' : nat -> Z).
+ Context (w'_0 : w' 0%nat = 1)
+ (w'_nonzero : forall i, w' i <> 0)
+ (w'_positive : forall i, w' i > 0)
+ (w'_divides : forall i : nat, w' (S i) / w' i > 0).
+ Context (w_0 : w 0%nat = 1)
+ (w_nonzero : forall i, w i <> 0)
+ (w_positive : forall i, w i > 0)
+ (w_multiples : forall i, w (S i) mod w i = 0)
+ (w_divides : forall i : nat, w (S i) / w i > 0).
+
+ (* take in inputs in base w. Converts to w', multiplies in that format, converts to w again, then flattens. *)
+ Definition mul_converted
+ n1 n2 (* lengths in original format *)
+ m1 m2 (* lengths in converted format *)
+ (n3 : nat) (* final length *)
+ (p1 p2 : list Z) :=
+ let p1' := BaseConversion.convert_bases w w' n1 m1 p1 in
+ let p2' := BaseConversion.convert_bases w w' n2 m2 p2 in
+ let p1_a := Positional.to_associational w' m1 p1' in
+ let p2_a := Positional.to_associational w' m2 p2' in
+ (*
+ let p3_a := Associational.carry (w' 1%nat) (w 1) (Associational.mul p1_a p2_a) in
+ *)
+ let p3_a := Associational.mul p1_a p2_a in
+ fst (flatten w (from_associational w n3 p3_a)).
+
+ Hint Rewrite
+ @Columns.eval_from_associational
+ @Associational.eval_carry
+ @Associational.eval_mul
+ @Positional.eval_to_associational
+ @BaseConversion.eval_convert_bases using solve [auto] : push_eval.
+
+ Lemma mul_converted_correct n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
+ length p1 = n1 -> length p2 = n2 ->
+ 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
+ Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2).
+ Proof.
+ cbv [mul_converted]; intros.
+ rewrite Columns.flatten_mod by auto using Columns.length_from_associational.
+ autorewrite with push_eval. auto using Z.mod_small.
+ Qed.
+
+ (* TODO: this section specializes to one-element lists in which
+ the intermediate weight is the square root of the old. It would
+ be better to specialize just to the relationship between
+ weights, rather than the size of the input. However, partial
+ reduction/CPS transform seems to take forever when dynamic list
+ allocation is happening. *)
+ Section single.
+ Context (w'_sq : forall i, (w' i) * (w' i) = w i).
+ Context (w_1_gt1 : w 1 > 1) (w'_1_gt1 : w' 1 > 1).
+
+ Derive convert_single
+ SuchThat (forall p, convert_single p = BaseConversion.convert_bases w w' 1 2 [p])
+ As convert_single_correct.
+ Proof.
+ intros.
+ cbv - [Z.add Z.div Z.mul Z.eqb Z.modulo].
+ assert (w 0 mod w' 1 = 1) as P0 by (rewrite w_0, Z.mod_1_l; omega).
+ assert (w' 1 =? 1 = false) as P1 by (apply Z.eqb_neq; omega).
+ assert (1 =? 0 = false) as P2 by reflexivity.
+ repeat match goal with
+ | _ => progress rewrite ?w_0, ?w'_0
+ | _ => progress rewrite ?P0, ?P1, ?P2
+ | _ => progress rewrite ?Z.mod_1_l, ?Z.eqb_refl by omega
+ | _ => progress autorewrite with zsimplify_fast
+ end.
+ autorewrite with zsimplify.
+ reflexivity.
+ Qed.
+
+ Derive mul_converted_single
+ SuchThat (forall (p1 p2 : Z), (0 <= p1 < w 1) -> (0 <= p2 < w 1) ->
+ mul_converted_single p1 p2 = mul_converted 1 1 2 2 2 [p1] [p2])
+ As mul_converted_single_eq.
+ Proof.
+ intros.
+ cbv [mul_converted].
+ rewrite <-!convert_single_correct.
+ cbv [convert_single].
+
+ (*
+ (* assert some things for omega to use later *)
+ rewrite <-(w'_sq 1) in *.
+ pose proof (Z.mod_pos_bound p1 (w' 1) ltac:(auto using Z.gt_lt)).
+ pose proof (Z.mod_pos_bound p2 (w' 1) ltac:(auto using Z.gt_lt)).
+ assert (0 <= p1 / w' 1 < w' 1) by (split; [ Z.zero_bounds | apply Z.div_lt_upper_bound; omega ]).
+ assert (0 <= p2 / w' 1 < w' 1) by (split; [ Z.zero_bounds | apply Z.div_lt_upper_bound; omega ]).
+ assert (w' 1 < w' 1 * w' 1) by (apply Z.lt_mul_diag_r; omega).
+ assert (w' 1 =? 0 = false) by (apply Z.eqb_neq; omega).
+ assert (1 =? 0 = false) by reflexivity.
+ assert (0 < w' 1 * w' 1) by Z.zero_bounds.
+
+ (* simplify carry *)
+ match goal with |- context [Associational.carry ?w ?fw ?x] =>
+ remember (Associational.carry w fw x) as X eqn:HeqX
+ end.
+ cbv - [Z.modulo Z.div Z.eqb Z.mul app] in HeqX. cbn [app] in HeqX.
+ rewrite w'_0 in HeqX; autorewrite with zsimplify_fast in HeqX.
+ rewrite Z.eqb_refl in HeqX.
+ repeat match type of HeqX with context [if ?x =? ?y then _ else _] =>
+ let H := fresh "H" in
+ case_eq (x =? y); intro H; rewrite H in HeqX;
+ rewrite ?Z.eqb_eq, ?Z.eqb_neq in H; try omega
+ end.
+ cbn [app] in HeqX.
+ rewrite !Z.div_small with (b:= w' 1 * w' 1) in HeqX by nia.
+ rewrite !Z.mod_small with (b:= w' 1 * w' 1) in HeqX by nia.
+ subst X.
+
+ (* simplify from_associational *)
+ match goal with |- context [from_associational ?w ?n ?x] =>
+ remember (from_associational w n x) as X eqn:HeqX
+ end.
+ cbv - [Z.modulo Z.div Z.eqb Z.mul cons_to_nth] in HeqX. cbn [app] in HeqX.
+ rewrite <-w'_sq in HeqX.
+ autorewrite with zsimplify_fast in HeqX.
+ rewrite !Z.mod_1_l in HeqX by omega.
+ rewrite !Z.mod_mul in HeqX by omega.
+ rewrite !Z.mod_small with (b:= w' 1 * w' 1) in HeqX by nia.
+ rewrite Z.eqb_refl in HeqX.
+ repeat match goal with H : Z.eqb _ _ = _ |- _ => rewrite H in HeqX end.
+ cbv - [Z.modulo Z.div Z.mul] in HeqX.
+ autorewrite with zsimplify in HeqX.
+ subst X.
+
+ (* simplify flatten *)
+ match goal with |- context [flatten ?w ?x] =>
+ remember (flatten w x) as X eqn:HeqX
+ end.
+ cbn in HeqX.
+ cbv [flatten_step] in HeqX. cbn in HeqX.
+ autorewrite with to_div_mod in HeqX.
+ cbn [fst snd] in HeqX.
+ rewrite w_0 in HeqX.
+ autorewrite with zsimplify in HeqX.
+ Check Z.div_small.
+ match type of HeqX with context [
+
+ cbv [Let_In] in HeqX.
+ autorewrite with to_div_mod in HeqX.
+ cbn [fst snd] in HeqX.
+ cbv - [flatten_column Z.div Z.modulo Z.mul] in HeqX.
+ cbv [flatten_step] in HeqX.
+ cbv - [Z.modulo Z.div Z.eqb Z.mul Z.add_get_carry_full Z.add fst snd] in HeqX. cbn [app] in HeqX.
+ rewrite <-w'_sq in HeqX.
+ autorewrite with zsimplify_fast in HeqX.
+ rewrite !Z.mod_1_l in HeqX by omega.
+ rewrite !Z.mod_mul in HeqX by omega.
+ rewrite !Z.mod_small with (b:= w' 1 * w' 1) in HeqX by nia.
+ rewrite Z.eqb_refl in HeqX.
+ repeat match goal with H : Z.eqb _ _ = _ |- _ => rewrite H in HeqX end.
+ cbv - [Z.modulo Z.div Z.mul] in HeqX.
+ autorewrite with zsimplify in HeqX.
+ subst X.
+ *)
+
+ subst mul_converted_single.
+ reflexivity.
+ Qed.
+
+ Lemma eval_mul_converted_single p1 p2 (_: 0 <= p1 < w 1) (_:0 <= p2 < w 1) (_: 0 <= p1 * p2 < w 2) :
+ Positional.eval w 2 (mul_converted_single p1 p2) = (Positional.eval w 1 [p1]) * (Positional.eval w 1 [p2]).
+ Proof. rewrite mul_converted_single_eq by auto. apply mul_converted_correct; cbn; nia. Qed.
+
+ Hint Rewrite @length_from_associational : distr_length.
+
+ Lemma mul_converted_single_mod x y :
+ 0 <= x < w 1 -> 0 <= y < w 1 ->
+ nth_default 0 (mul_converted_single x y) 0 = (x * y) mod (w 1).
+ Proof.
+ intros; rewrite mul_converted_single_eq by auto. cbv [mul_converted].
+ erewrite flatten_partitions by (auto; distr_length).
+ autorewrite with distr_length push_eval. cbn.
+ rewrite w_0; autorewrite with zsimplify.
+ reflexivity.
+ Qed.
+
+ Lemma mul_converted_single_div x y :
+ 0 <= x < w 1 -> 0 <= y < w 1 ->
+ 0 <= x * y < w 2 ->
+ nth_default 0 (mul_converted_single x y) 1 = (x * y) / (w 1).
+ Proof.
+ intros; rewrite mul_converted_single_eq by auto. cbv [mul_converted].
+ erewrite flatten_partitions by (auto; distr_length).
+ autorewrite with distr_length push_eval. cbn.
+ rewrite w_0; autorewrite with zsimplify.
+ apply Z.mod_small.
+ split.
+ { apply Z.div_nonneg; auto; omega. }
+ { apply Z.div_lt_upper_bound. omega.
+ rewrite Z.mul_div_eq_full by auto.
+ rewrite w_multiples. omega. }
+ Qed.
+
+ End single.
+ End mul_converted.
End Columns.
Module Compilers.
@@ -1006,9 +1315,9 @@ Module Compilers.
| false
=>
let rT := type.reify T in
- let not_x := refresh x ltac:(fun n => fresh n) in
- let not_x2 := refresh not_x ltac:(fun n => fresh n) in
- let not_x3 := refresh not_x2 ltac:(fun n => fresh n) in
+ let not_x := fresh in
+ let not_x2 := fresh in
+ let not_x3 := fresh in
(*let dummy := match goal with _ => idtac "reify_in_context: λ case:" term "using vars:" not_x not_x2 not_x3 end in*)
let rf0 :=
constr:(
@@ -5438,81 +5747,16 @@ Require Import Crypto.Util.ZUtil.Zselect Crypto.Util.ZUtil.AddModulo.
Module MontgomeryReduction.
Section MontRed'.
Context (N R N' R' : Z).
- Context (HN_range : 0 <= N < R) (HN'_range : 0 <= N' < R) (HN_nz : N <> 0)
+ Context (HN_range : 0 <= N < R) (HN'_range : 0 <= N' < R) (HN_nz : N <> 0) (R_gt_1 : R > 1)
(N'_good : Z.equiv_modulo R (N*N') (-1)) (R'_good: Z.equiv_modulo N (R*R') 1).
- Section mul_converted.
- Context (w w' : nat -> Z).
- Context (w'_sq : forall i, (w' i) * (w' i) = w i).
- Context (w'_0 : w' 0%nat = 1)
- (w'_positive : forall i, w' i > 0).
- Context (w_0 : w 0%nat = 1)
- (w_nonzero : forall i, w i <> 0)
- (w_positive : forall i, w i > 0)
- (w_multiples : forall i, w (S i) mod w i = 0)
- (w_divides : forall i : nat, w (S i) / w i > 0).
- Context (w_1_gt1 : w 1 > 1) (w'_1_gt1 : w' 1 > 1).
-
- (*
- (* TODO: get a version of convert-multiply-convert strategy working in
- general form, not specialized to one-element lists, and add it to
- arithmetic development in an appropriate place. May need to
- specialize it to the case where (forall i, (w' i)^2 = w i) in
- order for base conversion to simplify as expected. *)
- Definition chained_carries_noreduce weight n p idxs : list Z :=
- fold_right (fun a b => carry weight n n a b) p (rev idxs).
- Definition convert_bases (sw dw : nat -> Z) (sn dn : nat) (p : list Z) : list Z :=
- let p' := Positional.from_associational dw dn (Positional.to_associational sw sn p) in
- chained_carries_noreduce dw dn p' (seq 0 dn).
-
- (* take in inputs in base w. Converts to w', multiplies in that format, converts to w again, then flattens. *)
- Definition mul_converted
- w w' (* two different weight functions, initial/final and intermediate *)
- n1 n2 (* lengths in original format *)
- m1 m2 (* lengths in converted format *)
- (n3 : nat) (* final length *)
- (p1 p2 : list Z) :=
- let p1' := convert_bases w w' n1 m1 p1 in
- let p2' := convert_bases w w' n2 m2 p2 in
- let p1_a := Positional.to_associational w' m1 p1' in
- let p2_a := Positional.to_associational w' m2 p2' in
- let p3_a := Associational.mul p1_a p2_a in
- fst (Columns.flatten w (Columns.from_associational w n3 p3_a)).
- *)
-
- (* specialized version equivalent to [test_mul w w' 1 1 2 2 2 [p1] [p2] *)
- (* takes in 2 1-digit inputs in base w, produces a 2-digit output--same spec as mul_split *)
- Definition mul_converted_single (w w' : nat ->Z) (p1 p2 : Z) :=
- let p1' := [p1 mod w' 1%nat; p1 / w' 1%nat] in
- let p2' := [p2 mod w' 1%nat; p2 / w' 1%nat] in
- let p1_a := Positional.to_associational w' 2 p1' in
- let p2_a := Positional.to_associational w' 2 p2' in
- let p3_a := Associational.mul p1_a p2_a in
- fst (Columns.flatten w (Columns.from_associational w 2 p3_a)).
-
- Lemma mul_converted_single_eq p1 p2 :
- (0 <= p1 * p2 < (w 2)) ->
- mul_converted_single w w' p1 p2 = [(p1 * p2) mod (w 1); (p1 * p2) / (w 1) ].
- Proof.
- Admitted.
- Lemma mul_converted_single_correct p1 p2 :
- Positional.eval w 2 (mul_converted_single w w' p1 p2) = (Positional.eval w 1 [p1]) * (Positional.eval w 1 [p2]) mod (w 2).
- Proof.
- intros. cbv [mul_converted_single].
- rewrite Columns.flatten_mod by auto using Columns.length_from_associational.
- rewrite Columns.eval_from_associational by auto.
- rewrite Associational.eval_mul.
- cbv [Positional.eval Positional.to_associational Associational.eval].
- simpl [map seq combine fold_right]. rewrite w_0, w'_0.
- rewrite !Z.mul_div_eq by auto.
- f_equal; ring.
- Qed.
- End mul_converted.
-
Context (w w_half : nat -> Z).
Context (w_half_sq : forall i, (w_half i) * (w_half i) = w i).
Context (w_half_0 : w_half 0%nat = 1)
- (w_half_positive : forall i, w_half i > 0).
+ (w_half_nonzero : forall i, w_half i <> 0)
+ (w_half_positive : forall i, w_half i > 0)
+ (w_half_multiples : forall i, w_half (S i) mod w_half i = 0)
+ (w_half_divides : forall i : nat, w_half (S i) / w_half i > 0).
Context (w_0 : w 0%nat = 1)
(w_nonzero : forall i, w i <> 0)
(w_positive : forall i, w i > 0)
@@ -5521,8 +5765,8 @@ Module MontgomeryReduction.
Context (w_1_gt1 : w 1 > 1) (w_half_1_gt1 : w_half 1 > 1).
Definition montred' (lo_hi : (Z * Z)) :=
- dlet_nd y := nth_default 0 (mul_converted_single w w_half (fst lo_hi) N') 0 in
- dlet_nd t1_t2 := mul_converted_single w w_half y N in
+ dlet_nd y := nth_default 0 (Columns.mul_converted_single w w_half (fst lo_hi) N') 0 in
+ dlet_nd t1_t2 := Columns.mul_converted_single w w_half y N in
dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in
dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in
dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in
@@ -5531,10 +5775,11 @@ Module MontgomeryReduction.
Local Ltac solve_range H :=
repeat match goal with
- | _ => rewrite H, ?Z.pow_1_r, ?Z.pow_2_r
+ | _ => rewrite H, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r
| |- context [?a mod ?b] => unique pose proof (Z.mod_pos_bound a b ltac:(omega))
| |- 0 <= _ * _ < _ * _ =>
split; [ solve [Z.zero_bounds] | apply Z.mul_lt_mono_nonneg; omega ]
+ | _ => solve [auto]
end.
Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N)
@@ -5546,10 +5791,10 @@ Module MontgomeryReduction.
cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In].
rewrite Hlo, Hhi.
assert (0 <= T mod R * N' < w 2) by (solve_range Hw).
- rewrite !mul_converted_single_eq
- by (rewrite ?mul_converted_single_eq; try assumption; cbv [nth_default nth_error]; solve_range Hw).
+ rewrite !Columns.mul_converted_single_mod;
+ (auto; rewrite ?Columns.mul_converted_single_mod; solve_range Hw).
+ rewrite !Columns.mul_converted_single_div by (auto; solve_range Hw).
rewrite Hw, ?Z.pow_1_r.
- cbv [nth_default nth_error].
autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct.
(* pull out value before last modular reduction *)
@@ -5760,7 +6005,7 @@ Module Montgomery256.
expr_let 36 := MUL_256 @@ ((uint128)(fst @@ x_10 >> 128), (340282366841710300967557013911933812736)) in
expr_let 37 := ADD_256 @@ (x_29, x_36) in
expr_let 39 := ADD_256 @@ (fst @@ x_1, fst @@ x_28) in
- expr_let 40 := ADDC_256 @@ (fst @@ x_39, snd @@ x_1, fst @@ x_37) in
+ expr_let 40 := ADDC_256 @@ (snd @@ x_39, snd @@ x_1, fst @@ x_37) in
expr_let 41 := SELC @@ (snd @@ x_40, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in
expr_let 42 := fst @@ (SUB_256 @@ (fst @@ x_40, x_41)) in
ADDM @@ (x_42, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951))
@@ -5879,4 +6124,4 @@ c.Add256($r10, $r8, $r9_lo);
c.Sub($r42, $r40_lo, $r41);
c.AddM($ret, $r42, RegZero, RegMod);)))
: expr uint256
-*) \ No newline at end of file
+ *)