From cdd5ffb086eb647eabe640c81de9d8af7cd0a1dd Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Thu, 17 Jan 2019 15:07:47 -0500 Subject: Split up PushButtonSynthesis.v Closes #497 --- src/Arithmetic.v | 589 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 589 insertions(+) (limited to 'src/Arithmetic.v') diff --git a/src/Arithmetic.v b/src/Arithmetic.v index 5af73875b..4436bdf0e 100644 --- a/src/Arithmetic.v +++ b/src/Arithmetic.v @@ -4,6 +4,8 @@ Require Import Coq.ZArith.ZArith Coq.micromega.Lia Crypto.Algebra.Nsatz. Require Import Coq.Sorting.Mergesort Coq.Structures.Orders. Require Import Coq.Sorting.Permutation. Require Import Coq.derive.Derive. +Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. (* For MontgomeryReduction *) +Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. (* For MontgomeryReduction *) Require Import Crypto.Util.Tactics.UniquePose Crypto.Util.Decidable. Require Import Crypto.Util.Tuple Crypto.Util.Prod Crypto.Util.LetIn. Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.NatUtil. @@ -4775,3 +4777,590 @@ Module WordByWordMontgomery. Proof. apply to_bytesmod_correct. Qed. End modops. End WordByWordMontgomery. + +Module BarrettReduction. + (* TODO : generalize to multi-word and operate on (list Z) instead of T; maybe stop taking ops as context variables *) + Section Generic. + Context {T} (rep : T -> Z -> Prop) + (k : Z) (k_pos : 0 < k) + (low : T -> Z) + (low_correct : forall a x, rep a x -> low a = x mod 2 ^ k) + (shiftr : T -> Z -> T) + (shiftr_correct : forall a x n, + rep a x -> + 0 <= n <= k -> + rep (shiftr a n) (x / 2 ^ n)) + (mul_high : T -> T -> Z -> T) + (mul_high_correct : forall a b x y x0y1, + rep a x -> + rep b y -> + 2 ^ k <= x < 2^(k+1) -> + 0 <= y < 2^(k+1) -> + x0y1 = x mod 2 ^ k * (y / 2 ^ k) -> + rep (mul_high a b x0y1) (x * y / 2 ^ k)) + (mul : Z -> Z -> T) + (mul_correct : forall x y, + 0 <= x < 2^k -> + 0 <= y < 2^k -> + rep (mul x y) (x * y)) + (sub : T -> T -> T) + (sub_correct : forall a b x y, + rep a x -> + rep b y -> + 0 <= x - y < 2^k * 2^k -> + rep (sub a b) (x - y)) + (cond_sub1 : T -> Z -> Z) + (cond_sub1_correct : forall a x y, + rep a x -> + 0 <= x < 2 * y -> + 0 <= y < 2 ^ k -> + cond_sub1 a y = if (x Z -> Z) + (cond_sub2_correct : forall x y, cond_sub2 x y = if (x + 0 <= x < 2 * M -> + cond_sub2 (cond_sub1 a M) M = cond_sub2 (cond_sub2 x M) M. + Proof. + intros. + erewrite !cond_sub2_correct, !cond_sub1_correct by (eassumption || omega). + break_match; Z.ltb_to_lt; try lia; discriminate. + Qed. + + Lemma r_bounds : 0 <= x - q * M < 2 * M. + Proof. + pose proof looser_bound. pose proof q_bounds. pose proof x_mod_small. + subst q mu; split. + { Z.zero_bounds. apply qn_small; omega. } + { apply r_small_strong; rewrite ?Z.pow_1_r; auto; omega. } + Qed. + + Lemma reduce_correct : reduce = x mod M. + Proof. + pose proof looser_bound. pose proof r_bounds. pose proof q_bounds. + assert (2 * M < 2^k * 2^k) by nia. + rewrite barrett_reduction_small with (k:=k) (m:=mu) (offset:=1) (b:=2) by (auto; omega). + cbv [reduce Let_In]. + erewrite low_correct by eauto. Z.rewrite_mod_small. + erewrite two_conditional_subtracts by solve_rep. + rewrite !cond_sub2_correct. + subst q; reflexivity. + Qed. + End Generic. + + Section BarrettReduction. + Context (k : Z) (k_bound : 2 <= k). + Context (M muLow : Z). + Context (M_pos : 0 < M) + (muLow_eq : muLow + 2^k = 2^(2*k) / M) + (muLow_bounds : 0 <= muLow < 2^k) + (M_bound1 : 2 ^ (k - 1) < M < 2^k) + (M_bound2: 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2^k)). + + Context (n:nat) (Hn_nz: n <> 0%nat) (n_le_k : Z.of_nat n <= k). + Context (nout : nat) (Hnout : nout = 2%nat). + Let w := weight k 1. + Local Lemma k_range : 0 < 1 <= k. Proof. omega. Qed. + Let props : @weight_properties w := wprops k 1 k_range. + + Hint Rewrite Positional.eval_nil Positional.eval_snoc : push_eval. + + Definition low (t : list Z) : Z := nth_default 0 t 0. + Definition high (t : list Z) : Z := nth_default 0 t 1. + Definition represents (t : list Z) (x : Z) := + t = [x mod 2^k; x / 2^k] /\ 0 <= x < 2^k * 2^k. + + Lemma represents_eq t x : + represents t x -> t = [x mod 2^k; x / 2^k]. + Proof. cbv [represents]; tauto. Qed. + + Lemma represents_length t x : represents t x -> length t = 2%nat. + Proof. cbv [represents]; intuition. subst t; reflexivity. Qed. + + Lemma represents_low t x : + represents t x -> low t = x mod 2^k. + Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed. + + Lemma represents_high t x : + represents t x -> high t = x / 2^k. + Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed. + + Lemma represents_low_range t x : + represents t x -> 0 <= x mod 2^k < 2^k. + Proof. auto with zarith. Qed. + + Lemma represents_high_range t x : + represents t x -> 0 <= x / 2^k < 2^k. + Proof. + destruct 1 as [? [? ?] ]; intros. + auto using Z.div_lt_upper_bound with zarith. + Qed. + Hint Resolve represents_length represents_low_range represents_high_range. + + Lemma represents_range t x : + represents t x -> 0 <= x < 2^k*2^k. + Proof. cbv [represents]; tauto. Qed. + + Lemma represents_id x : + 0 <= x < 2^k * 2^k -> + represents [x mod 2^k; x / 2^k] x. + Proof. + intros; cbv [represents]; autorewrite with cancel_pair. + Z.rewrite_mod_small; tauto. + Qed. + + Local Ltac push_rep := + repeat match goal with + | H : represents ?t ?x |- _ => unique pose proof (represents_low_range _ _ H) + | H : represents ?t ?x |- _ => unique pose proof (represents_high_range _ _ H) + | H : represents ?t ?x |- _ => rewrite (represents_low t x) in * by assumption + | H : represents ?t ?x |- _ => rewrite (represents_high t x) in * by assumption + end. + + Definition shiftr (t : list Z) (n : Z) : list Z := + [Z.rshi (2^k) (high t) (low t) n; Z.rshi (2^k) 0 (high t) n]. + + Lemma shiftr_represents a i x : + represents a x -> + 0 <= i <= k -> + represents (shiftr a i) (x / 2 ^ i). + Proof. + cbv [shiftr]; intros; push_rep. + match goal with H : _ |- _ => pose proof (represents_range _ _ H) end. + assert (0 < 2 ^ i) by auto with zarith. + assert (x < 2 ^ i * 2 ^ k * 2 ^ k) by nia. + assert (0 <= x / 2 ^ k / 2 ^ i < 2 ^ k) by + (split; Z.zero_bounds; auto using Z.div_lt_upper_bound with zarith). + repeat match goal with + | _ => rewrite Z.rshi_correct by auto with zarith + | _ => rewrite <-Z.div_mod''' by auto with zarith + | _ => progress autorewrite with zsimplify_fast + | _ => progress Z.rewrite_mod_small + | |- context [represents [(?a / ?c) mod ?b; ?a / ?b / ?c] ] => + rewrite (Z.div_div_comm a b c) by auto with zarith + | _ => solve [auto using represents_id, Z.div_lt_upper_bound with zarith lia] + end. + Qed. + + Context (Hw : forall i, w i = (2 ^ k) ^ Z.of_nat i). + Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r. + + Definition wideadd t1 t2 := fst (Rows.add w 2 t1 t2). + (* TODO: use this definition once issue #352 is resolved *) + (* Definition widesub t1 t2 := fst (Rows.sub w 2 t1 t2). *) + Definition widesub (t1 t2 : list Z) := + let t1_0 := hd 0 t1 in + let t1_1 := hd 0 (tl t1) in + let t2_0 := hd 0 t2 in + let t2_1 := hd 0 (tl t2) in + dlet_nd x0 := Z.sub_get_borrow_full (2^k) t1_0 t2_0 in + dlet_nd x1 := Z.sub_with_get_borrow_full (2^k) (snd x0) t1_1 t2_1 in + [fst x0; fst x1]. + Definition widemul := BaseConversion.widemul_inlined k n nout. + + Lemma partition_represents x : + 0 <= x < 2^k*2^k -> + represents (Partition.partition w 2 x) x. + Proof. + intros; cbn. change_weight. + Z.rewrite_mod_small. + autorewrite with zsimplify_fast. + auto using represents_id. + Qed. + + Lemma eval_represents t x : + represents t x -> Positional.eval w 2 t = x. + Proof. + intros; rewrite (represents_eq t x) by assumption. + cbn. change_weight; push_rep. + autorewrite with zsimplify. reflexivity. + Qed. + + Ltac wide_op partitions_pf := + repeat match goal with + | _ => rewrite partitions_pf by eauto + | _ => rewrite partitions_pf by auto with zarith + | _ => erewrite eval_represents by eauto + | _ => solve [auto using partition_represents, represents_id] + end. + + Lemma wideadd_represents t1 t2 x y : + represents t1 x -> + represents t2 y -> + 0 <= x + y < 2^k*2^k -> + represents (wideadd t1 t2) (x + y). + Proof. intros; cbv [wideadd]. wide_op Rows.add_partitions. Qed. + + Lemma widesub_represents t1 t2 x y : + represents t1 x -> + represents t2 y -> + 0 <= x - y < 2^k*2^k -> + represents (widesub t1 t2) (x - y). + Proof. + intros; cbv [widesub Let_In]. + rewrite (represents_eq t1 x) by assumption. + rewrite (represents_eq t2 y) by assumption. + cbn [hd tl]. + autorewrite with to_div_mod. + pull_Zmod. + match goal with |- represents [?m; ?d] ?x => + replace d with (x / 2 ^ k); [solve [auto using represents_id] |] end. + rewrite <-(Z.mod_small ((x - y) / 2^k) (2^k)) by (split; try apply Z.div_lt_upper_bound; Z.zero_bounds). + f_equal. + transitivity ((x mod 2^k - y mod 2^k + 2^k * (x / 2 ^ k) - 2^k * (y / 2^k)) / 2^k). { + rewrite (Z.div_mod x (2^k)) at 1 by auto using Z.pow_nonzero with omega. + rewrite (Z.div_mod y (2^k)) at 1 by auto using Z.pow_nonzero with omega. + f_equal. ring. } + autorewrite with zsimplify. + ring. + Qed. + (* Works with Rows.sub-based widesub definition + Proof. intros; cbv [widesub]. wide_op Rows.sub_partitions. Qed. + *) + + (* TODO: MOVE Equivlalent Keys decl to Arithmetic? *) + Declare Equivalent Keys BaseConversion.widemul BaseConversion.widemul_inlined. + Lemma widemul_represents x y : + 0 <= x < 2^k -> + 0 <= y < 2^k -> + represents (widemul x y) (x * y). + Proof. + intros; cbv [widemul]. + assert (0 <= x * y < 2^k*2^k) by auto with zarith. + wide_op BaseConversion.widemul_correct. + Qed. + + Definition mul_high (a b : list Z) a0b1 : list Z := + dlet_nd a0b0 := widemul (low a) (low b) in + dlet_nd ab := wideadd [high a0b0; high b] [low b; 0] in + wideadd ab [a0b1; 0]. + + Lemma mul_high_idea d a b a0 a1 b0 b1 : + d <> 0 -> + a = d * a1 + a0 -> + b = d * b1 + b0 -> + (a * b) / d = a0 * b0 / d + d * a1 * b1 + a1 * b0 + a0 * b1. + Proof. + intros. subst a b. autorewrite with push_Zmul. + ring_simplify_subterms. rewrite Z.pow_2_r. + rewrite Z.div_add_exact by (push_Zmod; autorewrite with zsimplify; omega). + repeat match goal with + | |- context [d * ?a * ?b * ?c] => + replace (d * a * b * c) with (a * b * c * d) by ring + | |- context [d * ?a * ?b] => + replace (d * a * b) with (a * b * d) by ring + end. + rewrite !Z.div_add by omega. + autorewrite with zsimplify. + rewrite (Z.mul_comm a0 b0). + ring_simplify. ring. + Qed. + + Lemma represents_trans t x y: + represents t y -> y = x -> + represents t x. + Proof. congruence. Qed. + + Lemma represents_add x y : + 0 <= x < 2 ^ k -> + 0 <= y < 2 ^ k -> + represents [x;y] (x + 2^k*y). + Proof. + intros; cbv [represents]; autorewrite with zsimplify. + repeat split; (reflexivity || nia). + Qed. + + Lemma represents_small x : + 0 <= x < 2^k -> + represents [x; 0] x. + Proof. + intros. + eapply represents_trans. + { eauto using represents_add with zarith. } + { ring. } + Qed. + + Lemma mul_high_represents a b x y a0b1 : + represents a x -> + represents b y -> + 2^k <= x < 2^(k+1) -> + 0 <= y < 2^(k+1) -> + a0b1 = x mod 2^k * (y / 2^k) -> + represents (mul_high a b a0b1) ((x * y) / 2^k). + Proof. + cbv [mul_high Let_In]; rewrite Z.pow_add_r, Z.pow_1_r by omega; intros. + assert (4 <= 2 ^ k) by (transitivity (Z.pow 2 2); auto with zarith). + assert (0 <= x * y / 2^k < 2^k*2^k) by (Z.div_mod_to_quot_rem_in_goal; nia). + + rewrite mul_high_idea with (a:=x) (b:=y) (a0 := low a) (a1 := high a) (b0 := low b) (b1 := high b) in * + by (push_rep; Z.div_mod_to_quot_rem_in_goal; lia). + + push_rep. subst a0b1. + assert (y / 2 ^ k < 2) by (apply Z.div_lt_upper_bound; omega). + replace (x / 2 ^ k) with 1 in * by (rewrite Z.div_between_1; lia). + autorewrite with zsimplify_fast in *. + + eapply represents_trans. + { repeat (apply wideadd_represents; + [ | apply represents_small; Z.div_mod_to_quot_rem_in_goal; nia| ]). + erewrite represents_high; [ | apply widemul_represents; solve [ auto with zarith ] ]. + { apply represents_add; try reflexivity; solve [auto with zarith]. } + { match goal with H : 0 <= ?x + ?y < ?z |- 0 <= ?x < ?z => + split; [ solve [Z.zero_bounds] | ]; + eapply Z.le_lt_trans with (m:= x + y); nia + end. } + { omega. } } + { ring. } + Qed. + + Definition cond_sub1 (a : list Z) y : Z := + dlet_nd maybe_y := Z.zselect (Z.cc_l (high a)) 0 y in + dlet_nd diff := Z.sub_get_borrow_full (2^k) (low a) maybe_y in + fst diff. + + Lemma cc_l_only_bit : forall x s, 0 <= x < 2 * s -> Z.cc_l (x / s) = 0 <-> x < s. + Proof. + cbv [Z.cc_l]; intros. + rewrite Z.div_between_0_if by omega. + break_match; Z.ltb_to_lt; Z.rewrite_mod_small; omega. + Qed. + + Lemma cond_sub1_correct a x y : + represents a x -> + 0 <= x < 2 * y -> + 0 <= y < 2 ^ k -> + cond_sub1 a y = if (x 0) by auto with zarith. + assert (2 < 2 ^ k) by (eapply Z.le_lt_trans with (m:=2 ^ 1); auto with zarith). + + cbv [muSelect]. rewrite <-muLow_eq. + rewrite Z.zselect_correct, Z.cc_m_eq by auto with zarith. + replace xHigh with (x / 2^k) by (subst x; autorewrite with zsimplify; lia). + autorewrite with pull_Zdiv push_Zpow. + rewrite (Z.mul_comm (2 ^ k / 2)). + break_match; [ ring | ]. + match goal with H : 0 <= ?x < 2, H' : ?x <> 0 |- _ => replace x with 1 by omega end. + autorewrite with zsimplify; reflexivity. + Qed. + + Lemma mu_rep : represents [muLow; 1] (2 ^ (2 * k) / M). + Proof. rewrite <-muLow_eq. eapply represents_trans; auto with zarith. Qed. + + Derive barrett_reduce + SuchThat (barrett_reduce = x mod M) + As barrett_reduce_correct. + Proof. + erewrite <-reduce_correct with (rep:=represents) (muSelect:=muSelect) (k0:=k) (mut:=[muLow;1]) (xt0:=xt) + by (auto using x_bounds, muSelect_correct, x_rep, mu_rep; omega). + subst barrett_reduce. reflexivity. + Qed. + End Defn. + End BarrettReduction. +End BarrettReduction. + +Module MontgomeryReduction. + Local Coercion Z.of_nat : nat >-> Z. + Section MontRed'. + Context (N R N' R' : Z). + Context (HN_range : 0 <= N < R) (HN'_range : 0 <= N' < R) (HN_nz : N <> 0) (R_gt_1 : R > 1) + (N'_good : Z.equiv_modulo R (N*N') (-1)) (R'_good: Z.equiv_modulo N (R*R') 1). + + Context (Zlog2R : Z) . + Let w : nat -> Z := weight Zlog2R 1. + Context (n:nat) (Hn_nz: n <> 0%nat) (n_good : Zlog2R mod Z.of_nat n = 0). + Context (R_big_enough : n <= Zlog2R) + (R_two_pow : 2^Zlog2R = R). + Let w_mul : nat -> Z := weight (Zlog2R / n) 1. + Context (nout : nat) (Hnout : nout = 2%nat). + + Definition montred' (lo hi : Z) := + dlet_nd y := nth_default 0 (BaseConversion.widemul_inlined Zlog2R n nout lo N') 0 in + dlet_nd t1_t2 := (BaseConversion.widemul_inlined_reverse Zlog2R n nout N y) in + dlet_nd sum_carry := Rows.add (weight Zlog2R 1) 2 [lo;hi] t1_t2 in + dlet_nd y' := Z.zselect (snd sum_carry) 0 N in + dlet_nd lo''_carry := Z.sub_get_borrow_full R (nth_default 0 (fst sum_carry) 1) y' in + Z.add_modulo (fst lo''_carry) 0 N. + + Local Lemma Hw : forall i, w i = R ^ Z.of_nat i. + Proof. + clear -R_big_enough R_two_pow; cbv [w weight]; intro. + autorewrite with zsimplify. + rewrite Z.pow_mul_r, R_two_pow by omega; reflexivity. + Qed. + + Declare Equivalent Keys weight w. + Local Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r, ?Z.pow_1_l in *. + Local Ltac solve_range := + repeat match goal with + | _ => progress change_weight + | |- context [?a mod ?b] => unique pose proof (Z.mod_pos_bound a b ltac:(omega)) + | |- 0 <= _ => progress Z.zero_bounds + | |- 0 <= _ * _ < _ * _ => + split; [ solve [Z.zero_bounds] | apply Z.mul_lt_mono_nonneg; omega ] + | _ => solve [auto] + | _ => omega + end. + + Local Lemma eval2 x y : Positional.eval w 2 [x;y] = x + R * y. + Proof. cbn. change_weight. ring. Qed. + + Hint Rewrite BaseConversion.widemul_inlined_reverse_correct BaseConversion.widemul_inlined_correct + using (autorewrite with widemul push_nth_default; solve [solve_range]) : widemul. + + Lemma montred'_eq lo hi T (HT_range: 0 <= T < R * N) + (Hlo: lo = T mod R) (Hhi: hi = T / R): + montred' lo hi = reduce_via_partial N R N' T. + Proof. + rewrite <-reduce_via_partial_alt_eq by nia. + cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In]. + rewrite Hlo, Hhi. + assert (0 <= (T mod R) * N' < w 2) by (solve_range). + + autorewrite with widemul. + rewrite Rows.add_partitions, Rows.add_div by (distr_length; apply wprops; omega). + rewrite R_two_pow. + cbv [Partition.partition seq]. rewrite !eval2. + autorewrite with push_nth_default push_map. + autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct. + change_weight. + + (* pull out value before last modular reduction *) + match goal with |- (if (?n <=? ?x)%Z then ?x - ?n else ?x) = (if (?n <=? ?y) then ?y - ?n else ?y)%Z => + let P := fresh "H" in assert (x = y) as P; [|rewrite P; reflexivity] end. + + autorewrite with zsimplify. + rewrite (Z.mul_comm (((T mod R) * N') mod R) N) in *. + break_match; try reflexivity; Z.ltb_to_lt; rewrite Z.div_small_iff in * by omega; + repeat match goal with + | _ => progress autorewrite with zsimplify_fast + | |- context [?x mod (R * R)] => + unique pose proof (Z.mod_pos_bound x (R * R)); + try rewrite (Z.mod_small x (R * R)) in * by Z.rewrite_mod_small_solver + | _ => omega + | _ => progress Z.rewrite_mod_small + end. + Qed. + + Lemma montred'_correct lo hi T (HT_range: 0 <= T < R * N) + (Hlo: lo = T mod R) (Hhi: hi = T / R): montred' lo hi = (T * R') mod N. + Proof. + erewrite montred'_eq by eauto. + apply Z.equiv_modulo_mod_small; auto using reduce_via_partial_correct. + replace 0 with (Z.min 0 (R-N)) by (apply Z.min_l; omega). + apply reduce_via_partial_in_range; omega. + Qed. + End MontRed'. +End MontgomeryReduction. -- cgit v1.2.3