aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/Karatsuba.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/Arithmetic/Karatsuba.v')
-rw-r--r--src/Arithmetic/Karatsuba.v228
1 files changed, 0 insertions, 228 deletions
diff --git a/src/Arithmetic/Karatsuba.v b/src/Arithmetic/Karatsuba.v
deleted file mode 100644
index 1873e5ef1..000000000
--- a/src/Arithmetic/Karatsuba.v
+++ /dev/null
@@ -1,228 +0,0 @@
-Require Import Coq.ZArith.ZArith.
-Require Import Coq.micromega.Lia.
-Require Import Crypto.Algebra.Nsatz.
-Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil.
-Require Import Crypto.Arithmetic.Core. Import B. Import Positional.
-Require Import Crypto.Util.Tuple.
-Require Import Crypto.Util.IdfunWithAlt.
-Require Import Crypto.Util.ZUtil.EquivModulo.
-Local Open Scope Z_scope.
-
-Section Karatsuba.
-Context (weight : nat -> Z)
- (weight_0 : weight 0%nat = 1%Z)
- (weight_nonzero : forall i, weight i <> 0).
- (* [tuple Z n] is the "half-length" type,
- [tuple Z n2] is the "full-length" type *)
- Context {n n2 : nat} (n_nonzero : n <> 0%nat) (n2_nonzero : n2 <> 0%nat).
- Let T := tuple Z n.
- Let T2 := tuple Z n2.
-
- (*
- If x = x0 + sx1 and y = y0 + sy1, then xy = s^2 * z2 + s * z1 + s * z0,
- with:
-
- z2 = x1y1
- z0 = x0y0
- z1 = (x1+x0)(y1+y0) - (z2 + z0)
-
- Computing z1 one operation at a time:
- sum_z = z0 + z2
- sum_x = x1 + x0
- sum_y = y1 + y0
- mul_sumxy = sum_x * sum_y
- z1 = mul_sumxy - sum_z
- *)
- Definition karatsuba_mul_cps s (x y : T2) {R} (f:T2->R) :=
- split_cps (n:=n2) (m1:=n) (m2:=n) weight s x
- (fun x0_x1 => split_cps weight s y
- (fun y0_y1 => mul_cps weight (fst x0_x1) (fst y0_y1)
- (fun z0 => mul_cps weight(snd x0_x1) (snd y0_y1)
- (fun z2 => add_cps weight z0 z2
- (fun sum_z => add_cps weight (fst x0_x1) (snd x0_x1)
- (fun sum_x => add_cps weight (fst y0_y1) (snd y0_y1)
- (fun sum_y => mul_cps weight sum_x sum_y
- (fun mul_sumxy => unbalanced_sub_cps weight mul_sumxy sum_z
- (fun z1 => scmul_cps weight s z1
- (fun sz1 => scmul_cps weight (s^2) z2
- (fun s2z2 => add_cps weight s2z2 sz1
- (fun add_s2z2_sz1 => add_cps weight add_s2z2_sz1 z0 f)))))))))))).
-
- Definition karatsuba_mul s x y := @karatsuba_mul_cps s x y _ id.
- Lemma karatsuba_mul_id s x y R f :
- @karatsuba_mul_cps s x y R f = f (karatsuba_mul s x y).
- Proof.
- cbv [karatsuba_mul karatsuba_mul_cps].
- repeat autounfold.
- autorewrite with cancel_pair push_id uncps.
- reflexivity.
- Qed.
- Hint Opaque karatsuba_mul : uncps.
- Hint Rewrite karatsuba_mul_id : uncps.
-
- Lemma eval_karatsuba_mul s x y (s_nonzero:s <> 0) :
- eval weight (karatsuba_mul s x y) = eval weight x * eval weight y.
- Proof.
- cbv [karatsuba_mul karatsuba_mul_cps]; repeat autounfold.
- autorewrite with cancel_pair push_id uncps push_basesystem_eval.
- repeat match goal with
- | _ => rewrite <-eval_to_associational
- | |- context [(to_associational ?w ?x)] =>
- rewrite <-(Associational.eval_split
- s (to_associational w x)) by assumption
- | _ => rewrite <-Associational.eval_split by assumption
- | _ => setoid_rewrite Associational.eval_nil
- end.
- ring_simplify.
- nsatz.
- Qed.
-
- (* These definitions are intended to make bounds analysis go through
- for karatsuba. Essentially, we provide a version of the code to
- actually run and a version to bounds-check, along with a proof
- that they are exactly equal. This works around cases where the
- bounds proof requires high-level reasoning. *)
- Local Notation id_with_alt_bounds_cps := id_tuple_with_alt_cps'.
-
- (*
- If:
- s^2 mod p = (s + 1) mod p
- x = x0 + sx1
- y = y0 + sy1
- Then, with z0 and z2 as before (x0y0 and x1y1 respectively), let z1 = ((x0 + x1) * (y0 + y1)) - z0.
-
- Computing xy one operation at a time:
- sum_z = z0 + z2
- sum_x = x0 + x1
- sum_y = y0 + y1
- mul_sumxy = sum_x * sum_y
- z1 = mul_sumxy - z0
- sz1 = s * z1
- xy = sum_z - sz1
-
- The subtraction in the computation of z1 presents issues for
- bounds analysis. In particular, just analyzing the upper and lower
- bounds of the values would indicate that it could underflow--we
- know it won't because
-
- mul_sumxy -z0 = ((x0+x1) * (y0+y1)) - x0y0
- = (x0y0 + x1y0 + x0y1 + x1y1) - x0y0
- = x1y0 + x0y1 + x1y1
-
- Therefore, we use id_with_alt_bounds to indicate that the
- bounds-checker should check the non-subtracting form.
-
- *)
-
- (*
- Definition goldilocks_mul_cps_for_bounds_checker
- s (xs ys : T2) {R} (f:T2->R) :=
- split_cps (m1:=n) (m2:=n) weight s xs
- (fun x0_x1 => split_cps weight s ys
-
- (fun z1 => Positional.to_associational_cps weight z1
- (fun z1 => Associational.mul_cps (pair s 1::nil) z1
- (fun sz1 => Positional.from_associational_cps weight n2 sz1
- (fun sz1 => add_cps weight sum_z sz1 f)))))))))))).
- *)
-
- Let T3 := tuple Z (n2+n).
- Definition goldilocks_mul_cps s (xs ys : T2) {R} (f:T3->R) :=
- split_cps (m1:=n) (m2:=n) weight s xs
- (fun x0_x1 => split_cps weight s ys
- (fun y0_y1 => mul_cps weight (fst x0_x1) (fst y0_y1)
- (fun z0 => mul_cps weight (snd x0_x1) (snd y0_y1)
- (fun z2 => add_cps weight z0 z2
- (fun sum_z : tuple _ n2 => add_cps weight (fst x0_x1) (snd x0_x1)
- (fun sum_x => add_cps weight (fst y0_y1) (snd y0_y1)
- (fun sum_y => mul_cps weight sum_x sum_y
- (fun mul_sumxy =>
-
- id_with_alt_bounds_cps (fun f =>
- (unbalanced_sub_cps weight mul_sumxy z0 f)) (fun f =>
-
- (mul_cps weight (fst x0_x1) (snd y0_y1)
- (fun x0_y1 => mul_cps weight (snd x0_x1) (fst y0_y1)
- (fun x1_y0 => mul_cps weight (fst x0_x1) (fst y0_y1)
- (fun z0 => mul_cps weight (snd x0_x1) (snd y0_y1)
- (fun z2 => add_cps weight z0 z2
- (fun sum_z => add_cps weight x0_y1 x1_y0
- (fun z1' => add_cps weight z1' z2 f)))))))) (fun z1 =>
-
- Positional.to_associational_cps weight z1
- (fun z1 => Associational.mul_cps (pair s 1::nil) z1
- (fun sz1 => Positional.to_associational_cps weight sum_z
- (fun sum_z => Positional.from_associational_cps weight _ (sum_z++sz1) f
- )))))))))))).
-
- Definition goldilocks_mul s xs ys := goldilocks_mul_cps s xs ys id.
- Lemma goldilocks_mul_id s xs ys R f :
- @goldilocks_mul_cps s xs ys R f = f (goldilocks_mul s xs ys).
- Proof.
- cbv [goldilocks_mul goldilocks_mul_cps Let_In].
- repeat autounfold. autorewrite with uncps push_id.
- reflexivity.
- Qed.
- Hint Opaque goldilocks_mul : uncps.
- Hint Rewrite goldilocks_mul_id : uncps.
-
- Local Existing Instances Z.equiv_modulo_Reflexive
- RelationClasses.eq_Reflexive Z.equiv_modulo_Symmetric
- Z.equiv_modulo_Transitive Z.mul_mod_Proper Z.add_mod_Proper
- Z.modulo_equiv_modulo_Proper.
-
- Lemma goldilocks_mul_correct (p : Z) (p_nonzero : p <> 0) s (s_nonzero : s <> 0) (s2_modp : (s^2) mod p = (s+1) mod p) xs ys :
- (eval weight (goldilocks_mul s xs ys)) mod p = (eval weight xs * eval weight ys) mod p.
- Proof.
- cbv [goldilocks_mul_cps goldilocks_mul Let_In].
- Zmod_to_equiv_modulo.
- progress autounfold.
- progress autorewrite with push_id cancel_pair uncps push_basesystem_eval.
- rewrite !unfold_id_tuple_with_alt.
- repeat match goal with
- | _ => rewrite <-eval_to_associational
- | |- context [(to_associational ?w ?x)] =>
- rewrite <-(Associational.eval_split
- s (to_associational w x)) by assumption
- | _ => rewrite <-Associational.eval_split by assumption
- | _ => setoid_rewrite Associational.eval_nil
- end.
- progress autorewrite with push_id cancel_pair uncps push_basesystem_eval.
- repeat (rewrite ?eval_from_associational, ?eval_to_associational).
- progress autorewrite with push_id cancel_pair uncps push_basesystem_eval.
- repeat match goal with
- | _ => rewrite <-eval_to_associational
- | |- context [(to_associational ?w ?x)] =>
- rewrite <-(Associational.eval_split
- s (to_associational w x)) by assumption
- | _ => rewrite <-Associational.eval_split by assumption
- | _ => setoid_rewrite Associational.eval_nil
- end.
- ring_simplify.
- setoid_rewrite s2_modp.
- apply f_equal2; nsatz.
- assumption. assumption. omega.
- Qed.
-
- Lemma eval_goldilocks_mul (p : positive) s (s_nonzero : s <> 0) (s2_modp : mod_eq p (s^2) (s+1)) xs ys :
- mod_eq p (eval weight (goldilocks_mul s xs ys)) (eval weight xs * eval weight ys).
- Proof.
- apply goldilocks_mul_correct; auto; lia.
- Qed.
-End Karatsuba.
-Hint Opaque karatsuba_mul goldilocks_mul : uncps.
-Hint Rewrite karatsuba_mul_id goldilocks_mul_id : uncps.
-
-Hint Rewrite
- @eval_karatsuba_mul
- @eval_goldilocks_mul
- @goldilocks_mul_correct
- using (assumption || (div_mod_cps_t; auto)) : push_basesystem_eval.
-
-Ltac basesystem_partial_evaluation_unfolder t :=
- let t := (eval cbv delta [goldilocks_mul karatsuba_mul goldilocks_mul_cps karatsuba_mul_cps] in t) in
- let t := Arithmetic.Core.basesystem_partial_evaluation_unfolder t in
- t.
-
-Ltac Arithmetic.Core.basesystem_partial_evaluation_default_unfolder t ::=
- basesystem_partial_evaluation_unfolder t.