aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jagro@google.com>2018-06-29 17:16:15 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2018-07-03 19:28:55 -0400
commit5cc357b06d2e1abfb7fd5cfd492191d7e038c168 (patch)
tree20ede2ab07491068dd2b0beda67ac92b74ec4655 /src
parent87d8ee6e62e3b868597dee978cb762e9e370433a (diff)
WIP better square
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/NewPipeline/Arithmetic.v177
1 files changed, 165 insertions, 12 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v
index 72d488d80..b2ec86a57 100644
--- a/src/Experiments/NewPipeline/Arithmetic.v
+++ b/src/Experiments/NewPipeline/Arithmetic.v
@@ -1,5 +1,8 @@
(* Following http://adam.chlipala.net/theses/andreser.pdf chapter 3 *)
+Require Import Crypto.Algebra.Nsatz.
Require Import Coq.ZArith.ZArith Coq.micromega.Lia Crypto.Algebra.Nsatz.
+Require Import Coq.Sorting.Mergesort Coq.Structures.Orders.
+Require Import Coq.Sorting.Permutation.
Require Import Coq.derive.Derive.
Require Import Crypto.Util.Tactics.UniquePose Crypto.Util.Decidable.
Require Import Crypto.Util.Tuple Crypto.Util.Prod Crypto.Util.LetIn.
@@ -28,6 +31,7 @@ Require Import Crypto.Util.Tactics.SplitInContext.
Require Import Crypto.Util.Tactics.SubstEvars.
Require Import Crypto.Util.Notations.
Require Import Crypto.Util.ZUtil.Definitions.
+Require Import Crypto.Util.ZUtil.Sorting.
Require Import Crypto.Util.ZUtil.CC Crypto.Util.ZUtil.Rshi.
Require Import Crypto.Util.ZUtil.Zselect Crypto.Util.ZUtil.AddModulo.
Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit.
@@ -35,6 +39,7 @@ Require Import Crypto.Util.ZUtil Crypto.Util.ZUtil.Hints.Core.
Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div.
Require Import Crypto.Util.ZUtil.Hints.PullPush.
Require Import Crypto.Util.ZUtil.EquivModulo.
+Require Import Crypto.Util.Prod.
Require Import Crypto.Util.CPSNotations.
Require Import Crypto.Util.Equality.
Import ListNotations. Local Open Scope Z_scope.
@@ -102,18 +107,36 @@ Module Associational.
(* Goal: eval ?ab = eval [(100,(a1*b1));(10,a1*b0);(10,a0*b1);(1,a0*b0)]%RT *)
trivial. Defined.
+ Lemma eval_partition f (p:list (Z*Z)) :
+ eval (snd (partition f p)) + eval (fst (partition f p)) = eval p.
+ Proof. induction p; cbn [partition]; eta_expand; break_match; cbn [fst snd]; push; nsatz. Qed.
+ Hint Rewrite eval_partition : push_eval.
+
+ Lemma eval_partition' f (p:list (Z*Z)) :
+ eval (fst (partition f p)) + eval (snd (partition f p)) = eval p.
+ Proof. rewrite Z.add_comm, eval_partition; reflexivity. Qed.
+ Hint Rewrite eval_partition' : push_eval.
+
+ Lemma eval_fst_partition f p : eval (fst (partition f p)) = eval p - eval (snd (partition f p)).
+ Proof. rewrite <- (eval_partition f p); nsatz. Qed.
+ Lemma eval_snd_partition f p : eval (snd (partition f p)) = eval p - eval (fst (partition f p)).
+ Proof. rewrite <- (eval_partition f p); nsatz. Qed.
+
Definition split (s:Z) (p:list (Z*Z)) : list (Z*Z) * list (Z*Z)
:= let hi_lo := partition (fun t => fst t mod s =? 0) p in
(snd hi_lo, map (fun t => (fst t / s, snd t)) (fst hi_lo)).
- Lemma eval_split s p (s_nz:s<>0) :
- eval (fst (split s p)) + s * eval (snd (split s p)) = eval p.
- Proof. cbv [Let_In split]; induction p;
+ Lemma eval_snd_split s p (s_nz:s<>0) :
+ s * eval (snd (split s p)) = eval (fst (partition (fun t => fst t mod s =? 0) p)).
+ Proof. cbv [split Let_In]; induction p;
repeat match goal with
| |- context[?a/?b] =>
unique pose proof (Z_div_exact_full_2 a b ltac:(trivial) ltac:(trivial))
| _ => progress push
| _ => progress break_match
| _ => progress nsatz end. Qed.
+ Lemma eval_split s p (s_nz:s<>0) :
+ eval (fst (split s p)) + s * eval (snd (split s p)) = eval p.
+ Proof. rewrite eval_snd_split, eval_fst_partition by assumption; cbv [split Let_In]; cbn [fst snd]; omega. Qed.
Lemma reduction_rule a b s c (modulus_nz:s-c<>0) :
(a + s * b) mod (s - c) = (a + c * b) mod (s - c).
@@ -129,6 +152,142 @@ Module Associational.
rewrite <-reduction_rule, eval_split; trivial. Qed.
Hint Rewrite eval_reduce : push_eval.
+ Definition splitQ (s:Q) (p:list (Z*Z)) : list (Z*Z) * list (Z*Z)
+ := let hi_lo := partition (fun t => (fst t * Zpos (Qden s)) mod (Qnum s) =? 0) p in
+ (snd hi_lo, map (fun t => ((fst t * Zpos (Qden s)) / Qnum s, snd t)) (fst hi_lo)).
+ Lemma eval_snd_splitQ s p (s_nz:Qnum s<>0) :
+ Qnum s * eval (snd (splitQ s p)) = eval (fst (partition (fun t => (fst t * Zpos (Qden s)) mod (Qnum s) =? 0) p)) * Zpos (Qden s).
+ Proof.
+ (* Work around https://github.com/mit-plv/fiat-crypto/issues/381 ([nsatz] can't handle [Zpos]) *)
+ cbv [splitQ Let_In]; cbn [fst snd]; zify; generalize dependent (Zpos (Qden s)); generalize dependent (Qnum s); clear s; intros.
+ induction p;
+ repeat match goal with
+ | |- context[?a/?b] =>
+ unique pose proof (Z_div_exact_full_2 a b ltac:(trivial) ltac:(trivial))
+ | _ => progress push
+ | _ => progress break_match
+ | _ => progress nsatz end. Qed.
+ Lemma eval_splitQ s p (s_nz:Qnum s<>0) :
+ eval (fst (splitQ s p)) + (Qnum s * eval (snd (splitQ s p))) / Zpos (Qden s) = eval p.
+ Proof. rewrite eval_snd_splitQ, eval_fst_partition by assumption; cbv [splitQ Let_In]; cbn [fst snd]; Z.div_mod_to_quot_rem; nia. Qed.
+ Lemma eval_splitQ_mul s p (s_nz:Qnum s<>0) :
+ eval (fst (splitQ s p)) * Zpos (Qden s) + (Qnum s * eval (snd (splitQ s p))) = eval p * Zpos (Qden s).
+ Proof. rewrite eval_snd_splitQ, eval_fst_partition by assumption; cbv [splitQ Let_In]; cbn [fst snd]; nia. Qed.
+
+ Lemma eval_rev p : eval (rev p) = eval p.
+ Proof. induction p; cbn [rev]; push; lia. Qed.
+ Hint Rewrite eval_rev : push_eval.
+
+ Lemma eval_permutation (p q : list (Z * Z)) : Permutation p q -> eval p = eval q.
+ Proof. induction 1; push; nsatz. Qed.
+
+ Module RevWeightOrder <: TotalLeBool.
+ Definition t := (Z * Z)%type.
+ Definition leb (x y : t) := Z.leb (fst y) (fst x).
+ Infix "<=?" := leb.
+ Local Coercion is_true : bool >-> Sortclass.
+ Theorem leb_total : forall a1 a2, a1 <=? a2 \/ a2 <=? a1.
+ Proof.
+ cbv [is_true leb]; intros x y; rewrite !Z.leb_le; pose proof (Z.le_ge_cases (fst x) (fst y)).
+ omega.
+ Qed.
+ Global Instance leb_Transitive : Transitive leb.
+ Proof. repeat intro; unfold is_true, leb in *; Z.ltb_to_lt; omega. Qed.
+ End RevWeightOrder.
+
+ Module RevWeightSort := Mergesort.Sort RevWeightOrder.
+
+ Lemma eval_sort p : eval (RevWeightSort.sort p) = eval p.
+ Proof. symmetry; apply eval_permutation, RevWeightSort.Permuted_sort. Qed.
+ Hint Rewrite eval_sort : push_eval.
+
+ (* rough template (we actually have to do things a bit differently to account for duplicate weights):
+[ dlet fi_c := c * fi in
+ let (fj_high, fj_low) := split fj at s/fi.weight in
+ dlet fi_2 := 2 * fi in
+ dlet fi_2_c := 2 * fi_c in
+ (if fi.weight^2 >= s then fi_c * fi else fi * fi)
+ ++ fi_2_c * fj_high
+ ++ fi_2 * fj_low
+ | fi <- f , fj := (f weight less than i) ]
+ *)
+ (** N.B. We take advantage of dead code elimination to allow us to
+ let-bind partial products that we don't end up using *)
+ Definition reduce_square (s:Z) (c:list (Z*Z)) (p:list (Z*Z)) : list (Z*Z) :=
+ let two := [(1,2)] (* (weight, value) *) in
+ let p := RevWeightSort.sort p in (* now the highest weight limbs come first *)
+ list_rect
+ _
+ nil
+ (fun t ts acc
+ => (let hi_low_ts := partition (fun t' => (fst t * fst t') mod s =? 0) ts in
+ if ((snd t) * (snd t) <=? s) (* reduce doesn't apply; no multiplication by [c] *)
+ then (dlet two_t2 := 2 * snd t in
+ (fst t * fst t, snd t * snd t)
+ :: (map (fun t'
+ => (fst t * fst t', two_t2 * snd t'))
+ ts))
+ else (dlet c_t := mul [t] c in
+ dlet two_t := mul [t] two in
+ dlet two_c_t := mul c_t two in
+ (if ((snd t) * (snd t)) mod s =? 0 then [] else mul [t] [t])
+ ++ mul two_t (snd hi_low_ts)
+ ++ (map
+ (fun t => (fst t / s, snd t))
+ ((if ((snd t) * (snd t)) mod s =? 0 then mul c_t [t] else [])
+ ++ mul two_c_t (fst hi_low_ts)))))
+ ++ acc)
+ p.
+ Lemma eval_reduce_square s c p (s_nz:s<>0) (modulus_nz:s-eval c<>0)
+ : eval (reduce_square s c p) mod (s - eval c)
+ = (eval p * eval p) mod (s - eval c).
+ Proof.
+ (*rewrite <- (eval_sort p).
+ cbv [reduce_square Let_In]; generalize (RevWeightSort.sort p) (RevWeightSort.StronglySorted_sort p _); clear p; intros p Hpsort.
+ induction Hpsort; cbn [list_rect]; push; [ nsatz | ].
+ rewrite Z.mul_add_distr_r, !Z.mul_add_distr_l, !Z.add_assoc.
+ apply Z.add_mod_Proper; cbv [Z.equiv_modulo]; [ clear IHHpsort | exact IHHpsort ].
+ break_innermost_match_step; Z.ltb_to_lt; push; [ f_equal; nsatz | ].
+
+ SearchAbout partition.
+
+ rewrite <- !Z.add_assoc; apply Z.add_mod_Proper; cbv [Z.equiv_modulo].
+
+ { break_match; Z.ltb_to_lt; push; try (f_equal; nsatz).
+
+ repeat match goal with H : context[Z.modulo] |- _ => revert H end.
+ Z.div_mod_to_quot_rem; nsatz.
+ break_match; push; try (f_equal; nsatz); Z.ltb_to_lt; autorewrite with zsimplify_const.
+
+
+ {
+ Focus 2.
+ rewrite ?(Z.mul_comm (fst _) (snd _)), <- !Z.mul_assoc.
+ Z.rewrite_mod_small.
+ push_Zmod;
+ rewrite <- Z.mul_mod_full; Z.rewrite_mod_small; pull_Zmod.
+
+ pull_Zmod.
+ SearchAbout ((_ mod _) * (_ mod _)).
+
+ SearchAbout
+
+ cbn [mul map].
+ 2:apply IHHpsort.
+ rewrite Z
+ break_innermost_match; Z.ltb_to_lt; push.
+
+ cbn [filter].
+ SearchAbout filter cons.
+ apply Z.add_mod_Proper; cbv [Z.equiv_modulo].
+ Focus 2.
+ 2:
+ SearchAbout Proper Z.equiv_modulo.
+ SearchAbout ((_ + _) mod _).
+ eta_expand; push.
+ nsatz.*)Admitted.
+ Hint Rewrite eval_reduce_square : push_eval.
+
Definition bind_snd (p : list (Z*Z)) :=
map (fun t => dlet_nd t2 := snd t in (fst t, t2)) p.
@@ -138,9 +297,6 @@ Module Associational.
push; [|rewrite IHp]; reflexivity.
Qed.
- Lemma eval_rev p : eval (rev p) = eval p.
- Proof. induction p; cbn [rev]; push; lia. Qed.
-
Section Carries.
Definition carryterm (w fw:Z) (t:Z * Z) :=
if (Z.eqb (fst t) w)
@@ -308,17 +464,14 @@ Module Positional. Section Positional.
Definition squaremod (n:nat) (a:list Z) : list Z
:= let a_a := to_associational n a in
- let aa_a := Associational.square a_a in
- let aam_a := Associational.reduce s c aa_a in
- from_associational n aam_a.
+ let aa_a := Associational.reduce_square s c a_a in
+ from_associational n aa_a.
Lemma eval_squaremod n (f:list Z)
(Hf : length f = n) :
eval n (squaremod n f) mod (s - Associational.eval c)
= (eval n f * eval n f) mod (s - Associational.eval c).
Proof. cbv [squaremod]; push; trivial.
- destruct f; simpl in *; [ right; subst n | left; try omega.. ].
- clear; cbv -[Associational.reduce].
- induction c as [|?? IHc]; simpl; trivial. Qed.
+ destruct f; simpl in *; [ right; subst n; reflexivity | left; try omega.. ]. Qed.
End mulmod.
Hint Rewrite @eval_mulmod @eval_squaremod : push_eval.