From 3e4edb9a9b8cc15bdc02b9005e0b94561645b77b Mon Sep 17 00:00:00 2001 From: jadep Date: Thu, 7 Mar 2019 15:18:42 -0500 Subject: Get new Barrett proofs to generate Fancy code as before --- src/Arithmetic.v | 1315 ++++++++++---------- src/Arithmetic/BarrettReduction/Generalized.v | 9 + src/Fancy/Barrett256.v | 21 +- src/PushButtonSynthesis/BarrettReduction.v | 97 +- .../BarrettReductionReificationCache.v | 6 +- src/Rewriter.v | 20 + 6 files changed, 762 insertions(+), 706 deletions(-) (limited to 'src') diff --git a/src/Arithmetic.v b/src/Arithmetic.v index 1a25532f3..65eff76a4 100644 --- a/src/Arithmetic.v +++ b/src/Arithmetic.v @@ -56,6 +56,8 @@ Require Import Crypto.Util.Equality. Require Import Crypto.Util.Tactics.SetEvars. Import Coq.Lists.List ListNotations. Local Open Scope Z_scope. +Hint Rewrite Nat.add_1_r : natsimplify. (* TODO : put in a better location *) + Module Associational. Definition eval (p:list (Z*Z)) : Z := fold_right (fun x y => x + y) 0%Z (map (fun t => fst t * snd t) p). @@ -1831,6 +1833,15 @@ Module Partition. partition n x = partition n y. Proof. apply partition_Proper. Qed. + Lemma nth_default_partition d n x i : + (i < n)%nat -> + nth_default d (partition n x) i = x mod weight (S i) / weight i. + Proof. + cbv [partition]; intros. + rewrite map_nth_default with (x:=0%nat) by distr_length. + autorewrite with push_nth_default natsimplify. reflexivity. + Qed. + Fixpoint recursive_partition n i x := match n with | O => [] @@ -1877,6 +1888,7 @@ Module Partition. rewrite Nat.min_l by omega. reflexivity. Qed. + End Partition. Hint Rewrite length_partition length_recursive_partition : distr_length. Hint Rewrite eval_partition using (solve [auto; distr_length]) : push_eval. @@ -2498,7 +2510,7 @@ Module Rows. Definition flatten' (start_state : list Z * Z) (inp : rows) : list Z * Z := fold_right (fun next_row (state : list Z * Z)=> - let out_carry := sum_rows next_row (fst state) in + let out_carry := sum_rows (fst state) next_row in (fst out_carry, snd state + snd out_carry)) start_state inp. (* In order for the output to have the right length and bounds, @@ -2508,10 +2520,10 @@ Module Rows. flatten' (hd default inp, 0) (hd default (tl inp) :: tl (tl inp)). Lemma flatten'_cons state r inp : - flatten' state (r :: inp) = (fst (sum_rows r (fst (flatten' state inp))), snd (flatten' state inp) + snd (sum_rows r (fst (flatten' state inp)))). + flatten' state (r :: inp) = (fst (sum_rows (fst (flatten' state inp)) r), snd (flatten' state inp) + snd (sum_rows (fst (flatten' state inp)) r)). Proof using Type. cbv [flatten']; autorewrite with list push_fold_right. reflexivity. Qed. Lemma flatten'_snoc state r inp : - flatten' state (inp ++ r :: nil) = flatten' (fst (sum_rows r (fst state)), snd state + snd (sum_rows r (fst state))) inp. + flatten' state (inp ++ r :: nil) = flatten' (fst (sum_rows (fst state) r), snd state + snd (sum_rows (fst state) r)) inp. Proof using Type. cbv [flatten']; autorewrite with list push_fold_right. reflexivity. Qed. Lemma flatten'_nil state : flatten' state [] = state. Proof using Type. reflexivity. Qed. Hint Rewrite flatten'_cons flatten'_snoc flatten'_nil : push_flatten. @@ -3004,52 +3016,57 @@ Module BaseConversion. (* multiply two (n*k)-bit numbers by converting them to n k-bit limbs each, multiplying, then converting back *) Section widemul. Context (log2base : Z) (log2base_pos : 0 < log2base). - Context (n : nat) (n_nz : n <> 0%nat) (n_le_log2base : Z.of_nat n <= log2base) - (nout : nat) (nout_2 : nout = 2%nat). (* nout is always 2, but partial evaluation is overeager if it's a constant *) + Context (m n : nat) (m_nz : m <> 0%nat) (n_nz : n <> 0%nat) (n_le_log2base : Z.of_nat n <= log2base). Let dw : nat -> Z := weight (log2base / Z.of_nat n) 1. Let sw : nat -> Z := weight log2base 1. + Let mn := (m * n)%nat. + Let nout := (m * 2)%nat. + Local Lemma mn_nonzero : mn <> 0%nat. Proof. subst mn. apply Nat.neq_mul_0. auto. Qed. + Local Hint Resolve mn_nonzero. + Local Lemma nout_nonzero : nout <> 0%nat. Proof. subst nout. apply Nat.neq_mul_0. auto. Qed. + Local Hint Resolve nout_nonzero. Local Lemma base_bounds : 0 < 1 <= log2base. Proof using log2base_pos. clear -log2base_pos; auto with zarith. Qed. Local Lemma dbase_bounds : 0 < 1 <= log2base / Z.of_nat n. Proof using n_nz n_le_log2base. clear -n_nz n_le_log2base; auto with zarith. Qed. Let dwprops : @weight_properties dw := wprops (log2base / Z.of_nat n) 1 dbase_bounds. Let swprops : @weight_properties sw := wprops log2base 1 base_bounds. + Local Notation deval := (Positional.eval dw). + Local Notation seval := (Positional.eval sw). Hint Resolve Z.gt_lt Z.positive_is_nonzero Nat2Z.is_nonneg. - Definition widemul a b := mul_converted sw dw 1 1 n n nout (aligned_carries n nout) [a] [b]. + + (* TODO: Unspecialize from 2-limb *) + (* multi-limb version -- multiply two numbers that each have m limbs with (n*k) bits each, by converting them to numbers with (m*n) limbs of k bits each, multiplying, then converting back *) + Definition widemul a b := mul_converted sw dw m m mn mn nout (aligned_carries n nout) a b. Lemma widemul_correct a b : - 0 <= a * b < 2^log2base * 2^log2base -> - widemul a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base]. - Proof using dwprops swprops nout_2. - cbv [widemul]; intros. - rewrite mul_converted_partitions by auto with zarith. - subst nout. - unfold sw in *; cbv [weight]; cbn. - autorewrite with zsimplify. - rewrite Z.pow_mul_r, Z.pow_2_r by omega. - Z.rewrite_mod_small. reflexivity. - Qed. + length a = m -> + length b = m -> + widemul a b = Partition.partition sw nout (seval m a * seval m b). + Proof. apply mul_converted_partitions; auto with zarith. Qed. Derive widemul_inlined SuchThat (forall a b, - 0 <= a * b < 2^log2base * 2^log2base -> - widemul_inlined a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base]) + length a = m -> + length b = m -> + widemul_inlined a b = Partition.partition sw nout (seval m a * seval m b)) As widemul_inlined_correct. Proof. intros. rewrite <-widemul_correct by auto. cbv beta iota delta [widemul mul_converted]. - rewrite <-to_associational_inlined_correct with (p:=[a]). - rewrite <-to_associational_inlined_correct with (p:=[b]). + rewrite <-to_associational_inlined_correct with (p:=a). + rewrite <-to_associational_inlined_correct with (p:=b). rewrite <-from_associational_inlined_correct. subst widemul_inlined; reflexivity. Qed. Derive widemul_inlined_reverse SuchThat (forall a b, - 0 <= a * b < 2^log2base * 2^log2base -> - widemul_inlined_reverse a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base]) + length a = m -> + length b = m -> + widemul_inlined_reverse a b = Partition.partition sw nout (seval m a * seval m b)) As widemul_inlined_reverse_correct. Proof. intros. @@ -3060,10 +3077,10 @@ Module BaseConversion. [ | transitivity (from_associational sw dw idxs n p); [ | reflexivity ] ](* reverse to make addc chains line up *) end. { subst widemul_inlined_reverse; reflexivity. } - { rewrite from_associational_inlined_correct by (subst nout; auto). + { rewrite from_associational_inlined_correct by auto. cbv [from_associational]. rewrite !Rows.flatten_correct by eauto using Rows.length_from_associational. - rewrite !Rows.eval_from_associational by (subst nout; auto). + rewrite !Rows.eval_from_associational by auto. f_equal. rewrite !eval_carries, !Associational.bind_snd_correct, !Associational.eval_rev by auto. reflexivity. } @@ -3504,7 +3521,11 @@ Module UniformWeight. { transitivity (2 ^ 1); [ reflexivity | ]. apply Z.pow_le_mono_r; omega. } Qed. - + Lemma uweight_sum_indices lgr (Hr : 0 <= lgr) i j : uweight lgr (i + j) = uweight lgr i * uweight lgr j. + Proof. + rewrite !uweight_eq_alt by lia. + rewrite Nat2Z.inj_add; auto using Z.pow_add_r with zarith. + Qed. Lemma uweight_1 lgr : uweight lgr 1 = 2^lgr. Proof using Type. cbv [uweight weight]. @@ -3535,6 +3556,20 @@ Module UniformWeight. auto using uweight_recursive_partition_change_start with omega. Qed. + Lemma uweight_skipn_partition lgr (Hr : 0 < lgr) n x m : + skipn m (Partition.partition (uweight lgr) n x) = Partition.partition (uweight lgr) (n - m) (x / uweight lgr m). + Proof. + cbv [Partition.partition]; + repeat match goal with + | _ => progress intros + | _ => progress autorewrite with push_skipn natsimplify zsimplify_fast + | _ => rewrite skipn_seq by auto + | _ => rewrite weight_0 by auto using uwprops + | _ => rewrite Partition.recursive_partition_equiv' by auto using uwprops + | _ => auto using uweight_recursive_partition_change_start with zarith + end. + Qed. + Lemma uweight_partition_unique lgr (Hr : 0 < lgr) n ls : length ls = n -> (forall x, List.In x ls -> 0 <= x <= 2^lgr - 1) -> ls = Partition.partition (uweight lgr) n (Positional.eval (uweight lgr) n ls). @@ -3565,6 +3600,29 @@ Module UniformWeight. | [ |- ?x :: _ = ?x :: _ ] => apply f_equal end ]. Qed. + + Lemma uweight_eval_app' lgr (Hr : 0 <= lgr) n x y : + n = length x -> + Positional.eval (uweight lgr) (n + length y) (x ++ y) = Positional.eval (uweight lgr) n x + (uweight lgr n) * Positional.eval (uweight lgr) (length y) y. + Proof using Type. + induction y using rev_ind; + repeat match goal with + | _ => progress intros + | _ => progress distr_length + | _ => progress autorewrite with push_eval zsimplify natsimplify + | _ => rewrite Nat.add_succ_r + | H : ?x = 0%nat |- _ => subst x + | _ => progress rewrite ?app_nil_r, ?app_assoc + | _ => reflexivity + end. + rewrite IHy by auto. rewrite uweight_sum_indices; lia. + Qed. + + Lemma uweight_eval_app lgr (Hr : 0 <= lgr) n m x y : + n = length x -> + m = (n + length y)%nat -> + Positional.eval (uweight lgr) m (x ++ y) = Positional.eval (uweight lgr) n x + (uweight lgr n) * Positional.eval (uweight lgr) (length y) y. + Proof using Type. intros. subst m. apply uweight_eval_app'; lia. Qed. End UniformWeight. Module WordByWordMontgomery. @@ -4821,627 +4879,588 @@ Module WordByWordMontgomery. End WordByWordMontgomery. Hint Rewrite Z2Nat.inj_succ using omega : push_Zto_nat. (* TODO: put in PullPush *) -Hint Rewrite Nat.add_1_r : natsimplify. (* TODO : put in a better location *) Module BarrettReduction. - Locate barrett_reduction_small. - Check barrett_reduction_small. - Locate Associational. - Section Generic. - Context (k : Z) (k_pos : 0 < k) (b : Z) (b_ok : 2 < b). - Local Notation weight := (UniformWeight.uweight b). - Local Notation rep := (list Z). - - Definition eval (x : list Z) := Positional.eval weight (length x) x. - Definition valid (x : list Z) := - x = Partition.partition weight (length x) (eval x). - - (* TODO: make a lemma sort of like this (actually, probably doesn't have to be specialized to uniform weights) to allow splitting [app] in a positional evaluation *) - Lemma UniformWeight_eval_app : - forall b y n x, - length (x ++ y) = n -> - Positional.eval (UniformWeight.uweight b) n (x ++ y) - = Positional.eval (UniformWeight.uweight b) (length x) x - + UniformWeight.uweight b (length x) * Positional.eval (UniformWeight.uweight b) (n - length x) y. - Proof. - induction y using rev_ind; - repeat match goal with - | _ => progress (intros; subst) - | _ => progress distr_length; autorewrite with natsimplify in * - | |- context [(?a + ?b - ?a)%nat] => replace (a + b - a)%nat with b by lia - | _ => progress rewrite ?app_nil_r, ?app_assoc, ?Nat.add_succ_r - | _ => erewrite IHy by distr_length - | _ => progress autorewrite with natsimplify push_eval zsimplify_fast - | _ => reflexivity - end. - ring_simplify. - (* TODO: need to say here that uweight (i + j) = uweight i * uweight j *) - Admitted. - - Definition low : rep -> rep := firstn (Z.to_nat k). - Lemma low_correct : forall x, valid x -> eval (low x) = eval x mod b ^ k. - Proof. - cbv [eval low]; intros. - match goal with |- context [firstn ?n ?l] => - rewrite <-(firstn_skipn n l); - replace (firstn n (firstn n l ++ skipn n l)) with (firstn n l) - by (rewrite firstn_skipn; reflexivity) - end. - erewrite UniformWeight_eval_app by distr_length. - distr_length. - apply Nat.min_case_strong; intros. - { (* weight k = b ^ k, so second term goes away by mod *) - admit. } - { autorewrite with push_skipn push_eval zsimplify. - SearchAbout Partition.partition. - (* annoying, how do I prove valid -> mod_small? *) - (* might need to use recursive_paritition? *) - (* - if length x <= k: - skipn k x = [] - second term goes away by eval_nil - *) - distr_length. - 2:distr_length. - repeat match goal with - | |- context [Nat.min ?x (S (length ?p))] => - destruct (lt_dec (length p) x); - [ rewrite (Nat.min_r x) by lia | rewrite (Nat.min_l x) by lia ] - end. - - Print UniformWeight. - - apply natlike_ind with (x:=k); try omega; destruct x using rev_ind. - { autorewrite with push_eval zsimplify. reflexivity. } - { autorewrite with push_eval zsimplify. reflexivity. } - { intros; autorewrite with push_eval zsimplify distr_length push_firstn. reflexivity. } - { intros. clear IHx. - autorewrite with push_Zto_nat distr_length natsimplify push_firstn in *. - repeat match goal with - | |- context [Nat.min ?x (S (length ?p))] => - destruct (lt_dec (length p) x); - [ rewrite (Nat.min_r x) by lia | rewrite (Nat.min_l x) by lia ] - | |- context [firstn (?a - ?b)%nat [_] ] => - destruct (le_dec a b); [ rewrite (not_le_minus_0 a b) by lia - | rewrite (firstn_all2 (n:=(a-b)%nat)) by (distr_length; lia) ]; - try lia - end. - { rewrite firstn_all2 by lia. - autorewrite with push_eval in *. - rewrite Z.mod_small; [ reflexivity | ]. - (* since length x0 < S x1, it must be that length x0 <= x1 and therefore - [weight (length x0) <= b ^x1] *) - (* since length x0 < S x1, it must be that length x0 <= x1 and therefore - [weight (length x0) <= b ^x1] *) - - { admit. } - { SearchAbout (_ - _ = 0)%nat. - SearchAbout (S _ - _)%nat. - rewrite not_le_minus_0. - 2:lia. - - - match goal with |- context [firstn ?n ?l] => - rewrite <-(firstn_skipn n l); - replace (firstn n (firstn n l ++ skipn n l)) with (firstn n l) by (rewrite firstn_skipn; reflexivity) - end. - distr_length. - autorewrite with push_Zto_nat push_firstn push_skipn. - Print Rewrite HintDb push_skipn. - SearchAbout Nat.min S. - repeat apply Nat.min_case_strong; intros; try lia. - { SearchAbout Positional.eval app. - - match goal with - | |- context [(?a + (?b - ?a))%nat] => replace (a + (b - a))%nat with b by lia - end. - - { - Focus 9. - app - distr_length. - SearchAbout firstn nil. - - rewrite <-firstn - autorewrite with distr_length. - SearchAbout Nat.min. - apply Nat.min_case_strong; intros. - 2:autorewrite with push_firstn. - SearchAbout Positional.eval. - SearchAbout firstn. - Admitted. - Print Positional. - - - (* TODO : generalize to multi-word and operate on (list Z) instead of T; maybe stop taking ops as context variables *) + Import Partition. Section Generic. - Context {T} (rep : T -> Z -> Prop) - (k : Z) (k_pos : 0 < k) - (low : T -> Z) - (low_correct : forall a x, rep a x -> low a = x mod 2 ^ k) - (shiftr : T -> Z -> T) - (shiftr_correct : forall a x n, - rep a x -> - 0 <= n <= k -> - rep (shiftr a n) (x / 2 ^ n)) - (mul_high : T -> T -> Z -> T) - (mul_high_correct : forall a b x y x0y1, - rep a x -> - rep b y -> - 2 ^ k <= x < 2^(k+1) -> - 0 <= y < 2^(k+1) -> - x0y1 = x mod 2 ^ k * (y / 2 ^ k) -> - rep (mul_high a b x0y1) (x * y / 2 ^ k)) - (mul : Z -> Z -> T) - (mul_correct : forall x y, - 0 <= x < 2^k -> - 0 <= y < 2^k -> - rep (mul x y) (x * y)) - (sub : T -> T -> T) - (sub_correct : forall a b x y, - rep a x -> - rep b y -> - 0 <= x - y < 2^k * 2^k -> - rep (sub a b) (x - y)) - (cond_sub1 : T -> Z -> Z) - (cond_sub1_correct : forall a x y, - rep a x -> - 0 <= x < 2 * y -> - 0 <= y < 2 ^ k -> - cond_sub1 a y = if (x Z -> Z) - (cond_sub2_correct : forall x y, cond_sub2 x y = if (x list Z) + (q1_correct : + forall x, + 0 <= x < b ^ (2 * k) -> + q1 (partition (n*2)%nat x) = partition (n+1)%nat (x / b ^ (k - 1))) + (q3 : list Z -> list Z -> list Z) + (q3_correct : + forall x q1, + 0 <= x < b ^ (2 * k) -> + q1 = x / b ^ (k - 1) -> + q3 (partition (n*2) x) (partition (n+1) q1) = partition (n+1) ((mu * q1) / b ^ (k + 1))) + (r : list Z -> list Z -> list Z) + (r_correct : + forall x q1 q3, + 0 <= x < b ^ (2 * k) -> + q1 = x / b ^ (k - 1) -> + q3 = (mu * q1) / b ^ (k + 1) -> + r (partition (n*2) x) (partition (n+1) q3) = partition (n*2) (x - q3 * M)) + (final_reduce : list Z -> list Z) + (final_reduce_correct : + forall r, + 0 <= r < 2 * M -> + final_reduce (partition (n*2) r) = partition n (if dec (r < M) then r else r - M)). + + Context (x : Z) (x_range : 0 <= x < M * b ^ k) + (xt : list Z) (xt_correct : xt = partition (n*2) x). + + (* TODO: make barrett Z figure out some of this itself *) + Local Lemma M_pos : 0 < M. + Proof. assert (0 <= b ^ (k - 1)) by Z.zero_bounds. lia. Qed. + Local Lemma x_upper : x < b ^ (2 * k). Proof. - pose proof looser_bound. - subst mu. split. - { apply Z.div_le_lower_bound; omega. } - { apply Z.div_lt_upper_bound; try omega. - rewrite pow_2k_eq; apply Z.mul_lt_mono_pos_r; auto with zarith. } + assert (0 < b ^ k) by Z.zero_bounds. + apply Z.lt_le_trans with (m:= M * b^k); [ lia | ]. + transitivity (b^k * b^k); [ nia | ]. + rewrite <-Z.pow_2_r, <-Z.pow_mul_r by lia. + rewrite (Z.mul_comm k 2); reflexivity. Qed. + Local Lemma xmod_lt_M : x mod b ^ (k - 1) <= M. + Proof. pose proof (Z.mod_pos_bound x (b ^ (k - 1)) ltac:(Z.zero_bounds)). lia. Qed. + Local Hint Resolve M_pos x_upper xmod_lt_M. - Lemma shiftr_x_bounds : 0 <= x / 2 ^ (k - 1) < 2^(k+1). - Proof. - pose proof looser_bound. - split; [ solve [Z.zero_bounds] | ]. - apply Z.div_lt_upper_bound; auto with zarith. - rewrite <-pow_2k_eq. omega. - Qed. - Hint Resolve shiftr_x_bounds. - - Ltac solve_rep := eauto using shiftr_correct, mul_high_correct, mul_correct, sub_correct with omega. - - Let q := mu * (x / 2 ^ (k - 1)) / 2 ^ (k + 1). + Definition reduce := + dlet_nd q1t := q1 xt in + dlet_nd q3t := q3 xt q1t in + dlet_nd rt := r xt q3t in + final_reduce rt. - Lemma q_correct : rep qt q . + Lemma reduce_correct : reduce = partition n (x mod M). Proof. - pose proof mu_bounds. cbv [qt]; subst q. - rewrite Z.pow_add_r, <-Z.div_div by Z.zero_bounds. - solve_rep. - Qed. - Hint Resolve q_correct. - - Lemma x_mod_small : x mod 2 ^ (k - 1) <= M. - Proof. transitivity (2 ^ (k - 1)); auto with zarith. Qed. - Hint Resolve x_mod_small. - - Lemma q_bounds : 0 <= q < 2 ^ k. - Proof. - pose proof looser_bound. pose proof x_mod_small. pose proof mu_bounds. - split; subst q; [ solve [Z.zero_bounds] | ]. - edestruct q_nice_strong with (n:=M) as [? Hqnice]; - try rewrite Hqnice; auto; try omega; [ ]. - apply Z.le_lt_trans with (m:= x / M). - { break_match; omega. } - { apply Z.div_lt_upper_bound; omega. } - Qed. - - Lemma two_conditional_subtracts : - forall a x, - rep a x -> - 0 <= x < 2 * M -> - cond_sub2 (cond_sub1 a M) M = cond_sub2 (cond_sub2 x M) M. - Proof. - intros. - erewrite !cond_sub2_correct, !cond_sub1_correct by (eassumption || omega). - break_match; Z.ltb_to_lt; try lia; discriminate. - Qed. - - Lemma r_bounds : 0 <= x - q * M < 2 * M. - Proof. - pose proof looser_bound. pose proof q_bounds. pose proof x_mod_small. - subst q mu; split. - { Z.zero_bounds. apply qn_small; omega. } - { apply r_small_strong; rewrite ?Z.pow_1_r; auto; omega. } - Qed. - - Lemma reduce_correct : reduce = x mod M. - Proof. - pose proof looser_bound. pose proof r_bounds. pose proof q_bounds. - assert (2 * M < 2^k * 2^k) by nia. - rewrite barrett_reduction_small with (k:=k) (m:=mu) (offset:=1) (b:=2) by (auto; omega). cbv [reduce Let_In]. - erewrite low_correct by eauto. Z.rewrite_mod_small. - erewrite two_conditional_subtracts by solve_rep. - rewrite !cond_sub2_correct. - subst q; reflexivity. - Qed. - End Generic. - - Section BarrettReduction. - Context (k : Z) (k_bound : 2 <= k). - Context (M muLow : Z). - Context (M_pos : 0 < M) - (muLow_eq : muLow + 2^k = 2^(2*k) / M) - (muLow_bounds : 0 <= muLow < 2^k) - (M_bound1 : 2 ^ (k - 1) < M < 2^k) - (M_bound2: 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2^k)). - - Context (n:nat) (Hn_nz: n <> 0%nat) (n_le_k : Z.of_nat n <= k). - Context (nout : nat) (Hnout : nout = 2%nat). - Let w := weight k 1. - Local Lemma k_range : 0 < 1 <= k. Proof. omega. Qed. - Let props : @weight_properties w := wprops k 1 k_range. - - Hint Rewrite Positional.eval_nil Positional.eval_snoc : push_eval. - - Definition low (t : list Z) : Z := nth_default 0 t 0. - Definition high (t : list Z) : Z := nth_default 0 t 1. - Definition represents (t : list Z) (x : Z) := - t = [x mod 2^k; x / 2^k] /\ 0 <= x < 2^k * 2^k. - - Lemma represents_eq t x : - represents t x -> t = [x mod 2^k; x / 2^k]. - Proof. cbv [represents]; tauto. Qed. - - Lemma represents_length t x : represents t x -> length t = 2%nat. - Proof. cbv [represents]; intuition. subst t; reflexivity. Qed. - - Lemma represents_low t x : - represents t x -> low t = x mod 2^k. - Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed. - - Lemma represents_high t x : - represents t x -> high t = x / 2^k. - Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed. - - Lemma represents_low_range t x : - represents t x -> 0 <= x mod 2^k < 2^k. - Proof. auto with zarith. Qed. - - Lemma represents_high_range t x : - represents t x -> 0 <= x / 2^k < 2^k. - Proof. - destruct 1 as [? [? ?] ]; intros. - auto using Z.div_lt_upper_bound with zarith. - Qed. - Hint Resolve represents_length represents_low_range represents_high_range. - - Lemma represents_range t x : - represents t x -> 0 <= x < 2^k*2^k. - Proof. cbv [represents]; tauto. Qed. - - Lemma represents_id x : - 0 <= x < 2^k * 2^k -> - represents [x mod 2^k; x / 2^k] x. - Proof. - intros; cbv [represents]; autorewrite with cancel_pair. - Z.rewrite_mod_small; tauto. - Qed. - - Local Ltac push_rep := - repeat match goal with - | H : represents ?t ?x |- _ => unique pose proof (represents_low_range _ _ H) - | H : represents ?t ?x |- _ => unique pose proof (represents_high_range _ _ H) - | H : represents ?t ?x |- _ => rewrite (represents_low t x) in * by assumption - | H : represents ?t ?x |- _ => rewrite (represents_high t x) in * by assumption - end. - - Definition shiftr (t : list Z) (n : Z) : list Z := - [Z.rshi (2^k) (high t) (low t) n; Z.rshi (2^k) 0 (high t) n]. - - Lemma shiftr_represents a i x : - represents a x -> - 0 <= i <= k -> - represents (shiftr a i) (x / 2 ^ i). - Proof. - cbv [shiftr]; intros; push_rep. - match goal with H : _ |- _ => pose proof (represents_range _ _ H) end. - assert (0 < 2 ^ i) by auto with zarith. - assert (x < 2 ^ i * 2 ^ k * 2 ^ k) by nia. - assert (0 <= x / 2 ^ k / 2 ^ i < 2 ^ k) by - (split; Z.zero_bounds; auto using Z.div_lt_upper_bound with zarith). - repeat match goal with - | _ => rewrite Z.rshi_correct by auto with zarith - | _ => rewrite <-Z.div_mod''' by auto with zarith - | _ => progress autorewrite with zsimplify_fast - | _ => progress Z.rewrite_mod_small - | |- context [represents [(?a / ?c) mod ?b; ?a / ?b / ?c] ] => - rewrite (Z.div_div_comm a b c) by auto with zarith - | _ => solve [auto using represents_id, Z.div_lt_upper_bound with zarith lia] - end. + rewrite xt_correct, q1_correct, q3_correct by auto with lia. + erewrite r_correct by eauto with lia. rewrite final_reduce_correct; [ | ]. + { rewrite barrett_reduction_small_strong with (b:=b) (k:=k) (m:=mu) (offset:=1) by + auto using Z.lt_gt with zarith. + break_innermost_match; Z.ltb_to_lt; (lia || reflexivity). } + { split; [ Z.zero_bounds; apply qn_small | apply r_small_strong]; + auto using Z.lt_gt with zarith. } Qed. - Context (Hw : forall i, w i = (2 ^ k) ^ Z.of_nat i). - Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r. - - Definition wideadd t1 t2 := fst (Rows.add w 2 t1 t2). - (* TODO: use this definition once issue #352 is resolved *) - (* Definition widesub t1 t2 := fst (Rows.sub w 2 t1 t2). *) - Definition widesub (t1 t2 : list Z) := - let t1_0 := hd 0 t1 in - let t1_1 := hd 0 (tl t1) in - let t2_0 := hd 0 t2 in - let t2_1 := hd 0 (tl t2) in - dlet_nd x0 := Z.sub_get_borrow_full (2^k) t1_0 t2_0 in - dlet_nd x1 := Z.sub_with_get_borrow_full (2^k) (snd x0) t1_1 t2_1 in - [fst x0; fst x1]. - Definition widemul := BaseConversion.widemul_inlined k n nout. - - Lemma partition_represents x : - 0 <= x < 2^k*2^k -> - represents (Partition.partition w 2 x) x. - Proof. - intros; cbn. change_weight. - Z.rewrite_mod_small. - autorewrite with zsimplify_fast. - auto using represents_id. - Qed. - - Lemma eval_represents t x : - represents t x -> Positional.eval w 2 t = x. - Proof. - intros; rewrite (represents_eq t x) by assumption. - cbn. change_weight; push_rep. - autorewrite with zsimplify. reflexivity. - Qed. - - Ltac wide_op partitions_pf := - repeat match goal with - | _ => rewrite partitions_pf by eauto - | _ => rewrite partitions_pf by auto with zarith - | _ => erewrite eval_represents by eauto - | _ => solve [auto using partition_represents, represents_id] - end. - - Lemma wideadd_represents t1 t2 x y : - represents t1 x -> - represents t2 y -> - 0 <= x + y < 2^k*2^k -> - represents (wideadd t1 t2) (x + y). - Proof. intros; cbv [wideadd]. wide_op Rows.add_partitions. Qed. - - Lemma widesub_represents t1 t2 x y : - represents t1 x -> - represents t2 y -> - 0 <= x - y < 2^k*2^k -> - represents (widesub t1 t2) (x - y). - Proof. - intros; cbv [widesub Let_In]. - rewrite (represents_eq t1 x) by assumption. - rewrite (represents_eq t2 y) by assumption. - cbn [hd tl]. - autorewrite with to_div_mod. - pull_Zmod. - match goal with |- represents [?m; ?d] ?x => - replace d with (x / 2 ^ k); [solve [auto using represents_id] |] end. - rewrite <-(Z.mod_small ((x - y) / 2^k) (2^k)) by (split; try apply Z.div_lt_upper_bound; Z.zero_bounds). - f_equal. - transitivity ((x mod 2^k - y mod 2^k + 2^k * (x / 2 ^ k) - 2^k * (y / 2^k)) / 2^k). { - rewrite (Z.div_mod x (2^k)) at 1 by auto using Z.pow_nonzero with omega. - rewrite (Z.div_mod y (2^k)) at 1 by auto using Z.pow_nonzero with omega. - f_equal. ring. } - autorewrite with zsimplify. - ring. - Qed. - (* Works with Rows.sub-based widesub definition - Proof. intros; cbv [widesub]. wide_op Rows.sub_partitions. Qed. - *) - - (* TODO: MOVE Equivlalent Keys decl to Arithmetic? *) - Declare Equivalent Keys BaseConversion.widemul BaseConversion.widemul_inlined. - Lemma widemul_represents x y : - 0 <= x < 2^k -> - 0 <= y < 2^k -> - represents (widemul x y) (x * y). - Proof. - intros; cbv [widemul]. - assert (0 <= x * y < 2^k*2^k) by auto with zarith. - wide_op BaseConversion.widemul_correct. - Qed. - - Definition mul_high (a b : list Z) a0b1 : list Z := - dlet_nd a0b0 := widemul (low a) (low b) in - dlet_nd ab := wideadd [high a0b0; high b] [low b; 0] in - wideadd ab [a0b1; 0]. - - Lemma mul_high_idea d a b a0 a1 b0 b1 : - d <> 0 -> - a = d * a1 + a0 -> - b = d * b1 + b0 -> - (a * b) / d = a0 * b0 / d + d * a1 * b1 + a1 * b0 + a0 * b1. - Proof. - intros. subst a b. autorewrite with push_Zmul. - ring_simplify_subterms. rewrite Z.pow_2_r. - rewrite Z.div_add_exact by (push_Zmod; autorewrite with zsimplify; omega). - repeat match goal with - | |- context [d * ?a * ?b * ?c] => - replace (d * a * b * c) with (a * b * c * d) by ring - | |- context [d * ?a * ?b] => - replace (d * a * b) with (a * b * d) by ring - end. - rewrite !Z.div_add by omega. - autorewrite with zsimplify. - rewrite (Z.mul_comm a0 b0). - ring_simplify. ring. - Qed. + End Generic. - Lemma represents_trans t x y: - represents t y -> y = x -> - represents t x. - Proof. congruence. Qed. + (* Non-standard implementation -- uses specialized instructions and b=2 *) + Module Fancy. + Section Fancy. + Context (M mu width k : Z) + (sz : nat) (sz_nz : sz <> 0%nat) + (width_ok : 1 < width) + (k_pos : 0 < k) (* this can be inferred from other arguments but is easier to put here for tactics *) + (k_eq : k = width * Z.of_nat sz). + (* sz = 1, width = k = 256 *) + Local Notation w := (UniformWeight.uweight width). Local Notation eval := (Positional.eval w). + Context (mut Mt : list Z) (mut_correct : mut = partition w (sz+1) mu) (Mt_correct : Mt = partition w sz M). + Context (mu_eq : mu = 2 ^ (2 * k) / M) (M_range : 2^(k-1) < M < 2^k). + + Local Lemma wprops : @weight_properties w. Proof. apply UniformWeight.uwprops; auto with lia. Qed. + Local Hint Resolve wprops. + Hint Rewrite mut_correct Mt_correct : pull_partition. + + Lemma w_eq_2k : w sz = 2^k. Admitted. (* TODO *) + Lemma mu_range : 2^k <= mu < 2^(k+1). + Admitted. (* TODO *) + Lemma M_range' : 0 <= M < w sz. (* more convenient form, especially for mod_small *) + Proof. assert (0 <= 2 ^ (k-1)) by Z.zero_bounds. pose proof w_eq_2k; lia. Qed. + + Definition shiftr' (m : nat) (t : list Z) (n : Z) : list Z := + map (fun i => Z.rshi (2^width) (nth_default 0 t (S i)) (nth_default 0 t i) n) (seq 0 m). + + Definition shiftr (m : nat) (t : list Z) (n : Z) : list Z := + (* if width <= n, drop limbs first *) + if dec (width <= n) + then shiftr' m (skipn (Z.to_nat (n / width)) t) (n mod width) + else shiftr' m t n. + + Definition wideadd t1 t2 := fst (Rows.add w (sz*2) t1 t2). + Definition widesub t1 t2 := fst (Rows.sub w (sz*2) t1 t2). + Definition widemul := BaseConversion.widemul_inlined width sz 2. + (* widemul_inlined takes the following argument order : (width of limbs in input) (# limbs in input) (# parts to split each limb into before multiplying) *) + + Definition fill (n : nat) (a : list Z) := a ++ Positional.zeros (n - length a). + Definition low : list Z -> list Z := firstn sz. + Definition high : list Z -> list Z := skipn sz. + Definition mul_high (a b : list Z) a0b1 : list Z := + dlet_nd a0b0 := widemul (low a) (low b) in + dlet_nd ab := wideadd (high a0b0 ++ high b) (fill (sz*2) (low b)) in + wideadd ab a0b1. + + (* select based on the most significant bit of xHigh *) + Definition muSelect xt := + let xHigh := nth_default 0 xt (sz*2 - 1) in + Positional.select (Z.cc_m (2 ^ width) xHigh) (repeat 0 sz) (low mut). + + Definition cond_sub (a y : list Z) : list Z := + let cond := Z.cc_l (nth_default 0 (high a) 0) in (* a[k] = least significant bit of (high a) *) + dlet_nd maybe_y := Positional.select cond (repeat 0 sz) y in + dlet_nd diff := Rows.sub w sz (low a) maybe_y in (* (a mod (w sz) - y) mod (w sz)) = (a - y) mod (w sz); since we know a - y is < w sz this is okay by mod_small *) + fst diff. + + Definition cond_subM x := + if Nat.eq_dec sz 1 + then [Z.add_modulo (nth_default 0 x 0) 0 M] (* use the special instruction if we can *) + else Rows.conditional_sub w sz x Mt. + + Definition q1 (xt : list Z) := shiftr (sz+1) xt (k - 1). + + Definition q3 (xt q1t : list Z) := + dlet_nd muSelect := muSelect xt in (* make sure muSelect is not inlined in the output *) + dlet_nd twoq := mul_high (fill (sz*2) mut) (fill (sz*2) q1t) (fill (sz*2) muSelect) in + shiftr (sz+1) twoq 1. + + Definition r (xt q3t : list Z) := + dlet_nd r2 := widemul (low q3t) Mt in + widesub xt r2. + + Definition final_reduce (rt : list Z) := + dlet_nd rt := cond_sub rt Mt in + cond_subM rt. + + Section Proofs. + + (* TODO: move *) + Lemma divides_pow_le b n m : 0 <= n <= m -> (b ^ n | b ^ m). + Proof. + intros. replace m with (n + (m - n)) by ring. + rewrite Z.pow_add_r by lia. + apply Z.divide_factor_l. + Qed. - Lemma represents_add x y : - 0 <= x < 2 ^ k -> - 0 <= y < 2 ^ k -> - represents [x;y] (x + 2^k*y). - Proof. - intros; cbv [represents]; autorewrite with zsimplify. - repeat split; (reflexivity || nia). - Qed. + (* TODO: move to UW or Weight *) + Lemma mod_mod_weight a i j : + (i <= j)%nat -> (a mod (w j)) mod (w i) = a mod (w i). + Proof. + intros. rewrite <-Znumtheory.Zmod_div_mod; auto; [ ]. + rewrite !UniformWeight.uweight_eq_alt'. + apply divides_pow_le. nia. + Qed. - Lemma represents_small x : - 0 <= x < 2^k -> - represents [x; 0] x. - Proof. - intros. - eapply represents_trans. - { eauto using represents_add with zarith. } - { ring. } - Qed. + Lemma shiftr'_correct m n : + forall t tn, + (m <= tn)%nat -> 0 <= t < w tn -> 0 <= n < width -> + shiftr' m (partition w tn t) n = partition w m (t / 2 ^ n). + Proof. + cbv [shiftr']. induction m; intros; [ reflexivity | ]. + rewrite !partition_step, seq_snoc. + autorewrite with distr_length natsimplify push_map push_nth_default. + rewrite IHm by auto with zarith. + rewrite Z.rshi_correct, UniformWeight.uweight_S by auto with zarith. + rewrite <-Z.mod_pull_div by auto with zarith. + destruct (Nat.eq_dec (S m) tn); [subst tn | ]; rewrite !nth_default_partition by omega. + { rewrite nth_default_out_of_bounds by distr_length. + autorewrite with zsimplify. Z.rewrite_mod_small. + rewrite Z.div_div_comm by auto with zarith; reflexivity. } + { (* oh god the worst *) + rewrite !UniformWeight.uweight_S by auto with zarith. + rewrite <-!Z.mod_pull_div by auto with zarith. + rewrite Z.mod_pull_div with (b:=2^n) by auto with zarith. + rewrite <-Z.div_div by auto with zarith. + rewrite (Z.div_div_comm _ (2^width) (w m)) by auto with zarith. + rewrite Z.mod_pull_div with (b:=2^width) by auto with zarith. + rewrite Z.mul_div_eq'; auto using Z.lt_gt with zarith. + rewrite Z.rem_mul_r with (b:=2^width) (c:=2^width) by auto with zarith. + autorewrite with zsimplify. + rewrite <-Z.rem_mul_r by auto with zarith. + rewrite !Z.mod_pull_div by auto with zarith. + rewrite <-Znumtheory.Zmod_div_mod by + (try solve [Z.zero_bounds]; + rewrite <-!Z.pow_add_r by auto with zarith; + auto using Z.mul_divide_mono_r, divides_pow_le with zarith). + rewrite Z.div_div_comm by auto with zarith. + repeat (f_equal; try ring). } + Qed. + Lemma shiftr_correct m n : + forall t tn, + (Z.to_nat (n / width) <= tn)%nat -> + (m <= tn - Z.to_nat (n / width))%nat -> 0 <= t < w tn -> 0 <= n -> + shiftr m (partition w tn t) n = partition w m (t / 2 ^ n). + Proof. + cbv [shiftr]; intros. + break_innermost_match; [ | solve [auto using shiftr'_correct with zarith] ]. + pose proof (Z.mod_pos_bound n width ltac:(omega)). + assert (t / 2 ^ (n - n mod width) < w (tn - Z.to_nat (n / width))). + { apply Z.div_lt_upper_bound; [solve [Z.zero_bounds] | ]. + rewrite UniformWeight.uweight_eq_alt' in *. + rewrite <-Z.pow_add_r, Nat2Z.inj_sub, Z2Nat.id, <-Z.mul_div_eq by auto with zarith. + autorewrite with push_Zmul zsimplify. auto with zarith. } + repeat match goal with + | _ => progress rewrite ?UniformWeight.uweight_skipn_partition, ?UniformWeight.uweight_eq_alt' by auto with lia + | _ => rewrite Z2Nat.id by Z.zero_bounds + | _ => rewrite Z.mul_div_eq_full by auto with zarith + | _ => rewrite shiftr'_correct by auto with zarith + | _ => progress rewrite ?Z.div_div, <-?Z.pow_add_r by auto with zarith + end. + autorewrite with zsimplify. reflexivity. + Qed. + Hint Rewrite shiftr_correct using (solve [auto with lia]) : pull_partition. + + (* 2 ^ (k + 1) bits fit in sz + 1 limbs because we know 2^k bits fit in sz and 1 <= width *) + Lemma q1_correct x : + 0 <= x < w (sz * 2) -> + q1 (partition w (sz*2)%nat x) = partition w (sz+1)%nat (x / 2 ^ (k - 1)). + Proof. + cbv [q1]; intros. assert (1 <= Z.of_nat sz) by (destruct sz; lia). + assert (Z.to_nat ((k-1) / width) < sz)%nat. { + subst k. rewrite <-Z.add_opp_r. autorewrite with zsimplify. + apply Nat2Z.inj_lt. rewrite Z2Nat.id by lia. lia. } + assert (0 <= k - 1) by nia. + autorewrite with pull_partition. reflexivity. + Qed. - Lemma mul_high_represents a b x y a0b1 : - represents a x -> - represents b y -> - 2^k <= x < 2^(k+1) -> - 0 <= y < 2^(k+1) -> - a0b1 = x mod 2^k * (y / 2^k) -> - represents (mul_high a b a0b1) ((x * y) / 2^k). - Proof. - cbv [mul_high Let_In]; rewrite Z.pow_add_r, Z.pow_1_r by omega; intros. - assert (4 <= 2 ^ k) by (transitivity (Z.pow 2 2); auto with zarith). - assert (0 <= x * y / 2^k < 2^k*2^k) by (Z.div_mod_to_quot_rem_in_goal; nia). - - rewrite mul_high_idea with (a:=x) (b:=y) (a0 := low a) (a1 := high a) (b0 := low b) (b1 := high b) in * - by (push_rep; Z.div_mod_to_quot_rem_in_goal; lia). - - push_rep. subst a0b1. - assert (y / 2 ^ k < 2) by (apply Z.div_lt_upper_bound; omega). - replace (x / 2 ^ k) with 1 in * by (rewrite Z.div_between_1; lia). - autorewrite with zsimplify_fast in *. - - eapply represents_trans. - { repeat (apply wideadd_represents; - [ | apply represents_small; Z.div_mod_to_quot_rem_in_goal; nia| ]). - erewrite represents_high; [ | apply widemul_represents; solve [ auto with zarith ] ]. - { apply represents_add; try reflexivity; solve [auto with zarith]. } - { match goal with H : 0 <= ?x + ?y < ?z |- 0 <= ?x < ?z => - split; [ solve [Z.zero_bounds] | ]; - eapply Z.le_lt_trans with (m:= x + y); nia - end. } - { omega. } } - { ring. } - Qed. + Lemma low_correct n a : (sz <= n)%nat -> low (partition w n a) = partition w sz a. + Admitted. + Lemma high_correct a : high (partition w (sz*2) a) = partition w sz (a / w sz). + Admitted. + Lemma fill_correct n m a : + (n <= m)%nat -> + fill m (partition w n a) = partition w m (a mod w n). + Admitted. + Hint Rewrite low_correct high_correct fill_correct using lia : pull_partition. + + (* TODO: move *) + Lemma partition_app n m a b : + partition w n a ++ partition w m b = partition w (n+m) (a mod w n + b * w n). + Admitted. + + Lemma wideadd_correct a b : + wideadd (partition w (sz*2) a) (partition w (sz*2) b) = partition w (sz*2) (a + b). + Proof. + cbv [wideadd]. rewrite Rows.add_partitions by (distr_length; auto). + autorewrite with push_eval. + apply partition_eq_mod; auto with zarith. + Qed. + Lemma widesub_correct a b : + widesub (partition w (sz*2) a) (partition w (sz*2) b) = partition w (sz*2) (a - b). + Proof. + cbv [widesub]. rewrite Rows.sub_partitions by (distr_length; auto). + autorewrite with push_eval. + apply partition_eq_mod; auto with zarith. + Qed. + Lemma widemul_correct a b : + widemul (partition w sz a) (partition w sz b) = partition w (sz*2) ((a mod w sz) * (b mod w sz)). + Proof. + cbv [widemul]. rewrite BaseConversion.widemul_inlined_correct; (distr_length; auto). + autorewrite with push_eval. reflexivity. + Qed. + Hint Rewrite widemul_correct widesub_correct wideadd_correct using lia : pull_partition. + + Lemma mul_high_idea d a b a0 a1 b0 b1 : + d <> 0 -> + a = d * a1 + a0 -> + b = d * b1 + b0 -> + (a * b) / d = a0 * b0 / d + d * a1 * b1 + a1 * b0 + a0 * b1. + Proof. + intros. subst a b. autorewrite with push_Zmul. + ring_simplify_subterms. rewrite Z.pow_2_r. + rewrite Z.div_add_exact by (push_Zmod; autorewrite with zsimplify; omega). + repeat match goal with + | |- context [d * ?a * ?b * ?c] => + replace (d * a * b * c) with (a * b * c * d) by ring + | |- context [d * ?a * ?b] => + replace (d * a * b) with (a * b * d) by ring + end. + rewrite !Z.div_add by omega. + autorewrite with zsimplify. + rewrite (Z.mul_comm a0 b0). + ring_simplify. ring. + Qed. - Definition cond_sub1 (a : list Z) y : Z := - dlet_nd maybe_y := Z.zselect (Z.cc_l (high a)) 0 y in - dlet_nd diff := Z.sub_get_borrow_full (2^k) (low a) maybe_y in - fst diff. + (* TODO: move these? *) + Local Hint Resolve Z.lt_gt : zarith. + Hint Rewrite Z.div_add' using solve [auto with zarith] : zsimplify. + + Lemma mul_high_correct a b + (Ha : a / w sz = 1) + a0b1 (Ha0b1 : a0b1 = a mod w sz * (b / w sz)) : + mul_high (partition w (sz*2) a) (partition w (sz*2) b) (partition w (sz*2) a0b1) = + partition w (sz*2) (a * b / w sz). + Proof. + cbv [mul_high Let_In]. + erewrite mul_high_idea by auto using Z.div_mod with zarith. + repeat match goal with + | _ => progress autorewrite with pull_partition + | _ => progress rewrite ?Ha, ?Ha0b1 + | _ => rewrite partition_app by auto; + replace (sz+sz)%nat with (sz*2)%nat by lia + | _ => rewrite Z.mod_pull_div by auto with zarith + | _ => progress Z.rewrite_mod_small + | _ => f_equal; ring + end. + Qed. - Lemma cc_l_only_bit : forall x s, 0 <= x < 2 * s -> Z.cc_l (x / s) = 0 <-> x < s. - Proof. - cbv [Z.cc_l]; intros. - rewrite Z.div_between_0_if by omega. - break_match; Z.ltb_to_lt; Z.rewrite_mod_small; omega. - Qed. + Lemma muSelect_correct x : + 0 <= x < w (sz * 2) -> + muSelect (partition w (sz*2) x) = partition w (sz+1) (mu mod (w sz) * (x / 2 ^ (k - 1) / (w sz))). + Admitted. + Hint Rewrite muSelect_correct using lia : pull_partition. + + (* TODO: move *) + Lemma pow_pos_le a b : 0 < a -> 0 < b -> a <= a ^ b. + Proof. + intros; transitivity (a ^ 1). + { rewrite Z.pow_1_r; reflexivity. } + { apply Z.pow_le_mono; auto with zarith. } + Qed. + Hint Resolve pow_pos_le : zarith. + + Lemma q3_correct x (Hx : 0 <= x < w (sz * 2)) q1 (Hq1 : q1 = x / 2 ^ (k - 1)) : + q3 (partition w (sz*2) x) (partition w (sz+1) q1) = partition w (sz+1) ((mu*q1) / 2 ^ (k + 1)). + Proof. + cbv [q3 Let_In]. intros. + assert (0 <= q1 < w (sz+1)) by admit. + autorewrite with pull_partition. + rewrite ?fill_correct, ?muSelect_correct by lia. + rewrite <-Hq1. + rewrite mul_high_correct; [ | | ]. + 2: { (* mu mod w (sz + 1) / w sz = 1 *) + (* highest bit of mu is 1 -- do this by rewriting mod small using bounds on mu, then take axiom for highest bit *) + admit. } + 2: { + rewrite mod_mod_weight by lia. + push_Zmod. rewrite Z.mod_pull_div by auto with zarith. + assert (0 < w sz) by auto with zarith. + assert (w (sz+1) <= w (sz+1) * w sz) by (apply Z.le_mul_diag_r; auto with zarith). + Z.rewrite_mod_small. + autorewrite with natsimplify in *. + rewrite UniformWeight.uweight_S in * by auto with zarith. + assert (0 <= q1 / w sz < 2^width). + { split; [ solve [Z.zero_bounds] | ]. + apply Z.div_lt_upper_bound; auto with nia. } + pose proof (Z.mod_pos_bound mu (w sz) ltac:(auto)). + apply Z.mod_small. nia. } + rewrite shiftr_correct; auto with zarith. + 2: admit. (* (Z.to_nat (1 / width) <= sz * 2)%nat *) + 2: admit. (* (sz + 1 <= sz * 2 - Z.to_nat (1 / width))%nat *) + 2: admit. (* 0 <= mu mod w (sz + 1) * (q1 mod w (sz + 1)) / w sz < w (sz * 2) *) + (* TODO: see if things are easier using div_mod_to_quot_rem *) + Z.rewrite_mod_small. + pose proof mu_range. + replace (2 ^ (k+1)) with (w sz * 2) in * by admit. + assert (0 < 2 ^ k) by (Z.zero_bounds; nia). + assert (2 * w sz <= w (sz + 1)). + { autorewrite with natsimplify. rewrite UniformWeight.uweight_S by auto with zarith. + apply Z.mul_le_mono_nonneg_r; auto with zarith. } + Z.rewrite_mod_small. + rewrite Z.div_div by auto with zarith. + repeat (f_equal; try ring). + Admitted. + + Lemma cond_sub_correct a b : + cond_sub (partition w (sz*2) a) (partition w sz b) + = partition w sz (if dec ((a / w sz) mod 2 = 0) + then a + else a - b). + Proof. + cbv [cond_sub Let_In Z.cc_l]. autorewrite with pull_partition. + rewrite nth_default_partition by lia. + rewrite weight_0 by auto. autorewrite with zsimplify_fast. + rewrite UniformWeight.uweight_eq_alt' with (n:=1%nat). autorewrite with push_Zof_nat zsimplify. + rewrite <-Znumtheory.Zmod_div_mod by auto using Zpow_facts.Zpower_divide with zarith. + rewrite Positional.select_eq with (n:=sz) by (distr_length; apply w). + rewrite Rows.sub_partitions by (break_innermost_match; distr_length; auto). + break_innermost_match; autorewrite with push_eval zsimplify_fast; + apply partition_eq_mod; auto with zarith. + Qed. + Hint Rewrite cond_sub_correct : pull_partition. + Lemma cond_subM_correct a : + cond_subM (partition w sz a) + = partition w sz (if dec (a mod w sz < M) + then a + else a - M). + Proof. + Admitted. + Hint Rewrite cond_subM_correct : pull_partition. + + (* TODO: move this with barrett Z lemmas & prove *) + Lemma q3_range x : + 0 <= x < 2 ^ (2 * k) -> + 0 <= (mu * (x / 2 ^ (k-1))) / 2 ^ (k+1) < 2^k. + Proof using Type. + Admitted. - Lemma cond_sub1_correct a x y : - represents a x -> - 0 <= x < 2 * y -> - 0 <= y < 2 ^ k -> - cond_sub1 a y = if (x + q1 = x / 2 ^ (k - 1) -> + q3 = (mu*q1) / 2 ^ (k + 1) -> + r (partition w (sz*2) x) (partition w (sz+1) q3) = partition w (sz*2) (x - q3 * M). + Proof. + cbv [r Let_In]; intros. pose proof w_eq_22k. + autorewrite with pull_partition. + assert (0 <= q3 < w sz). + { subst q3 q1. rewrite w_eq_2k. apply q3_range; lia. } + pose proof M_range'. Z.rewrite_mod_small. + autorewrite with zsimplify. reflexivity. + Qed. - Section Defn. - Context (xLow xHigh : Z) (xLow_bounds : 0 <= xLow < 2^k) (xHigh_bounds : 0 <= xHigh < M). - Let xt := [xLow; xHigh]. - Let x := xLow + 2^k * xHigh. - - Lemma x_rep : represents xt x. - Proof. cbv [represents]; subst xt x; autorewrite with cancel_pair zsimplify; repeat split; nia. Qed. - - Lemma x_bounds : 0 <= x < M * 2 ^ k. - Proof. subst x; nia. Qed. - - Definition muSelect := Z.zselect (Z.cc_m (2 ^ k) xHigh) 0 muLow. - - Local Hint Resolve Z.div_nonneg Z.div_lt_upper_bound. - Local Hint Resolve shiftr_represents mul_high_represents widemul_represents widesub_represents - cond_sub1_correct cond_sub2_correct represents_low represents_add. - - Lemma muSelect_correct : - muSelect = (2 ^ (2 * k) / M) mod 2 ^ k * ((x / 2 ^ (k - 1)) / 2 ^ k). - Proof. - (* assertions to help arith tactics *) - pose proof x_bounds. - assert (2^k * M < 2 ^ (2*k)) by (rewrite <-Z.add_diag, Z.pow_add_r; nia). - assert (0 <= x / (2 ^ k * (2 ^ k / 2)) < 2) by (Z.div_mod_to_quot_rem_in_goal; auto with nia). - assert (0 < 2 ^ k / 2) by Z.zero_bounds. - assert (2 ^ (k - 1) <> 0) by auto with zarith. - assert (2 < 2 ^ k) by (eapply Z.le_lt_trans with (m:=2 ^ 1); auto with zarith). - - cbv [muSelect]. rewrite <-muLow_eq. - rewrite Z.zselect_correct, Z.cc_m_eq by auto with zarith. - replace xHigh with (x / 2^k) by (subst x; autorewrite with zsimplify; lia). - autorewrite with pull_Zdiv push_Zpow. - rewrite (Z.mul_comm (2 ^ k / 2)). - break_match; [ ring | ]. - match goal with H : 0 <= ?x < 2, H' : ?x <> 0 |- _ => replace x with 1 by omega end. - autorewrite with zsimplify; reflexivity. - Qed. + Lemma final_reduce_correct r: + 0 <= r < 2 * M -> + final_reduce (partition w (sz*2) r) = partition w sz (if dec (r < M) then r else r - M). + Proof. + intros; pose proof M_range'. + cbv [final_reduce Let_In]; intros; autorewrite with pull_partition. + apply partition_eq_mod; auto with zarith. + replace (2 ^ (k + 1)) with (2 * w sz) in * + by (rewrite Z.pow_add_r by auto with zarith; autorewrite with zsimplify; rewrite <-w_eq_2k; lia). + pose proof (Z.mod_pos_bound r (w sz) ltac:(auto with zarith)). + remember ((r / w sz) mod 2) as rk. + assert (rk = 1 -> r > M) by (subst rk; Z.rewrite_mod_small; intros; + rewrite (Z.div_mod r (w sz)) by auto with zarith; nia). + replace r with (w sz * rk + r mod w sz) by + (subst rk; Z.rewrite_mod_small; rewrite <-Z.div_mod; auto with zarith). + subst rk; rewrite Zmod_odd in *. + repeat match goal with + | _ => lia + | H : _ |- _ => rewrite Z.mod_small in H by auto with zarith lia + | _ => progress (push_Zmod; pull_Zmod); autorewrite with zsimplify_fast + | _ => break_innermost_match_step + end. + Qed. + End Proofs. + + Section Def. + Context (sz_eq_1 : sz = 1%nat). (* this is needed to get rid of branches in the templates; a different definition would be needed for sizes other than 1, but would be able to use the same proofs. *) + Local Hint Resolve q1_correct q3_correct r_correct final_reduce_correct. + + (* muselect relies on an initially-set flag, so pull it out of q3 *) + Definition fancy_reduce_muSelect_first xt := + dlet_nd muSelect := muSelect xt in + dlet_nd q1t := q1 xt in + dlet_nd twoq := mul_high (fill (sz * 2) mut) (fill (sz * 2) q1t) (fill (sz * 2) muSelect) in + dlet_nd q3t := shiftr (sz+1) twoq 1 in + dlet_nd rt := r xt q3t in + final_reduce rt. + + Lemma fancy_reduce_muSelect_first_correct x : + 0 <= x < M * 2^k -> + 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu -> + fancy_reduce_muSelect_first (partition w (sz*2) x) = partition w sz (x mod M). + Proof. + intros. pose proof w_eq_22k. + erewrite <-reduce_correct with (b:=2) (k:=k) (mu:=mu) by + (eauto with nia; intros; try rewrite q3'_correct; eauto with nia ). + reflexivity. + Qed. + + + (* + Definition fancy_reduce xLow xHigh := + dlet_nd muSelect : list Z := Positional.select (Z.cc_m (2 ^ width) xHigh) (repeat 0 sz) (low mut) in + dlet_nd q1t := [Z.rshi (2 ^ width) xHigh xLow (k - 1); Z.rshi (2 ^ width) 0 xHigh (k - 1)] in + dlet_nd a0b0 : list Z := widemul (low mut) (low q1t) in + dlet_nd ab : list Z := wideadd [high a0b0; high q1t] [low q1t; 0] in + dlet_nd a0 := wideadd ab [muSelect;0] in + dlet_nd q3t := [Z.rshi (2 ^ width) (nth_default 0 a0 1) (nth_default 0 a0 0) 1; Z.rshi (2 ^ width) (nth_default 0 a0 2) (nth_default 0 a0 1) 1] in + dlet_nd r2 : list Z := widemul (low q3t) Mt in + dlet_nd rt : list Z := widesub [xLow; xHigh] r2 in + dlet_nd maybe_M : list Z := Positional.select (Z.cc_l (nth_default 0 (high rt) 0)) (repeat 0 sz) Mt in + dlet_nd diff : list Z * Z := Rows.sub w sz (low rt) maybe_M in + dlet_nd rt0 :=fst diff in + Z.add_modulo (nth_default 0 rt0 0) 0 M. +*) + (* This simplifies the shiftr inside q3, specializing it to the # limbs so we don't end up with maps in the final expression *) + Derive q3' + SuchThat (forall xt q1t, q3' xt q1t = q3 xt q1t) + As q3'_correct. + Proof. + intros. cbv [q3]. rewrite sz_eq_1. autorewrite with natsimplify. + cbv [shiftr shiftr' seq]. break_match; try lia; [ ]. + rewrite Proper_Let_In_nd_changebody; [ | reflexivity | repeat intro ]. + 2 : apply Proper_Let_In_nd_changebody; + [ reflexivity | repeat intro; autorewrite with push_map push_nth_default; reflexivity ]. + subst q3'; reflexivity. + Qed. - Lemma mu_rep : represents [muLow; 1] (2 ^ (2 * k) / M). - Proof. rewrite <-muLow_eq. eapply represents_trans; auto with zarith. Qed. + Derive fancy_reduce' + SuchThat ( + forall x, + 0 <= x < M * 2^k -> + 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu -> + fancy_reduce' (partition w (sz*2) x) = partition w sz (x mod M)) + As fancy_reduce'_correct. + Proof. + intros. assert (k = width) as width_eq_k by nia. + erewrite <-fancy_reduce_muSelect_first_correct by nia. + cbv [fancy_reduce_muSelect_first q1 q3 shiftr final_reduce cond_subM]. + break_match; try solve [exfalso; lia]. + match goal with |- ?g ?x = ?rhs => + let f := (match (eval pattern x in rhs) with ?f _ => f end) in + assert (f = g); subst fancy_reduce'; reflexivity + end. + Qed. - Derive barrett_reduce - SuchThat (barrett_reduce = x mod M) - As barrett_reduce_correct. - Proof. - erewrite <-reduce_correct with (rep:=represents) (muSelect:=muSelect) (k0:=k) (mut:=[muLow;1]) (xt0:=xt) - by (auto using x_bounds, muSelect_correct, x_rep, mu_rep; omega). - subst barrett_reduce. reflexivity. - Qed. - End Defn. - End BarrettReduction. + (* TODO: Can't deal with input by shifting and stuff! *) + Derive fancy_reduce + SuchThat ( + forall xLow xHigh, + 0 <= xLow < 2 ^ k -> + 0 <= xHigh < M -> + 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu -> + fancy_reduce xLow xHigh = (xLow + 2^k * xHigh) mod M) + As fancy_reduce_correct. + Proof. + intros. pose proof w_eq_22k. assert (k = width) as width_eq_k by nia. + replace ((xLow + 2^k * xHigh) mod M) with (hd 0 (partition w sz ((xLow + 2^k * xHigh) mod M))). + 2 : { rewrite sz_eq_1. rewrite width_eq_k in *. cbv [Partition.partition map seq hd]. + rewrite !UniformWeight.uweight_S, !weight_0 by auto with zarith lia. + autorewrite with zsimplify. + assert (0 < 2 ^ (width - 1)) by Z.zero_bounds. + pose proof (Z.mod_pos_bound (xLow + 2^width * xHigh) M ltac:(lia)). + autorewrite with zsimplify. reflexivity. } + rewrite <-fancy_reduce'_correct by nia. + replace (partition w (sz*2) (xLow + 2^k * xHigh)) with [xLow; xHigh]. + 2 : { replace (sz * 2)%nat with 2%nat by (subst sz; lia). + rewrite width_eq_k in *. cbv [Partition.partition map seq]. + rewrite !UniformWeight.uweight_S, !weight_0 by auto with zarith lia. + autorewrite with zsimplify. + rewrite <-Z.mod_pull_div by Z.zero_bounds. + autorewrite with zsimplify. reflexivity. } + cbv [fancy_reduce' reduce q1 q3 shiftr final_reduce cond_subM]. + autorewrite with natsimplify. (* TODO: maybe need to get rid of hd 0? *) + cbv [shiftr' seq]. autorewrite with push_map push_nth_default. + subst fancy_reduce; reflexivity. + Qed. + (* TODO: decide which version to use + Derive fancy_reduce + SuchThat ( + forall x, + 0 <= x < M * 2 ^ k -> + 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu -> + fancy_reduce (partition w (sz*2) x) = partition w sz (x mod M)) + As fancy_reduce_correct. + Proof. + intros. pose proof w_eq_22k. + erewrite <-reduce_correct with (b:=2) (k:=k) (mu:=mu) by eauto with lia. + assert (width = k) by nia. + cbv [reduce q1 q3 shiftr final_reduce cond_subM]. + break_match; try solve [exfalso; lia]. + match goal with |- ?g ?x = ?rhs => + let f := (match (eval pattern x in rhs) with ?f _ => f end) in + assert (f = g); subst fancy_reduce; reflexivity + end. + Qed. + *) + End Def. + End Fancy. + End Fancy. End BarrettReduction. Module MontgomeryReduction. @@ -5454,14 +5473,13 @@ Module MontgomeryReduction. Context (Zlog2R : Z) . Let w : nat -> Z := weight Zlog2R 1. Context (n:nat) (Hn_nz: n <> 0%nat) (n_good : Zlog2R mod Z.of_nat n = 0). - Context (R_big_enough : n <= Zlog2R) + Context (R_big_enough : 2 <= Zlog2R) (R_two_pow : 2^Zlog2R = R). Let w_mul : nat -> Z := weight (Zlog2R / n) 1. - Context (nout : nat) (Hnout : nout = 2%nat). Definition montred' (lo hi : Z) := - dlet_nd y := nth_default 0 (BaseConversion.widemul_inlined Zlog2R n nout lo N') 0 in - dlet_nd t1_t2 := (BaseConversion.widemul_inlined_reverse Zlog2R n nout N y) in + dlet_nd y := nth_default 0 (BaseConversion.widemul_inlined Zlog2R 1 2 [lo] [N']) 0 in + dlet_nd t1_t2 := (BaseConversion.widemul_inlined_reverse Zlog2R 1 2 [N] [y]) in dlet_nd sum_carry := Rows.add (weight Zlog2R 1) 2 [lo;hi] t1_t2 in dlet_nd y' := Z.zselect (snd sum_carry) 0 N in dlet_nd lo''_carry := Z.sub_get_borrow_full R (nth_default 0 (fst sum_carry) 1) y' in @@ -5489,10 +5507,15 @@ Module MontgomeryReduction. Local Lemma eval2 x y : Positional.eval w 2 [x;y] = x + R * y. Proof. cbn. change_weight. ring. Qed. + Local Lemma eval1 x : Positional.eval w 1 [x] = x. + Proof. cbn. change_weight. ring. Qed. Hint Rewrite BaseConversion.widemul_inlined_reverse_correct BaseConversion.widemul_inlined_correct using (autorewrite with widemul push_nth_default; solve [solve_range]) : widemul. + (* TODO: move *) + Hint Rewrite Nat.mul_1_l : natsimplify. + Lemma montred'_eq lo hi T (HT_range: 0 <= T < R * N) (Hlo: lo = T mod R) (Hhi: hi = T / R): montred' lo hi = reduce_via_partial N R N' T. @@ -5501,30 +5524,34 @@ Module MontgomeryReduction. cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In]. rewrite Hlo, Hhi. assert (0 <= (T mod R) * N' < w 2) by (solve_range). - autorewrite with widemul. rewrite Rows.add_partitions, Rows.add_div by (distr_length; apply wprops; omega). - rewrite R_two_pow. - cbv [Partition.partition seq]. rewrite !eval2. - autorewrite with push_nth_default push_map. - autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct. + (* rewrite R_two_pow. *) + cbv [Partition.partition seq]. + repeat match goal with + | _ => progress rewrite ?eval1, ?eval2 + | _ => progress rewrite ?Z.zselect_correct, ?Z.add_modulo_correct + | _ => progress autorewrite with natsimplify push_nth_default push_map to_div_mod + end. change_weight. (* pull out value before last modular reduction *) match goal with |- (if (?n <=? ?x)%Z then ?x - ?n else ?x) = (if (?n <=? ?y) then ?y - ?n else ?y)%Z => let P := fresh "H" in assert (x = y) as P; [|rewrite P; reflexivity] end. + autorewrite with zsimplify. + Z.rewrite_mod_small. autorewrite with zsimplify. rewrite (Z.mul_comm (((T mod R) * N') mod R) N) in *. - break_match; try reflexivity; Z.ltb_to_lt; rewrite Z.div_small_iff in * by omega; - repeat match goal with - | _ => progress autorewrite with zsimplify_fast - | |- context [?x mod (R * R)] => - unique pose proof (Z.mod_pos_bound x (R * R)); - try rewrite (Z.mod_small x (R * R)) in * by Z.rewrite_mod_small_solver - | _ => omega - | _ => progress Z.rewrite_mod_small - end. + match goal with + |- context [(?x - (if dec (?a / ?b = 0) then 0 else ?y)) mod ?m + = if (?b <=? ?a) then (?x - ?y) mod ?m else ?x ] => + assert (a / b = 0 <-> a < b) by + (rewrite Z.div_between_0_if by (Z.div_mod_to_quot_rem; nia); + break_match; Z.ltb_to_lt; lia) + end. + break_match; Z.ltb_to_lt; try reflexivity; try lia; [ ]. + autorewrite with zsimplify_fast. Z.rewrite_mod_small. reflexivity. Qed. Lemma montred'_correct lo hi T (HT_range: 0 <= T < R * N) diff --git a/src/Arithmetic/BarrettReduction/Generalized.v b/src/Arithmetic/BarrettReduction/Generalized.v index c2885bc77..93aa9452f 100644 --- a/src/Arithmetic/BarrettReduction/Generalized.v +++ b/src/Arithmetic/BarrettReduction/Generalized.v @@ -218,6 +218,15 @@ Section barrett. autorewrite with push_Zmul zsimplify zstrip_div. auto with lia. Qed. + + Theorem barrett_reduction_small_strong + : a mod n = if r - 0 <= xHigh < M -> - BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 xLow xHigh = (xLow + 2 ^ machine_wordsize * xHigh) mod M. - Proof. - intros. - apply BarrettReduction.barrett_reduce_correct; cbv [machine_wordsize M muLow] in *; - try omega; - try match goal with - | |- context [weight] => intros; cbv [weight]; autorewrite with zsimplify; auto using Z.pow_mul_r with omega - end; lazy; try split; congruence. - Qed. - Definition barrett_red256_fancy' (xLow xHigh RegMuLow RegMod RegZero error : positive) := of_Expr 6%positive (make_consts [(RegMuLow, muLow); (RegMod, M); (RegZero, 0)]) @@ -136,7 +122,10 @@ Module Barrett256. | _ => econstructor end. } { cbn. cbv [muLow M]. - repeat (econstructor; [ solve [valid_expr_subgoal] | intros ]). + repeat (match goal with + | _ => eapply valid_LetInZZ + | _ => eapply valid_LetInZ + end; [ solve [valid_expr_subgoal] | intros ]). econstructor. valid_expr_subgoal. } { reflexivity. } Qed. diff --git a/src/PushButtonSynthesis/BarrettReduction.v b/src/PushButtonSynthesis/BarrettReduction.v index 265958c09..c0078b117 100644 --- a/src/PushButtonSynthesis/BarrettReduction.v +++ b/src/PushButtonSynthesis/BarrettReduction.v @@ -37,8 +37,7 @@ Local Set Keyed Unification. (* needed for making [autorewrite] fast, c.f. COQBU Local Opaque reified_barrett_red_gen. (* needed for making [autorewrite] not take a very long time *) Section rbarrett_red. - Context (k M : Z) (n nout : nat) - (machine_wordsize : Z). + Context (M machine_wordsize : Z). Let value_range := r[0 ~> (2^machine_wordsize - 1)%Z]%zrange. Let flag_range := r[0 ~> 1]%zrange. @@ -70,6 +69,31 @@ Section rbarrett_red. Qed. Local Hint Extern 1 => apply fancy_args_good: typeclass_instances. (* This is a kludge *) + Lemma mut_correct : + 0 < machine_wordsize -> + Partition.partition (UniformWeight.uweight machine_wordsize) (1 + 1) (muLow + 2 ^ machine_wordsize) = [muLow; 1]. + Proof. + intros; cbn. subst muLow. + assert (0 < 2^machine_wordsize) by ZeroBounds.Z.zero_bounds. + pose proof (Z.mod_pos_bound mu (2^machine_wordsize) ltac:(lia)). + rewrite !UniformWeight.uweight_S, weight_0; auto using UniformWeight.uwprops with lia. + autorewrite with zsimplify. + Modulo.push_Zmod. autorewrite with zsimplify. Modulo.pull_Zmod. + rewrite <-Modulo.Z.mod_pull_div by lia. + autorewrite with zsimplify. RewriteModSmall.Z.rewrite_mod_small. + reflexivity. + Qed. + Lemma Mt_correct : + 0 < machine_wordsize -> + 2^(machine_wordsize - 1) < M < 2^machine_wordsize -> + Partition.partition (UniformWeight.uweight machine_wordsize) 1 M = [M]. + Proof. + intros; cbn. assert (0 < 2^(machine_wordsize-1)) by ZeroBounds.Z.zero_bounds. + rewrite !UniformWeight.uweight_S, weight_0; auto using UniformWeight.uwprops with lia. + autorewrite with zsimplify. RewriteModSmall.Z.rewrite_mod_small. + reflexivity. + Qed. + (** Note: If you change the name or type signature of this function, you will need to update the code in CLI.v *) Definition check_args {T} (res : Pipeline.ErrorT T) @@ -78,19 +102,11 @@ Section rbarrett_red. (fun '(b, e) k => if b:bool then Error e else k) res [ - ((negb (2 <=? k))%Z, Pipeline.Value_not_ltZ "k < 2" 2 k); - ((n =? 0)%nat, Pipeline.Values_not_provably_distinctZ "n = 0" (Z.of_nat n) 0); - ((negb (0 0%nat - /\ Z.of_nat n <= k - /\ k = machine_wordsize - /\ n = 2%nat - /\ nout = 2%nat. + : 1 < machine_wordsize + /\ 2 ^ (machine_wordsize - 1) <= M < 2 ^ machine_wordsize + /\ muLow + 2 ^ machine_wordsize = (2 ^ 2) ^ machine_wordsize / M + /\ 2 ^ (machine_wordsize - 1) < M < 2 ^ machine_wordsize + /\ 2 * ((2 ^ 2) ^ machine_wordsize mod M) <= 2 ^ (machine_wordsize + 1) - (muLow + 2 ^ machine_wordsize). Proof using curve_good. clear -curve_good. cbv [check_args fold_right] in curve_good. break_innermost_match_hyps; try discriminate. rewrite Bool.negb_false_iff in *. Z.ltb_to_lt. - rewrite NPeano.Nat.eqb_neq in *. intros. repeat apply conj. { use_curve_good_t. } @@ -134,12 +143,6 @@ Section rbarrett_red. { use_curve_good_t. } { use_curve_good_t. } { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } Qed. Definition barrett_red @@ -148,7 +151,12 @@ Section rbarrett_red. fancy_args (* fancy *) possible_values (reified_barrett_red_gen - @ GallinaReify.Reify machine_wordsize @ GallinaReify.Reify M @ GallinaReify.Reify muLow @ GallinaReify.Reify 2%nat @ GallinaReify.Reify 2%nat) + @ GallinaReify.Reify M + @ GallinaReify.Reify machine_wordsize + @ GallinaReify.Reify machine_wordsize + @ GallinaReify.Reify 1%nat + @ GallinaReify.Reify [muLow;1] + @ GallinaReify.Reify [M]) (bound, (bound, tt)) bound. @@ -162,24 +170,27 @@ Section rbarrett_red. Local Ltac solve_barrett_red_preconditions := repeat first [ lia | assumption + | match goal with |- ?x = ?x => reflexivity end | apply use_curve_good | progress autorewrite with zsimplify | progress intros | progress cbv [weight] + | rewrite mut_correct + | rewrite Mt_correct | rewrite Z.pow_mul_r by lia ]. Local Strategy -100 [barrett_red]. (* needed for making Qed not take forever *) Lemma barrett_red_correct res (Hres : barrett_red = Success res) - : barrett_red_correct k M (expr.Interp (@ident.gen_interp cast_oor) res). - Proof using k M curve_good. + : barrett_red_correct machine_wordsize M (expr.Interp (@ident.gen_interp cast_oor) res). + Proof using M curve_good. cbv [barrett_red_correct]; intros. - assert (2 <= k) by apply use_curve_good. - rewrite <-barrett_reduce_correct with (muLow := muLow) (n:=n) (nout:=nout) by solve_barrett_red_preconditions. + assert (1 < machine_wordsize) by apply use_curve_good. + rewrite <-Fancy.fancy_reduce_correct with (mu := muLow + 2^machine_wordsize) (width:=machine_wordsize) (sz:=1%nat) (mut:=[muLow;1]) (Mt:=[M]) by solve_barrett_red_preconditions. prove_correctness' ltac:(fun _ => idtac) use_curve_good. - { congruence. } - { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. - subst k. rewrite Bool.andb_true_iff, !Z.leb_le. lia. } - { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. - subst k. rewrite Bool.andb_true_iff, !Z.leb_le. lia. } + { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. rewrite Bool.andb_true_iff, !Z.leb_le. lia. } + { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. rewrite Bool.andb_true_iff, !Z.leb_le. lia. } + { cbn. econstructor. } + { cbn. econstructor. } + { cbn. econstructor. } Qed. End rbarrett_red. diff --git a/src/PushButtonSynthesis/BarrettReductionReificationCache.v b/src/PushButtonSynthesis/BarrettReductionReificationCache.v index 4c538087e..265ada2d2 100644 --- a/src/PushButtonSynthesis/BarrettReductionReificationCache.v +++ b/src/PushButtonSynthesis/BarrettReductionReificationCache.v @@ -14,15 +14,15 @@ Local Set Keyed Unification. (* needed for making [autorewrite] fast, c.f. COQBU Module Export BarrettReduction. (* all the list operations from for_reification.ident *) Strategy 100 [length seq repeat combine map flat_map partition app rev fold_right update_nth nth_default ]. - Strategy -10 [barrett_reduce reduce]. + Strategy -10 [Fancy.fancy_reduce reduce]. Derive reified_barrett_red_gen - SuchThat (is_reification_of reified_barrett_red_gen barrett_reduce) + SuchThat (is_reification_of reified_barrett_red_gen Fancy.fancy_reduce) As reified_barrett_red_gen_correct. Proof. Time cache_reify (). Time Qed. Module Export ReifyHints. - Hint Extern 1 (_ = _) => apply_cached_reification barrett_reduce (proj1 reified_barrett_red_gen_correct) : reify_cache_gen. + Hint Extern 1 (_ = _) => apply_cached_reification Fancy.fancy_reduce (proj1 reified_barrett_red_gen_correct) : reify_cache_gen. Hint Immediate (proj2 reified_barrett_red_gen_correct) : wf_gen_cache. Hint Rewrite (proj1 reified_barrett_red_gen_correct) : interp_gen_cache. End ReifyHints. diff --git a/src/Rewriter.v b/src/Rewriter.v index 1aefc35e8..01d7614e3 100644 --- a/src/Rewriter.v +++ b/src/Rewriter.v @@ -2437,6 +2437,26 @@ Module Compilers. ; (forall c x, Z.abs c <= Z.abs max_const_val -> 'c * x = x * 'c) + + (* transform +- to + *) + ; (forall s y x, + Z.add_get_carry_full s x (- y) + = dlet vb := Z.sub_get_borrow_full s x y in (fst vb, - snd vb)) + ; (forall s y x, + Z.add_get_carry_full s (- y) x + = dlet vb := Z.sub_get_borrow_full s x y in (fst vb, - snd vb)) + ; (forall s y x, + Z.add_with_get_carry_full s 0 x (- y) + = dlet vb := Z.sub_get_borrow_full s x y in (fst vb, - snd vb)) + ; (forall s y x, + Z.add_with_get_carry_full s 0 (- y) x + = dlet vb := Z.sub_get_borrow_full s x y in (fst vb, - snd vb)) + ; (forall s c y x, + Z.add_with_get_carry_full s (- c) (- y) x + = dlet vb := Z.sub_with_get_borrow_full s c x y in (fst vb, - snd vb)) + ; (forall s c y x, + Z.add_with_get_carry_full s (- c) x (- y) + = dlet vb := Z.sub_with_get_borrow_full s c x y in (fst vb, - snd vb)) ] ; reify [ (* [do_again], so that if one of the arguments is concrete, we automatically get the rewrite rule for [Z_cast] applying to it *) -- cgit v1.2.3