diff options
author | Jason Gross <jagro@google.com> | 2018-06-29 17:16:15 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2018-07-03 19:28:55 -0400 |
commit | 5cc357b06d2e1abfb7fd5cfd492191d7e038c168 (patch) | |
tree | 20ede2ab07491068dd2b0beda67ac92b74ec4655 /src | |
parent | 87d8ee6e62e3b868597dee978cb762e9e370433a (diff) |
WIP better square
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/NewPipeline/Arithmetic.v | 177 |
1 files changed, 165 insertions, 12 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v index 72d488d80..b2ec86a57 100644 --- a/src/Experiments/NewPipeline/Arithmetic.v +++ b/src/Experiments/NewPipeline/Arithmetic.v @@ -1,5 +1,8 @@ (* Following http://adam.chlipala.net/theses/andreser.pdf chapter 3 *) +Require Import Crypto.Algebra.Nsatz. Require Import Coq.ZArith.ZArith Coq.micromega.Lia Crypto.Algebra.Nsatz. +Require Import Coq.Sorting.Mergesort Coq.Structures.Orders. +Require Import Coq.Sorting.Permutation. Require Import Coq.derive.Derive. Require Import Crypto.Util.Tactics.UniquePose Crypto.Util.Decidable. Require Import Crypto.Util.Tuple Crypto.Util.Prod Crypto.Util.LetIn. @@ -28,6 +31,7 @@ Require Import Crypto.Util.Tactics.SplitInContext. Require Import Crypto.Util.Tactics.SubstEvars. Require Import Crypto.Util.Notations. Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.Sorting. Require Import Crypto.Util.ZUtil.CC Crypto.Util.ZUtil.Rshi. Require Import Crypto.Util.ZUtil.Zselect Crypto.Util.ZUtil.AddModulo. Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit. @@ -35,6 +39,7 @@ Require Import Crypto.Util.ZUtil Crypto.Util.ZUtil.Hints.Core. Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div. Require Import Crypto.Util.ZUtil.Hints.PullPush. Require Import Crypto.Util.ZUtil.EquivModulo. +Require Import Crypto.Util.Prod. Require Import Crypto.Util.CPSNotations. Require Import Crypto.Util.Equality. Import ListNotations. Local Open Scope Z_scope. @@ -102,18 +107,36 @@ Module Associational. (* Goal: eval ?ab = eval [(100,(a1*b1));(10,a1*b0);(10,a0*b1);(1,a0*b0)]%RT *) trivial. Defined. + Lemma eval_partition f (p:list (Z*Z)) : + eval (snd (partition f p)) + eval (fst (partition f p)) = eval p. + Proof. induction p; cbn [partition]; eta_expand; break_match; cbn [fst snd]; push; nsatz. Qed. + Hint Rewrite eval_partition : push_eval. + + Lemma eval_partition' f (p:list (Z*Z)) : + eval (fst (partition f p)) + eval (snd (partition f p)) = eval p. + Proof. rewrite Z.add_comm, eval_partition; reflexivity. Qed. + Hint Rewrite eval_partition' : push_eval. + + Lemma eval_fst_partition f p : eval (fst (partition f p)) = eval p - eval (snd (partition f p)). + Proof. rewrite <- (eval_partition f p); nsatz. Qed. + Lemma eval_snd_partition f p : eval (snd (partition f p)) = eval p - eval (fst (partition f p)). + Proof. rewrite <- (eval_partition f p); nsatz. Qed. + Definition split (s:Z) (p:list (Z*Z)) : list (Z*Z) * list (Z*Z) := let hi_lo := partition (fun t => fst t mod s =? 0) p in (snd hi_lo, map (fun t => (fst t / s, snd t)) (fst hi_lo)). - Lemma eval_split s p (s_nz:s<>0) : - eval (fst (split s p)) + s * eval (snd (split s p)) = eval p. - Proof. cbv [Let_In split]; induction p; + Lemma eval_snd_split s p (s_nz:s<>0) : + s * eval (snd (split s p)) = eval (fst (partition (fun t => fst t mod s =? 0) p)). + Proof. cbv [split Let_In]; induction p; repeat match goal with | |- context[?a/?b] => unique pose proof (Z_div_exact_full_2 a b ltac:(trivial) ltac:(trivial)) | _ => progress push | _ => progress break_match | _ => progress nsatz end. Qed. + Lemma eval_split s p (s_nz:s<>0) : + eval (fst (split s p)) + s * eval (snd (split s p)) = eval p. + Proof. rewrite eval_snd_split, eval_fst_partition by assumption; cbv [split Let_In]; cbn [fst snd]; omega. Qed. Lemma reduction_rule a b s c (modulus_nz:s-c<>0) : (a + s * b) mod (s - c) = (a + c * b) mod (s - c). @@ -129,6 +152,142 @@ Module Associational. rewrite <-reduction_rule, eval_split; trivial. Qed. Hint Rewrite eval_reduce : push_eval. + Definition splitQ (s:Q) (p:list (Z*Z)) : list (Z*Z) * list (Z*Z) + := let hi_lo := partition (fun t => (fst t * Zpos (Qden s)) mod (Qnum s) =? 0) p in + (snd hi_lo, map (fun t => ((fst t * Zpos (Qden s)) / Qnum s, snd t)) (fst hi_lo)). + Lemma eval_snd_splitQ s p (s_nz:Qnum s<>0) : + Qnum s * eval (snd (splitQ s p)) = eval (fst (partition (fun t => (fst t * Zpos (Qden s)) mod (Qnum s) =? 0) p)) * Zpos (Qden s). + Proof. + (* Work around https://github.com/mit-plv/fiat-crypto/issues/381 ([nsatz] can't handle [Zpos]) *) + cbv [splitQ Let_In]; cbn [fst snd]; zify; generalize dependent (Zpos (Qden s)); generalize dependent (Qnum s); clear s; intros. + induction p; + repeat match goal with + | |- context[?a/?b] => + unique pose proof (Z_div_exact_full_2 a b ltac:(trivial) ltac:(trivial)) + | _ => progress push + | _ => progress break_match + | _ => progress nsatz end. Qed. + Lemma eval_splitQ s p (s_nz:Qnum s<>0) : + eval (fst (splitQ s p)) + (Qnum s * eval (snd (splitQ s p))) / Zpos (Qden s) = eval p. + Proof. rewrite eval_snd_splitQ, eval_fst_partition by assumption; cbv [splitQ Let_In]; cbn [fst snd]; Z.div_mod_to_quot_rem; nia. Qed. + Lemma eval_splitQ_mul s p (s_nz:Qnum s<>0) : + eval (fst (splitQ s p)) * Zpos (Qden s) + (Qnum s * eval (snd (splitQ s p))) = eval p * Zpos (Qden s). + Proof. rewrite eval_snd_splitQ, eval_fst_partition by assumption; cbv [splitQ Let_In]; cbn [fst snd]; nia. Qed. + + Lemma eval_rev p : eval (rev p) = eval p. + Proof. induction p; cbn [rev]; push; lia. Qed. + Hint Rewrite eval_rev : push_eval. + + Lemma eval_permutation (p q : list (Z * Z)) : Permutation p q -> eval p = eval q. + Proof. induction 1; push; nsatz. Qed. + + Module RevWeightOrder <: TotalLeBool. + Definition t := (Z * Z)%type. + Definition leb (x y : t) := Z.leb (fst y) (fst x). + Infix "<=?" := leb. + Local Coercion is_true : bool >-> Sortclass. + Theorem leb_total : forall a1 a2, a1 <=? a2 \/ a2 <=? a1. + Proof. + cbv [is_true leb]; intros x y; rewrite !Z.leb_le; pose proof (Z.le_ge_cases (fst x) (fst y)). + omega. + Qed. + Global Instance leb_Transitive : Transitive leb. + Proof. repeat intro; unfold is_true, leb in *; Z.ltb_to_lt; omega. Qed. + End RevWeightOrder. + + Module RevWeightSort := Mergesort.Sort RevWeightOrder. + + Lemma eval_sort p : eval (RevWeightSort.sort p) = eval p. + Proof. symmetry; apply eval_permutation, RevWeightSort.Permuted_sort. Qed. + Hint Rewrite eval_sort : push_eval. + + (* rough template (we actually have to do things a bit differently to account for duplicate weights): +[ dlet fi_c := c * fi in + let (fj_high, fj_low) := split fj at s/fi.weight in + dlet fi_2 := 2 * fi in + dlet fi_2_c := 2 * fi_c in + (if fi.weight^2 >= s then fi_c * fi else fi * fi) + ++ fi_2_c * fj_high + ++ fi_2 * fj_low + | fi <- f , fj := (f weight less than i) ] + *) + (** N.B. We take advantage of dead code elimination to allow us to + let-bind partial products that we don't end up using *) + Definition reduce_square (s:Z) (c:list (Z*Z)) (p:list (Z*Z)) : list (Z*Z) := + let two := [(1,2)] (* (weight, value) *) in + let p := RevWeightSort.sort p in (* now the highest weight limbs come first *) + list_rect + _ + nil + (fun t ts acc + => (let hi_low_ts := partition (fun t' => (fst t * fst t') mod s =? 0) ts in + if ((snd t) * (snd t) <=? s) (* reduce doesn't apply; no multiplication by [c] *) + then (dlet two_t2 := 2 * snd t in + (fst t * fst t, snd t * snd t) + :: (map (fun t' + => (fst t * fst t', two_t2 * snd t')) + ts)) + else (dlet c_t := mul [t] c in + dlet two_t := mul [t] two in + dlet two_c_t := mul c_t two in + (if ((snd t) * (snd t)) mod s =? 0 then [] else mul [t] [t]) + ++ mul two_t (snd hi_low_ts) + ++ (map + (fun t => (fst t / s, snd t)) + ((if ((snd t) * (snd t)) mod s =? 0 then mul c_t [t] else []) + ++ mul two_c_t (fst hi_low_ts))))) + ++ acc) + p. + Lemma eval_reduce_square s c p (s_nz:s<>0) (modulus_nz:s-eval c<>0) + : eval (reduce_square s c p) mod (s - eval c) + = (eval p * eval p) mod (s - eval c). + Proof. + (*rewrite <- (eval_sort p). + cbv [reduce_square Let_In]; generalize (RevWeightSort.sort p) (RevWeightSort.StronglySorted_sort p _); clear p; intros p Hpsort. + induction Hpsort; cbn [list_rect]; push; [ nsatz | ]. + rewrite Z.mul_add_distr_r, !Z.mul_add_distr_l, !Z.add_assoc. + apply Z.add_mod_Proper; cbv [Z.equiv_modulo]; [ clear IHHpsort | exact IHHpsort ]. + break_innermost_match_step; Z.ltb_to_lt; push; [ f_equal; nsatz | ]. + + SearchAbout partition. + + rewrite <- !Z.add_assoc; apply Z.add_mod_Proper; cbv [Z.equiv_modulo]. + + { break_match; Z.ltb_to_lt; push; try (f_equal; nsatz). + + repeat match goal with H : context[Z.modulo] |- _ => revert H end. + Z.div_mod_to_quot_rem; nsatz. + break_match; push; try (f_equal; nsatz); Z.ltb_to_lt; autorewrite with zsimplify_const. + + + { + Focus 2. + rewrite ?(Z.mul_comm (fst _) (snd _)), <- !Z.mul_assoc. + Z.rewrite_mod_small. + push_Zmod; + rewrite <- Z.mul_mod_full; Z.rewrite_mod_small; pull_Zmod. + + pull_Zmod. + SearchAbout ((_ mod _) * (_ mod _)). + + SearchAbout + + cbn [mul map]. + 2:apply IHHpsort. + rewrite Z + break_innermost_match; Z.ltb_to_lt; push. + + cbn [filter]. + SearchAbout filter cons. + apply Z.add_mod_Proper; cbv [Z.equiv_modulo]. + Focus 2. + 2: + SearchAbout Proper Z.equiv_modulo. + SearchAbout ((_ + _) mod _). + eta_expand; push. + nsatz.*)Admitted. + Hint Rewrite eval_reduce_square : push_eval. + Definition bind_snd (p : list (Z*Z)) := map (fun t => dlet_nd t2 := snd t in (fst t, t2)) p. @@ -138,9 +297,6 @@ Module Associational. push; [|rewrite IHp]; reflexivity. Qed. - Lemma eval_rev p : eval (rev p) = eval p. - Proof. induction p; cbn [rev]; push; lia. Qed. - Section Carries. Definition carryterm (w fw:Z) (t:Z * Z) := if (Z.eqb (fst t) w) @@ -308,17 +464,14 @@ Module Positional. Section Positional. Definition squaremod (n:nat) (a:list Z) : list Z := let a_a := to_associational n a in - let aa_a := Associational.square a_a in - let aam_a := Associational.reduce s c aa_a in - from_associational n aam_a. + let aa_a := Associational.reduce_square s c a_a in + from_associational n aa_a. Lemma eval_squaremod n (f:list Z) (Hf : length f = n) : eval n (squaremod n f) mod (s - Associational.eval c) = (eval n f * eval n f) mod (s - Associational.eval c). Proof. cbv [squaremod]; push; trivial. - destruct f; simpl in *; [ right; subst n | left; try omega.. ]. - clear; cbv -[Associational.reduce]. - induction c as [|?? IHc]; simpl; trivial. Qed. + destruct f; simpl in *; [ right; subst n; reflexivity | left; try omega.. ]. Qed. End mulmod. Hint Rewrite @eval_mulmod @eval_squaremod : push_eval. |