diff options
author | Andres Erbsen <andreser@mit.edu> | 2017-06-14 14:51:32 -0400 |
---|---|---|
committer | Andres Erbsen <andreser@mit.edu> | 2017-06-14 14:51:32 -0400 |
commit | 0eb8eeff3ddab8d27ae87dfdcbbc3d15065d275b (patch) | |
tree | 966fa1411928fe502459bee200b8dac2ae6aead7 | |
parent | af91e66e42f98c9fa09d27a42d4d27e9015de829 (diff) |
fix goldilocks karatsuba; TODO implement reduce
-rw-r--r-- | src/Arithmetic/Karatsuba.v | 264 | ||||
-rw-r--r-- | src/Specific/Karatsuba.v | 118 |
2 files changed, 120 insertions, 262 deletions
diff --git a/src/Arithmetic/Karatsuba.v b/src/Arithmetic/Karatsuba.v index d53351934..3c3009fde 100644 --- a/src/Arithmetic/Karatsuba.v +++ b/src/Arithmetic/Karatsuba.v @@ -110,227 +110,46 @@ Context (weight : nat -> Z) *) + (* Definition goldilocks_mul_cps_for_bounds_checker s (xs ys : T2) {R} (f:T2->R) := split_cps (m1:=n) (m2:=n) weight s xs (fun x0_x1 => split_cps weight s ys - (fun y0_y1 => mul_cps weight (fst x0_x1) (snd y0_y1) - (fun x0_y1 => mul_cps weight (snd x0_x1) (fst y0_y1) - (fun x1_y0 => mul_cps weight (fst x0_x1) (fst y0_y1) - (fun z0 => mul_cps weight (snd x0_x1) (snd y0_y1) - (fun z2 => add_cps weight z0 z2 - (fun sum_z => add_cps weight x0_y1 x1_y0 - (fun z1' => add_cps weight z1' z2 - (fun z1 => scmul_cps weight s z1 - (fun sz1 => add_cps weight sum_z sz1 f)))))))))). - Definition goldilocks_mul_cps s (xs ys : T2) {R} (f:T2->R) := + (fun z1 => Positional.to_associational_cps weight z1 + (fun z1 => Associational.mul_cps (pair s 1::nil) z1 + (fun sz1 => Positional.from_associational_cps weight n2 sz1 + (fun sz1 => add_cps weight sum_z sz1 f)))))))))))). + *) + + Let T3 := tuple Z (n2+n). + Definition goldilocks_mul_cps s (xs ys : T2) {R} (f:T3->R) := split_cps (m1:=n) (m2:=n) weight s xs (fun x0_x1 => split_cps weight s ys (fun y0_y1 => mul_cps weight (fst x0_x1) (fst y0_y1) (fun z0 => mul_cps weight (snd x0_x1) (snd y0_y1) (fun z2 => add_cps weight z0 z2 - (fun sum_z => add_cps weight (fst x0_x1) (snd x0_x1) + (fun sum_z : tuple _ n2 => add_cps weight (fst x0_x1) (snd x0_x1) (fun sum_x => add_cps weight (fst y0_y1) (snd y0_y1) (fun sum_y => mul_cps weight sum_x sum_y - (fun mul_sumxy => unbalanced_sub_cps weight mul_sumxy z0 - (fun z1 => scmul_cps weight s z1 - (fun sz1 => add_cps weight sum_z sz1 f)))))))))). - - - Lemma to_list_left_append {A N} t0 (t : tuple A N) : - to_list (S N) (left_append t0 t) = (to_list N t ++ t0 :: nil)%list. - Proof. - induction N; - repeat match goal with - | _ => destruct x - | _ => rewrite (subst_append (left_append t0 t)); - rewrite (subst_append t); rewrite !to_list_append; - rewrite <-!subst_append - | _ => progress (rewrite ?hd_left_append, ?tl_left_append) - | _ => rewrite IHN - | _ => reflexivity - end. - Qed. + (fun mul_sumxy => - Lemma seq_S_snoc len : forall start, - List.seq start (S len) = (List.seq start len ++ (len + start)%nat :: nil)%list. - Proof. - induction len; intros; [reflexivity|]. - transitivity (start :: List.seq (S start) (S len))%list; - [reflexivity|]. rewrite (IHlen (S start)). - simpl List.seq; rewrite plus_Snm_nSm. - apply List.app_comm_cons. - Qed. - - Require Import Crypto.Util.ListUtil. - Require Import Coq.Lists.List. - Lemma repeat_left_append {A N} (a : A) : - Tuple.repeat a (S N) = left_append a (Tuple.repeat a N). - Admitted. - - Lemma from_to_associational_id wt N x : - from_associational wt N (to_associational wt x) = x. - Proof. - cbv [from_associational to_associational from_associational_cps to_associational_cps]. - autorewrite with push_id uncps. - induction N. - { destruct x. reflexivity. } - { - rewrite (subst_left_append x). - rewrite to_list_left_append. - rewrite seq_S_snoc, plus_0_r. - rewrite map_app, map_cons, map_nil. - rewrite combine_app_samelength by distr_length. - rewrite combine_cons, combine_nil_r. - rewrite fold_right_app. - Admitted. - - Local Infix "**" := Associational.mul (at level 40). - - Local Definition multerm terms := - Associational.multerm (fst terms) (snd terms). - - Lemma mul_power_equiv (p q : list limb) : - Permutation.permutation - (p ** q) - (List.map multerm (list_prod p q)). - Admitted. - - Lemma permutation_from_associational (p q : list limb) : - Permutation.permutation p q -> forall wt N, - from_associational wt N p = from_associational wt N q. - Admitted. - - Lemma prod_append_binary_expansion {A : Type} {B : Set} (f:(A*A)->B) - (ws xs ys zs : list A) : - @Permutation.permutation B - (map f (list_prod (ws ++ xs) (ys ++ zs))) - (map f ((list_prod ws ys) ++ (list_prod ws zs) ++ (list_prod xs ys) ++ (list_prod xs zs))). - Admitted. - - Lemma to_from_associational_append wt N p q : - to_associational wt (from_associational wt N (p ++ q)) - = to_associational wt (from_associational wt N p) ++ to_associational wt (from_associational wt N q). - Admitted. - - Lemma binary_expansion wt N a b c d : - let to_from x := to_associational wt (from_associational wt N x) in - (to_from ((a ++ b) ** (c ++ d)) = to_from (to_from (a ** c) ++ (to_from (to_from (a ** d) ++ (to_from (b ** c))) ++ to_from (b ** d))))%list. - Proof. - intro. - pose proof (prod_append_binary_expansion multerm a b c d). - pose proof (mul_power_equiv (a ++ b) (c ++ d)). - let P := fresh "P" in - remember (fun w z x y H => Permutation.permutation_app_comp _ w z (x ** y) (map multerm (list_prod x y)) H (mul_power_equiv _ _)) as P; - pose proof (P _ _ b d (P _ _ b c (P _ _ a d (mul_power_equiv a c)))); - subst P. - rewrite !map_app, !app_assoc_reverse in *. - let H := fresh "H" in - match goal with - HA : Permutation.permutation ?x ?y, - HB : Permutation.permutation ?z ?x, - HC : Permutation.permutation ?w ?y |- _ => - assert (Permutation.permutation z w) as H by - eauto using Permutation.permutation_sym, Permutation.permutation_trans; - clear HA HB HC - end; apply permutation_from_associational with (wt := wt) (N := N) in H. - subst to_from. cbv beta. - f_equal. etransitivity; [eassumption|]. - rewrite !to_from_associational_append. - rewrite !from_to_associational_id. - rewrite <-!to_from_associational_append. - rewrite !from_to_associational_id. - rewrite !app_assoc_reverse. - reflexivity. - Qed. - - Local Notation from := (from_associational weight). - Local Notation to := (to_associational weight). - - Lemma subtraction_id N p q : - from N ((p ++ Associational.negate_snd p) ++ q) = from N q. - Admitted. - - Lemma goldilocks_mul_equiv' x0 x1 y0 y1 : - let X0 := to (from n x0) in - let X1 := to (from n x1) in - let Y0 := to (from n y0) in - let Y1 := to (from n y1) in - from n2 - (to (from n2 (to (from n2 (X0 ** Y1)) ++ to (from n2 (X1 ** Y0)))) ++ to (from n2 (X1 ** Y1))) = - from n2 - (to (from n2 (to (from n (X0 ++ X1)) ** to (from n (Y0 ++ Y1)))) ++ Associational.negate_snd (to (from n2 (X0 ** Y0)))). - Proof. - intros. - repeat match goal with - | _ => progress - (rewrite !to_from_associational_append, - !from_to_associational_id) - | _ => progress - (rewrite <-!to_from_associational_append, - !from_to_associational_id) - | _ => rewrite app_assoc_reverse - | _ => rewrite binary_expansion - | _ => subst X0 X1 Y0 Y1 - end. - match goal with - | |- _ = from ?n (?a ++ ?b ++ ?c ++ ?d ++ Associational.negate_snd ?a) => - transitivity (from n ((a ++ Associational.negate_snd a) ++ b ++ c ++ d)); - [|remember a as A; remember b as B; remember c as C; remember d as D; remember (Associational.negate_snd A) as negA] + dlet z1 := id_with_alt_bounds (unbalanced_sub_cps weight mul_sumxy z0 id) ( - end. - Focus 2. - { rewrite app_assoc_reverse. - apply permutation_from_associational. - replace (A ++ B ++ C ++ D ++ negA) with (A ++ (B ++ C ++ D) ++ negA). - auto using app_assoc, app_assoc_reverse. - rewrite !app_assoc_reverse; reflexivity. } Unfocus. - rewrite subtraction_id. - repeat match goal with - | _ => progress - (rewrite <-!to_from_associational_append, - !from_to_associational_id) - | _ => rewrite app_assoc_reverse - end. - reflexivity. - Qed. + (mul_cps weight (fst x0_x1) (snd y0_y1) + (fun x0_y1 => mul_cps weight (snd x0_x1) (fst y0_y1) + (fun x1_y0 => mul_cps weight (fst x0_x1) (fst y0_y1) + (fun z0 => mul_cps weight (snd x0_x1) (snd y0_y1) + (fun z2 => add_cps weight z0 z2 + (fun sum_z => add_cps weight x0_y1 x1_y0 + (fun z1' => add_cps weight z1' z2 id)))))))) in - Lemma goldilocks_mul_equiv s xs ys {R} f: - @goldilocks_mul_cps s xs ys R f = - @goldilocks_mul_cps_for_bounds_checker s xs ys R f. - Proof. - cbv [goldilocks_mul_cps_for_bounds_checker goldilocks_mul_cps]. - repeat autounfold. - autorewrite with cancel_pair push_id uncps. - apply f_equal. - repeat match goal with - |- context [Associational.mul ?x ?y] => - let m := fresh "m" in - remember (Associational.mul x y) as m end. - apply f_equal. - apply f_equal. - apply f_equal. - apply f_equal. - subst m m0 m1 m2. - apply f_equal2; try reflexivity. - apply f_equal. - symmetry. - apply goldilocks_mul_equiv'. - Qed. + Positional.to_associational_cps weight z1 + (fun z1 => Associational.mul_cps (pair s 1::nil) z1 + (fun sz1 => Positional.to_associational_cps weight sum_z + (fun sum_z => Positional.from_associational_cps weight _ (sum_z++sz1) f + ))))))))))). - Definition goldilocks_mul s xs ys := - id_with_alt_bounds - (@goldilocks_mul_cps s xs ys _ id) - (@goldilocks_mul_cps_for_bounds_checker s xs ys _ id). - Lemma goldilocks_mul_id s xs ys {R} f : - @goldilocks_mul_cps s xs ys R f = f (goldilocks_mul s xs ys). - Proof. - cbv [goldilocks_mul goldilocks_mul_cps]; rewrite !unfold_id_tuple_with_alt. - repeat autounfold. - autorewrite with cancel_pair push_id uncps. - reflexivity. - Qed. Local Existing Instances Z.equiv_modulo_Reflexive RelationClasses.eq_Reflexive Z.equiv_modulo_Symmetric @@ -338,22 +157,35 @@ Context (weight : nat -> Z) Z.modulo_equiv_modulo_Proper. Lemma goldilocks_mul_correct (p : Z) (p_nonzero : p <> 0) s (s_nonzero : s <> 0) (s2_modp : (s^2) mod p = (s+1) mod p) xs ys : - (eval weight (goldilocks_mul s xs ys)) mod p = (eval weight xs * eval weight ys) mod p. + (eval weight (goldilocks_mul_cps s xs ys id)) mod p = (eval weight xs * eval weight ys) mod p. Proof. - cbv [goldilocks_mul goldilocks_mul_cps]; rewrite !unfold_id_tuple_with_alt. + cbv [goldilocks_mul_cps Let_In]. Zmod_to_equiv_modulo. - repeat autounfold; autorewrite with push_id cancel_pair uncps push_basesystem_eval. + progress autounfold. + progress autorewrite with push_id cancel_pair uncps push_basesystem_eval. + rewrite !unfold_id_tuple_with_alt. repeat match goal with - | _ => rewrite <-eval_to_associational - | |- context [(to_associational ?w ?x)] => - rewrite <-(Associational.eval_split - s (to_associational w x)) by assumption - | _ => rewrite <-Associational.eval_split by assumption - | _ => setoid_rewrite Associational.eval_nil - end. - + | _ => rewrite <-eval_to_associational + | |- context [(to_associational ?w ?x)] => + rewrite <-(Associational.eval_split + s (to_associational w x)) by assumption + | _ => rewrite <-Associational.eval_split by assumption + | _ => setoid_rewrite Associational.eval_nil + end. + progress autorewrite with push_id cancel_pair uncps push_basesystem_eval. + repeat (rewrite ?eval_from_associational, ?eval_to_associational). + progress autorewrite with push_id cancel_pair uncps push_basesystem_eval. + repeat match goal with + | _ => rewrite <-eval_to_associational + | |- context [(to_associational ?w ?x)] => + rewrite <-(Associational.eval_split + s (to_associational w x)) by assumption + | _ => rewrite <-Associational.eval_split by assumption + | _ => setoid_rewrite Associational.eval_nil + end. ring_simplify. setoid_rewrite s2_modp. apply f_equal2; nsatz. + assumption. assumption. omega. Qed. End Karatsuba. diff --git a/src/Specific/Karatsuba.v b/src/Specific/Karatsuba.v index 39f76250c..ce8bb86fa 100644 --- a/src/Specific/Karatsuba.v +++ b/src/Specific/Karatsuba.v @@ -153,8 +153,69 @@ Section Ops51. Definition half_sz : nat := Eval compute in (sz / 2). Lemma half_sz_nonzero : half_sz <> 0%nat. Proof. cbv; congruence. Qed. +Ltac basesystem_partial_evaluation_RHS := + let t0 := (match goal with + | |- _ _ ?t => t + end) in + let t := + eval + cbv + delta [Positional.to_associational_cps Positional.to_associational + Positional.eval Positional.zeros Positional.add_to_nth_cps + Positional.add_to_nth Positional.place_cps Positional.place + Positional.from_associational_cps Positional.from_associational + Positional.carry_cps Positional.carry + Positional.chained_carries_cps Positional.chained_carries + Positional.sub_cps Positional.sub Positional.split_cps + Positional.scmul_cps Positional.unbalanced_sub_cps + Positional.negate_snd_cps Positional.add_cps Positional.opp_cps + Associational.eval Associational.multerm Associational.mul_cps + Associational.mul Associational.split_cps Associational.split + Associational.reduce_cps Associational.reduce + Associational.carryterm_cps Associational.carryterm + Associational.carry_cps Associational.carry + Associational.negate_snd_cps Associational.negate_snd div modulo + id_tuple_with_alt id_tuple'_with_alt + ] + in t0 + in + let t := eval pattern @runtime_mul in t in + let t := (match t with + | ?t _ => t + end) in + let t := eval pattern @runtime_add in t in + let t := (match t with + | ?t _ => t + end) in + let t := eval pattern @runtime_opp in t in + let t := (match t with + | ?t _ => t + end) in + let t := eval pattern @runtime_shr in t in + let t := (match t with + | ?t _ => t + end) in + let t := eval pattern @runtime_and in t in + let t := (match t with + | ?t _ => t + end) in + let t := eval pattern @Let_In in t in + let t := (match t with + | ?t _ => t + end) in + let t := eval pattern @id_with_alt in t in + let t := (match t with + | ?t _ => t + end) in + let t1 := fresh "t1" in + pose (t1 := t); + transitivity + (t1 (@id_with_alt) (@Let_In) (@runtime_and) (@runtime_shr) (@runtime_opp) (@runtime_add) + (@runtime_mul)); + [ replace_with_vm_compute t1; clear t1 | reflexivity ]. + Print id_tuple_with_alt. Definition goldilocks_mul_sig : - {mul : (Z^sz -> Z^sz -> Z^sz)%type | + {mul : (Z^sz -> Z^sz -> Z^(sz+half_sz))%type | forall a b : Z^sz, mul a b = goldilocks_mul_cps (n:=half_sz) (n2:=sz) wt (2 ^ 224) a b id}. Proof. @@ -166,39 +227,16 @@ Section Ops51. reflexivity. Defined. - Definition goldilocks_mul_for_bounds_checker_sig : - {mul : (Z^sz -> Z^sz -> Z^sz)%type | - forall a b : Z^sz, - mul a b = goldilocks_mul_cps_for_bounds_checker (n:=half_sz) (n2:=sz) wt (2 ^ 224) a b id}. - Proof. - eexists; cbv beta zeta; intros. - cbv [goldilocks_mul_cps_for_bounds_checker]. - repeat autounfold. - basesystem_partial_evaluation_RHS. - do_replace_match_with_destructuring_match_in_goal. - reflexivity. - Defined. - - Lemma goldilocks_mul_sig_equiv a b : - proj1_sig goldilocks_mul_sig a b = - proj1_sig goldilocks_mul_for_bounds_checker_sig a b. - Proof. - rewrite (proj2_sig goldilocks_mul_sig). - rewrite (proj2_sig goldilocks_mul_for_bounds_checker_sig). - apply goldilocks_mul_equiv; - auto using half_sz_nonzero, sz_nonzero, wt_nonzero. - Qed. - Definition mul_sig : - {mul : (Z^sz -> Z^sz -> Z^sz)%type | + {mul : (Z^sz -> Z^sz -> Z^(sz+half_sz))%type | forall a b : Z^sz, let eval := Positional.Fdecode (m := m) wt in - eval (mul a b) = (eval a * eval b)%F}. + Positional.Fdecode (m := m) wt (mul a b) = (eval a * eval b)%F}. Proof. eexists; cbv beta zeta; intros. pose proof wt_nonzero. let x := constr:( - goldilocks_mul (n:=half_sz) (n2:=sz) wt (2^224) a b ) in + goldilocks_mul_cps (n:=half_sz) (n2:=sz) wt (2^224) a b id) in F_mod_eq; transitivity (Positional.eval wt x); repeat autounfold; @@ -207,29 +245,16 @@ Section Ops51. apply goldilocks_mul_correct; try assumption; cbv; congruence ]. cbv [mod_eq]; apply f_equal2; [ | reflexivity ]; apply f_equal. - cbv [goldilocks_mul]. - transitivity - (Tuple.eta_tuple - (fun a - => Tuple.eta_tuple - (fun b - => id_tuple_with_alt - ((proj1_sig goldilocks_mul_sig) a b) - ((proj1_sig goldilocks_mul_for_bounds_checker_sig) a b)) - b) - a). - { cbv [proj1_sig goldilocks_mul_for_bounds_checker_sig goldilocks_mul_sig Tuple.eta_tuple Tuple.eta_tuple_dep sz Tuple.eta_tuple'_dep id_tuple_with_alt id_tuple'_with_alt]; - cbn [fst snd]. - reflexivity. } - { rewrite !Tuple.strip_eta_tuple, !unfold_id_tuple_with_alt. - rewrite (proj2_sig goldilocks_mul_sig). reflexivity. } + etransitivity;[|apply (proj2_sig (goldilocks_mul_sig))]. + cbv [proj1_sig goldilocks_mul_sig]. + reflexivity. Defined. Definition square_sig : - {square : (Z^sz -> Z^sz)%type | + {square : (Z^sz -> Z^(sz+half_sz))%type | forall a : Z^sz, let eval := Positional.Fdecode (m := m) wt in - eval (square a) = (eval a * eval a)%F}. + Positional.Fdecode (m := m) wt (square a) = (eval a * eval a)%F}. Proof. eexists; cbv beta zeta; intros. rewrite <-(proj2_sig mul_sig). @@ -306,6 +331,7 @@ Section Ops51. reflexivity. Defined. + (* TODO: implement reduce, reduce after mul and square Definition ring_56 := (Ring.ring_by_isomorphism (F := F m) @@ -329,7 +355,7 @@ Section Ops51. (proj2_sig add_sig) (proj2_sig sub_sig) (proj2_sig mul_sig) - ). + ). *) (* Eval cbv [proj1_sig add_sig] in (proj1_sig add_sig). |