aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jagro@google.com>2018-06-29 19:30:37 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2018-07-03 19:28:55 -0400
commit63e0712c217944df2a0a06c439c834e1b0a7295e (patch)
treed3db3ddb08201abbc8eb01fba4e9747cb57dcc61 /src
parent5cc357b06d2e1abfb7fd5cfd492191d7e038c168 (diff)
Try a different reduce_square
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/NewPipeline/Arithmetic.v49
1 files changed, 26 insertions, 23 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v
index b2ec86a57..e6b25efaa 100644
--- a/src/Experiments/NewPipeline/Arithmetic.v
+++ b/src/Experiments/NewPipeline/Arithmetic.v
@@ -152,6 +152,7 @@ Module Associational.
rewrite <-reduction_rule, eval_split; trivial. Qed.
Hint Rewrite eval_reduce : push_eval.
+ (*
Definition splitQ (s:Q) (p:list (Z*Z)) : list (Z*Z) * list (Z*Z)
:= let hi_lo := partition (fun t => (fst t * Zpos (Qden s)) mod (Qnum s) =? 0) p in
(snd hi_lo, map (fun t => ((fst t * Zpos (Qden s)) / Qnum s, snd t)) (fst hi_lo)).
@@ -173,11 +174,11 @@ Module Associational.
Lemma eval_splitQ_mul s p (s_nz:Qnum s<>0) :
eval (fst (splitQ s p)) * Zpos (Qden s) + (Qnum s * eval (snd (splitQ s p))) = eval p * Zpos (Qden s).
Proof. rewrite eval_snd_splitQ, eval_fst_partition by assumption; cbv [splitQ Let_In]; cbn [fst snd]; nia. Qed.
-
+ *)
Lemma eval_rev p : eval (rev p) = eval p.
Proof. induction p; cbn [rev]; push; lia. Qed.
Hint Rewrite eval_rev : push_eval.
-
+ (*
Lemma eval_permutation (p q : list (Z * Z)) : Permutation p q -> eval p = eval q.
Proof. induction 1; push; nsatz. Qed.
@@ -200,7 +201,7 @@ Module Associational.
Lemma eval_sort p : eval (RevWeightSort.sort p) = eval p.
Proof. symmetry; apply eval_permutation, RevWeightSort.Permuted_sort. Qed.
Hint Rewrite eval_sort : push_eval.
-
+ *)
(* rough template (we actually have to do things a bit differently to account for duplicate weights):
[ dlet fi_c := c * fi in
let (fj_high, fj_low) := split fj at s/fi.weight in
@@ -213,30 +214,32 @@ Module Associational.
*)
(** N.B. We take advantage of dead code elimination to allow us to
let-bind partial products that we don't end up using *)
- Definition reduce_square (s:Z) (c:list (Z*Z)) (p:list (Z*Z)) : list (Z*Z) :=
+ (** [v] -> [(v, v*c, v*c*2, v*2)] *)
+ Definition let_bind_for_reduce_square (c:list (Z*Z)) (p:list (Z*Z)) : list ((Z*Z) * list(Z*Z) * list(Z*Z) * list(Z*Z)) :=
let two := [(1,2)] (* (weight, value) *) in
- let p := RevWeightSort.sort p in (* now the highest weight limbs come first *)
+ map (fun t => dlet c_t := mul [t] c in dlet two_c_t := mul c_t two in dlet two_t := mul [t] two in (t, c_t, two_c_t, two_t)) p.
+ Definition reduce_square (s:Z) (c:list (Z*Z)) (p:list (Z*Z)) : list (Z*Z) :=
+ let p := let_bind_for_reduce_square c p in
+ let div_s := map (fun t => (fst t / s, snd t)) in
list_rect
_
nil
- (fun t ts acc
- => (let hi_low_ts := partition (fun t' => (fst t * fst t') mod s =? 0) ts in
- if ((snd t) * (snd t) <=? s) (* reduce doesn't apply; no multiplication by [c] *)
- then (dlet two_t2 := 2 * snd t in
- (fst t * fst t, snd t * snd t)
- :: (map (fun t'
- => (fst t * fst t', two_t2 * snd t'))
- ts))
- else (dlet c_t := mul [t] c in
- dlet two_t := mul [t] two in
- dlet two_c_t := mul c_t two in
- (if ((snd t) * (snd t)) mod s =? 0 then [] else mul [t] [t])
- ++ mul two_t (snd hi_low_ts)
- ++ (map
- (fun t => (fst t / s, snd t))
- ((if ((snd t) * (snd t)) mod s =? 0 then mul c_t [t] else [])
- ++ mul two_c_t (fst hi_low_ts)))))
- ++ acc)
+ (fun '(t, c_t, two_c_t, two_t) ts acc
+ => (if (fst t mod s =? 0)
+ then div_s (mul c_t [t])
+ else mul [t] [t])
+ ++ (flat_map
+ (fun '(t', c_t', two_c_t', two_t')
+ => if ((fst t * fst t') mod s =? 0)
+ then div_s
+ (if fst t' <=? fst t
+ then mul two_c_t [t']
+ else mul [t] two_c_t)
+ else (if fst t' <=? fst t
+ then mul two_t [t']
+ else mul [t] two_t))
+ ts)
+ ++ acc)
p.
Lemma eval_reduce_square s c p (s_nz:s<>0) (modulus_nz:s-eval c<>0)
: eval (reduce_square s c p) mod (s - eval c)