From b9770718be4f65de7b0cdfcd1c08000e5eac8ca4 Mon Sep 17 00:00:00 2001 From: jadep Date: Fri, 2 Jun 2017 16:47:13 -0400 Subject: Make Karatsuba depend on Arithmetic/Core to make calling it less of a pain --- src/Arithmetic/Karatsuba.v | 139 +++++++++++++++++++++------------------------ src/Specific/Karatsuba.v | 118 ++------------------------------------ 2 files changed, 71 insertions(+), 186 deletions(-) (limited to 'src') diff --git a/src/Arithmetic/Karatsuba.v b/src/Arithmetic/Karatsuba.v index 7b2004a2d..f17623da7 100644 --- a/src/Arithmetic/Karatsuba.v +++ b/src/Arithmetic/Karatsuba.v @@ -1,48 +1,19 @@ Require Import Coq.ZArith.ZArith. Require Import Crypto.Algebra.Nsatz. -Require Import Crypto.Util.ZUtil Crypto.Util.LetIn Crypto.Util.CPSUtil. +Require Import Crypto.Util.ZUtil Crypto.Util.LetIn Crypto.Util.CPSUtil Crypto.Util.Tactics. +Require Import Crypto.Arithmetic.Core. Import B. Import Positional. +Require Import Crypto.Util.Tuple. Local Open Scope Z_scope. Section Karatsuba. - (* T is the "half-length" type, T2 is the "full-length" type *) - Context {T T2 : Type} (eval : T -> Z) (eval2 : T2 -> Z). - - (* multiplication takes half-length inputs to full-length output *) - Context {mul_cps : T -> T -> forall {R}, (T2->R)->R} - {mul : T -> T -> T2} - {mul_id : forall x y {R} f, @mul_cps x y R f = f (mul x y)} - {eval_mul : forall x y, eval2 (mul x y) = eval x * eval y}. - - (* splitting takes full-length input to half-length outputs *) - Context {split_cps : Z -> T2 -> forall {R}, ((T * T)->R)->R} - {split : Z -> T2 -> T * T} - {split_id : forall s x R f, @split_cps s x R f = f (split s x)} - {eval_split : forall s x, s <> 0 -> eval (fst (split s x)) + s * (eval (snd (split s x))) = eval2 x}. - - (* half-length add *) - Context {add_cps : T -> T -> forall {R}, (T->R)->R} - {add : T -> T -> T} - {add_id : forall x y {R} f, @add_cps x y R f = f (add x y)} - {eval_add : forall x y, eval (add x y) = eval x + eval y}. - - (* full-length operations: sub, add, scmul *) - Context {sub2_cps : T2 -> T2 -> forall {R}, (T2->R)->R} - {sub2 : T2 -> T2 -> T2} - {sub2_id : forall x y {R} f, @sub2_cps x y R f = f (sub2 x y)} - {eval_sub2 : forall x y, eval2 (sub2 x y) = eval2 x - eval2 y} - {add2_cps : T2 -> T2 -> forall {R}, (T2->R)->R} - {add2 : T2 -> T2 -> T2} - {add2_id : forall x y {R} f, @add2_cps x y R f = f (add2 x y)} - {eval_add2 : forall x y, eval2 (add2 x y) = eval2 x + eval2 y} - {scmul2_cps : Z -> T2 -> forall {R}, (T2->R)->R} - {scmul2 : Z -> T2 -> T2} - {scmul2_id : forall z x {R} f, @scmul2_cps z x R f = f (scmul2 z x)} - {eval_scmul2 : forall c x, eval2 (scmul2 c x) = c * eval2 x}. - - Local Ltac rewrite_id := - repeat progress rewrite ?mul_id, ?split_id, ?add_id, ?sub2_id, ?add2_id, ?scmul2_id. - Local Ltac rewrite_eval := - repeat progress rewrite ?eval_mul, ?eval_split, ?eval_add, ?eval_sub2, ?eval_add2, ?eval_scmul2. +Context (weight : nat -> Z) + (weight_0 : weight 0%nat = 1%Z) + (weight_nonzero : forall i, weight i <> 0). + (* [tuple Z n] is the "half-length" type, + [tuple Z n2] is the "full-length" type *) + Context {n n2 : nat} (n_nonzero : n <> 0%nat) (n2_nonzero : n2 <> 0%nat). + Let T := tuple Z n. + Let T2 := tuple Z n2. (* If x = x0 + sx1 and y = y0 + sy1, then xy = s^2 * z2 + s * z1 + s * z0, @@ -60,34 +31,45 @@ Section Karatsuba. z1 = mul_sumxy - sum_z *) Definition karatsuba_mul_cps s (x y : T2) {R} (f:T2->R) := - split_cps s x _ - (fun x0_x1 => split_cps s y _ - (fun y0_y1 => mul_cps (fst x0_x1) (fst y0_y1) _ - (fun z0 => mul_cps (snd x0_x1) (snd y0_y1) _ - (fun z2 => add2_cps z0 z2 _ - (fun sum_z => add_cps (fst x0_x1) (snd x0_x1) _ - (fun sum_x => add_cps (fst y0_y1) (snd y0_y1) _ - (fun sum_y => mul_cps sum_x sum_y _ - (fun mul_sumxy => sub2_cps mul_sumxy sum_z _ - (fun z1 => scmul2_cps s z1 _ - (fun sz1 => scmul2_cps (s^2) z2 _ - (fun s2z2 => add2_cps s2z2 sz1 _ - (fun add_s2z2_sz1 => add2_cps add_s2z2_sz1 z0 _ f)))))))))))). + split_cps (n:=n2) (m1:=n) (m2:=n) weight s x + (fun x0_x1 => split_cps weight s y + (fun y0_y1 => mul_cps weight (fst x0_x1) (fst y0_y1) + (fun z0 => mul_cps weight(snd x0_x1) (snd y0_y1) + (fun z2 => add_cps weight z0 z2 + (fun sum_z => add_cps weight (fst x0_x1) (snd x0_x1) + (fun sum_x => add_cps weight (fst y0_y1) (snd y0_y1) + (fun sum_y => mul_cps weight sum_x sum_y + (fun mul_sumxy => unbalanced_sub_cps weight mul_sumxy sum_z + (fun z1 => scmul_cps weight s z1 + (fun sz1 => scmul_cps weight (s^2) z2 + (fun s2z2 => add_cps weight s2z2 sz1 + (fun add_s2z2_sz1 => add_cps weight add_s2z2_sz1 z0 f)))))))))))). Definition karatsuba_mul s x y := @karatsuba_mul_cps s x y _ id. Lemma karatsuba_mul_id s x y R f : @karatsuba_mul_cps s x y R f = f (karatsuba_mul s x y). Proof. - cbv [karatsuba_mul karatsuba_mul_cps]. rewrite_id. + cbv [karatsuba_mul karatsuba_mul_cps]. + repeat autounfold. + autorewrite with cancel_pair push_id uncps. reflexivity. Qed. Lemma eval_karatsuba_mul s x y (s_nonzero:s <> 0) : - eval2 (karatsuba_mul s x y) = eval2 x * eval2 y. + eval weight (karatsuba_mul s x y) = eval weight x * eval weight y. Proof. - cbv [karatsuba_mul karatsuba_mul_cps]. rewrite_id. - repeat rewrite push_id. rewrite_eval. - rewrite <-(eval_split s x), <-(eval_split s y) by assumption; ring. + cbv [karatsuba_mul karatsuba_mul_cps]; repeat autounfold. + autorewrite with cancel_pair push_id uncps push_basesystem_eval. + repeat match goal with + | _ => rewrite <-eval_to_associational + | |- context [(to_associational ?w ?x)] => + rewrite <-(Associational.eval_split + s (to_associational w x)) by assumption + | _ => rewrite <-Associational.eval_split by assumption + | _ => setoid_rewrite Associational.eval_nil + end. + ring_simplify. + nsatz. Qed. (* @@ -109,23 +91,25 @@ Section Karatsuba. *) Definition goldilocks_mul_cps s (xs ys : T2) {R} (f:T2->R) := - split_cps s xs _ - (fun x0_x1 => split_cps s ys _ - (fun y0_y1 => mul_cps (fst x0_x1) (fst y0_y1) _ - (fun z0 => mul_cps (snd x0_x1) (snd y0_y1) _ - (fun z2 => add2_cps z0 z2 _ - (fun sum_z => add_cps (fst x0_x1) (snd x0_x1) _ - (fun sum_x => add_cps (fst y0_y1) (snd y0_y1) _ - (fun sum_y => mul_cps sum_x sum_y _ - (fun mul_sumxy => sub2_cps mul_sumxy z0 _ - (fun z1 => scmul2_cps s z1 _ - (fun sz1 => add2_cps sum_z sz1 _ f)))))))))). + split_cps (m1:=n) (m2:=n) weight s xs + (fun x0_x1 => split_cps weight s ys + (fun y0_y1 => mul_cps weight (fst x0_x1) (fst y0_y1) + (fun z0 => mul_cps weight (snd x0_x1) (snd y0_y1) + (fun z2 => add_cps weight z0 z2 + (fun sum_z => add_cps weight (fst x0_x1) (snd x0_x1) + (fun sum_x => add_cps weight (fst y0_y1) (snd y0_y1) + (fun sum_y => mul_cps weight sum_x sum_y + (fun mul_sumxy => unbalanced_sub_cps weight mul_sumxy z0 + (fun z1 => scmul_cps weight s z1 + (fun sz1 => add_cps weight sum_z sz1 f)))))))))). Definition goldilocks_mul s xs ys := @goldilocks_mul_cps s xs ys _ id. Lemma goldilocks_mul_id s xs ys {R} f : @goldilocks_mul_cps s xs ys R f = f (goldilocks_mul s xs ys). Proof. - cbv [goldilocks_mul goldilocks_mul_cps]. rewrite_id. + cbv [goldilocks_mul goldilocks_mul_cps]. + repeat autounfold. + autorewrite with cancel_pair push_id uncps. reflexivity. Qed. @@ -135,11 +119,20 @@ Section Karatsuba. Z.modulo_equiv_modulo_Proper. Lemma goldilocks_mul_correct (p : Z) (p_nonzero : p <> 0) s (s_nonzero : s <> 0) (s2_modp : (s^2) mod p = (s+1) mod p) xs ys : - (eval2 (goldilocks_mul s xs ys)) mod p = (eval2 xs * eval2 ys) mod p. + (eval weight (goldilocks_mul s xs ys)) mod p = (eval weight xs * eval weight ys) mod p. Proof. cbv [goldilocks_mul_cps goldilocks_mul]; Zmod_to_equiv_modulo. - rewrite_id. rewrite push_id. rewrite_eval. - repeat progress rewrite <-?(eval_split s xs), <-?(eval_split s ys) by assumption; ring_simplify. + repeat autounfold; autorewrite with push_id cancel_pair uncps push_basesystem_eval. + repeat match goal with + | _ => rewrite <-eval_to_associational + | |- context [(to_associational ?w ?x)] => + rewrite <-(Associational.eval_split + s (to_associational w x)) by assumption + | _ => rewrite <-Associational.eval_split by assumption + | _ => setoid_rewrite Associational.eval_nil + end. + + ring_simplify. setoid_rewrite s2_modp. apply f_equal2; nsatz. Qed. diff --git a/src/Specific/Karatsuba.v b/src/Specific/Karatsuba.v index 0f205f253..70834e9d7 100644 --- a/src/Specific/Karatsuba.v +++ b/src/Specific/Karatsuba.v @@ -149,119 +149,8 @@ Section Ops51. solve_op_F wt x. reflexivity. Defined. - Check goldilocks_mul_cps. Definition half_sz : nat := Eval compute in (sz / 2). - (* TODO: move *) - Definition Positional_split_cps {n m1 m2} (s:Z) (p : tuple Z n) - {T} (f:(tuple Z m1 * tuple Z m2) -> T) := - Positional.to_associational_cps wt p - (fun P => Associational.split_cps s P - (fun split_P => - f (Positional.from_associational wt m1 (fst split_P), - (Positional.from_associational wt m2 (snd split_P))))). - Definition Positional_scmul_cps {n} (x : Z) (p: tuple Z n) - {T} (f:tuple Z n->T) := - Positional.to_associational_cps wt p - (fun P => Associational.mul_cps P [(1, x)] - (fun R => Positional.from_associational_cps wt n R f)). - Definition Positional_sub_cps {n} (p q: tuple Z n) - {T} (f:tuple Z n->T) := - Positional.to_associational_cps wt p - (fun P => Positional.to_associational_cps wt q - (fun Q => Associational.negate_snd_cps Q - (fun negQ => Positional.from_associational_cps wt n (P ++ negQ) f))). - Definition goldilocks448_cps := - (goldilocks_mul_cps - (T := tuple Z half_sz) (T2 := tuple Z sz) - (mul_cps := Positional.mul_cps (n:=half_sz) wt) - (add_cps := Positional.add_cps (n:=half_sz) wt) - (add2_cps := Positional.add_cps (n:=sz) wt) - (split_cps := Positional_split_cps (n:=sz) (m1:=half_sz) (m2 := half_sz)) - (scmul2_cps := Positional_scmul_cps (n:=sz)) - (sub2_cps := Positional_sub_cps (n:=sz)) - ). - Hint Unfold goldilocks448_cps. - Check goldilocks_mul_id. - Definition goldilocks448_id - mul_id add_id add2_id split_id scmul2_id sub2_id - := - (goldilocks_mul_id - (T := tuple Z half_sz) (T2 := tuple Z sz) - (mul_cps := Positional.mul_cps (n:=half_sz) wt) - (add_cps := Positional.add_cps (n:=half_sz) wt) - (add2_cps := Positional.add_cps (n:=sz) wt) - (split_cps := Positional_split_cps (n:=sz) (m1:=half_sz) (m2 := half_sz)) - (scmul2_cps := Positional_scmul_cps (n:=sz)) - (sub2_cps := Positional_sub_cps (n:=sz)) - (mul := fun a b => Positional.mul_cps (n:= half_sz) wt a b id) - (add := fun a b => Positional.add_cps (n:=half_sz) wt a b id) - (add2 := fun a b => Positional.add_cps (n:=sz) wt a b id) - (split := fun s a => Positional_split_cps (n:=sz) (m1:=half_sz) (m2 := half_sz) s a id) - (scmul2 := fun x a => Positional_scmul_cps (n:=sz) x a id) - (sub2 := fun a b => Positional_sub_cps (n:=sz) a b id) - (mul_id := mul_id) - (add_id := add_id) - (add2_id := add2_id) - (split_id := split_id) - (scmul2_id := scmul2_id) - (sub2_id := sub2_id) - ). - Definition goldilocks448_correct' - mul_id add_id add2_id split_id scmul2_id sub2_id - eval_mul eval_add eval_add2 eval_split eval_scmul2 eval_sub2 - := - (goldilocks_mul_correct - (T := tuple Z half_sz) (T2 := tuple Z sz) - (Positional.eval (n:=half_sz) wt) - (Positional.eval (n:=sz) wt) - (mul_cps := Positional.mul_cps (n:=half_sz) wt) - (add_cps := Positional.add_cps (n:=half_sz) wt) - (add2_cps := Positional.add_cps (n:=sz) wt) - (split_cps := Positional_split_cps (n:=sz) (m1:=half_sz) (m2 := half_sz)) - (scmul2_cps := Positional_scmul_cps (n:=sz)) - (sub2_cps := Positional_sub_cps (n:=sz)) - (mul := fun a b => Positional.mul_cps (n:= half_sz) wt a b id) - (add := fun a b => Positional.add_cps (n:=half_sz) wt a b id) - (add2 := fun a b => Positional.add_cps (n:=sz) wt a b id) - (split := fun s a => Positional_split_cps (n:=sz) (m1:=half_sz) (m2 := half_sz) s a id) - (scmul2 := fun x a => Positional_scmul_cps (n:=sz) x a id) - (sub2 := fun a b => Positional_sub_cps (n:=sz) a b id) - (mul_id := mul_id) - (add_id := add_id) - (add2_id := add2_id) - (split_id := split_id) - (scmul2_id := scmul2_id) - (sub2_id := sub2_id) - (eval_mul := eval_mul) - (eval_add := eval_add) - (eval_add2 := eval_add2) - (eval_split := eval_split) - (eval_scmul2 := eval_scmul2) - (eval_sub2 := eval_sub2) - ). - Check goldilocks448_correct'. - Hint Unfold Positional_split_cps Positional_scmul_cps Positional_sub_cps. - Lemma goldilocks448_correct : - forall p : positive, - forall s : Z, - s <> 0 -> - s ^ 2 mod p = (s + 1) mod p -> - forall xs ys : Z ^ sz, - mod_eq (Z.to_pos p) - (Positional.eval wt (goldilocks448_cps s xs ys _ id)) - (Positional.eval wt xs * Positional.eval wt ys). - Proof. - pose proof wt_nonzero. - intros; autounfold. cbv [mod_eq]. - rewrite goldilocks448_id by (intros; autounfold; autorewrite with uncps push_id; reflexivity). autorewrite with push_id. - apply goldilocks448_correct'; try assumption; intros; autounfold; - autorewrite with uncps push_id cancel_pair push_basesystem_eval; - try reflexivity. - { setoid_rewrite Associational.eval_nil. ring. } - { rewrite Pos2Z.id; congruence. } - Qed. - Definition mul_sig : {mul : (Z^sz -> Z^sz -> Z^sz)%type | forall a b : Z^sz, @@ -270,16 +159,19 @@ Section Ops51. Proof. eexists; cbv beta zeta; intros. pose proof wt_nonzero. + Print goldilocks_mul_cps. let x := constr:( - goldilocks448_cps (2^224) a b _ id) in + goldilocks_mul_cps (n:=half_sz) (n2:=sz) wt (2^224) a b id) in F_mod_eq; transitivity (Positional.eval wt x); repeat autounfold; [ | autorewrite with uncps push_id push_basesystem_eval; - apply goldilocks448_correct; cbv; congruence ]. + apply goldilocks_mul_correct; try assumption; cbv; congruence ]. cbv[mod_eq]; apply f_equal2; [ | reflexivity ]; apply f_equal. + cbv [goldilocks_mul_cps]. + repeat autounfold. basesystem_partial_evaluation_RHS. do_replace_match_with_destructuring_match_in_goal. reflexivity. -- cgit v1.2.3