aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Arithmetic/Karatsuba.v169
1 files changed, 133 insertions, 36 deletions
diff --git a/src/Arithmetic/Karatsuba.v b/src/Arithmetic/Karatsuba.v
index 0f20bb238..7b2004a2d 100644
--- a/src/Arithmetic/Karatsuba.v
+++ b/src/Arithmetic/Karatsuba.v
@@ -1,49 +1,146 @@
Require Import Coq.ZArith.ZArith.
Require Import Crypto.Algebra.Nsatz.
-Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.ZUtil Crypto.Util.LetIn Crypto.Util.CPSUtil.
Local Open Scope Z_scope.
Section Karatsuba.
- Context {T : Type} (eval : T -> Z)
- (sub : T -> T -> T)
- (eval_sub : forall x y, eval (sub x y) = eval x - eval y)
- (mul : T -> T -> T)
- (eval_mul : forall x y, eval (mul x y) = eval x * eval y)
- (add : T -> T -> T)
- (eval_add : forall x y, eval (add x y) = eval x + eval y)
- (scmul : Z -> T -> T)
- (eval_scmul : forall c x, eval (scmul c x) = c * eval x)
- (split : Z -> T -> T * T)
- (eval_split : forall s x, s <> 0 -> eval (fst (split s x)) + s * (eval (snd (split s x))) = eval x)
- .
-
- Definition karatsuba_mul s (x y : T) : T :=
- let xab := split s x in
- let yab := split s y in
- let xy0 := mul (fst xab) (fst yab) in
- let xy2 := mul (snd xab) (snd yab) in
- let xy1 := sub (mul (add (fst xab) (snd xab)) (add (fst yab) (snd yab))) (add xy2 xy0) in
- add (add (scmul (s^2) xy2) (scmul s xy1)) xy0.
+ (* T is the "half-length" type, T2 is the "full-length" type *)
+ Context {T T2 : Type} (eval : T -> Z) (eval2 : T2 -> Z).
- Lemma eval_karatsuba_mul s x y (s_nonzero:s <> 0) :
- eval (karatsuba_mul s x y) = eval x * eval y.
- Proof using Type*. cbv [karatsuba_mul]; repeat rewrite ?eval_sub, ?eval_mul, ?eval_add, ?eval_scmul.
- rewrite <-(eval_split s x), <-(eval_split s y) by assumption; ring. Qed.
+ (* multiplication takes half-length inputs to full-length output *)
+ Context {mul_cps : T -> T -> forall {R}, (T2->R)->R}
+ {mul : T -> T -> T2}
+ {mul_id : forall x y {R} f, @mul_cps x y R f = f (mul x y)}
+ {eval_mul : forall x y, eval2 (mul x y) = eval x * eval y}.
+
+ (* splitting takes full-length input to half-length outputs *)
+ Context {split_cps : Z -> T2 -> forall {R}, ((T * T)->R)->R}
+ {split : Z -> T2 -> T * T}
+ {split_id : forall s x R f, @split_cps s x R f = f (split s x)}
+ {eval_split : forall s x, s <> 0 -> eval (fst (split s x)) + s * (eval (snd (split s x))) = eval2 x}.
+
+ (* half-length add *)
+ Context {add_cps : T -> T -> forall {R}, (T->R)->R}
+ {add : T -> T -> T}
+ {add_id : forall x y {R} f, @add_cps x y R f = f (add x y)}
+ {eval_add : forall x y, eval (add x y) = eval x + eval y}.
+
+ (* full-length operations: sub, add, scmul *)
+ Context {sub2_cps : T2 -> T2 -> forall {R}, (T2->R)->R}
+ {sub2 : T2 -> T2 -> T2}
+ {sub2_id : forall x y {R} f, @sub2_cps x y R f = f (sub2 x y)}
+ {eval_sub2 : forall x y, eval2 (sub2 x y) = eval2 x - eval2 y}
+ {add2_cps : T2 -> T2 -> forall {R}, (T2->R)->R}
+ {add2 : T2 -> T2 -> T2}
+ {add2_id : forall x y {R} f, @add2_cps x y R f = f (add2 x y)}
+ {eval_add2 : forall x y, eval2 (add2 x y) = eval2 x + eval2 y}
+ {scmul2_cps : Z -> T2 -> forall {R}, (T2->R)->R}
+ {scmul2 : Z -> T2 -> T2}
+ {scmul2_id : forall z x {R} f, @scmul2_cps z x R f = f (scmul2 z x)}
+ {eval_scmul2 : forall c x, eval2 (scmul2 c x) = c * eval2 x}.
+
+ Local Ltac rewrite_id :=
+ repeat progress rewrite ?mul_id, ?split_id, ?add_id, ?sub2_id, ?add2_id, ?scmul2_id.
+ Local Ltac rewrite_eval :=
+ repeat progress rewrite ?eval_mul, ?eval_split, ?eval_add, ?eval_sub2, ?eval_add2, ?eval_scmul2.
+ (*
+ If x = x0 + sx1 and y = y0 + sy1, then xy = s^2 * z2 + s * z1 + s * z0,
+ with:
+
+ z2 = x1y1
+ z0 = x0y0
+ z1 = (x1+x0)(y1+y0) - (z2 + z0)
+
+ Computing z1 one operation at a time:
+ sum_z = z0 + z2
+ sum_x = x1 + x0
+ sum_y = y1 + y0
+ mul_sumxy = sum_x * sum_y
+ z1 = mul_sumxy - sum_z
+ *)
+ Definition karatsuba_mul_cps s (x y : T2) {R} (f:T2->R) :=
+ split_cps s x _
+ (fun x0_x1 => split_cps s y _
+ (fun y0_y1 => mul_cps (fst x0_x1) (fst y0_y1) _
+ (fun z0 => mul_cps (snd x0_x1) (snd y0_y1) _
+ (fun z2 => add2_cps z0 z2 _
+ (fun sum_z => add_cps (fst x0_x1) (snd x0_x1) _
+ (fun sum_x => add_cps (fst y0_y1) (snd y0_y1) _
+ (fun sum_y => mul_cps sum_x sum_y _
+ (fun mul_sumxy => sub2_cps mul_sumxy sum_z _
+ (fun z1 => scmul2_cps s z1 _
+ (fun sz1 => scmul2_cps (s^2) z2 _
+ (fun s2z2 => add2_cps s2z2 sz1 _
+ (fun add_s2z2_sz1 => add2_cps add_s2z2_sz1 z0 _ f)))))))))))).
+
+ Definition karatsuba_mul s x y := @karatsuba_mul_cps s x y _ id.
+ Lemma karatsuba_mul_id s x y R f :
+ @karatsuba_mul_cps s x y R f = f (karatsuba_mul s x y).
+ Proof.
+ cbv [karatsuba_mul karatsuba_mul_cps]. rewrite_id.
+ reflexivity.
+ Qed.
+
+ Lemma eval_karatsuba_mul s x y (s_nonzero:s <> 0) :
+ eval2 (karatsuba_mul s x y) = eval2 x * eval2 y.
+ Proof.
+ cbv [karatsuba_mul karatsuba_mul_cps]. rewrite_id.
+ repeat rewrite push_id. rewrite_eval.
+ rewrite <-(eval_split s x), <-(eval_split s y) by assumption; ring.
+ Qed.
- Definition goldilocks_mul s (xs ys : T) : T :=
- let a_b := split s xs in
- let c_d := split s ys in
- let ac := mul (fst a_b) (fst c_d) in
- (add (add ac (mul (snd a_b) (snd c_d)))
- (scmul s (sub (mul (add (fst a_b) (snd a_b)) (add (fst c_d) (snd c_d))) ac))).
+ (*
+ If:
+ s^2 mod p = (s + 1) mod p
+ x = x0 + sx1
+ y = y0 + sy1
+ Then, with z0 and z2 as before and z1 = ((a + b) * (c + d)) - z0,
+ xy mod p = (z0 + z2 + sz1) mod p
+
+ Computing xy one operation at a time:
+ sum_z = z0 + z2
+ sum_x = x0 + x1
+ sum_y = y0 + y1
+ mul_sumxy = sum_x * sum_y
+ z1 = mul_sumxy - z0
+ sz1 = s * z1
+ xy = sum_z - sz1
+
+ *)
+ Definition goldilocks_mul_cps s (xs ys : T2) {R} (f:T2->R) :=
+ split_cps s xs _
+ (fun x0_x1 => split_cps s ys _
+ (fun y0_y1 => mul_cps (fst x0_x1) (fst y0_y1) _
+ (fun z0 => mul_cps (snd x0_x1) (snd y0_y1) _
+ (fun z2 => add2_cps z0 z2 _
+ (fun sum_z => add_cps (fst x0_x1) (snd x0_x1) _
+ (fun sum_x => add_cps (fst y0_y1) (snd y0_y1) _
+ (fun sum_y => mul_cps sum_x sum_y _
+ (fun mul_sumxy => sub2_cps mul_sumxy z0 _
+ (fun z1 => scmul2_cps s z1 _
+ (fun sz1 => add2_cps sum_z sz1 _ f)))))))))).
- Local Existing Instances Z.equiv_modulo_Reflexive RelationClasses.eq_Reflexive Z.equiv_modulo_Symmetric Z.equiv_modulo_Transitive Z.mul_mod_Proper Z.add_mod_Proper Z.modulo_equiv_modulo_Proper.
+ Definition goldilocks_mul s xs ys := @goldilocks_mul_cps s xs ys _ id.
+ Lemma goldilocks_mul_id s xs ys {R} f :
+ @goldilocks_mul_cps s xs ys R f = f (goldilocks_mul s xs ys).
+ Proof.
+ cbv [goldilocks_mul goldilocks_mul_cps]. rewrite_id.
+ reflexivity.
+ Qed.
+
+ Local Existing Instances Z.equiv_modulo_Reflexive
+ RelationClasses.eq_Reflexive Z.equiv_modulo_Symmetric
+ Z.equiv_modulo_Transitive Z.mul_mod_Proper Z.add_mod_Proper
+ Z.modulo_equiv_modulo_Proper.
Lemma goldilocks_mul_correct (p : Z) (p_nonzero : p <> 0) s (s_nonzero : s <> 0) (s2_modp : (s^2) mod p = (s+1) mod p) xs ys :
- (eval (goldilocks_mul s xs ys)) mod p = (eval xs * eval ys) mod p.
- Proof using Type*. cbv [goldilocks_mul]; Zmod_to_equiv_modulo.
- repeat rewrite ?eval_mul, ?eval_add, ?eval_sub, ?eval_scmul, <-?(eval_split s xs), <-?(eval_split s ys) by assumption; ring_simplify.
+ (eval2 (goldilocks_mul s xs ys)) mod p = (eval2 xs * eval2 ys) mod p.
+ Proof.
+ cbv [goldilocks_mul_cps goldilocks_mul]; Zmod_to_equiv_modulo.
+ rewrite_id. rewrite push_id. rewrite_eval.
+ repeat progress rewrite <-?(eval_split s xs), <-?(eval_split s ys) by assumption; ring_simplify.
setoid_rewrite s2_modp.
- apply f_equal2; nsatz. Qed.
+ apply f_equal2; nsatz.
+ Qed.
End Karatsuba.