aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-06-11 19:15:15 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-06-11 19:15:41 -0400
commit97c72ad6da000682171c819ba712c6c68a09686f (patch)
treefcf95c073028869bfbb88acc03377e5a5fd70b76 /src
parente8b12aeec4abea243b7f0be1100a6f33a6ca15ad (diff)
Factor karatsuba through IdfunWithAlt, add test
Currently the refinement is commented out. Also, we drop the proof of equality early (and probably won't use it in the first place); there's no way we can carry around such a proof in reflective-land, so we'll need to add an arithmetical-equality semi-decider to reflective-land that can prove the relevant equalities (or we'll need to leave them over as side-conditions). I expeect this may make things significantly easier on @jadephilipoom.
Diffstat (limited to 'src')
-rw-r--r--src/Arithmetic/Karatsuba.v51
-rw-r--r--src/Specific/IntegrationTestKaratsubaMul.v70
-rw-r--r--src/Specific/Karatsuba.v20
3 files changed, 111 insertions, 30 deletions
diff --git a/src/Arithmetic/Karatsuba.v b/src/Arithmetic/Karatsuba.v
index 3c1471365..d518d1dbd 100644
--- a/src/Arithmetic/Karatsuba.v
+++ b/src/Arithmetic/Karatsuba.v
@@ -3,6 +3,7 @@ Require Import Crypto.Algebra.Nsatz.
Require Import Crypto.Util.ZUtil Crypto.Util.LetIn Crypto.Util.CPSUtil Crypto.Util.Tactics.
Require Import Crypto.Arithmetic.Core. Import B. Import Positional.
Require Import Crypto.Util.Tuple.
+Require Import Crypto.Util.IdfunWithAlt.
Local Open Scope Z_scope.
Section Karatsuba.
@@ -15,10 +16,10 @@ Context (weight : nat -> Z)
Let T := tuple Z n.
Let T2 := tuple Z n2.
- (*
- If x = x0 + sx1 and y = y0 + sy1, then xy = s^2 * z2 + s * z1 + s * z0,
+ (*
+ If x = x0 + sx1 and y = y0 + sy1, then xy = s^2 * z2 + s * z1 + s * z0,
with:
-
+
z2 = x1y1
z0 = x0y0
z1 = (x1+x0)(y1+y0) - (z2 + z0)
@@ -77,19 +78,16 @@ Context (weight : nat -> Z)
actually run and a version to bounds-check, along with a proof
that they are exactly equal. This works around cases where the
bounds proof requires high-level reasoning. *)
- Definition id_with_alt_bounds {A} (value : A) (value_for_alt_bounds : A) : A
- := value.
- Definition id_with_alt_bounds_and_proof {A} (value : A) (value_for_alt_bounds : A)
- {pf : value = value_for_alt_bounds}
- := id_with_alt_bounds value value_for_alt_bounds.
-
+ Local Notation id_with_alt_bounds := id_tuple_with_alt.
+ Local Notation id_with_alt_bounds_and_proof := id_tuple_with_alt_proof.
+
(*
If:
s^2 mod p = (s + 1) mod p
x = x0 + sx1
y = y0 + sy1
Then, with z0 and z2 as before (x0y0 and x1y1 respectively), let z1 = ((x0 + x1) * (y0 + y1)) - z0.
-
+
Computing xy one operation at a time:
sum_z = z0 + z2
sum_x = x0 + x1
@@ -104,13 +102,13 @@ Context (weight : nat -> Z)
bounds of the values would indicate that it could underflow--we
know it won't because
- mul_sumxy -z0 = ((x0+x1) * (y0+y1)) - x0y0
- = (x0y0 + x1y0 + x0y1 + x1y1) - x0y0
+ mul_sumxy -z0 = ((x0+x1) * (y0+y1)) - x0y0
+ = (x0y0 + x1y0 + x0y1 + x1y1) - x0y0
= x1y0 + x0y1 + x1y1
Therefore, we use id_with_alt_bounds to indicate that the
bounds-checker should check the non-subtracting form.
-
+
*)
Definition goldilocks_mul_cps_for_bounds_checker
@@ -126,7 +124,7 @@ Context (weight : nat -> Z)
(fun z1' => add_cps weight z1' z2
(fun z1 => scmul_cps weight s z1
(fun sz1 => add_cps weight sum_z sz1 f)))))))))).
-
+
Definition goldilocks_mul_cps s (xs ys : T2) {R} (f:T2->R) :=
split_cps (m1:=n) (m2:=n) weight s xs
(fun x0_x1 => split_cps weight s ys
@@ -190,11 +188,11 @@ Context (weight : nat -> Z)
Admitted.
Local Infix "**" := Associational.mul (at level 40).
-
+
Local Definition multerm terms :=
Associational.multerm (fst terms) (snd terms).
-
- Lemma mul_power_equiv (p q : list limb) :
+
+ Lemma mul_power_equiv (p q : list limb) :
Permutation.permutation
(p ** q)
(List.map multerm (list_prod p q)).
@@ -254,7 +252,7 @@ Context (weight : nat -> Z)
Lemma subtraction_id N p q :
from N ((p ++ Associational.negate_snd p) ++ q) = from N q.
Admitted.
-
+
Lemma goldilocks_mul_equiv' x0 x1 y0 y1 :
let X0 := to (from n x0) in
let X1 := to (from n x1) in
@@ -281,10 +279,10 @@ Context (weight : nat -> Z)
| |- _ = from ?n (?a ++ ?b ++ ?c ++ ?d ++ Associational.negate_snd ?a) =>
transitivity (from n ((a ++ Associational.negate_snd a) ++ b ++ c ++ d));
[|remember a as A; remember b as B; remember c as C; remember d as D; remember (Associational.negate_snd A) as negA]
-
+
end.
Focus 2.
- { rewrite app_assoc_reverse.
+ { rewrite app_assoc_reverse.
apply permutation_from_associational.
replace (A ++ B ++ C ++ D ++ negA) with (A ++ (B ++ C ++ D) ++ negA).
auto using app_assoc, app_assoc_reverse.
@@ -298,7 +296,7 @@ Context (weight : nat -> Z)
end.
reflexivity.
Qed.
-
+
Lemma goldilocks_mul_equiv s xs ys {R} f:
@goldilocks_mul_cps s xs ys R f =
@goldilocks_mul_cps_for_bounds_checker s xs ys R f.
@@ -323,19 +321,19 @@ Context (weight : nat -> Z)
Qed.
Definition goldilocks_mul s xs ys :=
- id_with_alt_bounds_and_proof (pf := goldilocks_mul_equiv _ _ _ _)
+ id_with_alt_bounds_and_proof
+ (pf := goldilocks_mul_equiv _ _ _ _)
(@goldilocks_mul_cps s xs ys _ id)
(@goldilocks_mul_cps_for_bounds_checker s xs ys _ id).
Lemma goldilocks_mul_id s xs ys {R} f :
@goldilocks_mul_cps s xs ys R f = f (goldilocks_mul s xs ys).
Proof.
- cbv [id_with_alt_bounds_and_proof
- id_with_alt_bounds goldilocks_mul goldilocks_mul_cps].
+ cbv [goldilocks_mul goldilocks_mul_cps]; rewrite !unfold_id_tuple_with_alt_proof.
repeat autounfold.
autorewrite with cancel_pair push_id uncps.
reflexivity.
Qed.
-
+
Local Existing Instances Z.equiv_modulo_Reflexive
RelationClasses.eq_Reflexive Z.equiv_modulo_Symmetric
Z.equiv_modulo_Transitive Z.mul_mod_Proper Z.add_mod_Proper
@@ -344,8 +342,7 @@ Context (weight : nat -> Z)
Lemma goldilocks_mul_correct (p : Z) (p_nonzero : p <> 0) s (s_nonzero : s <> 0) (s2_modp : (s^2) mod p = (s+1) mod p) xs ys :
(eval weight (goldilocks_mul s xs ys)) mod p = (eval weight xs * eval weight ys) mod p.
Proof.
- cbv [id_with_alt_bounds_and_proof
- id_with_alt_bounds goldilocks_mul goldilocks_mul_cps].
+ cbv [goldilocks_mul goldilocks_mul_cps]; rewrite !unfold_id_tuple_with_alt_proof.
Zmod_to_equiv_modulo.
repeat autounfold; autorewrite with push_id cancel_pair uncps push_basesystem_eval.
repeat match goal with
diff --git a/src/Specific/IntegrationTestKaratsubaMul.v b/src/Specific/IntegrationTestKaratsubaMul.v
new file mode 100644
index 000000000..7f48c0c9f
--- /dev/null
+++ b/src/Specific/IntegrationTestKaratsubaMul.v
@@ -0,0 +1,70 @@
+Require Import Coq.ZArith.ZArith.
+Require Import Coq.Lists.List.
+Local Open Scope Z_scope.
+
+Require Import Crypto.Arithmetic.Core.
+Require Import Crypto.Util.FixedWordSizes.
+Require Import Crypto.Specific.Karatsuba.
+Require Import Crypto.Arithmetic.PrimeFieldTheorems.
+Require Import Crypto.Util.Tuple Crypto.Util.Sigma Crypto.Util.Sigma.MapProjections Crypto.Util.Sigma.Lift Crypto.Util.Notations Crypto.Util.ZRange Crypto.Util.BoundedWord.
+Require Import Crypto.Util.Tactics.Head.
+Require Import Crypto.Util.Tactics.MoveLetIn.
+Import ListNotations.
+
+Require Import Crypto.Specific.IntegrationTestTemporaryMiscCommon.
+
+Require Import Crypto.Compilers.Z.Bounds.Pipeline.
+
+Section BoundedField25p5.
+ Local Coercion Z.of_nat : nat >-> Z.
+
+ Let limb_widths := Eval vm_compute in (List.map (fun i => Z.log2 (wt (S i) / wt i)) (seq 0 sz)).
+ Let length_lw := Eval compute in List.length limb_widths.
+
+ Local Notation b_of exp := {| lower := 0 ; upper := 2^exp + 2^(exp-3) |}%Z (only parsing). (* max is [(0, 2^(exp+2) + 2^exp + 2^(exp-1) + 2^(exp-3) + 2^(exp-4) + 2^(exp-5) + 2^(exp-6) + 2^(exp-10) + 2^(exp-12) + 2^(exp-13) + 2^(exp-14) + 2^(exp-15) + 2^(exp-17) + 2^(exp-23) + 2^(exp-24))%Z] *)
+ (* The definition [bounds_exp] is a tuple-version of the
+ limb-widths, which are the [exp] argument in [b_of] above, i.e.,
+ the approximate base-2 exponent of the bounds on the limb in that
+ position. *)
+ Let bounds_exp : Tuple.tuple Z length_lw
+ := Eval compute in
+ Tuple.from_list length_lw limb_widths eq_refl.
+ Let bounds : Tuple.tuple zrange length_lw
+ := Eval compute in
+ Tuple.map (fun e => b_of e) bounds_exp.
+
+ Let lgbitwidth := Eval compute in (Z.to_nat (Z.log2_up (List.fold_right Z.max 0 limb_widths))).
+ Let bitwidth := Eval compute in (2^lgbitwidth)%nat.
+ Let feZ : Type := tuple Z sz.
+ Let feW : Type := tuple (wordT lgbitwidth) sz.
+ Let feBW : Type := BoundedWord sz bitwidth bounds.
+ Let phi : feBW -> F m :=
+ fun x => B.Positional.Fdecode wt (BoundedWordToZ _ _ _ x).
+
+ (* TODO : change this to field once field isomorphism happens *)
+ Definition mul :
+ { mul : feBW -> feBW -> feBW
+ | forall a b, phi (mul a b) = F.mul (phi a) (phi b) }.
+ Proof.
+ lazymatch goal with
+ | [ |- { f | forall a b, ?phi (f a b) = @?rhs a b } ]
+ => apply lift2_sig with (P:=fun a b f => phi f = rhs a b)
+ end.
+ intros a b.
+ eexists_sig_etransitivity. all:cbv [phi].
+ rewrite <- (proj2_sig mul_sig).
+ symmetry; rewrite <- (proj2_sig carry_sig); symmetry.
+ set (carry_mulZ := fun a b => proj1_sig carry_sig (proj1_sig mul_sig a b)).
+ change (proj1_sig carry_sig (proj1_sig mul_sig ?a ?b)) with (carry_mulZ a b).
+ context_to_dlet_in_rhs carry_mulZ.
+ cbv beta iota delta [carry_mulZ proj1_sig mul_sig carry_sig fst snd runtime_add runtime_and runtime_mul runtime_opp runtime_shr sz].
+ reflexivity.
+ sig_dlet_in_rhs_to_context.
+ apply (fun f => proj2_sig_map (fun THIS_NAME_MUST_NOT_BE_UNDERSCORE_TO_WORK_AROUND_CONSTR_MATCHING_ANAOMLIES___BUT_NOTE_THAT_IF_THIS_NAME_IS_LOWERCASE_A___THEN_REIFICATION_STACK_OVERFLOWS___AND_I_HAVE_NO_IDEA_WHATS_GOING_ON p => f_equal f p)).
+ (* jgross start here! *)
+ (*Set Ltac Profiling.*)
+ Time admit; refine_reflectively.
+ (*Show Ltac Profile.*)
+ Time Admitted. (* Finished transaction in 10.167 secs (10.123u,0.023s) (successful) *)
+
+End BoundedField25p5.
diff --git a/src/Specific/Karatsuba.v b/src/Specific/Karatsuba.v
index dc2f42d44..dd85203b7 100644
--- a/src/Specific/Karatsuba.v
+++ b/src/Specific/Karatsuba.v
@@ -6,6 +6,7 @@ Require Import (*Crypto.Util.Tactics*) Crypto.Util.Decidable.
Require Import Crypto.Util.LetIn Crypto.Util.ZUtil Crypto.Util.Tactics.
Require Import Crypto.Arithmetic.Karatsuba.
Require Crypto.Util.Tuple.
+Require Import Crypto.Util.IdfunWithAlt.
Local Notation tuple := Tuple.tuple.
Local Open Scope list_scope.
Local Open Scope Z_scope.
@@ -207,9 +208,22 @@ Section Ops51.
cbv [mod_eq]; apply f_equal2;
[ | reflexivity ]; apply f_equal.
cbv [goldilocks_mul].
- transitivity (id_with_alt_bounds_and_proof (pf := goldilocks_mul_sig_equiv a b) ((proj1_sig goldilocks_mul_sig) a b) ((proj1_sig goldilocks_mul_for_bounds_checker_sig) a b)).
- { cbv [proj1_sig goldilocks_mul_for_bounds_checker_sig goldilocks_mul_sig]. reflexivity. }
- { cbv [id_with_alt_bounds_and_proof id_with_alt_bounds].
+ transitivity
+ (Tuple.eta_tuple
+ (fun a
+ => Tuple.eta_tuple
+ (fun b
+ => id_tuple_with_alt_proof
+ (pf := goldilocks_mul_sig_equiv a b)
+ ((proj1_sig goldilocks_mul_sig) a b)
+ ((proj1_sig goldilocks_mul_for_bounds_checker_sig) a b))
+ b)
+ a).
+ { cbv [proj1_sig goldilocks_mul_for_bounds_checker_sig goldilocks_mul_sig Tuple.eta_tuple Tuple.eta_tuple_dep sz Tuple.eta_tuple'_dep id_tuple_with_alt_proof id_tuple'_with_alt_proof];
+ cbn [fst snd].
+ cbv [id_with_alt_proof].
+ reflexivity. }
+ { rewrite !Tuple.strip_eta_tuple, !unfold_id_tuple_with_alt_proof.
rewrite (proj2_sig goldilocks_mul_sig). reflexivity. }
Defined.