aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic.v
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2019-01-17 15:07:47 -0500
committerGravatar Jason Gross <jasongross9@gmail.com>2019-01-18 19:44:48 -0500
commitcdd5ffb086eb647eabe640c81de9d8af7cd0a1dd (patch)
tree4540df27da661c35fdc5246f1692fa124003ff6f /src/Arithmetic.v
parentb99dd6da3b6370bc225d3b501bda07c49fd29c12 (diff)
Split up PushButtonSynthesis.v
Closes #497
Diffstat (limited to 'src/Arithmetic.v')
-rw-r--r--src/Arithmetic.v589
1 files changed, 589 insertions, 0 deletions
diff --git a/src/Arithmetic.v b/src/Arithmetic.v
index 5af73875b..4436bdf0e 100644
--- a/src/Arithmetic.v
+++ b/src/Arithmetic.v
@@ -4,6 +4,8 @@ Require Import Coq.ZArith.ZArith Coq.micromega.Lia Crypto.Algebra.Nsatz.
Require Import Coq.Sorting.Mergesort Coq.Structures.Orders.
Require Import Coq.Sorting.Permutation.
Require Import Coq.derive.Derive.
+Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. (* For MontgomeryReduction *)
+Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. (* For MontgomeryReduction *)
Require Import Crypto.Util.Tactics.UniquePose Crypto.Util.Decidable.
Require Import Crypto.Util.Tuple Crypto.Util.Prod Crypto.Util.LetIn.
Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.NatUtil.
@@ -4775,3 +4777,590 @@ Module WordByWordMontgomery.
Proof. apply to_bytesmod_correct. Qed.
End modops.
End WordByWordMontgomery.
+
+Module BarrettReduction.
+ (* TODO : generalize to multi-word and operate on (list Z) instead of T; maybe stop taking ops as context variables *)
+ Section Generic.
+ Context {T} (rep : T -> Z -> Prop)
+ (k : Z) (k_pos : 0 < k)
+ (low : T -> Z)
+ (low_correct : forall a x, rep a x -> low a = x mod 2 ^ k)
+ (shiftr : T -> Z -> T)
+ (shiftr_correct : forall a x n,
+ rep a x ->
+ 0 <= n <= k ->
+ rep (shiftr a n) (x / 2 ^ n))
+ (mul_high : T -> T -> Z -> T)
+ (mul_high_correct : forall a b x y x0y1,
+ rep a x ->
+ rep b y ->
+ 2 ^ k <= x < 2^(k+1) ->
+ 0 <= y < 2^(k+1) ->
+ x0y1 = x mod 2 ^ k * (y / 2 ^ k) ->
+ rep (mul_high a b x0y1) (x * y / 2 ^ k))
+ (mul : Z -> Z -> T)
+ (mul_correct : forall x y,
+ 0 <= x < 2^k ->
+ 0 <= y < 2^k ->
+ rep (mul x y) (x * y))
+ (sub : T -> T -> T)
+ (sub_correct : forall a b x y,
+ rep a x ->
+ rep b y ->
+ 0 <= x - y < 2^k * 2^k ->
+ rep (sub a b) (x - y))
+ (cond_sub1 : T -> Z -> Z)
+ (cond_sub1_correct : forall a x y,
+ rep a x ->
+ 0 <= x < 2 * y ->
+ 0 <= y < 2 ^ k ->
+ cond_sub1 a y = if (x <? 2 ^ k) then x else x - y)
+ (cond_sub2 : Z -> Z -> Z)
+ (cond_sub2_correct : forall x y, cond_sub2 x y = if (x <? y) then x else x - y).
+ Context (xt mut : T) (M muSelect: Z).
+
+ Let mu := 2 ^ (2 * k) / M.
+ Context x (mu_rep : rep mut mu) (x_rep : rep xt x).
+ Context (M_nz : 0 < M)
+ (x_range : 0 <= x < M * 2 ^ k)
+ (M_range : 2 ^ (k - 1) < M < 2 ^ k)
+ (M_good : 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu)
+ (muSelect_correct: muSelect = mu mod 2 ^ k * (x / 2 ^ (k - 1) / 2 ^ k)).
+
+ Definition qt :=
+ dlet_nd muSelect := muSelect in (* makes sure muSelect is not inlined in the output *)
+ dlet_nd q1 := shiftr xt (k - 1) in
+ dlet_nd twoq := mul_high mut q1 muSelect in
+ shiftr twoq 1.
+ Definition reduce :=
+ dlet_nd qt := qt in
+ dlet_nd r2 := mul (low qt) M in
+ dlet_nd r := sub xt r2 in
+ let q3 := cond_sub1 r M in
+ cond_sub2 q3 M.
+
+ Lemma looser_bound : M * 2 ^ k < 2 ^ (2*k).
+ Proof. clear -M_range M_nz x_range k_pos; rewrite <-Z.add_diag, Z.pow_add_r; nia. Qed.
+
+ Lemma pow_2k_eq : 2 ^ (2*k) = 2 ^ (k - 1) * 2 ^ (k + 1).
+ Proof. clear -k_pos; rewrite <-Z.pow_add_r by omega. f_equal; ring. Qed.
+
+ Lemma mu_bounds : 2 ^ k <= mu < 2^(k+1).
+ Proof.
+ pose proof looser_bound.
+ subst mu. split.
+ { apply Z.div_le_lower_bound; omega. }
+ { apply Z.div_lt_upper_bound; try omega.
+ rewrite pow_2k_eq; apply Z.mul_lt_mono_pos_r; auto with zarith. }
+ Qed.
+
+ Lemma shiftr_x_bounds : 0 <= x / 2 ^ (k - 1) < 2^(k+1).
+ Proof.
+ pose proof looser_bound.
+ split; [ solve [Z.zero_bounds] | ].
+ apply Z.div_lt_upper_bound; auto with zarith.
+ rewrite <-pow_2k_eq. omega.
+ Qed.
+ Hint Resolve shiftr_x_bounds.
+
+ Ltac solve_rep := eauto using shiftr_correct, mul_high_correct, mul_correct, sub_correct with omega.
+
+ Let q := mu * (x / 2 ^ (k - 1)) / 2 ^ (k + 1).
+
+ Lemma q_correct : rep qt q .
+ Proof.
+ pose proof mu_bounds. cbv [qt]; subst q.
+ rewrite Z.pow_add_r, <-Z.div_div by Z.zero_bounds.
+ solve_rep.
+ Qed.
+ Hint Resolve q_correct.
+
+ Lemma x_mod_small : x mod 2 ^ (k - 1) <= M.
+ Proof. transitivity (2 ^ (k - 1)); auto with zarith. Qed.
+ Hint Resolve x_mod_small.
+
+ Lemma q_bounds : 0 <= q < 2 ^ k.
+ Proof.
+ pose proof looser_bound. pose proof x_mod_small. pose proof mu_bounds.
+ split; subst q; [ solve [Z.zero_bounds] | ].
+ edestruct q_nice_strong with (n:=M) as [? Hqnice];
+ try rewrite Hqnice; auto; try omega; [ ].
+ apply Z.le_lt_trans with (m:= x / M).
+ { break_match; omega. }
+ { apply Z.div_lt_upper_bound; omega. }
+ Qed.
+
+ Lemma two_conditional_subtracts :
+ forall a x,
+ rep a x ->
+ 0 <= x < 2 * M ->
+ cond_sub2 (cond_sub1 a M) M = cond_sub2 (cond_sub2 x M) M.
+ Proof.
+ intros.
+ erewrite !cond_sub2_correct, !cond_sub1_correct by (eassumption || omega).
+ break_match; Z.ltb_to_lt; try lia; discriminate.
+ Qed.
+
+ Lemma r_bounds : 0 <= x - q * M < 2 * M.
+ Proof.
+ pose proof looser_bound. pose proof q_bounds. pose proof x_mod_small.
+ subst q mu; split.
+ { Z.zero_bounds. apply qn_small; omega. }
+ { apply r_small_strong; rewrite ?Z.pow_1_r; auto; omega. }
+ Qed.
+
+ Lemma reduce_correct : reduce = x mod M.
+ Proof.
+ pose proof looser_bound. pose proof r_bounds. pose proof q_bounds.
+ assert (2 * M < 2^k * 2^k) by nia.
+ rewrite barrett_reduction_small with (k:=k) (m:=mu) (offset:=1) (b:=2) by (auto; omega).
+ cbv [reduce Let_In].
+ erewrite low_correct by eauto. Z.rewrite_mod_small.
+ erewrite two_conditional_subtracts by solve_rep.
+ rewrite !cond_sub2_correct.
+ subst q; reflexivity.
+ Qed.
+ End Generic.
+
+ Section BarrettReduction.
+ Context (k : Z) (k_bound : 2 <= k).
+ Context (M muLow : Z).
+ Context (M_pos : 0 < M)
+ (muLow_eq : muLow + 2^k = 2^(2*k) / M)
+ (muLow_bounds : 0 <= muLow < 2^k)
+ (M_bound1 : 2 ^ (k - 1) < M < 2^k)
+ (M_bound2: 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2^k)).
+
+ Context (n:nat) (Hn_nz: n <> 0%nat) (n_le_k : Z.of_nat n <= k).
+ Context (nout : nat) (Hnout : nout = 2%nat).
+ Let w := weight k 1.
+ Local Lemma k_range : 0 < 1 <= k. Proof. omega. Qed.
+ Let props : @weight_properties w := wprops k 1 k_range.
+
+ Hint Rewrite Positional.eval_nil Positional.eval_snoc : push_eval.
+
+ Definition low (t : list Z) : Z := nth_default 0 t 0.
+ Definition high (t : list Z) : Z := nth_default 0 t 1.
+ Definition represents (t : list Z) (x : Z) :=
+ t = [x mod 2^k; x / 2^k] /\ 0 <= x < 2^k * 2^k.
+
+ Lemma represents_eq t x :
+ represents t x -> t = [x mod 2^k; x / 2^k].
+ Proof. cbv [represents]; tauto. Qed.
+
+ Lemma represents_length t x : represents t x -> length t = 2%nat.
+ Proof. cbv [represents]; intuition. subst t; reflexivity. Qed.
+
+ Lemma represents_low t x :
+ represents t x -> low t = x mod 2^k.
+ Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed.
+
+ Lemma represents_high t x :
+ represents t x -> high t = x / 2^k.
+ Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed.
+
+ Lemma represents_low_range t x :
+ represents t x -> 0 <= x mod 2^k < 2^k.
+ Proof. auto with zarith. Qed.
+
+ Lemma represents_high_range t x :
+ represents t x -> 0 <= x / 2^k < 2^k.
+ Proof.
+ destruct 1 as [? [? ?] ]; intros.
+ auto using Z.div_lt_upper_bound with zarith.
+ Qed.
+ Hint Resolve represents_length represents_low_range represents_high_range.
+
+ Lemma represents_range t x :
+ represents t x -> 0 <= x < 2^k*2^k.
+ Proof. cbv [represents]; tauto. Qed.
+
+ Lemma represents_id x :
+ 0 <= x < 2^k * 2^k ->
+ represents [x mod 2^k; x / 2^k] x.
+ Proof.
+ intros; cbv [represents]; autorewrite with cancel_pair.
+ Z.rewrite_mod_small; tauto.
+ Qed.
+
+ Local Ltac push_rep :=
+ repeat match goal with
+ | H : represents ?t ?x |- _ => unique pose proof (represents_low_range _ _ H)
+ | H : represents ?t ?x |- _ => unique pose proof (represents_high_range _ _ H)
+ | H : represents ?t ?x |- _ => rewrite (represents_low t x) in * by assumption
+ | H : represents ?t ?x |- _ => rewrite (represents_high t x) in * by assumption
+ end.
+
+ Definition shiftr (t : list Z) (n : Z) : list Z :=
+ [Z.rshi (2^k) (high t) (low t) n; Z.rshi (2^k) 0 (high t) n].
+
+ Lemma shiftr_represents a i x :
+ represents a x ->
+ 0 <= i <= k ->
+ represents (shiftr a i) (x / 2 ^ i).
+ Proof.
+ cbv [shiftr]; intros; push_rep.
+ match goal with H : _ |- _ => pose proof (represents_range _ _ H) end.
+ assert (0 < 2 ^ i) by auto with zarith.
+ assert (x < 2 ^ i * 2 ^ k * 2 ^ k) by nia.
+ assert (0 <= x / 2 ^ k / 2 ^ i < 2 ^ k) by
+ (split; Z.zero_bounds; auto using Z.div_lt_upper_bound with zarith).
+ repeat match goal with
+ | _ => rewrite Z.rshi_correct by auto with zarith
+ | _ => rewrite <-Z.div_mod''' by auto with zarith
+ | _ => progress autorewrite with zsimplify_fast
+ | _ => progress Z.rewrite_mod_small
+ | |- context [represents [(?a / ?c) mod ?b; ?a / ?b / ?c] ] =>
+ rewrite (Z.div_div_comm a b c) by auto with zarith
+ | _ => solve [auto using represents_id, Z.div_lt_upper_bound with zarith lia]
+ end.
+ Qed.
+
+ Context (Hw : forall i, w i = (2 ^ k) ^ Z.of_nat i).
+ Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r.
+
+ Definition wideadd t1 t2 := fst (Rows.add w 2 t1 t2).
+ (* TODO: use this definition once issue #352 is resolved *)
+ (* Definition widesub t1 t2 := fst (Rows.sub w 2 t1 t2). *)
+ Definition widesub (t1 t2 : list Z) :=
+ let t1_0 := hd 0 t1 in
+ let t1_1 := hd 0 (tl t1) in
+ let t2_0 := hd 0 t2 in
+ let t2_1 := hd 0 (tl t2) in
+ dlet_nd x0 := Z.sub_get_borrow_full (2^k) t1_0 t2_0 in
+ dlet_nd x1 := Z.sub_with_get_borrow_full (2^k) (snd x0) t1_1 t2_1 in
+ [fst x0; fst x1].
+ Definition widemul := BaseConversion.widemul_inlined k n nout.
+
+ Lemma partition_represents x :
+ 0 <= x < 2^k*2^k ->
+ represents (Partition.partition w 2 x) x.
+ Proof.
+ intros; cbn. change_weight.
+ Z.rewrite_mod_small.
+ autorewrite with zsimplify_fast.
+ auto using represents_id.
+ Qed.
+
+ Lemma eval_represents t x :
+ represents t x -> Positional.eval w 2 t = x.
+ Proof.
+ intros; rewrite (represents_eq t x) by assumption.
+ cbn. change_weight; push_rep.
+ autorewrite with zsimplify. reflexivity.
+ Qed.
+
+ Ltac wide_op partitions_pf :=
+ repeat match goal with
+ | _ => rewrite partitions_pf by eauto
+ | _ => rewrite partitions_pf by auto with zarith
+ | _ => erewrite eval_represents by eauto
+ | _ => solve [auto using partition_represents, represents_id]
+ end.
+
+ Lemma wideadd_represents t1 t2 x y :
+ represents t1 x ->
+ represents t2 y ->
+ 0 <= x + y < 2^k*2^k ->
+ represents (wideadd t1 t2) (x + y).
+ Proof. intros; cbv [wideadd]. wide_op Rows.add_partitions. Qed.
+
+ Lemma widesub_represents t1 t2 x y :
+ represents t1 x ->
+ represents t2 y ->
+ 0 <= x - y < 2^k*2^k ->
+ represents (widesub t1 t2) (x - y).
+ Proof.
+ intros; cbv [widesub Let_In].
+ rewrite (represents_eq t1 x) by assumption.
+ rewrite (represents_eq t2 y) by assumption.
+ cbn [hd tl].
+ autorewrite with to_div_mod.
+ pull_Zmod.
+ match goal with |- represents [?m; ?d] ?x =>
+ replace d with (x / 2 ^ k); [solve [auto using represents_id] |] end.
+ rewrite <-(Z.mod_small ((x - y) / 2^k) (2^k)) by (split; try apply Z.div_lt_upper_bound; Z.zero_bounds).
+ f_equal.
+ transitivity ((x mod 2^k - y mod 2^k + 2^k * (x / 2 ^ k) - 2^k * (y / 2^k)) / 2^k). {
+ rewrite (Z.div_mod x (2^k)) at 1 by auto using Z.pow_nonzero with omega.
+ rewrite (Z.div_mod y (2^k)) at 1 by auto using Z.pow_nonzero with omega.
+ f_equal. ring. }
+ autorewrite with zsimplify.
+ ring.
+ Qed.
+ (* Works with Rows.sub-based widesub definition
+ Proof. intros; cbv [widesub]. wide_op Rows.sub_partitions. Qed.
+ *)
+
+ (* TODO: MOVE Equivlalent Keys decl to Arithmetic? *)
+ Declare Equivalent Keys BaseConversion.widemul BaseConversion.widemul_inlined.
+ Lemma widemul_represents x y :
+ 0 <= x < 2^k ->
+ 0 <= y < 2^k ->
+ represents (widemul x y) (x * y).
+ Proof.
+ intros; cbv [widemul].
+ assert (0 <= x * y < 2^k*2^k) by auto with zarith.
+ wide_op BaseConversion.widemul_correct.
+ Qed.
+
+ Definition mul_high (a b : list Z) a0b1 : list Z :=
+ dlet_nd a0b0 := widemul (low a) (low b) in
+ dlet_nd ab := wideadd [high a0b0; high b] [low b; 0] in
+ wideadd ab [a0b1; 0].
+
+ Lemma mul_high_idea d a b a0 a1 b0 b1 :
+ d <> 0 ->
+ a = d * a1 + a0 ->
+ b = d * b1 + b0 ->
+ (a * b) / d = a0 * b0 / d + d * a1 * b1 + a1 * b0 + a0 * b1.
+ Proof.
+ intros. subst a b. autorewrite with push_Zmul.
+ ring_simplify_subterms. rewrite Z.pow_2_r.
+ rewrite Z.div_add_exact by (push_Zmod; autorewrite with zsimplify; omega).
+ repeat match goal with
+ | |- context [d * ?a * ?b * ?c] =>
+ replace (d * a * b * c) with (a * b * c * d) by ring
+ | |- context [d * ?a * ?b] =>
+ replace (d * a * b) with (a * b * d) by ring
+ end.
+ rewrite !Z.div_add by omega.
+ autorewrite with zsimplify.
+ rewrite (Z.mul_comm a0 b0).
+ ring_simplify. ring.
+ Qed.
+
+ Lemma represents_trans t x y:
+ represents t y -> y = x ->
+ represents t x.
+ Proof. congruence. Qed.
+
+ Lemma represents_add x y :
+ 0 <= x < 2 ^ k ->
+ 0 <= y < 2 ^ k ->
+ represents [x;y] (x + 2^k*y).
+ Proof.
+ intros; cbv [represents]; autorewrite with zsimplify.
+ repeat split; (reflexivity || nia).
+ Qed.
+
+ Lemma represents_small x :
+ 0 <= x < 2^k ->
+ represents [x; 0] x.
+ Proof.
+ intros.
+ eapply represents_trans.
+ { eauto using represents_add with zarith. }
+ { ring. }
+ Qed.
+
+ Lemma mul_high_represents a b x y a0b1 :
+ represents a x ->
+ represents b y ->
+ 2^k <= x < 2^(k+1) ->
+ 0 <= y < 2^(k+1) ->
+ a0b1 = x mod 2^k * (y / 2^k) ->
+ represents (mul_high a b a0b1) ((x * y) / 2^k).
+ Proof.
+ cbv [mul_high Let_In]; rewrite Z.pow_add_r, Z.pow_1_r by omega; intros.
+ assert (4 <= 2 ^ k) by (transitivity (Z.pow 2 2); auto with zarith).
+ assert (0 <= x * y / 2^k < 2^k*2^k) by (Z.div_mod_to_quot_rem_in_goal; nia).
+
+ rewrite mul_high_idea with (a:=x) (b:=y) (a0 := low a) (a1 := high a) (b0 := low b) (b1 := high b) in *
+ by (push_rep; Z.div_mod_to_quot_rem_in_goal; lia).
+
+ push_rep. subst a0b1.
+ assert (y / 2 ^ k < 2) by (apply Z.div_lt_upper_bound; omega).
+ replace (x / 2 ^ k) with 1 in * by (rewrite Z.div_between_1; lia).
+ autorewrite with zsimplify_fast in *.
+
+ eapply represents_trans.
+ { repeat (apply wideadd_represents;
+ [ | apply represents_small; Z.div_mod_to_quot_rem_in_goal; nia| ]).
+ erewrite represents_high; [ | apply widemul_represents; solve [ auto with zarith ] ].
+ { apply represents_add; try reflexivity; solve [auto with zarith]. }
+ { match goal with H : 0 <= ?x + ?y < ?z |- 0 <= ?x < ?z =>
+ split; [ solve [Z.zero_bounds] | ];
+ eapply Z.le_lt_trans with (m:= x + y); nia
+ end. }
+ { omega. } }
+ { ring. }
+ Qed.
+
+ Definition cond_sub1 (a : list Z) y : Z :=
+ dlet_nd maybe_y := Z.zselect (Z.cc_l (high a)) 0 y in
+ dlet_nd diff := Z.sub_get_borrow_full (2^k) (low a) maybe_y in
+ fst diff.
+
+ Lemma cc_l_only_bit : forall x s, 0 <= x < 2 * s -> Z.cc_l (x / s) = 0 <-> x < s.
+ Proof.
+ cbv [Z.cc_l]; intros.
+ rewrite Z.div_between_0_if by omega.
+ break_match; Z.ltb_to_lt; Z.rewrite_mod_small; omega.
+ Qed.
+
+ Lemma cond_sub1_correct a x y :
+ represents a x ->
+ 0 <= x < 2 * y ->
+ 0 <= y < 2 ^ k ->
+ cond_sub1 a y = if (x <? 2 ^ k) then x else x - y.
+ Proof.
+ intros; cbv [cond_sub1 Let_In]. rewrite Z.zselect_correct. push_rep.
+ break_match; Z.ltb_to_lt; rewrite cc_l_only_bit in *; try omega;
+ autorewrite with zsimplify_fast to_div_mod pull_Zmod; auto with zarith.
+ Qed.
+
+ Definition cond_sub2 x y := Z.add_modulo x 0 y.
+ Lemma cond_sub2_correct x y :
+ cond_sub2 x y = if (x <? y) then x else x - y.
+ Proof.
+ cbv [cond_sub2]. rewrite Z.add_modulo_correct.
+ autorewrite with zsimplify_fast. break_match; Z.ltb_to_lt; omega.
+ Qed.
+
+ Section Defn.
+ Context (xLow xHigh : Z) (xLow_bounds : 0 <= xLow < 2^k) (xHigh_bounds : 0 <= xHigh < M).
+ Let xt := [xLow; xHigh].
+ Let x := xLow + 2^k * xHigh.
+
+ Lemma x_rep : represents xt x.
+ Proof. cbv [represents]; subst xt x; autorewrite with cancel_pair zsimplify; repeat split; nia. Qed.
+
+ Lemma x_bounds : 0 <= x < M * 2 ^ k.
+ Proof. subst x; nia. Qed.
+
+ Definition muSelect := Z.zselect (Z.cc_m (2 ^ k) xHigh) 0 muLow.
+
+ Local Hint Resolve Z.div_nonneg Z.div_lt_upper_bound.
+ Local Hint Resolve shiftr_represents mul_high_represents widemul_represents widesub_represents
+ cond_sub1_correct cond_sub2_correct represents_low represents_add.
+
+ Lemma muSelect_correct :
+ muSelect = (2 ^ (2 * k) / M) mod 2 ^ k * ((x / 2 ^ (k - 1)) / 2 ^ k).
+ Proof.
+ (* assertions to help arith tactics *)
+ pose proof x_bounds.
+ assert (2^k * M < 2 ^ (2*k)) by (rewrite <-Z.add_diag, Z.pow_add_r; nia).
+ assert (0 <= x / (2 ^ k * (2 ^ k / 2)) < 2) by (Z.div_mod_to_quot_rem_in_goal; auto with nia).
+ assert (0 < 2 ^ k / 2) by Z.zero_bounds.
+ assert (2 ^ (k - 1) <> 0) by auto with zarith.
+ assert (2 < 2 ^ k) by (eapply Z.le_lt_trans with (m:=2 ^ 1); auto with zarith).
+
+ cbv [muSelect]. rewrite <-muLow_eq.
+ rewrite Z.zselect_correct, Z.cc_m_eq by auto with zarith.
+ replace xHigh with (x / 2^k) by (subst x; autorewrite with zsimplify; lia).
+ autorewrite with pull_Zdiv push_Zpow.
+ rewrite (Z.mul_comm (2 ^ k / 2)).
+ break_match; [ ring | ].
+ match goal with H : 0 <= ?x < 2, H' : ?x <> 0 |- _ => replace x with 1 by omega end.
+ autorewrite with zsimplify; reflexivity.
+ Qed.
+
+ Lemma mu_rep : represents [muLow; 1] (2 ^ (2 * k) / M).
+ Proof. rewrite <-muLow_eq. eapply represents_trans; auto with zarith. Qed.
+
+ Derive barrett_reduce
+ SuchThat (barrett_reduce = x mod M)
+ As barrett_reduce_correct.
+ Proof.
+ erewrite <-reduce_correct with (rep:=represents) (muSelect:=muSelect) (k0:=k) (mut:=[muLow;1]) (xt0:=xt)
+ by (auto using x_bounds, muSelect_correct, x_rep, mu_rep; omega).
+ subst barrett_reduce. reflexivity.
+ Qed.
+ End Defn.
+ End BarrettReduction.
+End BarrettReduction.
+
+Module MontgomeryReduction.
+ Local Coercion Z.of_nat : nat >-> Z.
+ Section MontRed'.
+ Context (N R N' R' : Z).
+ Context (HN_range : 0 <= N < R) (HN'_range : 0 <= N' < R) (HN_nz : N <> 0) (R_gt_1 : R > 1)
+ (N'_good : Z.equiv_modulo R (N*N') (-1)) (R'_good: Z.equiv_modulo N (R*R') 1).
+
+ Context (Zlog2R : Z) .
+ Let w : nat -> Z := weight Zlog2R 1.
+ Context (n:nat) (Hn_nz: n <> 0%nat) (n_good : Zlog2R mod Z.of_nat n = 0).
+ Context (R_big_enough : n <= Zlog2R)
+ (R_two_pow : 2^Zlog2R = R).
+ Let w_mul : nat -> Z := weight (Zlog2R / n) 1.
+ Context (nout : nat) (Hnout : nout = 2%nat).
+
+ Definition montred' (lo hi : Z) :=
+ dlet_nd y := nth_default 0 (BaseConversion.widemul_inlined Zlog2R n nout lo N') 0 in
+ dlet_nd t1_t2 := (BaseConversion.widemul_inlined_reverse Zlog2R n nout N y) in
+ dlet_nd sum_carry := Rows.add (weight Zlog2R 1) 2 [lo;hi] t1_t2 in
+ dlet_nd y' := Z.zselect (snd sum_carry) 0 N in
+ dlet_nd lo''_carry := Z.sub_get_borrow_full R (nth_default 0 (fst sum_carry) 1) y' in
+ Z.add_modulo (fst lo''_carry) 0 N.
+
+ Local Lemma Hw : forall i, w i = R ^ Z.of_nat i.
+ Proof.
+ clear -R_big_enough R_two_pow; cbv [w weight]; intro.
+ autorewrite with zsimplify.
+ rewrite Z.pow_mul_r, R_two_pow by omega; reflexivity.
+ Qed.
+
+ Declare Equivalent Keys weight w.
+ Local Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r, ?Z.pow_1_l in *.
+ Local Ltac solve_range :=
+ repeat match goal with
+ | _ => progress change_weight
+ | |- context [?a mod ?b] => unique pose proof (Z.mod_pos_bound a b ltac:(omega))
+ | |- 0 <= _ => progress Z.zero_bounds
+ | |- 0 <= _ * _ < _ * _ =>
+ split; [ solve [Z.zero_bounds] | apply Z.mul_lt_mono_nonneg; omega ]
+ | _ => solve [auto]
+ | _ => omega
+ end.
+
+ Local Lemma eval2 x y : Positional.eval w 2 [x;y] = x + R * y.
+ Proof. cbn. change_weight. ring. Qed.
+
+ Hint Rewrite BaseConversion.widemul_inlined_reverse_correct BaseConversion.widemul_inlined_correct
+ using (autorewrite with widemul push_nth_default; solve [solve_range]) : widemul.
+
+ Lemma montred'_eq lo hi T (HT_range: 0 <= T < R * N)
+ (Hlo: lo = T mod R) (Hhi: hi = T / R):
+ montred' lo hi = reduce_via_partial N R N' T.
+ Proof.
+ rewrite <-reduce_via_partial_alt_eq by nia.
+ cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In].
+ rewrite Hlo, Hhi.
+ assert (0 <= (T mod R) * N' < w 2) by (solve_range).
+
+ autorewrite with widemul.
+ rewrite Rows.add_partitions, Rows.add_div by (distr_length; apply wprops; omega).
+ rewrite R_two_pow.
+ cbv [Partition.partition seq]. rewrite !eval2.
+ autorewrite with push_nth_default push_map.
+ autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct.
+ change_weight.
+
+ (* pull out value before last modular reduction *)
+ match goal with |- (if (?n <=? ?x)%Z then ?x - ?n else ?x) = (if (?n <=? ?y) then ?y - ?n else ?y)%Z =>
+ let P := fresh "H" in assert (x = y) as P; [|rewrite P; reflexivity] end.
+
+ autorewrite with zsimplify.
+ rewrite (Z.mul_comm (((T mod R) * N') mod R) N) in *.
+ break_match; try reflexivity; Z.ltb_to_lt; rewrite Z.div_small_iff in * by omega;
+ repeat match goal with
+ | _ => progress autorewrite with zsimplify_fast
+ | |- context [?x mod (R * R)] =>
+ unique pose proof (Z.mod_pos_bound x (R * R));
+ try rewrite (Z.mod_small x (R * R)) in * by Z.rewrite_mod_small_solver
+ | _ => omega
+ | _ => progress Z.rewrite_mod_small
+ end.
+ Qed.
+
+ Lemma montred'_correct lo hi T (HT_range: 0 <= T < R * N)
+ (Hlo: lo = T mod R) (Hhi: hi = T / R): montred' lo hi = (T * R') mod N.
+ Proof.
+ erewrite montred'_eq by eauto.
+ apply Z.equiv_modulo_mod_small; auto using reduce_via_partial_correct.
+ replace 0 with (Z.min 0 (R-N)) by (apply Z.min_l; omega).
+ apply reduce_via_partial_in_range; omega.
+ Qed.
+ End MontRed'.
+End MontgomeryReduction.