aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-04-09 10:06:10 +0200
committerGravatar Jade Philipoom <jadep@google.com>2018-04-09 10:06:10 +0200
commit76f2195bfbb31caccac87203f65e4538f3a0dafc (patch)
tree632e119f01f09fbb49e2593e6e99e701c9664aa9
parent3e779870f34a39bcb5c43d8d2f2c1749ac830575 (diff)
reorganization: move more things into BaseConversion
m---------coqprime0
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v321
2 files changed, 161 insertions, 160 deletions
diff --git a/coqprime b/coqprime
-Subproject 59e3bf69a84c593ad733b83dbcfa90036f5d052
+Subproject bd626ee330cc28aadfc2d675772f5077b098f71
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index dd52e96f4..8be01866a 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -595,30 +595,6 @@ Section mod_ops.
Qed.
End mod_ops.
-Module BaseConversion.
- Import Positional.
- Section BaseConversion.
- Context (sw dw : nat -> Z) (* source/destination weight functions *)
- (dw_0 : dw 0%nat = 1)
- (dw_nz : forall i, dw i <> 0).
- Context (dw_divides : forall i : nat, dw (S i) / dw i > 0).
-
- Definition convert_bases (sn dn : nat) (p : list Z) : list Z :=
- let p' := Positional.from_associational dw dn (Positional.to_associational sw sn p) in
- chained_carries_no_reduce dw dn p' (seq 0 (pred dn)).
-
- Lemma eval_convert_bases sn dn p :
- (dn <> 0%nat) -> length p = sn ->
- eval dw dn (convert_bases sn dn p) = eval sw sn p.
- Proof.
- cbv [convert_bases]; intros.
- rewrite eval_chained_carries_no_reduce; auto using ZUtil.Z.positive_is_nonzero.
- rewrite eval_from_associational; auto.
- Qed.
-
- End BaseConversion.
-End BaseConversion.
-
Module Saturated.
Section Weight.
Context (weight : nat->Z)
@@ -640,7 +616,7 @@ Module Saturated.
| _ => reflexivity
end.
Qed.
-
+
Lemma weight_multiples_full j i : (i <= j)%nat -> weight j mod weight i = 0.
Proof.
intros; replace j with (i + (j - i))%nat by omega.
@@ -714,7 +690,7 @@ Module Saturated.
rewrite Z.mod_pull_div, Z.mul_div_eq' by auto using Z.gt_lt.
repeat (f_equal; try omega).
Qed.
-
+
Lemma add_mod_l_multiple a b n m:
0 < n / m -> m <> 0 -> n mod m = 0 ->
(a mod n + b) mod m = (a + b) mod m.
@@ -727,7 +703,7 @@ Module Saturated.
Qed.
Definition is_div_mod {T} (evalf : T -> Z) dm y n :=
- evalf (fst dm) = y mod n /\ snd dm = y / n.
+ evalf (fst dm) = y mod n /\ snd dm = y / n.
Lemma is_div_mod_step {T} evalf1 evalf2 dm1 dm2 y1 y2 n1 n2 x :
n1 > 0 ->
@@ -842,14 +818,14 @@ Module Columns.
| _ => progress autorewrite with pull_Zmod pull_Zdiv zsimplify_fast
| _ => progress autorewrite with list distr_length push_eval
end.
-
+
Lemma flatten_column_mod fw (xs : list Z) :
fst (flatten_column fw xs) = sum xs mod fw.
Proof.
induction xs; simpl flatten_column; cbv [Let_In];
repeat match goal with
| _ => rewrite IHxs
- | _ => progress push
+ | _ => progress push
end.
Qed. Hint Rewrite flatten_column_mod : to_div_mod.
@@ -964,7 +940,7 @@ Module Columns.
Hint Rewrite Positional.eval_zeros : push_eval.
Hint Rewrite Positional.length_from_associational : distr_length.
Hint Rewrite Positional.eval_add_to_nth using (solve [distr_length]): push_eval.
-
+
(* from_associational *)
Definition from_associational n (p:list (Z*Z)) : list (list Z) :=
List.fold_right (fun t ls =>
@@ -1062,7 +1038,7 @@ Module Rows.
| _ => progress distr_length
| _ => rewrite Positional.eval_snoc with (n:=n) by distr_length
| _ => progress autorewrite with cancel_pair push_eval push_map in *
- | _ => ring
+ | _ => ring
end.
rewrite IHinp by distr_length.
destruct x; cbn [hd tl]; rewrite ?sum_nil, ?sum_cons; ring.
@@ -1072,7 +1048,7 @@ Module Rows.
length inp = n -> length (fst (extract_row inp)) = n.
Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed.
Hint Rewrite length_fst_extract_row : distr_length.
-
+
Lemma length_snd_extract_row n (inp : cols) :
length inp = n -> length (snd (extract_row inp)) = n.
Proof. cbv [extract_row]; autorewrite with cancel_pair; distr_length. Qed.
@@ -1084,7 +1060,7 @@ Module Rows.
(* TODO: move to where list is defined *)
Hint Rewrite @app_nil_l : list.
Hint Rewrite <-@app_comm_cons: list.
-
+
Lemma max_column_size_nil : max_column_size nil = 0%nat.
Proof. reflexivity. Qed. Hint Rewrite max_column_size_nil : push_max_column_size.
Lemma max_column_size_cons col (inp : cols) :
@@ -1142,7 +1118,7 @@ Module Rows.
Proof. apply eval_from_columns'_with_length. reflexivity. Qed.
Hint Rewrite length_snd_from_columns' : distr_length.
Lemma eval_from_columns' m st n :
- (length (fst st) = n) ->
+ (length (fst st) = n) ->
eval n (snd (from_columns' m st)) = Columns.eval weight n (fst st) + eval n (snd st)
- Columns.eval weight n (fst (from_columns' m st)).
Proof. apply eval_from_columns'_with_length. Qed.
@@ -1191,7 +1167,7 @@ Module Rows.
(* from associational *)
Definition from_associational n (p : list (Z * Z)) := from_columns (Columns.from_associational weight n p).
-
+
Lemma eval_from_associational n p: (n <> 0%nat \/ p = nil) ->
eval n (from_associational n p) = Associational.eval p.
Proof.
@@ -1253,7 +1229,7 @@ Module Rows.
rewrite <-nth_default_eq in *.
autorewrite with push_nth_default in *.
rewrite eq_nat_dec_refl in *.
- congruence.
+ congruence.
Qed.
Lemma from_associational_nonnil n p :
@@ -1334,7 +1310,7 @@ Module Rows.
apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length in *; try omega.
eapply is_div_mod_step with (x := x1 + x2); try eassumption; push.
Qed.
-
+
Lemma sum_rows_div_mod n row1 row2 :
length row1 = n -> length row2 = n ->
let eval := Positional.eval weight in
@@ -1373,7 +1349,7 @@ Module Rows.
= ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight (S i))) / (weight i).
Proof.
induction row1 as [|x1 row1]; destruct row2 as [|x2 row2]; intros; subst nm; push; [].
-
+
rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2').
apply IHrow1; clear IHrow1; push;
repeat match goal with
@@ -1390,7 +1366,7 @@ Module Rows.
push. rewrite add_mod_div_multiple by auto using Z.lt_le_incl.
push. }
Qed.
-
+
Lemma sum_rows_partitions row1: forall row2 n i,
length row1 = n -> length row2 = n -> (i < n)%nat ->
nth_default 0 (fst (sum_rows row1 row2)) i
@@ -1493,7 +1469,7 @@ Module Rows.
(forall row, In row inp -> length row = n) ->
length (fst (flatten' start_state inp)) = n.
Proof. apply flatten'_div_mod_length. Qed.
- Hint Rewrite length_flatten' : distr_length.
+ Hint Rewrite length_flatten' : distr_length.
Lemma length_flatten n inp :
(forall row, In row inp -> length row = n) ->
@@ -1551,8 +1527,8 @@ Module Rows.
match goal with H : context [nth_default _ (p ++ [ _ ])] |- _ =>
rewrite <-H by omega end.
{ autorewrite with push_nth_default natsimplify. reflexivity. }
- { autorewrite with push_nth_default natsimplify.
- break_match; omega. }
+ { autorewrite with push_nth_default natsimplify.
+ break_match; omega. }
Qed.
Lemma flatten_partitions' inp n :
@@ -1595,31 +1571,96 @@ Module Rows.
End Rows.
End Rows.
-Module MulConverted.
- Section mul_converted.
- Context (w w' : nat -> Z).
- Context (w'_0 : w' 0%nat = 1)
- (w'_nonzero : forall i, w' i <> 0)
- (w'_positive : forall i, w' i > 0)
- (w'_divides : forall i : nat, w' (S i) / w' i > 0).
- Context (w_0 : w 0%nat = 1)
- (w_nonzero : forall i, w i <> 0)
- (w_positive : forall i, w i > 0)
- (w_multiples : forall i, w (S i) mod w i = 0)
- (w_divides : forall i : nat, w (S i) / w i > 0).
-
- (* TODO : move this stuff to BaseConversion *)
- Definition to_associational n m p : list (Z * Z) :=
- let p' := BaseConversion.convert_bases w w' n m p in
- Positional.to_associational w' m p'.
-
- Definition from_associational idxs n (p : list (Z * Z)) : list Z :=
- (* important not to use Positional.carry here; we don't want to accumulate yet *)
- let p' := fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p (rev idxs) in
- fst (Rows.flatten w n (Rows.from_associational w n p')).
-
- (* takes in inputs in base w, converts to w', multiplies in that
- format, converts to w again, then flattens. *)
+Module BaseConversion.
+ Import Positional.
+ Section BaseConversion.
+ Context (sw dw : nat -> Z) (* source/destination weight functions *)
+ (dw_0 : dw 0%nat = 1)
+ (sw_0 : sw 0%nat = 1)
+ (dw_nz : forall i, dw i <> 0)
+ (sw_nz : forall i, sw i <> 0)
+ (sw_pos : forall i, sw i > 0)
+ (sw_multiples : forall i, sw (S i) mod sw i = 0)
+ (sw_divides : forall i, sw (S i) / sw i > 0)
+ (dw_divides : forall i, dw (S i) / dw i > 0).
+
+ Definition convert_bases (sn dn : nat) (p : list Z) : list Z :=
+ let p' := Positional.from_associational dw dn (Positional.to_associational sw sn p) in
+ chained_carries_no_reduce dw dn p' (seq 0 (pred dn)).
+
+ Lemma eval_convert_bases sn dn p :
+ (dn <> 0%nat) -> length p = sn ->
+ eval dw dn (convert_bases sn dn p) = eval sw sn p.
+ Proof.
+ cbv [convert_bases]; intros.
+ rewrite eval_chained_carries_no_reduce; auto using ZUtil.Z.positive_is_nonzero.
+ rewrite eval_from_associational; auto.
+ Qed.
+
+ Hint Rewrite
+ @Rows.eval_from_associational
+ @Associational.eval_carry
+ @Associational.eval_mul
+ @Positional.eval_to_associational
+ @eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval.
+
+ Ltac push_eval := intros; autorewrite with push_eval; auto with zarith.
+
+ (* convert from positional in one weight to the other, then to associational *)
+ Definition to_associational n m p : list (Z * Z) :=
+ let p' := convert_bases n m p in
+ Positional.to_associational dw m p'.
+
+ (* carry at specified indices in dw, then use Rows.flatten to convert to Positional with sw *)
+ Definition from_associational idxs n (p : list (Z * Z)) : list Z :=
+ (* important not to use Positional.carry here; we don't want to accumulate yet *)
+ let p' := fold_right (fun i acc => Associational.carry (dw i) (dw (S i) / dw i) acc) p (rev idxs) in
+ fst (Rows.flatten sw n (Rows.from_associational sw n p')).
+
+ Lemma eval_carries p idxs :
+ Associational.eval (fold_right (fun i acc => Associational.carry (dw i) (dw (S i) / dw i) acc) p idxs) =
+ Associational.eval p.
+ Proof. apply fold_right_invariant; push_eval. Qed.
+ Hint Rewrite eval_carries: push_eval.
+
+ Lemma eval_to_associational n m p :
+ m <> 0%nat -> length p = n ->
+ Associational.eval (to_associational n m p) = Positional.eval sw n p.
+ Proof. cbv [to_associational]; push_eval. Qed.
+ Hint Rewrite eval_to_associational using solve [push_eval; distr_length] : push_eval.
+
+ Lemma eval_from_associational idxs n p :
+ n <> 0%nat -> 0 <= Associational.eval p < sw n ->
+ Positional.eval sw n (from_associational idxs n p) = Associational.eval p.
+ Proof.
+ cbv [from_associational]; intros.
+ rewrite Rows.flatten_mod by eauto using Rows.length_from_associational.
+ push_eval.
+ Qed.
+ Hint Rewrite eval_from_associational using solve [push_eval; distr_length] : push_eval.
+
+ Lemma from_associational_partitions n idxs p (_:n<>0%nat):
+ forall i, (i < n)%nat ->
+ nth_default 0 (from_associational idxs n p) i = (Associational.eval p) mod (sw (S i)) / sw i.
+ Proof.
+ intros; cbv [from_associational].
+ rewrite Rows.flatten_partitions with (n:=n) by (eauto using Rows.length_from_associational; omega).
+ push_eval.
+ Qed.
+
+ Lemma from_associational_eq n idxs p (_:n<>0%nat):
+ from_associational idxs n p = Rows.partition sw n (Associational.eval p).
+ Proof.
+ intros. cbv [from_associational].
+ rewrite Rows.flatten_partitions' with (n:=n) by eauto using Rows.length_from_associational.
+ push_eval.
+ Qed.
+
+ (* carry chain that aligns terms in the intermediate weight with the final weight *)
+ Definition aligned_carries (log_dw_sw nout : nat)
+ := (map (fun i => ((log_dw_sw * (i + 1)) - 1))%nat (seq 0 nout)).
+
+ Section mul_converted.
Definition mul_converted
n1 n2 (* lengths in original format *)
m1 m2 (* lengths in converted format *)
@@ -1631,104 +1672,63 @@ Module MulConverted.
let p3_a := Associational.mul p1_a p2_a in
from_associational idxs n3 p3_a.
- Hint Rewrite
- @Rows.eval_from_associational
- @Associational.eval_carry
- @Associational.eval_mul
- @Positional.eval_to_associational
- @BaseConversion.eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval.
-
- Ltac push_eval := intros; autorewrite with push_eval; auto with zarith.
-
- Lemma eval_carries p idxs :
- Associational.eval (fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p idxs) =
- Associational.eval p.
- Proof. apply fold_right_invariant; push_eval. Qed.
- Hint Rewrite eval_carries: push_eval.
-
- Lemma eval_to_associational n m p :
- m <> 0%nat -> length p = n ->
- Associational.eval (to_associational n m p) = Positional.eval w n p.
- Proof. cbv [to_associational]; push_eval. Qed.
- Hint Rewrite eval_to_associational using solve [push_eval; distr_length] : push_eval.
-
- Lemma eval_from_associational idxs n p :
- n <> 0%nat -> 0 <= Associational.eval p < w n ->
- Positional.eval w n (from_associational idxs n p) = Associational.eval p.
- Proof.
- cbv [from_associational]; intros.
- rewrite Rows.flatten_mod by eauto using Rows.length_from_associational.
- push_eval.
- Qed.
- Hint Rewrite eval_from_associational using solve [push_eval; distr_length] : push_eval.
-
Lemma eval_mul_converted n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
length p1 = n1 -> length p2 = n2 ->
- 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
- Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2).
+ 0 <= (Positional.eval sw n1 p1 * Positional.eval sw n2 p2) < sw n3 ->
+ Positional.eval sw n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval sw n1 p1) * (Positional.eval sw n2 p2).
Proof. cbv [mul_converted]; push_eval. Qed.
Hint Rewrite eval_mul_converted : push_eval.
- Lemma from_associational_partitions n idxs p (_:n<>0%nat):
- forall i, (i < n)%nat ->
- nth_default 0 (from_associational idxs n p) i = (Associational.eval p) mod (w (S i)) / w i.
- Proof.
- intros; cbv [from_associational].
- rewrite Rows.flatten_partitions with (n:=n) by (eauto using Rows.length_from_associational; omega).
- push_eval.
- Qed.
-
Lemma mul_converted_partitions n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
length p1 = n1 -> length p2 = n2 ->
- forall i, (i < n3)%nat ->
- nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) i = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w (S i)) / w i.
- Proof.
- intros; cbv [mul_converted].
- rewrite from_associational_partitions by auto. push_eval.
- Qed.
-
- Lemma from_associational_eq n idxs p (_:n<>0%nat):
- from_associational idxs n p = Rows.partition w n (Associational.eval p).
- Proof.
- intros. cbv [from_associational].
- rewrite Rows.flatten_partitions' with (n:=n) by eauto using Rows.length_from_associational.
- push_eval.
- Qed.
-
- (* TODO: convert all _partitions proofs to this form? *)
- Lemma mul_converted_eq n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
- length p1 = n1 -> length p2 = n2 ->
- mul_converted n1 n2 m1 m2 n3 idxs p1 p2 = Rows.partition w n3 (Positional.eval w n1 p1 * Positional.eval w n2 p2).
+ mul_converted n1 n2 m1 m2 n3 idxs p1 p2 = Rows.partition sw n3 (Positional.eval sw n1 p1 * Positional.eval sw n2 p2).
Proof.
intros; cbv [mul_converted].
rewrite from_associational_eq by auto. push_eval.
Qed.
+ End mul_converted.
+ End BaseConversion.
- Section aligned.
- Context (log_w'_w : nat) (Hlog_w'_w : forall i, (w' i) ^ Z.of_nat log_w'_w = w i).
- Context (n n' nout : nat) (Hn : n <> 0%nat) (Hn' : n' <> 0%nat) (Hnout : nout <> 0%nat).
-
- (* carry chain that aligns terms in the intermediate weight with the final weight *)
- Definition aligned_carries := (map (fun i => ((log_w'_w * (i + 1)) - 1))%nat (seq 0 nout)).
-
- Definition from_associational_aligned := from_associational aligned_carries nout.
-
- (* assumes both inputs have the same length *)
- Definition mul_aligned := mul_converted n n n' n' nout aligned_carries.
-
- Lemma mul_aligned_eq p1 p2 :
- length p1 = n -> length p2 = n ->
- mul_aligned p1 p2 = Rows.partition w nout (Positional.eval w n p1 * Positional.eval w n p2).
- Proof. cbv [mul_aligned aligned_carries]. auto using mul_converted_eq. Qed.
-
- Lemma eval_mul_aligned p1 p2 :
- length p1 = n -> length p2 = n ->
- 0 <= Positional.eval w n p1 * Positional.eval w n p2 < w nout ->
- Positional.eval w nout (mul_aligned p1 p2) = Positional.eval w n p1 * Positional.eval w n p2.
- Proof. cbv [mul_aligned aligned_carries]. auto using eval_mul_converted. Qed.
- End aligned.
- End mul_converted.
-End MulConverted.
+ (* multiply two (n*k)-bit numbers by converting them to n k-bit limbs each, multiplying, then converting back *)
+ Section widemul.
+ Context (dw : nat -> Z) (n : nat) (n_nz : n <> 0%nat).
+ Context (dw_0 : dw 0%nat = 1)
+ (dw_pos : forall i, dw i > 0)
+ (dw_multiples : forall i, dw (S i) mod dw i = 0)
+ (dw_divides : forall i, dw (S i) / dw i > 0).
+ Context (nout : nat) (nout_nz : nout <> 0%nat). (* this is always 2, but reification has trouble if it's a constant *)
+ Let sw i := (dw i) ^ Z.of_nat n.
+
+ Hint Resolve Z.gt_lt Z.positive_is_nonzero Nat2Z.is_nonneg.
+ (* TODO : There has got to be a cleaner way to do weight functions *)
+ Lemma sw_0 : sw 0 = 1.
+ Proof. subst sw; cbv beta. rewrite dw_0. auto using Z.pow_1_l. Qed.
+ Lemma sw_pos : forall i, sw i > 0.
+ Proof. subst sw; intros; apply Z.lt_gt. Z.zero_bounds. Qed.
+ Lemma sw_nz : forall i, sw i <> 0.
+ Proof. auto using sw_pos. Qed.
+ Lemma sw_multiples : forall i, sw (S i) mod sw i = 0.
+ Proof.
+ subst sw; cbv beta; intros.
+ rewrite Saturated.weight_div_mod with (weight := dw) (j:=S i) (i:=i) by auto.
+ rewrite Z.pow_mul_l.
+ push_Zmod. autorewrite with zsimplify. reflexivity.
+ Qed.
+ Lemma sw_divides : forall i, sw (S i) / sw i > 0.
+ Proof. intros; apply Z.div_positive_gt_0; auto using sw_pos, sw_multiples. Qed.
+ Hint Resolve sw_0 sw_pos sw_nz sw_multiples sw_divides.
+
+ Definition widemul a b := mul_converted sw dw 1 1 n n nout (aligned_carries n nout) [a] [b].
+
+ Lemma widemul_partitions a b :
+ widemul a b = Rows.partition sw nout (a * b).
+ Proof.
+ cbv [widemul].
+ rewrite mul_converted_partitions by auto with zarith.
+ cbn; rewrite sw_0; ring_simplify_subterms. reflexivity.
+ Qed.
+ End widemul.
+End BaseConversion.
Module Import MOVEME.
Fixpoint fold_andb_map {A B} (f : A -> B -> bool) (ls1 : list A) (ls2 : list B)
@@ -7810,6 +7810,7 @@ Module MontgomeryReduction.
:= weight_multiples _ _ half_log2R_good.
Let w_mul_divides : forall i : nat, w_mul (S i) / w_mul i > 0
:= weight_divides _ _ half_log2R_good.
+(*
Let w_0 : w 0%nat = 1 := weight_0 _ _.
Let w_nonzero : forall i, w i <> 0
:= weight_nz _ _ log2R_good.
@@ -7820,14 +7821,13 @@ Module MontgomeryReduction.
Let w_divides : forall i : nat, w (S i) / w i > 0
:= weight_divides _ _ log2R_good.
Let w_1_gt1 : w 1 > 1 := weight_1_gt_1 _ _ log2R_good.
+*)
Let w_mul_1_gt1 : w_mul 1 > 1 := weight_1_gt_1 _ _ half_log2R_good.
Context (nout : nat) (Hnout : nout = 2%nat).
- Definition widemul a b := MulConverted.mul_aligned w w_mul n 1%nat n nout [a] [b].
-
Definition montred' (lo_hi : (Z * Z)) :=
- dlet_nd y := nth_default 0 (widemul (fst lo_hi) N') 0 in
- dlet_nd t1_t2 := widemul y N in
+ dlet_nd y := nth_default 0 (BaseConversion.widemul w_mul n nout (fst lo_hi) N') 0 in
+ dlet_nd t1_t2 := (BaseConversion.widemul w_mul n nout y N) in
dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in
dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in
dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in
@@ -7841,7 +7841,7 @@ Module MontgomeryReduction.
rewrite Z.pow_mul_r, R_two_pow by omega; reflexivity.
Qed.
- Local Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r in *.
+ Local Ltac change_weight := rewrite ?w_mul_pown, !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r in *.
Local Ltac solve_range :=
repeat match goal with
| _ => progress change_weight
@@ -7855,9 +7855,10 @@ Module MontgomeryReduction.
end.
Lemma widemul_correct x y :
- 0 <= x * y < w 2 -> widemul x y = [(x * y) mod R; (x * y) / R].
+ 0 <= x * y < w 2 -> BaseConversion.widemul w_mul n nout x y = [(x * y) mod R; (x * y) / R].
Proof.
- intros; cbv [widemul]. rewrite MulConverted.mul_aligned_eq by (auto; distr_length).
+ intros.
+ rewrite BaseConversion.widemul_partitions by (auto; omega).
subst nout. cbn. change_weight.
autorewrite with zsimplify.
Z.rewrite_mod_small. reflexivity.