From 71820cce3ba80acf0a09d7506c49ba2dd6e32d95 Mon Sep 17 00:00:00 2001 From: jadep Date: Thu, 14 Mar 2019 12:07:28 -0400 Subject: split up Arithmetic (imports etc. not yet fixed, does not build) --- src/Arithmetic/BarrettReduction.v | 609 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 609 insertions(+) create mode 100644 src/Arithmetic/BarrettReduction.v (limited to 'src/Arithmetic/BarrettReduction.v') diff --git a/src/Arithmetic/BarrettReduction.v b/src/Arithmetic/BarrettReduction.v new file mode 100644 index 000000000..c5cc1ecde --- /dev/null +++ b/src/Arithmetic/BarrettReduction.v @@ -0,0 +1,609 @@ +(* TODO: prune these *) +Require Import Crypto.Algebra.Nsatz. +Require Import Coq.ZArith.ZArith Coq.micromega.Lia Crypto.Algebra.Nsatz. +Require Import Coq.Sorting.Mergesort Coq.Structures.Orders. +Require Import Coq.Sorting.Permutation. +Require Import Coq.derive.Derive. +Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. (* For MontgomeryReduction *) +Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. (* For MontgomeryReduction *) +Require Import Crypto.Util.Tactics.UniquePose Crypto.Util.Decidable. +Require Import Crypto.Util.Tuple Crypto.Util.Prod Crypto.Util.LetIn. +Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.NatUtil. +Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. +Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. +Require Import Crypto.Arithmetic.BarrettReduction.Generalized. +Require Import Crypto.Arithmetic.Partition. +Require Import Crypto.Arithmetic.ModularArithmeticTheorems. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.Tactics.RunTacticAsConstr. +Require Import Crypto.Util.Tactics.Head. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.OptionList. +Require Import Crypto.Util.Prod. +Require Import Crypto.Util.Sum. +Require Import Crypto.Util.Bool. +Require Import Crypto.Util.Sigma. +Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div Crypto.Util.ZUtil.Hints.Core. +Require Import Crypto.Util.ZUtil.Tactics.RewriteModSmall. +Require Import Crypto.Util.ZUtil.Tactics.PeelLe. +Require Import Crypto.Util.ZUtil.Tactics.LinearSubstitute. +Require Import Crypto.Util.ZUtil.Tactics.ZeroBounds. +Require Import Crypto.Util.ZUtil.Modulo.PullPush. +Require Import Crypto.Util.ZUtil.Opp. +Require Import Crypto.Util.ZUtil.Log2. +Require Import Crypto.Util.ZUtil.Le. +Require Import Crypto.Util.ZUtil.Hints.PullPush. +Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit. +Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem. +Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.Tactics.SplitInContext. +Require Import Crypto.Util.Tactics.SubstEvars. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.Sorting. +Require Import Crypto.Util.ZUtil.CC Crypto.Util.ZUtil.Rshi. +Require Import Crypto.Util.ZUtil.Zselect Crypto.Util.ZUtil.AddModulo. +Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit. +Require Import Crypto.Util.ZUtil.Hints.Core. +Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div. +Require Import Crypto.Util.ZUtil.Hints.PullPush. +Require Import Crypto.Util.ZUtil.EquivModulo. +Require Import Crypto.Util.Prod. +Require Import Crypto.Util.CPSNotations. +Require Import Crypto.Util.Equality. +Require Import Crypto.Util.Tactics.SetEvars. +Import Coq.Lists.List ListNotations. Local Open Scope Z_scope. + +Section Generic. + Context (b k M mu width : Z) (n : nat) + (b_ok : 1 < b) + (k_pos : 0 < k) + (bk_eq : b^k = 2^(width * Z.of_nat n)) + (M_range : b ^ (k - 1) < M < b ^ k) + (mu_eq : mu = b ^ (2 * k) / M) + (width_pos : 0 < width) + (strong_bound : b ^ 1 * (b ^ (2 * k) mod M) <= b ^ (k + 1) - mu). + Local Notation weight := (UniformWeight.uweight width). + Local Notation partition := (Partition.partition weight). + Context (q1 : list Z -> list Z) + (q1_correct : + forall x, + 0 <= x < b ^ (2 * k) -> + q1 (partition (n*2)%nat x) = partition (n+1)%nat (x / b ^ (k - 1))) + (q3 : list Z -> list Z -> list Z) + (q3_correct : + forall x q1, + 0 <= x < b ^ (2 * k) -> + q1 = x / b ^ (k - 1) -> + q3 (partition (n*2) x) (partition (n+1) q1) = partition (n+1) ((mu * q1) / b ^ (k + 1))) + (r : list Z -> list Z -> list Z) + (r_correct : + forall x q3, + 0 <= x < M * b ^ k -> + 0 <= q3 -> + (exists b : bool, q3 = x / M + (if b then -1 else 0)) -> + r (partition (n*2) x) (partition (n+1) q3) = partition n (x mod M)). + + Context (x : Z) (x_range : 0 <= x < M * b ^ k) + (xt : list Z) (xt_correct : xt = partition (n*2) x). + + Local Lemma M_pos : 0 < M. + Proof. assert (0 <= b ^ (k - 1)) by Z.zero_bounds. lia. Qed. + Local Lemma M_upper : M < weight n. + Proof. rewrite UniformWeight.uweight_eq_alt'. lia. Qed. + Local Lemma x_upper : x < b ^ (2 * k). + Proof. + assert (0 < b ^ k) by Z.zero_bounds. + apply Z.lt_le_trans with (m:= M * b^k); [ lia | ]. + transitivity (b^k * b^k); [ nia | ]. + rewrite <-Z.pow_2_r, <-Z.pow_mul_r by lia. + rewrite (Z.mul_comm k 2); reflexivity. + Qed. + Local Lemma xmod_lt_M : x mod b ^ (k - 1) <= M. + Proof. pose proof (Z.mod_pos_bound x (b ^ (k - 1)) ltac:(Z.zero_bounds)). lia. Qed. + Local Hint Resolve M_pos x_upper xmod_lt_M. + + Definition reduce := + dlet_nd q1t := q1 xt in + dlet_nd q3t := q3 xt q1t in + r xt q3t. + + Lemma q1_range : 0 <= x / b^(k-1) < b^(k+1). + Proof. + split; [ solve [Z.zero_bounds] | ]. + assert (0 < b ^ (k - 1)) by Z.zero_bounds. + assert (0 < b ^ k) by Z.zero_bounds. + apply Z.div_lt_upper_bound; [ solve [Z.zero_bounds] | ]. + eapply Z.lt_le_trans with (m:=b^k * b^k); + [ nia | autorewrite with pull_Zpow; apply Z.pow_le_mono; lia ]. + Qed. + + Lemma q3_range : 0 <= mu * (x / b ^ (k - 1)) / b ^ (k + 1). + Proof. + assert (0 < b ^ (k - 1)) by Z.zero_bounds. + subst mu; Z.zero_bounds. + Qed. + + Lemma reduce_correct : reduce = partition n (x mod M). + Proof. + cbv [reduce Let_In]. pose proof q3_range. + rewrite xt_correct, q1_correct, q3_correct by auto with lia. + assert (exists cond : bool, ((mu * (x / b^(k-1))) / b^(k+1)) = x / M + (if cond then -1 else 0)) as Hq3. + { destruct q_nice_strong with (b:=b) (k:=k) (m:=mu) (offset:=1) (a:=x) (n:=M) as [cond Hcond]; + eauto using Z.lt_gt with zarith. } + eauto using r_correct with lia. + Qed. +End Generic. + +(* Non-standard implementation -- uses specialized instructions and b=2 *) +Module Fancy. + Section Fancy. + Context (M mu width k : Z) + (sz : nat) (sz_nz : sz <> 0%nat) + (width_ok : 1 < width) + (k_pos : 0 < k) (* this can be inferred from other arguments but is easier to put here for tactics *) + (k_eq : k = width * Z.of_nat sz). + (* sz = 1, width = k = 256 *) + Local Notation w := (UniformWeight.uweight width). Local Notation eval := (Positional.eval w). + Context (mut Mt : list Z) (mut_correct : mut = partition w (sz+1) mu) (Mt_correct : Mt = partition w sz M). + Context (mu_eq : mu = 2 ^ (2 * k) / M) (muHigh_one : mu / w sz = 1) (M_range : 2^(k-1) < M < 2^k). + + Local Lemma wprops : @weight_properties w. Proof. apply UniformWeight.uwprops; auto with lia. Qed. + Local Hint Resolve wprops. + Hint Rewrite mut_correct Mt_correct : pull_partition. + + Lemma w_eq_2k : w sz = 2^k. Proof. rewrite UniformWeight.uweight_eq_alt' by auto. congruence. Qed. + Lemma mu_range : 2^k <= mu < 2^(k+1). + Proof. + rewrite mu_eq. assert (0 < 2^(k-1)) by Z.zero_bounds. + assert (2^k < M * 2). + { replace (2^k) with (2^(k-1+1)) by (f_equal; lia). + rewrite Z.pow_add_r, Z.pow_1_r by lia. + lia. } + replace (2 ^ (2 * k)) with (2^(k+k)) by (f_equal; lia). + rewrite !Z.pow_add_r, Z.pow_1_r by lia. split. + { apply Z.div_le_lower_bound; nia. } + { apply Z.div_lt_upper_bound; nia. } + Qed. + Lemma mu_range' : 0 <= mu < 2 * w sz. + Proof. + pose proof mu_range. assert (0 < 2^k) by auto with zarith. + assert (2^(k+1) = 2 * w sz); [ | lia]. + rewrite k_eq, UniformWeight.uweight_eq_alt'. + rewrite Z.pow_add_r, Z.pow_1_r by lia. lia. + Qed. + Lemma M_range' : 0 <= M < w sz. (* more convenient form, especially for mod_small *) + Proof. assert (0 <= 2 ^ (k-1)) by Z.zero_bounds. pose proof w_eq_2k; lia. Qed. + + Definition shiftr' (m : nat) (t : list Z) (n : Z) : list Z := + map (fun i => Z.rshi (2^width) (nth_default 0 t (S i)) (nth_default 0 t i) n) (seq 0 m). + + Definition shiftr (m : nat) (t : list Z) (n : Z) : list Z := + (* if width <= n, drop limbs first *) + if dec (width <= n) + then shiftr' m (skipn (Z.to_nat (n / width)) t) (n mod width) + else shiftr' m t n. + + Definition wideadd t1 t2 := fst (Rows.add w (sz*2) t1 t2). + Definition widesub t1 t2 := fst (Rows.sub w (sz*2) t1 t2). + Definition widemul := BaseConversion.widemul_inlined width sz 2. + (* widemul_inlined takes the following argument order : (width of limbs in input) (# limbs in input) (# parts to split each limb into before multiplying) *) + + Definition fill (n : nat) (a : list Z) := a ++ Positional.zeros (n - length a). + Definition low : list Z -> list Z := firstn sz. + Definition high : list Z -> list Z := skipn sz. + Definition mul_high (a b : list Z) a0b1 : list Z := + dlet_nd a0b0 := widemul (low a) (low b) in + dlet_nd ab := wideadd (high a0b0 ++ high b) (fill (sz*2) (low b)) in + wideadd ab a0b1. + + (* select based on the most significant bit of xHigh *) + Definition muSelect xt := + let xHigh := nth_default 0 xt (sz*2 - 1) in + Positional.select (Z.cc_m (2 ^ width) xHigh) (Positional.zeros sz) (low mut). + + Definition cond_sub (a y : list Z) : list Z := + let cond := Z.cc_l (nth_default 0 (high a) 0) in (* a[k] = least significant bit of (high a) *) + dlet_nd maybe_y := Positional.select cond (Positional.zeros sz) y in + dlet_nd diff := Rows.sub w sz (low a) maybe_y in (* (a mod (w sz) - y) mod (w sz)) = (a - y) mod (w sz); since we know a - y is < w sz this is okay by mod_small *) + fst diff. + + Definition cond_subM x := + if Nat.eq_dec sz 1 + then [Z.add_modulo (nth_default 0 x 0) 0 M] (* use the special instruction if we can *) + else Rows.conditional_sub w sz x Mt. + + Definition q1 (xt : list Z) := shiftr (sz+1) xt (k - 1). + + Definition q3 (xt q1t : list Z) := + dlet_nd muSelect := muSelect xt in (* make sure muSelect is not inlined in the output *) + dlet_nd twoq := mul_high (fill (sz*2) mut) (fill (sz*2) q1t) (fill (sz*2) muSelect) in + shiftr (sz+1) twoq 1. + + Definition r (xt q3t : list Z) := + dlet_nd r2 := widemul (low q3t) Mt in + dlet_nd rt := widesub xt r2 in + dlet_nd rt := cond_sub rt Mt in + cond_subM rt. + + Section Proofs. + Lemma shiftr'_correct m n : + forall t tn, + (m <= tn)%nat -> 0 <= t < w tn -> 0 <= n < width -> + shiftr' m (partition w tn t) n = partition w m (t / 2 ^ n). + Proof. + cbv [shiftr']. induction m; intros; [ reflexivity | ]. + rewrite !partition_step, seq_snoc. + autorewrite with distr_length natsimplify push_map push_nth_default. + rewrite IHm, Z.rshi_correct, UniformWeight.uweight_S by auto with zarith. + rewrite <-Z.mod_pull_div by auto with zarith. + destruct (Nat.eq_dec (S m) tn); [subst tn | ]; rewrite !nth_default_partition by omega. + { rewrite nth_default_out_of_bounds by distr_length. + autorewrite with zsimplify. Z.rewrite_mod_small. + rewrite Z.div_div_comm by auto with zarith; reflexivity. } + { repeat match goal with + | _ => rewrite UniformWeight.uweight_pull_mod by auto with zarith + | _ => rewrite Z.mod_mod_small by auto with zarith + | _ => rewrite <-Znumtheory.Zmod_div_mod by (Z.zero_bounds; auto with zarith) + | _ => rewrite UniformWeight.uweight_eq_alt with (n:=1%nat) by auto with zarith + | |- context [(t / w (S m)) mod 2^width * 2^width] => + replace (t / w (S m)) with (t / w m / 2^width) by + (rewrite UniformWeight.uweight_S, Z.div_div by auto with zarith; f_equal; lia); + rewrite Z.mod_pull_div with (b:=2^width) by auto with zarith; + rewrite Z.mul_div_eq' by auto with zarith + | _ => progress autorewrite with natsimplify zsimplify_fast zsimplify + end. + replace (2^width*2^width) with (2^width*2^(width-n)*2^n) by (autorewrite with pull_Zpow; f_equal; lia). + rewrite <-Z.mod_pull_div, <-Znumtheory.Zmod_div_mod by (Z.zero_bounds; auto with zarith). + rewrite Z.div_div_comm by Z.zero_bounds. reflexivity. } + Qed. + Lemma shiftr_correct m n : + forall t tn, + (Z.to_nat (n / width) <= tn)%nat -> + (m <= tn - Z.to_nat (n / width))%nat -> 0 <= t < w tn -> 0 <= n -> + shiftr m (partition w tn t) n = partition w m (t / 2 ^ n). + Proof. + cbv [shiftr]; intros. + break_innermost_match; [ | solve [auto using shiftr'_correct with zarith] ]. + pose proof (Z.mod_pos_bound n width ltac:(omega)). + assert (t / 2 ^ (n - n mod width) < w (tn - Z.to_nat (n / width))). + { apply Z.div_lt_upper_bound; [solve [Z.zero_bounds] | ]. + rewrite UniformWeight.uweight_eq_alt' in *. + rewrite <-Z.pow_add_r, Nat2Z.inj_sub, Z2Nat.id, <-Z.mul_div_eq by auto with zarith. + autorewrite with push_Zmul zsimplify. auto with zarith. } + repeat match goal with + | _ => progress rewrite ?UniformWeight.uweight_skipn_partition, ?UniformWeight.uweight_eq_alt' by auto with lia + | _ => rewrite Z2Nat.id by Z.zero_bounds + | _ => rewrite Z.mul_div_eq_full by auto with zarith + | _ => rewrite shiftr'_correct by auto with zarith + | _ => progress rewrite ?Z.div_div, <-?Z.pow_add_r by auto with zarith + end. + autorewrite with zsimplify. reflexivity. + Qed. + Hint Rewrite shiftr_correct using (solve [auto with lia]) : pull_partition. + + (* 2 ^ (k + 1) bits fit in sz + 1 limbs because we know 2^k bits fit in sz and 1 <= width *) + Lemma q1_correct x : + 0 <= x < w (sz * 2) -> + q1 (partition w (sz*2)%nat x) = partition w (sz+1)%nat (x / 2 ^ (k - 1)). + Proof. + cbv [q1]; intros. assert (1 <= Z.of_nat sz) by (destruct sz; lia). + assert (Z.to_nat ((k-1) / width) < sz)%nat. { + subst k. rewrite <-Z.add_opp_r. autorewrite with zsimplify. + apply Nat2Z.inj_lt. rewrite Z2Nat.id by lia. lia. } + assert (0 <= k - 1) by nia. + autorewrite with pull_partition. reflexivity. + Qed. + + Lemma low_correct n a : (sz <= n)%nat -> low (partition w n a) = partition w sz a. + Proof. cbv [low]; auto using UniformWeight.uweight_firstn_partition with lia. Qed. + Lemma high_correct a : high (partition w (sz*2) a) = partition w sz (a / w sz). + Proof. cbv [high]. rewrite UniformWeight.uweight_skipn_partition by lia. f_equal; lia. Qed. + Lemma fill_correct n m a : + (n <= m)%nat -> + fill m (partition w n a) = partition w m (a mod w n). + Proof. + cbv [fill]; intros. distr_length. + rewrite <-partition_0 with (weight:=w). + rewrite UniformWeight.uweight_partition_app by lia. + f_equal; lia. + Qed. + Hint Rewrite low_correct high_correct fill_correct using lia : pull_partition. + + Lemma wideadd_correct a b : + wideadd (partition w (sz*2) a) (partition w (sz*2) b) = partition w (sz*2) (a + b). + Proof. + cbv [wideadd]. rewrite Rows.add_partitions by (distr_length; auto). + autorewrite with push_eval. + apply partition_eq_mod; auto with zarith. + Qed. + Lemma widesub_correct a b : + widesub (partition w (sz*2) a) (partition w (sz*2) b) = partition w (sz*2) (a - b). + Proof. + cbv [widesub]. rewrite Rows.sub_partitions by (distr_length; auto). + autorewrite with push_eval. + apply partition_eq_mod; auto with zarith. + Qed. + Lemma widemul_correct a b : + widemul (partition w sz a) (partition w sz b) = partition w (sz*2) ((a mod w sz) * (b mod w sz)). + Proof. + cbv [widemul]. rewrite BaseConversion.widemul_inlined_correct; (distr_length; auto). + autorewrite with push_eval. reflexivity. + Qed. + Hint Rewrite widemul_correct widesub_correct wideadd_correct using lia : pull_partition. + + Lemma mul_high_idea d a b a0 a1 b0 b1 : + d <> 0 -> + a = d * a1 + a0 -> + b = d * b1 + b0 -> + (a * b) / d = a0 * b0 / d + d * a1 * b1 + a1 * b0 + a0 * b1. + Proof. + intros. subst a b. autorewrite with push_Zmul. + ring_simplify_subterms. rewrite Z.pow_2_r. + rewrite Z.div_add_exact by (push_Zmod; autorewrite with zsimplify; omega). + repeat match goal with + | |- context [d * ?a * ?b * ?c] => + replace (d * a * b * c) with (a * b * c * d) by ring + | |- context [d * ?a * ?b] => + replace (d * a * b) with (a * b * d) by ring + end. + rewrite !Z.div_add by omega. + autorewrite with zsimplify. + rewrite (Z.mul_comm a0 b0). + ring_simplify. ring. + Qed. + + Lemma mul_high_correct a b + (Ha : a / w sz = 1) + a0b1 (Ha0b1 : a0b1 = a mod w sz * (b / w sz)) : + mul_high (partition w (sz*2) a) (partition w (sz*2) b) (partition w (sz*2) a0b1) = + partition w (sz*2) (a * b / w sz). + Proof. + cbv [mul_high Let_In]. + erewrite mul_high_idea by auto using Z.div_mod with zarith. + repeat match goal with + | _ => progress autorewrite with pull_partition + | _ => progress rewrite ?Ha, ?Ha0b1 + | _ => rewrite UniformWeight.uweight_partition_app by lia; + replace (sz+sz)%nat with (sz*2)%nat by lia + | _ => rewrite Z.mod_pull_div by auto with zarith + | _ => progress Z.rewrite_mod_small + | _ => f_equal; ring + end. + Qed. + + Hint Rewrite UniformWeight.uweight_S UniformWeight.uweight_eq_alt' using lia : weight_to_pow. + Hint Rewrite <-UniformWeight.uweight_S UniformWeight.uweight_eq_alt' using lia : pow_to_weight. + + Lemma q1_range x : + 0 <= x < w (sz * 2) -> + 0 <= x / 2 ^ (k-1) < 2 * w sz. + Proof. + intros; split; [ solve [Z.zero_bounds] | ]. + apply Z.div_lt_upper_bound; [ solve [Z.zero_bounds] | ]. + assert (w (sz * 2) <= 2 ^ (k-1) * (2 * w sz)); [ | lia ]. + autorewrite with weight_to_pow pull_Zpow. + apply Z.pow_le_mono_r; lia. + Qed. + + (* use zero_bounds in zutil_arith *) + Local Ltac zutil_arith ::= solve [ omega | Psatz.lia | auto with nocore | solve [Z.zero_bounds] ]. + + Lemma muSelect_correct x : + 0 <= x < w (sz * 2) -> + muSelect (partition w (sz*2) x) = partition w sz (mu mod (w sz) * (x / 2 ^ (k - 1) / (w sz))). + Proof. + cbv [muSelect]; intros; + repeat match goal with + | _ => progress autorewrite with pull_partition natsimplify + | _ => progress rewrite ?Z.cc_m_eq by auto with zarith + | _ => erewrite Positional.select_eq by (distr_length; eauto) + | _ => rewrite nth_default_partition by lia + | _ => progress replace (S (sz * 2 - 1)) with (sz * 2)%nat by lia + | H : 0 <= ?x < ?m |- context [?x mod ?m] => rewrite (Z.mod_small x m) by auto with zarith + end. + replace (x / (w (sz * 2 - 1)) / (2 ^ width / 2)) with (x / (2 ^ (k - 1)) / w sz) by + (autorewrite with weight_to_pow pull_Zpow pull_Zdiv; do 2 f_equal; nia). + rewrite Z.div_between_0_if with (a:=x / 2^(k-1)) by (Z.zero_bounds; auto using q1_range). + break_innermost_match; try lia; autorewrite with zsimplify_fast; [ | ]. + { apply partition_eq_mod; auto with zarith. } + { rewrite partition_0; reflexivity. } + Qed. + Hint Rewrite muSelect_correct using lia : pull_partition. + + Lemma mu_q1_range x (Hx : 0 <= x < w (sz * 2)) : mu * (x / 2^(k-1)) < w sz * w (sz * 2). + Proof. + pose proof mu_range'. pose proof q1_range x ltac:(lia). + replace (w (sz * 2)) with (w sz * w sz) by + (autorewrite with weight_to_pow pull_Zpow; f_equal; lia). + apply Z.lt_le_trans with (m:= 2 * w sz * (2 * w sz)); [ nia | ]. + assert (4 <= w sz); [ | nia ]. change 4 with (Z.pow 2 2). + autorewrite with weight_to_pow. apply Z.pow_le_mono_r; nia. + Qed. + + Lemma q3_correct x (Hx : 0 <= x < w (sz * 2)) q1 (Hq1 : q1 = x / 2 ^ (k - 1)) : + q3 (partition w (sz*2) x) (partition w (sz+1) q1) = partition w (sz+1) ((mu*q1) / 2 ^ (k + 1)). + Proof. + cbv [q3 Let_In]. intros. pose proof mu_q1_range x ltac:(lia). + pose proof mu_range'. pose proof q1_range x ltac:(lia). + autorewrite with pull_partition pull_Zmod. + assert (2 * w sz < w (sz + 1)) by (autorewrite with weight_to_pow pull_Zpow; auto with zarith lia). + Z.rewrite_mod_small. rewrite <-Hq1 in *. + rewrite mul_high_correct by + (try lia; rewrite Z.div_between_0_if with (a:=q1) by lia; + break_innermost_match; autorewrite with zsimplify; reflexivity). + rewrite shiftr_correct by (rewrite ?Z.div_small, ?Z2Nat.inj_0 by lia; auto with zarith lia). + autorewrite with weight_to_pow pull_Zpow pull_Zdiv. + congruence. + Qed. + + Lemma cond_sub_correct a b : + cond_sub (partition w (sz*2) a) (partition w sz b) + = partition w sz (if dec ((a / w sz) mod 2 = 0) + then a + else a - b). + Proof. + intros; cbv [cond_sub Let_In Z.cc_l]. autorewrite with pull_partition. + rewrite nth_default_partition by lia. + rewrite weight_0 by auto. autorewrite with zsimplify_fast. + rewrite UniformWeight.uweight_eq_alt' with (n:=1%nat). autorewrite with push_Zof_nat zsimplify. + rewrite <-Znumtheory.Zmod_div_mod by auto using Zpow_facts.Zpower_divide with zarith. + rewrite Positional.select_eq with (n:=sz) by (distr_length; apply w). + rewrite Rows.sub_partitions by (break_innermost_match; distr_length; auto). + break_innermost_match; autorewrite with push_eval zsimplify_fast; + apply partition_eq_mod; auto with zarith. + Qed. + Hint Rewrite cond_sub_correct : pull_partition. + Lemma cond_subM_correct a : + cond_subM (partition w sz a) + = partition w sz (if dec (a mod w sz < M) + then a + else a - M). + Proof. + cbv [cond_subM]. autorewrite with pull_partition. pose proof M_range'. + rewrite Rows.conditional_sub_partitions by + (distr_length; auto; autorewrite with push_eval; try apply partition_eq_mod; auto with zarith). + rewrite nth_default_partition, weight_0, Z.add_modulo_correct by auto with lia. + autorewrite with zsimplify_fast push_eval. Z.rewrite_mod_small. + pose proof Z.mod_pos_bound a (w 1) ltac:(auto). + break_innermost_match; Z.ltb_to_lt; + repeat match goal with + | _ => lia + | _ => reflexivity + | _ => apply partition_eq_mod; solve [auto with zarith] + | _ => rewrite partition_step, weight_0 by auto + | _ => progress autorewrite with zsimplify_fast + | _ => progress Z.rewrite_mod_small + | _ => rewrite Z.sub_mod_l with (a:=a) + end. + Qed. + Hint Rewrite cond_subM_correct : pull_partition. + + Lemma w_eq_22k : w (sz * 2) = 2 ^ (2 * k). + Proof. + replace (sz * 2)%nat with (sz + sz)%nat by lia. + rewrite UniformWeight.uweight_sum_indices, w_eq_2k, <-Z.pow_add_r by lia. + f_equal; lia. + Qed. + + Lemma r_idea x q3 (b:bool) : + 0 <= x < M * 2 ^ k -> + 0 <= q3 -> + q3 = x / M + (if b then -1 else 0) -> + x - q3 mod w sz * M = x mod M + (if b then M else 0). + Proof. + intros. assert (0 < 2^(k-1)) by Z.zero_bounds. + assert (q3 < w sz). + { apply Z.le_lt_trans with (m:=x/M); [ subst q3; break_innermost_match; lia | ]. + autorewrite with weight_to_pow. rewrite <-k_eq. auto with zarith. } + Z.rewrite_mod_small. + repeat match goal with + | _ => progress autorewrite with push_Zmul + | H : q3 = ?e |- _ => progress replace (q3 * M) with (e * M) by (rewrite H; reflexivity) + | _ => rewrite (Z.mul_div_eq' x M) by lia + end. + break_innermost_match; Z.ltb_to_lt; lia. + Qed. + + Lemma r_correct x q3 : + 0 <= x < M * 2 ^ k -> + 0 <= q3 -> + (exists b : bool, q3 = x / M + (if b then -1 else 0)) -> + r (partition w (sz*2) x) (partition w (sz+1) q3) = partition w sz (x mod M). + Proof. + intros; cbv [r Let_In]. pose proof M_range'. assert (0 < 2^(k-1)) by Z.zero_bounds. + autorewrite with pull_partition. Z.rewrite_mod_small. + match goal with H : exists _, q3 = _ |- _ => destruct H end. + erewrite r_idea by eassumption. + pose proof (Z.mod_pos_bound x M ltac:(lia)). + rewrite Z.div_between_0_if with (b:=w sz) by (break_innermost_match; auto with zarith). + rewrite Z.mod_small with (b:=2) by (break_innermost_match; lia). + break_innermost_match; Z.ltb_to_lt; try lia; autorewrite with zsimplify_fast; + repeat match goal with + | |- exists e, _ /\ _ /\ ?f ?x = ?f e => exists x; split; [ | split ] + | _ => rewrite Z.mod_small in * by lia + | _ => progress Z.rewrite_mod_small + | _ => progress (push_Zmod; pull_Zmod); autorewrite with zsimplify_fast + | _ => lia + | _ => reflexivity + end. + Qed. + End Proofs. + + Section Def. + Context (sz_eq_1 : sz = 1%nat). (* this is needed to get rid of branches in the templates; a different definition would be needed for sizes other than 1, but would be able to use the same proofs. *) + Local Hint Resolve q1_correct q3_correct r_correct. + + (* muselect relies on an initially-set flag, so pull it out of q3 *) + Definition fancy_reduce_muSelect_first xt := + dlet_nd muSelect := muSelect xt in + dlet_nd q1t := q1 xt in + dlet_nd twoq := mul_high (fill (sz * 2) mut) (fill (sz * 2) q1t) (fill (sz * 2) muSelect) in + dlet_nd q3t := shiftr (sz+1) twoq 1 in + r xt q3t. + + Lemma fancy_reduce_muSelect_first_correct x : + 0 <= x < M * 2^k -> + 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu -> + fancy_reduce_muSelect_first (partition w (sz*2) x) = partition w sz (x mod M). + Proof. + intros. pose proof w_eq_22k. + erewrite <-reduce_correct with (b:=2) (k:=k) (mu:=mu) by + (eauto with nia; intros; try rewrite q3'_correct; try rewrite <-k_eq; eauto with nia ). + reflexivity. + Qed. + + Derive fancy_reduce' + SuchThat ( + forall x, + 0 <= x < M * 2^k -> + 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu -> + fancy_reduce' (partition w (sz*2) x) = partition w sz (x mod M)) + As fancy_reduce'_correct. + Proof. + intros. assert (k = width) as width_eq_k by nia. + erewrite <-fancy_reduce_muSelect_first_correct by nia. + cbv [fancy_reduce_muSelect_first q1 q3 shiftr r cond_subM]. + break_match; try solve [exfalso; lia]. + match goal with |- ?g ?x = ?rhs => + let f := (match (eval pattern x in rhs) with ?f _ => f end) in + assert (f = g); subst fancy_reduce'; reflexivity + end. + Qed. + + Definition fancy_reduce xLow xHigh := hd 0 (fancy_reduce' [xLow;xHigh]). + + Lemma partition_2 xLow xHigh : + 0 <= xLow < 2 ^ k -> + 0 <= xHigh < M -> + partition w 2 (xLow + 2^k * xHigh) = [xLow;xHigh]. + Proof. + replace k with width in M_range |- * by nia; intros. cbv [partition map seq]. + rewrite !UniformWeight.uweight_S, !weight_0 by auto with zarith lia. + autorewrite with zsimplify. + rewrite <-Z.mod_pull_div by Z.zero_bounds. + autorewrite with zsimplify. reflexivity. + Qed. + + Lemma fancy_reduce_correct xLow xHigh : + 0 <= xLow < 2 ^ k -> + 0 <= xHigh < M -> + 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu -> + fancy_reduce xLow xHigh = (xLow + 2^k * xHigh) mod M. + Proof. + assert (M < 2^width) by (replace width with k by nia; lia). + assert (0 < 2 ^ (k - 1)) by Z.zero_bounds. + pose proof (Z.mod_pos_bound (xLow + 2^k * xHigh) M ltac:(lia)). + intros. cbv [fancy_reduce]. rewrite <-partition_2 by lia. + replace 2%nat with (sz*2)%nat by lia. + rewrite fancy_reduce'_correct by nia. + rewrite sz_eq_1; cbv [partition map seq hd]. + rewrite !UniformWeight.uweight_S, !weight_0 by auto with zarith lia. + autorewrite with zsimplify. reflexivity. + Qed. + End Def. + End Fancy. +End Fancy. \ No newline at end of file -- cgit v1.2.3