From 684d356bcb81ca36314cd7864c62a1d97af4ea99 Mon Sep 17 00:00:00 2001 From: jadep Date: Tue, 12 Mar 2019 12:46:00 -0400 Subject: finish proofs --- src/Arithmetic.v | 456 +++++++++++++++++------------ src/PushButtonSynthesis/BarrettReduction.v | 2 + 2 files changed, 272 insertions(+), 186 deletions(-) (limited to 'src') diff --git a/src/Arithmetic.v b/src/Arithmetic.v index 65eff76a4..75e93deaf 100644 --- a/src/Arithmetic.v +++ b/src/Arithmetic.v @@ -1889,6 +1889,14 @@ Module Partition. reflexivity. Qed. + Lemma partition_0 n : partition n 0 = Positional.zeros n. + Proof. + cbv [partition]. + erewrite Positional.zeros_ext_map with (p:=seq 0 n) by distr_length. + apply map_ext; intros. + autorewrite with zsimplify; reflexivity. + Qed. + End Partition. Hint Rewrite length_partition length_recursive_partition : distr_length. Hint Rewrite eval_partition using (solve [auto; distr_length]) : push_eval. @@ -3035,9 +3043,6 @@ Module BaseConversion. Hint Resolve Z.gt_lt Z.positive_is_nonzero Nat2Z.is_nonneg. - - (* TODO: Unspecialize from 2-limb *) - (* multi-limb version -- multiply two numbers that each have m limbs with (n*k) bits each, by converting them to numbers with (m*n) limbs of k bits each, multiplying, then converting back *) Definition widemul a b := mul_converted sw dw m m mn mn nout (aligned_carries n nout) a b. Lemma widemul_correct a b : @@ -3556,6 +3561,19 @@ Module UniformWeight. auto using uweight_recursive_partition_change_start with omega. Qed. + Lemma uweight_firstn_partition lgr (Hr : 0 < lgr) n x m (Hm : (m <= n)%nat) : + firstn m (Partition.partition (uweight lgr) n x) = Partition.partition (uweight lgr) m x. + Proof. + cbv [Partition.partition]; + repeat match goal with + | _ => progress intros + | _ => progress autorewrite with push_firstn natsimplify zsimplify_fast + | _ => rewrite Nat.min_l by lia + | _ => rewrite weight_0 by auto using uwprops + | _ => reflexivity + end. + Qed. + Lemma uweight_skipn_partition lgr (Hr : 0 < lgr) n x m : skipn m (Partition.partition (uweight lgr) n x) = Partition.partition (uweight lgr) (n - m) (x / uweight lgr m). Proof. @@ -3623,6 +3641,21 @@ Module UniformWeight. m = (n + length y)%nat -> Positional.eval (uweight lgr) m (x ++ y) = Positional.eval (uweight lgr) n x + (uweight lgr n) * Positional.eval (uweight lgr) (length y) y. Proof using Type. intros. subst m. apply uweight_eval_app'; lia. Qed. + + Lemma uweight_partition_app lgr (Hr : 0 < lgr) n m a b : + Partition.partition (uweight lgr) n a ++ Partition.partition (uweight lgr) m b + = Partition.partition (uweight lgr) (n+m) (a mod uweight lgr n + b * uweight lgr n). + Proof. + assert (0 < uweight lgr n) by auto using uwprops. + match goal with |- _ = ?rhs => rewrite <-(firstn_skipn n rhs) end. + rewrite uweight_firstn_partition, uweight_skipn_partition by lia. + rewrite Z.div_add by lia. + rewrite (Z.div_small (_ mod _)) by auto with zarith. + f_equal. + { apply Partition.partition_eq_mod; [ auto using uwprops | ]. + push_Zmod. autorewrite with zsimplify. reflexivity. } + { f_equal; lia. } + Qed. End UniformWeight. Module WordByWordMontgomery. @@ -4885,8 +4918,10 @@ Module BarrettReduction. Context (b k M mu width : Z) (n : nat) (b_ok : 1 < b) (k_pos : 0 < k) + (bk_eq : b^k = 2^(width * Z.of_nat n)) (M_range : b ^ (k - 1) < M < b ^ k) (mu_eq : mu = b ^ (2 * k) / M) + (width_pos : 0 < width) (strong_bound : b ^ 1 * (b ^ (2 * k) mod M) <= b ^ (k + 1) - mu). Local Notation weight := (UniformWeight.uweight width). Local Notation partition := (Partition.partition weight). @@ -4903,23 +4938,19 @@ Module BarrettReduction. q3 (partition (n*2) x) (partition (n+1) q1) = partition (n+1) ((mu * q1) / b ^ (k + 1))) (r : list Z -> list Z -> list Z) (r_correct : - forall x q1 q3, - 0 <= x < b ^ (2 * k) -> - q1 = x / b ^ (k - 1) -> - q3 = (mu * q1) / b ^ (k + 1) -> - r (partition (n*2) x) (partition (n+1) q3) = partition (n*2) (x - q3 * M)) - (final_reduce : list Z -> list Z) - (final_reduce_correct : - forall r, - 0 <= r < 2 * M -> - final_reduce (partition (n*2) r) = partition n (if dec (r < M) then r else r - M)). + forall x q3, + 0 <= x < M * b ^ k -> + 0 <= q3 -> + (exists b : bool, q3 = x / M + (if b then -1 else 0)) -> + r (partition (n*2) x) (partition (n+1) q3) = partition n (x mod M)). Context (x : Z) (x_range : 0 <= x < M * b ^ k) (xt : list Z) (xt_correct : xt = partition (n*2) x). - (* TODO: make barrett Z figure out some of this itself *) Local Lemma M_pos : 0 < M. Proof. assert (0 <= b ^ (k - 1)) by Z.zero_bounds. lia. Qed. + Local Lemma M_upper : M < weight n. + Proof. rewrite UniformWeight.uweight_eq_alt'. lia. Qed. Local Lemma x_upper : x < b ^ (2 * k). Proof. assert (0 < b ^ k) by Z.zero_bounds. @@ -4934,22 +4965,34 @@ Module BarrettReduction. Definition reduce := dlet_nd q1t := q1 xt in - dlet_nd q3t := q3 xt q1t in - dlet_nd rt := r xt q3t in - final_reduce rt. + dlet_nd q3t := q3 xt q1t in + r xt q3t. + + Lemma q1_range : 0 <= x / b^(k-1) < b^(k+1). + Proof. + split; [ solve [Z.zero_bounds] | ]. + assert (0 < b ^ (k - 1)) by Z.zero_bounds. + assert (0 < b ^ k) by Z.zero_bounds. + apply Z.div_lt_upper_bound; [ solve [Z.zero_bounds] | ]. + eapply Z.lt_le_trans with (m:=b^k * b^k); + [ nia | autorewrite with pull_Zpow; apply Z.pow_le_mono; lia ]. + Qed. + + Lemma q3_range : 0 <= mu * (x / b ^ (k - 1)) / b ^ (k + 1). + Proof. + assert (0 < b ^ (k - 1)) by Z.zero_bounds. + subst mu; Z.zero_bounds. + Qed. Lemma reduce_correct : reduce = partition n (x mod M). Proof. - cbv [reduce Let_In]. + cbv [reduce Let_In]. pose proof q3_range. rewrite xt_correct, q1_correct, q3_correct by auto with lia. - erewrite r_correct by eauto with lia. rewrite final_reduce_correct; [ | ]. - { rewrite barrett_reduction_small_strong with (b:=b) (k:=k) (m:=mu) (offset:=1) by - auto using Z.lt_gt with zarith. - break_innermost_match; Z.ltb_to_lt; (lia || reflexivity). } - { split; [ Z.zero_bounds; apply qn_small | apply r_small_strong]; - auto using Z.lt_gt with zarith. } + assert (exists cond : bool, ((mu * (x / b^(k-1))) / b^(k+1)) = x / M + (if cond then -1 else 0)) as Hq3. + { destruct q_nice_strong with (b:=b) (k:=k) (m:=mu) (offset:=1) (a:=x) (n:=M) as [cond Hcond]; + eauto using Z.lt_gt with zarith. } + eauto using r_correct with lia. Qed. - End Generic. (* Non-standard implementation -- uses specialized instructions and b=2 *) @@ -4963,15 +5006,32 @@ Module BarrettReduction. (* sz = 1, width = k = 256 *) Local Notation w := (UniformWeight.uweight width). Local Notation eval := (Positional.eval w). Context (mut Mt : list Z) (mut_correct : mut = partition w (sz+1) mu) (Mt_correct : Mt = partition w sz M). - Context (mu_eq : mu = 2 ^ (2 * k) / M) (M_range : 2^(k-1) < M < 2^k). + Context (mu_eq : mu = 2 ^ (2 * k) / M) (muHigh_one : mu / w sz = 1) (M_range : 2^(k-1) < M < 2^k). Local Lemma wprops : @weight_properties w. Proof. apply UniformWeight.uwprops; auto with lia. Qed. Local Hint Resolve wprops. Hint Rewrite mut_correct Mt_correct : pull_partition. - Lemma w_eq_2k : w sz = 2^k. Admitted. (* TODO *) + Lemma w_eq_2k : w sz = 2^k. Proof. rewrite UniformWeight.uweight_eq_alt' by auto. congruence. Qed. Lemma mu_range : 2^k <= mu < 2^(k+1). - Admitted. (* TODO *) + Proof. + rewrite mu_eq. assert (0 < 2^(k-1)) by Z.zero_bounds. + assert (2^k < M * 2). + { replace (2^k) with (2^(k-1+1)) by (f_equal; lia). + rewrite Z.pow_add_r, Z.pow_1_r by lia. + lia. } + replace (2 ^ (2 * k)) with (2^(k+k)) by (f_equal; lia). + rewrite !Z.pow_add_r, Z.pow_1_r by lia. split. + { apply Z.div_le_lower_bound; nia. } + { apply Z.div_lt_upper_bound; nia. } + Qed. + Lemma mu_range' : 0 <= mu < 2 * w sz. + Proof. + pose proof mu_range. assert (0 < 2^k) by auto with zarith. + assert (2^(k+1) = 2 * w sz); [ | lia]. + rewrite k_eq, UniformWeight.uweight_eq_alt'. + rewrite Z.pow_add_r, Z.pow_1_r by lia. lia. + Qed. Lemma M_range' : 0 <= M < w sz. (* more convenient form, especially for mod_small *) Proof. assert (0 <= 2 ^ (k-1)) by Z.zero_bounds. pose proof w_eq_2k; lia. Qed. @@ -5000,13 +5060,13 @@ Module BarrettReduction. (* select based on the most significant bit of xHigh *) Definition muSelect xt := let xHigh := nth_default 0 xt (sz*2 - 1) in - Positional.select (Z.cc_m (2 ^ width) xHigh) (repeat 0 sz) (low mut). + Positional.select (Z.cc_m (2 ^ width) xHigh) (Positional.zeros sz) (low mut). Definition cond_sub (a y : list Z) : list Z := let cond := Z.cc_l (nth_default 0 (high a) 0) in (* a[k] = least significant bit of (high a) *) - dlet_nd maybe_y := Positional.select cond (repeat 0 sz) y in - dlet_nd diff := Rows.sub w sz (low a) maybe_y in (* (a mod (w sz) - y) mod (w sz)) = (a - y) mod (w sz); since we know a - y is < w sz this is okay by mod_small *) - fst diff. + dlet_nd maybe_y := Positional.select cond (Positional.zeros sz) y in + dlet_nd diff := Rows.sub w sz (low a) maybe_y in (* (a mod (w sz) - y) mod (w sz)) = (a - y) mod (w sz); since we know a - y is < w sz this is okay by mod_small *) + fst diff. Definition cond_subM x := if Nat.eq_dec sz 1 @@ -5017,16 +5077,14 @@ Module BarrettReduction. Definition q3 (xt q1t : list Z) := dlet_nd muSelect := muSelect xt in (* make sure muSelect is not inlined in the output *) - dlet_nd twoq := mul_high (fill (sz*2) mut) (fill (sz*2) q1t) (fill (sz*2) muSelect) in - shiftr (sz+1) twoq 1. + dlet_nd twoq := mul_high (fill (sz*2) mut) (fill (sz*2) q1t) (fill (sz*2) muSelect) in + shiftr (sz+1) twoq 1. Definition r (xt q3t : list Z) := dlet_nd r2 := widemul (low q3t) Mt in - widesub xt r2. - - Definition final_reduce (rt : list Z) := + dlet_nd rt := widesub xt r2 in dlet_nd rt := cond_sub rt Mt in - cond_subM rt. + cond_subM rt. Section Proofs. @@ -5120,20 +5178,20 @@ Module BarrettReduction. Qed. Lemma low_correct n a : (sz <= n)%nat -> low (partition w n a) = partition w sz a. - Admitted. + Proof. cbv [low]; auto using UniformWeight.uweight_firstn_partition with lia. Qed. Lemma high_correct a : high (partition w (sz*2) a) = partition w sz (a / w sz). - Admitted. + Proof. cbv [high]. rewrite UniformWeight.uweight_skipn_partition by lia. f_equal; lia. Qed. Lemma fill_correct n m a : (n <= m)%nat -> fill m (partition w n a) = partition w m (a mod w n). - Admitted. + Proof. + cbv [fill]; intros. distr_length. + rewrite <-partition_0 with (weight:=w). + rewrite UniformWeight.uweight_partition_app by lia. + f_equal; lia. + Qed. Hint Rewrite low_correct high_correct fill_correct using lia : pull_partition. - (* TODO: move *) - Lemma partition_app n m a b : - partition w n a ++ partition w m b = partition w (n+m) (a mod w n + b * w n). - Admitted. - Lemma wideadd_correct a b : wideadd (partition w (sz*2) a) (partition w (sz*2) b) = partition w (sz*2) (a + b). Proof. @@ -5192,7 +5250,7 @@ Module BarrettReduction. repeat match goal with | _ => progress autorewrite with pull_partition | _ => progress rewrite ?Ha, ?Ha0b1 - | _ => rewrite partition_app by auto; + | _ => rewrite UniformWeight.uweight_partition_app by lia; replace (sz+sz)%nat with (sz*2)%nat by lia | _ => rewrite Z.mod_pull_div by auto with zarith | _ => progress Z.rewrite_mod_small @@ -5200,12 +5258,6 @@ Module BarrettReduction. end. Qed. - Lemma muSelect_correct x : - 0 <= x < w (sz * 2) -> - muSelect (partition w (sz*2) x) = partition w (sz+1) (mu mod (w sz) * (x / 2 ^ (k - 1) / (w sz))). - Admitted. - Hint Rewrite muSelect_correct using lia : pull_partition. - (* TODO: move *) Lemma pow_pos_le a b : 0 < a -> 0 < b -> a <= a ^ b. Proof. @@ -5215,47 +5267,102 @@ Module BarrettReduction. Qed. Hint Resolve pow_pos_le : zarith. + (* TODO: move *) + Lemma pow_pos_lt a b : 1 < a -> 1 < b -> a < a ^ b. + Proof. + intros; eapply Z.le_lt_trans with (m:=a ^ 1). + { rewrite Z.pow_1_r; reflexivity. } + { apply Z.pow_lt_mono_r; auto with zarith. } + Qed. + Hint Resolve pow_pos_lt : zarith. + + (* TODO: move *) + Lemma pow_div_base a b : a <> 0 -> 0 < b -> a ^ b / a = a ^ (b - 1). + Proof. intros; rewrite Z.pow_sub_r, Z.pow_1_r; lia. Qed. + Hint Rewrite pow_div_base using zutil_arith : pull_Zpow. + + (* TODO: move *) + Lemma pow_mul_base a b : 0 <= b -> a * a ^ b = a ^ (b + 1). + Proof. intros; rewrite <-Z.pow_succ_r, <-Z.add_1_r by lia; reflexivity. Qed. + Hint Rewrite pow_mul_base using zutil_arith : pull_Zpow. + + + (* improved! *) + Ltac zero_bounds' := + repeat match goal with + | |- ?a <> 0 => apply Z.positive_is_nonzero + | |- ?a > 0 => apply Z.lt_gt + | |- ?a >= 0 => apply Z.le_ge + end; + try match goal with + | |- 0 < ?a => Z.zero_bounds + | |- 0 <= ?a => Z.zero_bounds + end. + + Ltac zutil_arith ::= solve [ omega | Psatz.lia | auto with nocore | solve [zero_bounds'] ]. + + Hint Rewrite UniformWeight.uweight_S UniformWeight.uweight_eq_alt' using lia : weight_to_pow. + Hint Rewrite <-UniformWeight.uweight_S UniformWeight.uweight_eq_alt' using lia : pow_to_weight. + + Lemma q1_range x : + 0 <= x < w (sz * 2) -> + 0 <= x / 2 ^ (k-1) < 2 * w sz. + Proof. + intros; split; [ solve [Z.zero_bounds] | ]. + apply Z.div_lt_upper_bound; [ solve [Z.zero_bounds] | ]. + assert (w (sz * 2) <= 2 ^ (k-1) * (2 * w sz)); [ | lia ]. + autorewrite with weight_to_pow pull_Zpow. + apply Z.pow_le_mono_r; lia. + Qed. + + Lemma muSelect_correct x : + 0 <= x < w (sz * 2) -> + muSelect (partition w (sz*2) x) = partition w sz (mu mod (w sz) * (x / 2 ^ (k - 1) / (w sz))). + Proof. + cbv [muSelect]; intros; + repeat match goal with + | _ => progress autorewrite with pull_partition natsimplify + | _ => progress rewrite ?Z.cc_m_eq by auto with zarith + | _ => erewrite Positional.select_eq by (distr_length; eauto) + | _ => rewrite nth_default_partition by lia + | _ => progress replace (S (sz * 2 - 1)) with (sz * 2)%nat by lia + | H : 0 <= ?x < ?m |- context [?x mod ?m] => rewrite (Z.mod_small x m) by auto with zarith + end. + replace (x / (w (sz * 2 - 1)) / (2 ^ width / 2)) with (x / (2 ^ (k - 1)) / w sz) by + (autorewrite with weight_to_pow pull_Zpow pull_Zdiv; do 2 f_equal; nia). + + rewrite Z.div_between_0_if with (a:=x / 2^(k-1)) by (zero_bounds'; auto using q1_range). + break_innermost_match; try lia; autorewrite with zsimplify_fast; [ | ]. + { apply partition_eq_mod; auto with zarith. } + { rewrite partition_0; reflexivity. } + Qed. + Hint Rewrite muSelect_correct using lia : pull_partition. + + Lemma mu_q1_range x (Hx : 0 <= x < w (sz * 2)) : mu * (x / 2^(k-1)) < w sz * w (sz * 2). + Proof. + pose proof mu_range'. pose proof q1_range x ltac:(lia). + replace (w (sz * 2)) with (w sz * w sz) by + (autorewrite with weight_to_pow pull_Zpow; f_equal; lia). + apply Z.lt_le_trans with (m:= 2 * w sz * (2 * w sz)); [ nia | ]. + assert (4 <= w sz); [ | nia ]. change 4 with (Z.pow 2 2). + autorewrite with weight_to_pow. apply Z.pow_le_mono_r; nia. + Qed. + Lemma q3_correct x (Hx : 0 <= x < w (sz * 2)) q1 (Hq1 : q1 = x / 2 ^ (k - 1)) : q3 (partition w (sz*2) x) (partition w (sz+1) q1) = partition w (sz+1) ((mu*q1) / 2 ^ (k + 1)). Proof. - cbv [q3 Let_In]. intros. - assert (0 <= q1 < w (sz+1)) by admit. - autorewrite with pull_partition. - rewrite ?fill_correct, ?muSelect_correct by lia. - rewrite <-Hq1. - rewrite mul_high_correct; [ | | ]. - 2: { (* mu mod w (sz + 1) / w sz = 1 *) - (* highest bit of mu is 1 -- do this by rewriting mod small using bounds on mu, then take axiom for highest bit *) - admit. } - 2: { - rewrite mod_mod_weight by lia. - push_Zmod. rewrite Z.mod_pull_div by auto with zarith. - assert (0 < w sz) by auto with zarith. - assert (w (sz+1) <= w (sz+1) * w sz) by (apply Z.le_mul_diag_r; auto with zarith). - Z.rewrite_mod_small. - autorewrite with natsimplify in *. - rewrite UniformWeight.uweight_S in * by auto with zarith. - assert (0 <= q1 / w sz < 2^width). - { split; [ solve [Z.zero_bounds] | ]. - apply Z.div_lt_upper_bound; auto with nia. } - pose proof (Z.mod_pos_bound mu (w sz) ltac:(auto)). - apply Z.mod_small. nia. } - rewrite shiftr_correct; auto with zarith. - 2: admit. (* (Z.to_nat (1 / width) <= sz * 2)%nat *) - 2: admit. (* (sz + 1 <= sz * 2 - Z.to_nat (1 / width))%nat *) - 2: admit. (* 0 <= mu mod w (sz + 1) * (q1 mod w (sz + 1)) / w sz < w (sz * 2) *) - (* TODO: see if things are easier using div_mod_to_quot_rem *) - Z.rewrite_mod_small. - pose proof mu_range. - replace (2 ^ (k+1)) with (w sz * 2) in * by admit. - assert (0 < 2 ^ k) by (Z.zero_bounds; nia). - assert (2 * w sz <= w (sz + 1)). - { autorewrite with natsimplify. rewrite UniformWeight.uweight_S by auto with zarith. - apply Z.mul_le_mono_nonneg_r; auto with zarith. } - Z.rewrite_mod_small. - rewrite Z.div_div by auto with zarith. - repeat (f_equal; try ring). - Admitted. + cbv [q3 Let_In]. intros. pose proof mu_q1_range x ltac:(lia). + pose proof mu_range'. pose proof q1_range x ltac:(lia). + autorewrite with pull_partition pull_Zmod. + assert (2 * w sz < w (sz + 1)) by (autorewrite with weight_to_pow pull_Zpow; auto with zarith lia). + Z.rewrite_mod_small. rewrite <-Hq1 in *. + rewrite mul_high_correct by + (try lia; rewrite Z.div_between_0_if with (a:=q1) by lia; + break_innermost_match; autorewrite with zsimplify; reflexivity). + rewrite shiftr_correct by (rewrite ?Z.div_small, ?Z2Nat.inj_0 by lia; auto with zarith lia). + autorewrite with weight_to_pow pull_Zpow pull_Zdiv. + congruence. + Qed. Lemma cond_sub_correct a b : cond_sub (partition w (sz*2) a) (partition w sz b) @@ -5263,7 +5370,7 @@ Module BarrettReduction. then a else a - b). Proof. - cbv [cond_sub Let_In Z.cc_l]. autorewrite with pull_partition. + intros; cbv [cond_sub Let_In Z.cc_l]. autorewrite with pull_partition. rewrite nth_default_partition by lia. rewrite weight_0 by auto. autorewrite with zsimplify_fast. rewrite UniformWeight.uweight_eq_alt' with (n:=1%nat). autorewrite with push_Zof_nat zsimplify. @@ -5280,16 +5387,26 @@ Module BarrettReduction. then a else a - M). Proof. - Admitted. + cbv [cond_subM]. autorewrite with pull_partition. pose proof M_range'. + rewrite Rows.conditional_sub_partitions by + (distr_length; auto; autorewrite with push_eval; try apply partition_eq_mod; auto with zarith). + rewrite nth_default_partition, weight_0, Z.add_modulo_correct by auto with lia. + autorewrite with zsimplify_fast push_eval. Z.rewrite_mod_small. + pose proof Z.mod_pos_bound a (w 1) ltac:(auto). + break_innermost_match; Z.ltb_to_lt; + repeat match goal with + | _ => lia + | _ => reflexivity + | _ => apply partition_eq_mod; solve [auto with zarith] + | _ => rewrite partition_step, weight_0 by auto + | _ => progress autorewrite with zsimplify_fast + | _ => progress Z.rewrite_mod_small + | _ => rewrite Z.sub_mod_l with (a:=a) + end. + Qed. Hint Rewrite cond_subM_correct : pull_partition. - (* TODO: move this with barrett Z lemmas & prove *) - Lemma q3_range x : - 0 <= x < 2 ^ (2 * k) -> - 0 <= (mu * (x / 2 ^ (k-1))) / 2 ^ (k+1) < 2^k. - Proof using Type. - Admitted. - + (* TODO: unused?*) Lemma w_eq_22k : w (sz * 2) = 2 ^ (2 * k). Proof. replace (sz * 2)%nat with (sz + sz)%nat by lia. @@ -5297,48 +5414,66 @@ Module BarrettReduction. f_equal; lia. Qed. - Lemma r_correct x q1 q3 : - 0 <= x < w (sz * 2) -> - q1 = x / 2 ^ (k - 1) -> - q3 = (mu*q1) / 2 ^ (k + 1) -> - r (partition w (sz*2) x) (partition w (sz+1) q3) = partition w (sz*2) (x - q3 * M). + (* TODO: unused?*) + Lemma q3M_eq x q3 (b:bool) : + 0 <= x < M * 2 ^ k -> + q3 = x / M + (if b then -1 else 0) -> + q3 * M = x - x mod M - (if b then M else 0). Proof. - cbv [r Let_In]; intros. pose proof w_eq_22k. - autorewrite with pull_partition. - assert (0 <= q3 < w sz). - { subst q3 q1. rewrite w_eq_2k. apply q3_range; lia. } - pose proof M_range'. Z.rewrite_mod_small. - autorewrite with zsimplify. reflexivity. + intros. assert (0 < 2^(k-1)) by Z.zero_bounds. rewrite Z.mod_eq by lia. + subst q3. break_innermost_match; ring. Qed. - Lemma final_reduce_correct r: - 0 <= r < 2 * M -> - final_reduce (partition w (sz*2) r) = partition w sz (if dec (r < M) then r else r - M). + (* TODO: move *) + Hint Resolve Z.positive_is_nonzero Z.lt_gt : zarith. + + Lemma r_idea x q3 (b:bool) : + 0 <= x < M * 2 ^ k -> + 0 <= q3 -> + q3 = x / M + (if b then -1 else 0) -> + x - q3 mod w sz * M = x mod M + (if b then M else 0). Proof. - intros; pose proof M_range'. - cbv [final_reduce Let_In]; intros; autorewrite with pull_partition. - apply partition_eq_mod; auto with zarith. - replace (2 ^ (k + 1)) with (2 * w sz) in * - by (rewrite Z.pow_add_r by auto with zarith; autorewrite with zsimplify; rewrite <-w_eq_2k; lia). - pose proof (Z.mod_pos_bound r (w sz) ltac:(auto with zarith)). - remember ((r / w sz) mod 2) as rk. - assert (rk = 1 -> r > M) by (subst rk; Z.rewrite_mod_small; intros; - rewrite (Z.div_mod r (w sz)) by auto with zarith; nia). - replace r with (w sz * rk + r mod w sz) by - (subst rk; Z.rewrite_mod_small; rewrite <-Z.div_mod; auto with zarith). - subst rk; rewrite Zmod_odd in *. + intros. assert (0 < 2^(k-1)) by Z.zero_bounds. + assert (q3 < w sz). + { apply Z.le_lt_trans with (m:=x/M); [ subst q3; break_innermost_match; lia | ]. + autorewrite with weight_to_pow. rewrite <-k_eq. auto with zarith. } + Z.rewrite_mod_small. repeat match goal with - | _ => lia - | H : _ |- _ => rewrite Z.mod_small in H by auto with zarith lia - | _ => progress (push_Zmod; pull_Zmod); autorewrite with zsimplify_fast - | _ => break_innermost_match_step + | _ => progress autorewrite with push_Zmul + | H : q3 = ?e |- _ => progress replace (q3 * M) with (e * M) by (rewrite H; reflexivity) + | _ => rewrite (Z.mul_div_eq' x M) by lia end. + break_innermost_match; Z.ltb_to_lt; lia. + Qed. + + Lemma r_correct x q3 : + 0 <= x < M * 2 ^ k -> + 0 <= q3 -> + (exists b : bool, q3 = x / M + (if b then -1 else 0)) -> + r (partition w (sz*2) x) (partition w (sz+1) q3) = partition w sz (x mod M). + Proof. + intros; cbv [r Let_In]. pose proof M_range'. assert (0 < 2^(k-1)) by Z.zero_bounds. + autorewrite with pull_partition. Z.rewrite_mod_small. + match goal with H : exists _, q3 = _ |- _ => destruct H end. + erewrite r_idea by eassumption. + pose proof (Z.mod_pos_bound x M ltac:(lia)). + rewrite Z.div_between_0_if with (b:=w sz) by (break_innermost_match; auto with zarith). + rewrite Z.mod_small with (b:=2) by (break_innermost_match; lia). + break_innermost_match; Z.ltb_to_lt; try lia; autorewrite with zsimplify_fast; + repeat match goal with + | |- exists e, _ /\ _ /\ ?f ?x = ?f e => exists x; split; [ | split ] + | _ => rewrite Z.mod_small in * by lia + | _ => progress Z.rewrite_mod_small + | _ => progress (push_Zmod; pull_Zmod); autorewrite with zsimplify_fast + | _ => lia + | _ => reflexivity + end. Qed. End Proofs. Section Def. Context (sz_eq_1 : sz = 1%nat). (* this is needed to get rid of branches in the templates; a different definition would be needed for sizes other than 1, but would be able to use the same proofs. *) - Local Hint Resolve q1_correct q3_correct r_correct final_reduce_correct. + Local Hint Resolve q1_correct q3_correct r_correct. (* muselect relies on an initially-set flag, so pull it out of q3 *) Definition fancy_reduce_muSelect_first xt := @@ -5346,8 +5481,7 @@ Module BarrettReduction. dlet_nd q1t := q1 xt in dlet_nd twoq := mul_high (fill (sz * 2) mut) (fill (sz * 2) q1t) (fill (sz * 2) muSelect) in dlet_nd q3t := shiftr (sz+1) twoq 1 in - dlet_nd rt := r xt q3t in - final_reduce rt. + r xt q3t. Lemma fancy_reduce_muSelect_first_correct x : 0 <= x < M * 2^k -> @@ -5356,38 +5490,9 @@ Module BarrettReduction. Proof. intros. pose proof w_eq_22k. erewrite <-reduce_correct with (b:=2) (k:=k) (mu:=mu) by - (eauto with nia; intros; try rewrite q3'_correct; eauto with nia ). + (eauto with nia; intros; try rewrite q3'_correct; try rewrite <-k_eq; eauto with nia ). reflexivity. Qed. - - - (* - Definition fancy_reduce xLow xHigh := - dlet_nd muSelect : list Z := Positional.select (Z.cc_m (2 ^ width) xHigh) (repeat 0 sz) (low mut) in - dlet_nd q1t := [Z.rshi (2 ^ width) xHigh xLow (k - 1); Z.rshi (2 ^ width) 0 xHigh (k - 1)] in - dlet_nd a0b0 : list Z := widemul (low mut) (low q1t) in - dlet_nd ab : list Z := wideadd [high a0b0; high q1t] [low q1t; 0] in - dlet_nd a0 := wideadd ab [muSelect;0] in - dlet_nd q3t := [Z.rshi (2 ^ width) (nth_default 0 a0 1) (nth_default 0 a0 0) 1; Z.rshi (2 ^ width) (nth_default 0 a0 2) (nth_default 0 a0 1) 1] in - dlet_nd r2 : list Z := widemul (low q3t) Mt in - dlet_nd rt : list Z := widesub [xLow; xHigh] r2 in - dlet_nd maybe_M : list Z := Positional.select (Z.cc_l (nth_default 0 (high rt) 0)) (repeat 0 sz) Mt in - dlet_nd diff : list Z * Z := Rows.sub w sz (low rt) maybe_M in - dlet_nd rt0 :=fst diff in - Z.add_modulo (nth_default 0 rt0 0) 0 M. -*) - (* This simplifies the shiftr inside q3, specializing it to the # limbs so we don't end up with maps in the final expression *) - Derive q3' - SuchThat (forall xt q1t, q3' xt q1t = q3 xt q1t) - As q3'_correct. - Proof. - intros. cbv [q3]. rewrite sz_eq_1. autorewrite with natsimplify. - cbv [shiftr shiftr' seq]. break_match; try lia; [ ]. - rewrite Proper_Let_In_nd_changebody; [ | reflexivity | repeat intro ]. - 2 : apply Proper_Let_In_nd_changebody; - [ reflexivity | repeat intro; autorewrite with push_map push_nth_default; reflexivity ]. - subst q3'; reflexivity. - Qed. Derive fancy_reduce' SuchThat ( @@ -5399,7 +5504,7 @@ Module BarrettReduction. Proof. intros. assert (k = width) as width_eq_k by nia. erewrite <-fancy_reduce_muSelect_first_correct by nia. - cbv [fancy_reduce_muSelect_first q1 q3 shiftr final_reduce cond_subM]. + cbv [fancy_reduce_muSelect_first q1 q3 shiftr r cond_subM]. break_match; try solve [exfalso; lia]. match goal with |- ?g ?x = ?rhs => let f := (match (eval pattern x in rhs) with ?f _ => f end) in @@ -5407,7 +5512,6 @@ Module BarrettReduction. end. Qed. - (* TODO: Can't deal with input by shifting and stuff! *) Derive fancy_reduce SuchThat ( forall xLow xHigh, @@ -5433,31 +5537,11 @@ Module BarrettReduction. autorewrite with zsimplify. rewrite <-Z.mod_pull_div by Z.zero_bounds. autorewrite with zsimplify. reflexivity. } - cbv [fancy_reduce' reduce q1 q3 shiftr final_reduce cond_subM]. + cbv [fancy_reduce' reduce q1 q3 shiftr r cond_subM]. autorewrite with natsimplify. (* TODO: maybe need to get rid of hd 0? *) cbv [shiftr' seq]. autorewrite with push_map push_nth_default. subst fancy_reduce; reflexivity. Qed. - (* TODO: decide which version to use - Derive fancy_reduce - SuchThat ( - forall x, - 0 <= x < M * 2 ^ k -> - 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu -> - fancy_reduce (partition w (sz*2) x) = partition w sz (x mod M)) - As fancy_reduce_correct. - Proof. - intros. pose proof w_eq_22k. - erewrite <-reduce_correct with (b:=2) (k:=k) (mu:=mu) by eauto with lia. - assert (width = k) by nia. - cbv [reduce q1 q3 shiftr final_reduce cond_subM]. - break_match; try solve [exfalso; lia]. - match goal with |- ?g ?x = ?rhs => - let f := (match (eval pattern x in rhs) with ?f _ => f end) in - assert (f = g); subst fancy_reduce; reflexivity - end. - Qed. - *) End Def. End Fancy. End Fancy. diff --git a/src/PushButtonSynthesis/BarrettReduction.v b/src/PushButtonSynthesis/BarrettReduction.v index c0078b117..504a26a0c 100644 --- a/src/PushButtonSynthesis/BarrettReduction.v +++ b/src/PushButtonSynthesis/BarrettReduction.v @@ -177,6 +177,7 @@ Section rbarrett_red. | progress cbv [weight] | rewrite mut_correct | rewrite Mt_correct + | rewrite UniformWeight.uweight_eq_alt' | rewrite Z.pow_mul_r by lia ]. Local Strategy -100 [barrett_red]. (* needed for making Qed not take forever *) @@ -185,6 +186,7 @@ Section rbarrett_red. Proof using M curve_good. cbv [barrett_red_correct]; intros. assert (1 < machine_wordsize) by apply use_curve_good. + pose proof (Z.mod_pos_bound mu (2^machine_wordsize) ltac:(lia)). rewrite <-Fancy.fancy_reduce_correct with (mu := muLow + 2^machine_wordsize) (width:=machine_wordsize) (sz:=1%nat) (mut:=[muLow;1]) (Mt:=[M]) by solve_barrett_red_preconditions. prove_correctness' ltac:(fun _ => idtac) use_curve_good. { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. rewrite Bool.andb_true_iff, !Z.leb_le. lia. } -- cgit v1.2.3