aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-04-06 15:09:43 +0200
committerGravatar Jade Philipoom <jadep@google.com>2018-04-06 15:09:43 +0200
commit3e779870f34a39bcb5c43d8d2f2c1749ac830575 (patch)
tree80be8856ac2bc023b36a5459e960b9a19b703b88
parent59e1cd39e3adc7212e9acb4995d092400b94f7ab (diff)
better factoring-out of mul_converted stuff, define saturated arith operations
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v322
1 files changed, 246 insertions, 76 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index bbddecd02..dd52e96f4 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -1017,6 +1017,7 @@ Module Rows.
Local Notation cols := (list (list Z)) (only parsing).
Hint Rewrite Positional.eval_nil Positional.eval0 @Positional.eval_snoc
+ Positional.eval_to_associational
Columns.eval_nil Columns.eval_snoc using (auto; solve [distr_length]) : push_eval.
Hint Resolve in_eq in_cons.
Hint Resolve Z.gt_lt.
@@ -1206,6 +1207,65 @@ Module Rows.
match goal with H: _ |- _ => apply length_from_columns in H end.
rewrite Columns.length_from_associational in *; auto.
Qed.
+
+ (* TODO : move *)
+ Lemma max_0_iff a b : Nat.max a b = 0%nat <-> (a = 0%nat /\ b = 0%nat).
+ Proof.
+ destruct a, b; try tauto.
+ rewrite <-Nat.succ_max_distr.
+ split; [ | destruct 1]; congruence.
+ Qed.
+ Lemma max_column_size_zero_iff x :
+ max_column_size x = 0%nat <-> (forall c, In c x -> c = nil).
+ Proof.
+ cbv [max_column_size]; induction x; intros; [ cbn; tauto | ].
+ autorewrite with push_fold_right push_map.
+ rewrite max_0_iff, IHx.
+ split; intros; [ | rewrite length_zero_iff_nil; solve [auto] ].
+ match goal with H : _ /\ _ |- _ => destruct H end.
+ In_cases; subst; auto using length0_nil.
+ Qed.
+
+ Lemma Columns_from_associational_step n t p :
+ Columns.from_associational weight n (t :: p) =
+ Columns.cons_to_nth (fst (Positional.place weight t (Nat.pred n)))
+ (snd (Positional.place weight t (Nat.pred n)))
+ (Columns.from_associational weight n p).
+ Admitted.
+
+ Lemma max_column_size_Columns_from_associational n p :
+ n <> 0%nat -> p <> nil ->
+ max_column_size (Columns.from_associational weight n p) <> 0%nat.
+ Proof.
+ intros.
+ rewrite max_column_size_zero_iff.
+ intro. destruct p; [congruence | ].
+ rewrite Columns_from_associational_step in *.
+ cbv [Columns.cons_to_nth] in *.
+ match goal with H : forall c, In c (update_nth ?n ?f ?ls) -> _ |- _ =>
+ assert (n < length (update_nth n f ls))%nat;
+ [ | specialize (H (nth n (update_nth n f ls) nil) ltac:(auto using nth_In)) ]
+ end.
+ { distr_length.
+ rewrite Columns.length_from_associational.
+ remember (Nat.pred n) as m. replace n with (S m) by omega.
+ apply Positional.place_in_range. }
+ rewrite <-nth_default_eq in *.
+ autorewrite with push_nth_default in *.
+ rewrite eq_nat_dec_refl in *.
+ congruence.
+ Qed.
+
+ Lemma from_associational_nonnil n p :
+ n <> 0%nat -> p <> nil ->
+ from_associational n p <> nil.
+ Proof.
+ intros; cbv [from_associational from_columns from_columns'].
+ pose proof (max_column_size_Columns_from_associational n p ltac:(auto) ltac:(auto)).
+ case_eq (max_column_size (Columns.from_associational weight n p)); [omega|].
+ intros; cbn.
+ rewrite <-length_zero_iff_nil. distr_length.
+ Qed.
End FromAssociational.
Section Flatten.
@@ -1360,11 +1420,11 @@ Module Rows.
let out_carry := sum_rows next_row (fst state) in
(fst out_carry, snd state + snd out_carry)) start_state (rev inp).
- (* For correctness if there is only one row, we add a row of
- zeroes with the same length so that the add loop still happens. *)
- Definition flatten (inp : rows) : list Z * Z :=
- let first_row := hd nil inp in
- flatten' (first_row, 0) (hd (Positional.zeros (length first_row)) (tl inp) :: tl (tl inp)).
+ (* In order for the output to have the right length and bounds,
+ we insert rows of zeroes if there are fewer than two rows. *)
+ Definition flatten n (inp : rows) : list Z * Z :=
+ let default := Positional.zeros n in
+ flatten' (hd default inp, 0) (hd default (tl inp) :: tl (tl inp)).
Lemma flatten'_cons state r inp :
flatten' state (r :: inp) = flatten' (fst (sum_rows r (fst state)), snd state + snd (sum_rows r (fst state))) inp.
@@ -1408,20 +1468,24 @@ Module Rows.
Lemma flatten_div_mod inp n :
(forall row, In row inp -> length row = n) ->
- is_div_mod (Positional.eval weight n) (flatten inp) (eval n inp) (weight n).
+ is_div_mod (Positional.eval weight n) (flatten n inp) (eval n inp) (weight n).
Proof.
intros; cbv [flatten].
- destruct inp; [|destruct inp]; cbn [hd tl]; try solve [cbv [is_div_mod]; push].
- eapply is_div_mod_result_equal; try apply flatten'_div_mod_length; push.
+ destruct inp; [|destruct inp]; cbn [hd tl].
+ { cbv [is_div_mod]; push.
+ erewrite sum_rows_div by (distr_length; reflexivity).
+ push. }
+ { cbv [is_div_mod]; push. }
+ { eapply is_div_mod_result_equal; try apply flatten'_div_mod_length; push. }
Qed.
Lemma flatten_mod inp n :
(forall row, In row inp -> length row = n) ->
- Positional.eval weight n (fst (flatten inp)) = (eval n inp) mod (weight n).
+ Positional.eval weight n (fst (flatten n inp)) = (eval n inp) mod (weight n).
Proof. apply flatten_div_mod. Qed.
Lemma flatten_div inp n :
(forall row, In row inp -> length row = n) ->
- snd (flatten inp) = (eval n inp) / (weight n).
+ snd (flatten n inp) = (eval n inp) / (weight n).
Proof. apply flatten_div_mod. Qed.
Lemma length_flatten' n start_state inp :
@@ -1433,18 +1497,18 @@ Module Rows.
Lemma length_flatten n inp :
(forall row, In row inp -> length row = n) ->
- inp <> nil ->
- length (fst (flatten inp)) = n.
+ length (fst (flatten n inp)) = n.
Proof.
- intros. apply flatten'_div_mod_length; push;
- destruct inp as [|? [|? ?] ]; try congruence; cbn [hd tl] in *; push.
- subst row; distr_length; auto.
+ intros.
+ apply length_flatten'; push;
+ destruct inp as [|? [|? ?] ]; try congruence; cbn [hd tl] in *; push;
+ subst row; distr_length.
Qed. Hint Rewrite length_flatten : distr_length.
Lemma flatten'_partitions n inp : forall start_state,
+ inp <> nil ->
length (fst start_state) = n ->
(forall row, In row inp -> length row = n) ->
- inp <> nil ->
forall i, (i < n)%nat ->
nth_default 0 (fst (flatten' start_state inp)) i
= ((Positional.eval weight n (fst start_state) + eval n inp) mod weight (S i)) / (weight i).
@@ -1460,16 +1524,74 @@ Module Rows.
Lemma flatten_partitions inp n :
(forall row, In row inp -> length row = n) ->
forall i, (i < n)%nat ->
- nth_default 0 (fst (flatten inp)) i = (eval n inp mod weight (S i)) / (weight i).
+ nth_default 0 (fst (flatten n inp)) i = (eval n inp mod weight (S i)) / (weight i).
Proof.
intros; cbv [flatten].
intros; destruct inp as [| ? [| ? ?] ]; try congruence; cbn [hd tl] in *; try solve [push].
- { cbn. autorewrite with push_nth_default. reflexivity. }
+ { cbn. autorewrite with push_nth_default.
+ rewrite sum_rows_partitions with (n:=n) by distr_length.
+ autorewrite with push_eval zsimplify_fast.
+ auto with zarith. }
{ push. rewrite sum_rows_partitions with (n:=n) by distr_length; push. }
{ rewrite flatten'_partitions with (n:=n); push. }
Qed.
+
+ Definition partition n x :=
+ map (fun i => (x mod weight (S i)) / weight i) (seq 0 n).
+
+ Lemma nth_default_partitions x : forall p n,
+ (forall i, (i < n)%nat -> nth_default 0 p i = (x mod weight (S i)) / weight i) ->
+ length p = n ->
+ p = partition n x.
+ Proof.
+ cbv [partition]; induction p using rev_ind; intros; distr_length; subst n; [reflexivity|].
+ rewrite Nat.add_1_r, seq_snoc.
+ autorewrite with natsimplify push_map.
+ rewrite <-IHp; auto; intros;
+ match goal with H : context [nth_default _ (p ++ [ _ ])] |- _ =>
+ rewrite <-H by omega end.
+ { autorewrite with push_nth_default natsimplify. reflexivity. }
+ { autorewrite with push_nth_default natsimplify.
+ break_match; omega. }
+ Qed.
+
+ Lemma flatten_partitions' inp n :
+ (forall row, In row inp -> length row = n) ->
+ fst (flatten n inp) = partition n (eval n inp).
+ Proof. auto using nth_default_partitions, flatten_partitions, length_flatten. Qed.
End Flatten.
+ Section Ops.
+ Definition add n p q :=
+ let p_a := Positional.to_associational weight n p in
+ let q_a := Positional.to_associational weight n q in
+ flatten n (from_associational n (p_a ++ q_a)).
+
+ Definition sub n p q :=
+ let p_a := Positional.to_associational weight n p in
+ let q_a := Positional.to_associational weight n q in
+ flatten n (from_associational n (p_a ++ Associational.negate_snd q_a)).
+
+ Lemma add_partitions n p q :
+ n <> 0%nat -> length p = n -> length q = n ->
+ fst (add n p q) = partition n (Positional.eval weight n p + Positional.eval weight n q).
+ Proof.
+ intros; cbv [add].
+ rewrite flatten_partitions' by eauto using length_from_associational.
+ rewrite eval_from_associational by auto.
+ autorewrite with push_eval; reflexivity.
+ Qed.
+
+ Lemma sub_partitions n p q :
+ n <> 0%nat -> length p = n -> length q = n ->
+ fst (sub n p q) = partition n (Positional.eval weight n p - Positional.eval weight n q).
+ Proof.
+ intros; cbv [sub].
+ rewrite flatten_partitions' by eauto using length_from_associational.
+ rewrite eval_from_associational by auto.
+ autorewrite with push_eval; reflexivity.
+ Qed.
+ End Ops.
End Rows.
End Rows.
@@ -1486,6 +1608,16 @@ Module MulConverted.
(w_multiples : forall i, w (S i) mod w i = 0)
(w_divides : forall i : nat, w (S i) / w i > 0).
+ (* TODO : move this stuff to BaseConversion *)
+ Definition to_associational n m p : list (Z * Z) :=
+ let p' := BaseConversion.convert_bases w w' n m p in
+ Positional.to_associational w' m p'.
+
+ Definition from_associational idxs n (p : list (Z * Z)) : list Z :=
+ (* important not to use Positional.carry here; we don't want to accumulate yet *)
+ let p' := fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p (rev idxs) in
+ fst (Rows.flatten w n (Rows.from_associational w n p')).
+
(* takes in inputs in base w, converts to w', multiplies in that
format, converts to w again, then flattens. *)
Definition mul_converted
@@ -1494,14 +1626,10 @@ Module MulConverted.
(n3 : nat) (* final length *)
(idxs : list nat) (* carries to do -- this helps preemptively line up weights *)
(p1 p2 : list Z) :=
- let p1' := BaseConversion.convert_bases w w' n1 m1 p1 in
- let p2' := BaseConversion.convert_bases w w' n2 m2 p2 in
- let p1_a := Positional.to_associational w' m1 p1' in
- let p2_a := Positional.to_associational w' m2 p2' in
- let p3_a := Associational.mul p1_a p2_a in
- (* important not to use Positional.carry here; we don't want to accumulate yet *)
- let p3'_a := fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p3_a (rev idxs) in
- fst (Rows.flatten w (Rows.from_associational w n3 p3'_a)).
+ let p1_a := to_associational n1 m1 p1 in
+ let p2_a := to_associational n2 m2 p2 in
+ let p3_a := Associational.mul p1_a p2_a in
+ from_associational idxs n3 p3_a.
Hint Rewrite
@Rows.eval_from_associational
@@ -1510,55 +1638,95 @@ Module MulConverted.
@Positional.eval_to_associational
@BaseConversion.eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval.
+ Ltac push_eval := intros; autorewrite with push_eval; auto with zarith.
+
Lemma eval_carries p idxs :
Associational.eval (fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p idxs) =
Associational.eval p.
- Proof. apply fold_right_invariant; intros; autorewrite with push_eval; congruence. Qed.
+ Proof. apply fold_right_invariant; push_eval. Qed.
Hint Rewrite eval_carries: push_eval.
+ Lemma eval_to_associational n m p :
+ m <> 0%nat -> length p = n ->
+ Associational.eval (to_associational n m p) = Positional.eval w n p.
+ Proof. cbv [to_associational]; push_eval. Qed.
+ Hint Rewrite eval_to_associational using solve [push_eval; distr_length] : push_eval.
+
+ Lemma eval_from_associational idxs n p :
+ n <> 0%nat -> 0 <= Associational.eval p < w n ->
+ Positional.eval w n (from_associational idxs n p) = Associational.eval p.
+ Proof.
+ cbv [from_associational]; intros.
+ rewrite Rows.flatten_mod by eauto using Rows.length_from_associational.
+ push_eval.
+ Qed.
+ Hint Rewrite eval_from_associational using solve [push_eval; distr_length] : push_eval.
+
Lemma eval_mul_converted n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
length p1 = n1 -> length p2 = n2 ->
0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2).
+ Proof. cbv [mul_converted]; push_eval. Qed.
+ Hint Rewrite eval_mul_converted : push_eval.
+
+ Lemma from_associational_partitions n idxs p (_:n<>0%nat):
+ forall i, (i < n)%nat ->
+ nth_default 0 (from_associational idxs n p) i = (Associational.eval p) mod (w (S i)) / w i.
Proof.
- cbv [mul_converted]; intros.
- rewrite Rows.flatten_mod by eauto using Rows.length_from_associational.
- autorewrite with push_eval. auto using Z.mod_small.
+ intros; cbv [from_associational].
+ rewrite Rows.flatten_partitions with (n:=n) by (eauto using Rows.length_from_associational; omega).
+ push_eval.
Qed.
- Hint Rewrite eval_mul_converted : push_eval.
- Lemma mul_converted_mod n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
+ Lemma mul_converted_partitions n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
length p1 = n1 -> length p2 = n2 ->
- 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
- nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1).
+ forall i, (i < n3)%nat ->
+ nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) i = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w (S i)) / w i.
Proof.
intros; cbv [mul_converted].
- rewrite Rows.flatten_partitions with (n:=n3) by (eauto using Rows.length_from_associational; omega).
- autorewrite with distr_length push_eval natsimplify.
- rewrite w_0; autorewrite with zsimplify.
- reflexivity.
+ rewrite from_associational_partitions by auto. push_eval.
Qed.
- Lemma mul_converted_div n1 n2 m1 m2 n3 idxs p1 p2:
- m1 <> 0%nat -> m2 <> 0%nat -> n3 = 2%nat ->
+ Lemma from_associational_eq n idxs p (_:n<>0%nat):
+ from_associational idxs n p = Rows.partition w n (Associational.eval p).
+ Proof.
+ intros. cbv [from_associational].
+ rewrite Rows.flatten_partitions' with (n:=n) by eauto using Rows.length_from_associational.
+ push_eval.
+ Qed.
+
+ (* TODO: convert all _partitions proofs to this form? *)
+ Lemma mul_converted_eq n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
length p1 = n1 -> length p2 = n2 ->
- 0 <= Positional.eval w n1 p1 ->
- 0 <= Positional.eval w n2 p2 ->
- 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
- nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1).
+ mul_converted n1 n2 m1 m2 n3 idxs p1 p2 = Rows.partition w n3 (Positional.eval w n1 p1 * Positional.eval w n2 p2).
Proof.
- intros; subst n3; cbv [mul_converted].
- rewrite Rows.flatten_partitions with (n:=2%nat) by (eauto using Rows.length_from_associational; omega).
- autorewrite with distr_length push_eval.
- rewrite Z.mod_small; omega.
+ intros; cbv [mul_converted].
+ rewrite from_associational_eq by auto. push_eval.
Qed.
- (* shortcut definition for convert-mul-convert for cases when we are halving the bitwidth before multiplying. *)
- (* the most important feature here is the carries--we carry from all the odd indices after multiplying,
- thus pre-aligning everything with the double-size bitwidth *)
- Definition mul_converted_halve n n2 :=
- mul_converted n n n2 n2 n2 (map (fun x => 2*x + 1)%nat (seq 0 n)).
-
+ Section aligned.
+ Context (log_w'_w : nat) (Hlog_w'_w : forall i, (w' i) ^ Z.of_nat log_w'_w = w i).
+ Context (n n' nout : nat) (Hn : n <> 0%nat) (Hn' : n' <> 0%nat) (Hnout : nout <> 0%nat).
+
+ (* carry chain that aligns terms in the intermediate weight with the final weight *)
+ Definition aligned_carries := (map (fun i => ((log_w'_w * (i + 1)) - 1))%nat (seq 0 nout)).
+
+ Definition from_associational_aligned := from_associational aligned_carries nout.
+
+ (* assumes both inputs have the same length *)
+ Definition mul_aligned := mul_converted n n n' n' nout aligned_carries.
+
+ Lemma mul_aligned_eq p1 p2 :
+ length p1 = n -> length p2 = n ->
+ mul_aligned p1 p2 = Rows.partition w nout (Positional.eval w n p1 * Positional.eval w n p2).
+ Proof. cbv [mul_aligned aligned_carries]. auto using mul_converted_eq. Qed.
+
+ Lemma eval_mul_aligned p1 p2 :
+ length p1 = n -> length p2 = n ->
+ 0 <= Positional.eval w n p1 * Positional.eval w n p2 < w nout ->
+ Positional.eval w nout (mul_aligned p1 p2) = Positional.eval w n p1 * Positional.eval w n p2.
+ Proof. cbv [mul_aligned aligned_carries]. auto using eval_mul_converted. Qed.
+ End aligned.
End mul_converted.
End MulConverted.
@@ -7655,14 +7823,11 @@ Module MontgomeryReduction.
Let w_mul_1_gt1 : w_mul 1 > 1 := weight_1_gt_1 _ _ half_log2R_good.
Context (nout : nat) (Hnout : nout = 2%nat).
- (* simpler version of mul_converted with a carry chain that aligns
- terms in the intermediate weight with the final weight *)
- Definition mul_converted_aligned w w' (log_w'_w : nat) n m nout :=
- MulConverted.mul_converted w w' n n m m nout (map (fun i => ((log_w'_w * (i + 1)) - 1))%nat (seq 0 nout)).
+ Definition widemul a b := MulConverted.mul_aligned w w_mul n 1%nat n nout [a] [b].
Definition montred' (lo_hi : (Z * Z)) :=
- dlet_nd y := nth_default 0 (mul_converted_aligned w w_mul n 1%nat n nout [fst lo_hi] [N']) 0 in
- dlet_nd t1_t2 := mul_converted_aligned w w_mul n 1%nat n nout [y] [N] in
+ dlet_nd y := nth_default 0 (widemul (fst lo_hi) N') 0 in
+ dlet_nd t1_t2 := widemul y N in
dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in
dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in
dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in
@@ -7676,9 +7841,10 @@ Module MontgomeryReduction.
rewrite Z.pow_mul_r, R_two_pow by omega; reflexivity.
Qed.
+ Local Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r in *.
Local Ltac solve_range :=
repeat match goal with
- | _ => rewrite Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r
+ | _ => progress change_weight
| |- context [?a mod ?b] => unique pose proof (Z.mod_pos_bound a b ltac:(omega))
| |- 0 <= _ => progress Z.zero_bounds
| |- 0 <= _ * _ < _ * _ =>
@@ -7688,9 +7854,14 @@ Module MontgomeryReduction.
| _ => nia
end.
- Hint Rewrite
- MulConverted.mul_converted_mod MulConverted.mul_converted_div using (solve [auto; autorewrite with mul_conv; solve_range])
- : mul_conv.
+ Lemma widemul_correct x y :
+ 0 <= x * y < w 2 -> widemul x y = [(x * y) mod R; (x * y) / R].
+ Proof.
+ intros; cbv [widemul]. rewrite MulConverted.mul_aligned_eq by (auto; distr_length).
+ subst nout. cbn. change_weight.
+ autorewrite with zsimplify.
+ Z.rewrite_mod_small. reflexivity.
+ Qed.
Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N)
(Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R):
@@ -7698,11 +7869,11 @@ Module MontgomeryReduction.
Proof.
rewrite <-reduce_via_partial_alt_eq by nia.
cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In].
- rewrite Hlo, Hhi. subst nout.
+ rewrite Hlo, Hhi.
assert (0 <= T mod R * N' < w 2) by (solve_range).
- cbv [mul_converted_aligned]. cbn [seq map].
- autorewrite with mul_conv.
- rewrite Hw, ?Z.pow_1_r.
+
+ rewrite !widemul_correct by (rewrite ?widemul_correct; autorewrite with push_nth_default; solve_range).
+ autorewrite with push_nth_default.
autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct.
(* pull out value before last modular reduction *)
@@ -7713,8 +7884,7 @@ Module MontgomeryReduction.
|- context [if R * R <=? ?x then _ else _] =>
match goal with |- context [if dec (?xHigh / R = 0) then _ else _] =>
assert (x / R = xHigh) as cond_equiv end end.
- { apply Z.mul_cancel_r with (p:=R); [omega|]. cbn.
- rewrite w_0. autorewrite with zsimplify_fast.
+ { apply Z.mul_cancel_r with (p:=R); [omega|].
autorewrite with push_Zmul zdiv_to_mod push_Zmod; ring. }
rewrite <-cond_equiv. rewrite ?Z.mod_pull_div, ?Z.div_div by omega.
assert (0 < R * R)%Z by Z.zero_bounds.
@@ -7762,7 +7932,7 @@ Module MontgomeryReduction.
Let bound := Some r[0 ~> (2^machine_wordsize - 1)%Z]%zrange.
Definition relax_zrange_of_machine_wordsize
- := relax_zrange_gen [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize; 4 * machine_wordsize]%Z.
+ := relax_zrange_gen [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize]%Z.
Local Arguments relax_zrange_of_machine_wordsize / .
Let relax_zrange := relax_zrange_of_machine_wordsize.
@@ -7855,7 +8025,7 @@ montred256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z *
expr_let x29 := SELC (x28₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in
expr_let x30 := Z.cast uint256 @@ (fst @@ SUB_256 (x28₁, x29)) in
ADDM (x30, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951)
- : Expr (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z)
+ : Expr (type.uncurry (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z))
*)
End Montgomery256.
@@ -7881,10 +8051,10 @@ Module Montgomery256PrintingNotations.
(AppIdent
(primitive 79228162514264337593543950335)
TT) (only printing, at level 9) : expr_scope.
- Notation "'RegMod' '<<' '128'" :=
+ Notation "'RegMod' '>>' '128'" :=
(AppIdent
(primitive 340282366841710300967557013911933812736)
- TT) (only printing, at level 9, format "'RegMod' '<<' '128'") : expr_scope.
+ TT) (only printing, at level 9, format "'RegMod' '>>' '128'") : expr_scope.
Notation "'Lower128{RegPinv}'" :=
(AppIdent
(primitive 79228162514264337593543950337)
@@ -7981,7 +8151,7 @@ c.Lower128($x12, $x10_lo);
c.Mul128x128($x13, Lower128{RegMod}, $x11);
c.ShiftR($x14, $x13, 128);
c.Lower128($x15, $x13);
-c.Mul128x128($x16, RegMod << 128, $x12);
+c.Mul128x128($x16, RegMod >> 128, $x12);
c.ShiftR($x17, $x16, 128);
c.Lower128($x18, $x16);
c.ShiftL($x19, $x18, 128);
@@ -7990,7 +8160,7 @@ c.Add256($x21, $x19, $x20);
c.Addc($x22, $x14, $x17);
c.ShiftL($x23, $x15, 128);
c.Add256($x24, $x23, $x21_lo);
-c.Mul128x128($x25, RegMod << 128, $x11);
+c.Mul128x128($x25, RegMod >> 128, $x11);
c.Addc($x26, $x25, $x22_lo);
c.Add256($x27, $x_lo, $x24_lo);
c.Addc($x28, $x_hi, $x26_lo);