aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Andres Erbsen <andreser@mit.edu>2017-06-14 14:51:32 -0400
committerGravatar Andres Erbsen <andreser@mit.edu>2017-06-14 14:51:32 -0400
commit0eb8eeff3ddab8d27ae87dfdcbbc3d15065d275b (patch)
tree966fa1411928fe502459bee200b8dac2ae6aead7
parentaf91e66e42f98c9fa09d27a42d4d27e9015de829 (diff)
fix goldilocks karatsuba; TODO implement reduce
-rw-r--r--src/Arithmetic/Karatsuba.v264
-rw-r--r--src/Specific/Karatsuba.v118
2 files changed, 120 insertions, 262 deletions
diff --git a/src/Arithmetic/Karatsuba.v b/src/Arithmetic/Karatsuba.v
index d53351934..3c3009fde 100644
--- a/src/Arithmetic/Karatsuba.v
+++ b/src/Arithmetic/Karatsuba.v
@@ -110,227 +110,46 @@ Context (weight : nat -> Z)
*)
+ (*
Definition goldilocks_mul_cps_for_bounds_checker
s (xs ys : T2) {R} (f:T2->R) :=
split_cps (m1:=n) (m2:=n) weight s xs
(fun x0_x1 => split_cps weight s ys
- (fun y0_y1 => mul_cps weight (fst x0_x1) (snd y0_y1)
- (fun x0_y1 => mul_cps weight (snd x0_x1) (fst y0_y1)
- (fun x1_y0 => mul_cps weight (fst x0_x1) (fst y0_y1)
- (fun z0 => mul_cps weight (snd x0_x1) (snd y0_y1)
- (fun z2 => add_cps weight z0 z2
- (fun sum_z => add_cps weight x0_y1 x1_y0
- (fun z1' => add_cps weight z1' z2
- (fun z1 => scmul_cps weight s z1
- (fun sz1 => add_cps weight sum_z sz1 f)))))))))).
- Definition goldilocks_mul_cps s (xs ys : T2) {R} (f:T2->R) :=
+ (fun z1 => Positional.to_associational_cps weight z1
+ (fun z1 => Associational.mul_cps (pair s 1::nil) z1
+ (fun sz1 => Positional.from_associational_cps weight n2 sz1
+ (fun sz1 => add_cps weight sum_z sz1 f)))))))))))).
+ *)
+
+ Let T3 := tuple Z (n2+n).
+ Definition goldilocks_mul_cps s (xs ys : T2) {R} (f:T3->R) :=
split_cps (m1:=n) (m2:=n) weight s xs
(fun x0_x1 => split_cps weight s ys
(fun y0_y1 => mul_cps weight (fst x0_x1) (fst y0_y1)
(fun z0 => mul_cps weight (snd x0_x1) (snd y0_y1)
(fun z2 => add_cps weight z0 z2
- (fun sum_z => add_cps weight (fst x0_x1) (snd x0_x1)
+ (fun sum_z : tuple _ n2 => add_cps weight (fst x0_x1) (snd x0_x1)
(fun sum_x => add_cps weight (fst y0_y1) (snd y0_y1)
(fun sum_y => mul_cps weight sum_x sum_y
- (fun mul_sumxy => unbalanced_sub_cps weight mul_sumxy z0
- (fun z1 => scmul_cps weight s z1
- (fun sz1 => add_cps weight sum_z sz1 f)))))))))).
-
-
- Lemma to_list_left_append {A N} t0 (t : tuple A N) :
- to_list (S N) (left_append t0 t) = (to_list N t ++ t0 :: nil)%list.
- Proof.
- induction N;
- repeat match goal with
- | _ => destruct x
- | _ => rewrite (subst_append (left_append t0 t));
- rewrite (subst_append t); rewrite !to_list_append;
- rewrite <-!subst_append
- | _ => progress (rewrite ?hd_left_append, ?tl_left_append)
- | _ => rewrite IHN
- | _ => reflexivity
- end.
- Qed.
+ (fun mul_sumxy =>
- Lemma seq_S_snoc len : forall start,
- List.seq start (S len) = (List.seq start len ++ (len + start)%nat :: nil)%list.
- Proof.
- induction len; intros; [reflexivity|].
- transitivity (start :: List.seq (S start) (S len))%list;
- [reflexivity|]. rewrite (IHlen (S start)).
- simpl List.seq; rewrite plus_Snm_nSm.
- apply List.app_comm_cons.
- Qed.
-
- Require Import Crypto.Util.ListUtil.
- Require Import Coq.Lists.List.
- Lemma repeat_left_append {A N} (a : A) :
- Tuple.repeat a (S N) = left_append a (Tuple.repeat a N).
- Admitted.
-
- Lemma from_to_associational_id wt N x :
- from_associational wt N (to_associational wt x) = x.
- Proof.
- cbv [from_associational to_associational from_associational_cps to_associational_cps].
- autorewrite with push_id uncps.
- induction N.
- { destruct x. reflexivity. }
- {
- rewrite (subst_left_append x).
- rewrite to_list_left_append.
- rewrite seq_S_snoc, plus_0_r.
- rewrite map_app, map_cons, map_nil.
- rewrite combine_app_samelength by distr_length.
- rewrite combine_cons, combine_nil_r.
- rewrite fold_right_app.
- Admitted.
-
- Local Infix "**" := Associational.mul (at level 40).
-
- Local Definition multerm terms :=
- Associational.multerm (fst terms) (snd terms).
-
- Lemma mul_power_equiv (p q : list limb) :
- Permutation.permutation
- (p ** q)
- (List.map multerm (list_prod p q)).
- Admitted.
-
- Lemma permutation_from_associational (p q : list limb) :
- Permutation.permutation p q -> forall wt N,
- from_associational wt N p = from_associational wt N q.
- Admitted.
-
- Lemma prod_append_binary_expansion {A : Type} {B : Set} (f:(A*A)->B)
- (ws xs ys zs : list A) :
- @Permutation.permutation B
- (map f (list_prod (ws ++ xs) (ys ++ zs)))
- (map f ((list_prod ws ys) ++ (list_prod ws zs) ++ (list_prod xs ys) ++ (list_prod xs zs))).
- Admitted.
-
- Lemma to_from_associational_append wt N p q :
- to_associational wt (from_associational wt N (p ++ q))
- = to_associational wt (from_associational wt N p) ++ to_associational wt (from_associational wt N q).
- Admitted.
-
- Lemma binary_expansion wt N a b c d :
- let to_from x := to_associational wt (from_associational wt N x) in
- (to_from ((a ++ b) ** (c ++ d)) = to_from (to_from (a ** c) ++ (to_from (to_from (a ** d) ++ (to_from (b ** c))) ++ to_from (b ** d))))%list.
- Proof.
- intro.
- pose proof (prod_append_binary_expansion multerm a b c d).
- pose proof (mul_power_equiv (a ++ b) (c ++ d)).
- let P := fresh "P" in
- remember (fun w z x y H => Permutation.permutation_app_comp _ w z (x ** y) (map multerm (list_prod x y)) H (mul_power_equiv _ _)) as P;
- pose proof (P _ _ b d (P _ _ b c (P _ _ a d (mul_power_equiv a c))));
- subst P.
- rewrite !map_app, !app_assoc_reverse in *.
- let H := fresh "H" in
- match goal with
- HA : Permutation.permutation ?x ?y,
- HB : Permutation.permutation ?z ?x,
- HC : Permutation.permutation ?w ?y |- _ =>
- assert (Permutation.permutation z w) as H by
- eauto using Permutation.permutation_sym, Permutation.permutation_trans;
- clear HA HB HC
- end; apply permutation_from_associational with (wt := wt) (N := N) in H.
- subst to_from. cbv beta.
- f_equal. etransitivity; [eassumption|].
- rewrite !to_from_associational_append.
- rewrite !from_to_associational_id.
- rewrite <-!to_from_associational_append.
- rewrite !from_to_associational_id.
- rewrite !app_assoc_reverse.
- reflexivity.
- Qed.
-
- Local Notation from := (from_associational weight).
- Local Notation to := (to_associational weight).
-
- Lemma subtraction_id N p q :
- from N ((p ++ Associational.negate_snd p) ++ q) = from N q.
- Admitted.
-
- Lemma goldilocks_mul_equiv' x0 x1 y0 y1 :
- let X0 := to (from n x0) in
- let X1 := to (from n x1) in
- let Y0 := to (from n y0) in
- let Y1 := to (from n y1) in
- from n2
- (to (from n2 (to (from n2 (X0 ** Y1)) ++ to (from n2 (X1 ** Y0)))) ++ to (from n2 (X1 ** Y1))) =
- from n2
- (to (from n2 (to (from n (X0 ++ X1)) ** to (from n (Y0 ++ Y1)))) ++ Associational.negate_snd (to (from n2 (X0 ** Y0)))).
- Proof.
- intros.
- repeat match goal with
- | _ => progress
- (rewrite !to_from_associational_append,
- !from_to_associational_id)
- | _ => progress
- (rewrite <-!to_from_associational_append,
- !from_to_associational_id)
- | _ => rewrite app_assoc_reverse
- | _ => rewrite binary_expansion
- | _ => subst X0 X1 Y0 Y1
- end.
- match goal with
- | |- _ = from ?n (?a ++ ?b ++ ?c ++ ?d ++ Associational.negate_snd ?a) =>
- transitivity (from n ((a ++ Associational.negate_snd a) ++ b ++ c ++ d));
- [|remember a as A; remember b as B; remember c as C; remember d as D; remember (Associational.negate_snd A) as negA]
+ dlet z1 := id_with_alt_bounds (unbalanced_sub_cps weight mul_sumxy z0 id) (
- end.
- Focus 2.
- { rewrite app_assoc_reverse.
- apply permutation_from_associational.
- replace (A ++ B ++ C ++ D ++ negA) with (A ++ (B ++ C ++ D) ++ negA).
- auto using app_assoc, app_assoc_reverse.
- rewrite !app_assoc_reverse; reflexivity. } Unfocus.
- rewrite subtraction_id.
- repeat match goal with
- | _ => progress
- (rewrite <-!to_from_associational_append,
- !from_to_associational_id)
- | _ => rewrite app_assoc_reverse
- end.
- reflexivity.
- Qed.
+ (mul_cps weight (fst x0_x1) (snd y0_y1)
+ (fun x0_y1 => mul_cps weight (snd x0_x1) (fst y0_y1)
+ (fun x1_y0 => mul_cps weight (fst x0_x1) (fst y0_y1)
+ (fun z0 => mul_cps weight (snd x0_x1) (snd y0_y1)
+ (fun z2 => add_cps weight z0 z2
+ (fun sum_z => add_cps weight x0_y1 x1_y0
+ (fun z1' => add_cps weight z1' z2 id)))))))) in
- Lemma goldilocks_mul_equiv s xs ys {R} f:
- @goldilocks_mul_cps s xs ys R f =
- @goldilocks_mul_cps_for_bounds_checker s xs ys R f.
- Proof.
- cbv [goldilocks_mul_cps_for_bounds_checker goldilocks_mul_cps].
- repeat autounfold.
- autorewrite with cancel_pair push_id uncps.
- apply f_equal.
- repeat match goal with
- |- context [Associational.mul ?x ?y] =>
- let m := fresh "m" in
- remember (Associational.mul x y) as m end.
- apply f_equal.
- apply f_equal.
- apply f_equal.
- apply f_equal.
- subst m m0 m1 m2.
- apply f_equal2; try reflexivity.
- apply f_equal.
- symmetry.
- apply goldilocks_mul_equiv'.
- Qed.
+ Positional.to_associational_cps weight z1
+ (fun z1 => Associational.mul_cps (pair s 1::nil) z1
+ (fun sz1 => Positional.to_associational_cps weight sum_z
+ (fun sum_z => Positional.from_associational_cps weight _ (sum_z++sz1) f
+ ))))))))))).
- Definition goldilocks_mul s xs ys :=
- id_with_alt_bounds
- (@goldilocks_mul_cps s xs ys _ id)
- (@goldilocks_mul_cps_for_bounds_checker s xs ys _ id).
- Lemma goldilocks_mul_id s xs ys {R} f :
- @goldilocks_mul_cps s xs ys R f = f (goldilocks_mul s xs ys).
- Proof.
- cbv [goldilocks_mul goldilocks_mul_cps]; rewrite !unfold_id_tuple_with_alt.
- repeat autounfold.
- autorewrite with cancel_pair push_id uncps.
- reflexivity.
- Qed.
Local Existing Instances Z.equiv_modulo_Reflexive
RelationClasses.eq_Reflexive Z.equiv_modulo_Symmetric
@@ -338,22 +157,35 @@ Context (weight : nat -> Z)
Z.modulo_equiv_modulo_Proper.
Lemma goldilocks_mul_correct (p : Z) (p_nonzero : p <> 0) s (s_nonzero : s <> 0) (s2_modp : (s^2) mod p = (s+1) mod p) xs ys :
- (eval weight (goldilocks_mul s xs ys)) mod p = (eval weight xs * eval weight ys) mod p.
+ (eval weight (goldilocks_mul_cps s xs ys id)) mod p = (eval weight xs * eval weight ys) mod p.
Proof.
- cbv [goldilocks_mul goldilocks_mul_cps]; rewrite !unfold_id_tuple_with_alt.
+ cbv [goldilocks_mul_cps Let_In].
Zmod_to_equiv_modulo.
- repeat autounfold; autorewrite with push_id cancel_pair uncps push_basesystem_eval.
+ progress autounfold.
+ progress autorewrite with push_id cancel_pair uncps push_basesystem_eval.
+ rewrite !unfold_id_tuple_with_alt.
repeat match goal with
- | _ => rewrite <-eval_to_associational
- | |- context [(to_associational ?w ?x)] =>
- rewrite <-(Associational.eval_split
- s (to_associational w x)) by assumption
- | _ => rewrite <-Associational.eval_split by assumption
- | _ => setoid_rewrite Associational.eval_nil
- end.
-
+ | _ => rewrite <-eval_to_associational
+ | |- context [(to_associational ?w ?x)] =>
+ rewrite <-(Associational.eval_split
+ s (to_associational w x)) by assumption
+ | _ => rewrite <-Associational.eval_split by assumption
+ | _ => setoid_rewrite Associational.eval_nil
+ end.
+ progress autorewrite with push_id cancel_pair uncps push_basesystem_eval.
+ repeat (rewrite ?eval_from_associational, ?eval_to_associational).
+ progress autorewrite with push_id cancel_pair uncps push_basesystem_eval.
+ repeat match goal with
+ | _ => rewrite <-eval_to_associational
+ | |- context [(to_associational ?w ?x)] =>
+ rewrite <-(Associational.eval_split
+ s (to_associational w x)) by assumption
+ | _ => rewrite <-Associational.eval_split by assumption
+ | _ => setoid_rewrite Associational.eval_nil
+ end.
ring_simplify.
setoid_rewrite s2_modp.
apply f_equal2; nsatz.
+ assumption. assumption. omega.
Qed.
End Karatsuba.
diff --git a/src/Specific/Karatsuba.v b/src/Specific/Karatsuba.v
index 39f76250c..ce8bb86fa 100644
--- a/src/Specific/Karatsuba.v
+++ b/src/Specific/Karatsuba.v
@@ -153,8 +153,69 @@ Section Ops51.
Definition half_sz : nat := Eval compute in (sz / 2).
Lemma half_sz_nonzero : half_sz <> 0%nat. Proof. cbv; congruence. Qed.
+Ltac basesystem_partial_evaluation_RHS :=
+ let t0 := (match goal with
+ | |- _ _ ?t => t
+ end) in
+ let t :=
+ eval
+ cbv
+ delta [Positional.to_associational_cps Positional.to_associational
+ Positional.eval Positional.zeros Positional.add_to_nth_cps
+ Positional.add_to_nth Positional.place_cps Positional.place
+ Positional.from_associational_cps Positional.from_associational
+ Positional.carry_cps Positional.carry
+ Positional.chained_carries_cps Positional.chained_carries
+ Positional.sub_cps Positional.sub Positional.split_cps
+ Positional.scmul_cps Positional.unbalanced_sub_cps
+ Positional.negate_snd_cps Positional.add_cps Positional.opp_cps
+ Associational.eval Associational.multerm Associational.mul_cps
+ Associational.mul Associational.split_cps Associational.split
+ Associational.reduce_cps Associational.reduce
+ Associational.carryterm_cps Associational.carryterm
+ Associational.carry_cps Associational.carry
+ Associational.negate_snd_cps Associational.negate_snd div modulo
+ id_tuple_with_alt id_tuple'_with_alt
+ ]
+ in t0
+ in
+ let t := eval pattern @runtime_mul in t in
+ let t := (match t with
+ | ?t _ => t
+ end) in
+ let t := eval pattern @runtime_add in t in
+ let t := (match t with
+ | ?t _ => t
+ end) in
+ let t := eval pattern @runtime_opp in t in
+ let t := (match t with
+ | ?t _ => t
+ end) in
+ let t := eval pattern @runtime_shr in t in
+ let t := (match t with
+ | ?t _ => t
+ end) in
+ let t := eval pattern @runtime_and in t in
+ let t := (match t with
+ | ?t _ => t
+ end) in
+ let t := eval pattern @Let_In in t in
+ let t := (match t with
+ | ?t _ => t
+ end) in
+ let t := eval pattern @id_with_alt in t in
+ let t := (match t with
+ | ?t _ => t
+ end) in
+ let t1 := fresh "t1" in
+ pose (t1 := t);
+ transitivity
+ (t1 (@id_with_alt) (@Let_In) (@runtime_and) (@runtime_shr) (@runtime_opp) (@runtime_add)
+ (@runtime_mul));
+ [ replace_with_vm_compute t1; clear t1 | reflexivity ].
+ Print id_tuple_with_alt.
Definition goldilocks_mul_sig :
- {mul : (Z^sz -> Z^sz -> Z^sz)%type |
+ {mul : (Z^sz -> Z^sz -> Z^(sz+half_sz))%type |
forall a b : Z^sz,
mul a b = goldilocks_mul_cps (n:=half_sz) (n2:=sz) wt (2 ^ 224) a b id}.
Proof.
@@ -166,39 +227,16 @@ Section Ops51.
reflexivity.
Defined.
- Definition goldilocks_mul_for_bounds_checker_sig :
- {mul : (Z^sz -> Z^sz -> Z^sz)%type |
- forall a b : Z^sz,
- mul a b = goldilocks_mul_cps_for_bounds_checker (n:=half_sz) (n2:=sz) wt (2 ^ 224) a b id}.
- Proof.
- eexists; cbv beta zeta; intros.
- cbv [goldilocks_mul_cps_for_bounds_checker].
- repeat autounfold.
- basesystem_partial_evaluation_RHS.
- do_replace_match_with_destructuring_match_in_goal.
- reflexivity.
- Defined.
-
- Lemma goldilocks_mul_sig_equiv a b :
- proj1_sig goldilocks_mul_sig a b =
- proj1_sig goldilocks_mul_for_bounds_checker_sig a b.
- Proof.
- rewrite (proj2_sig goldilocks_mul_sig).
- rewrite (proj2_sig goldilocks_mul_for_bounds_checker_sig).
- apply goldilocks_mul_equiv;
- auto using half_sz_nonzero, sz_nonzero, wt_nonzero.
- Qed.
-
Definition mul_sig :
- {mul : (Z^sz -> Z^sz -> Z^sz)%type |
+ {mul : (Z^sz -> Z^sz -> Z^(sz+half_sz))%type |
forall a b : Z^sz,
let eval := Positional.Fdecode (m := m) wt in
- eval (mul a b) = (eval a * eval b)%F}.
+ Positional.Fdecode (m := m) wt (mul a b) = (eval a * eval b)%F}.
Proof.
eexists; cbv beta zeta; intros.
pose proof wt_nonzero.
let x := constr:(
- goldilocks_mul (n:=half_sz) (n2:=sz) wt (2^224) a b ) in
+ goldilocks_mul_cps (n:=half_sz) (n2:=sz) wt (2^224) a b id) in
F_mod_eq;
transitivity (Positional.eval wt x); repeat autounfold;
@@ -207,29 +245,16 @@ Section Ops51.
apply goldilocks_mul_correct; try assumption; cbv; congruence ].
cbv [mod_eq]; apply f_equal2;
[ | reflexivity ]; apply f_equal.
- cbv [goldilocks_mul].
- transitivity
- (Tuple.eta_tuple
- (fun a
- => Tuple.eta_tuple
- (fun b
- => id_tuple_with_alt
- ((proj1_sig goldilocks_mul_sig) a b)
- ((proj1_sig goldilocks_mul_for_bounds_checker_sig) a b))
- b)
- a).
- { cbv [proj1_sig goldilocks_mul_for_bounds_checker_sig goldilocks_mul_sig Tuple.eta_tuple Tuple.eta_tuple_dep sz Tuple.eta_tuple'_dep id_tuple_with_alt id_tuple'_with_alt];
- cbn [fst snd].
- reflexivity. }
- { rewrite !Tuple.strip_eta_tuple, !unfold_id_tuple_with_alt.
- rewrite (proj2_sig goldilocks_mul_sig). reflexivity. }
+ etransitivity;[|apply (proj2_sig (goldilocks_mul_sig))].
+ cbv [proj1_sig goldilocks_mul_sig].
+ reflexivity.
Defined.
Definition square_sig :
- {square : (Z^sz -> Z^sz)%type |
+ {square : (Z^sz -> Z^(sz+half_sz))%type |
forall a : Z^sz,
let eval := Positional.Fdecode (m := m) wt in
- eval (square a) = (eval a * eval a)%F}.
+ Positional.Fdecode (m := m) wt (square a) = (eval a * eval a)%F}.
Proof.
eexists; cbv beta zeta; intros.
rewrite <-(proj2_sig mul_sig).
@@ -306,6 +331,7 @@ Section Ops51.
reflexivity.
Defined.
+ (* TODO: implement reduce, reduce after mul and square
Definition ring_56 :=
(Ring.ring_by_isomorphism
(F := F m)
@@ -329,7 +355,7 @@ Section Ops51.
(proj2_sig add_sig)
(proj2_sig sub_sig)
(proj2_sig mul_sig)
- ).
+ ). *)
(*
Eval cbv [proj1_sig add_sig] in (proj1_sig add_sig).