diff options
author | Jason Gross <jgross@mit.edu> | 2017-06-11 19:15:15 -0400 |
---|---|---|
committer | Jason Gross <jgross@mit.edu> | 2017-06-11 19:15:41 -0400 |
commit | 97c72ad6da000682171c819ba712c6c68a09686f (patch) | |
tree | fcf95c073028869bfbb88acc03377e5a5fd70b76 /src | |
parent | e8b12aeec4abea243b7f0be1100a6f33a6ca15ad (diff) |
Factor karatsuba through IdfunWithAlt, add test
Currently the refinement is commented out.
Also, we drop the proof of equality early (and probably won't use it in
the first place); there's no way we can carry around such a proof in
reflective-land, so we'll need to add an arithmetical-equality
semi-decider to reflective-land that can prove the relevant equalities
(or we'll need to leave them over as side-conditions). I expeect this
may make things significantly easier on @jadephilipoom.
Diffstat (limited to 'src')
-rw-r--r-- | src/Arithmetic/Karatsuba.v | 51 | ||||
-rw-r--r-- | src/Specific/IntegrationTestKaratsubaMul.v | 70 | ||||
-rw-r--r-- | src/Specific/Karatsuba.v | 20 |
3 files changed, 111 insertions, 30 deletions
diff --git a/src/Arithmetic/Karatsuba.v b/src/Arithmetic/Karatsuba.v index 3c1471365..d518d1dbd 100644 --- a/src/Arithmetic/Karatsuba.v +++ b/src/Arithmetic/Karatsuba.v @@ -3,6 +3,7 @@ Require Import Crypto.Algebra.Nsatz. Require Import Crypto.Util.ZUtil Crypto.Util.LetIn Crypto.Util.CPSUtil Crypto.Util.Tactics. Require Import Crypto.Arithmetic.Core. Import B. Import Positional. Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.IdfunWithAlt. Local Open Scope Z_scope. Section Karatsuba. @@ -15,10 +16,10 @@ Context (weight : nat -> Z) Let T := tuple Z n. Let T2 := tuple Z n2. - (* - If x = x0 + sx1 and y = y0 + sy1, then xy = s^2 * z2 + s * z1 + s * z0, + (* + If x = x0 + sx1 and y = y0 + sy1, then xy = s^2 * z2 + s * z1 + s * z0, with: - + z2 = x1y1 z0 = x0y0 z1 = (x1+x0)(y1+y0) - (z2 + z0) @@ -77,19 +78,16 @@ Context (weight : nat -> Z) actually run and a version to bounds-check, along with a proof that they are exactly equal. This works around cases where the bounds proof requires high-level reasoning. *) - Definition id_with_alt_bounds {A} (value : A) (value_for_alt_bounds : A) : A - := value. - Definition id_with_alt_bounds_and_proof {A} (value : A) (value_for_alt_bounds : A) - {pf : value = value_for_alt_bounds} - := id_with_alt_bounds value value_for_alt_bounds. - + Local Notation id_with_alt_bounds := id_tuple_with_alt. + Local Notation id_with_alt_bounds_and_proof := id_tuple_with_alt_proof. + (* If: s^2 mod p = (s + 1) mod p x = x0 + sx1 y = y0 + sy1 Then, with z0 and z2 as before (x0y0 and x1y1 respectively), let z1 = ((x0 + x1) * (y0 + y1)) - z0. - + Computing xy one operation at a time: sum_z = z0 + z2 sum_x = x0 + x1 @@ -104,13 +102,13 @@ Context (weight : nat -> Z) bounds of the values would indicate that it could underflow--we know it won't because - mul_sumxy -z0 = ((x0+x1) * (y0+y1)) - x0y0 - = (x0y0 + x1y0 + x0y1 + x1y1) - x0y0 + mul_sumxy -z0 = ((x0+x1) * (y0+y1)) - x0y0 + = (x0y0 + x1y0 + x0y1 + x1y1) - x0y0 = x1y0 + x0y1 + x1y1 Therefore, we use id_with_alt_bounds to indicate that the bounds-checker should check the non-subtracting form. - + *) Definition goldilocks_mul_cps_for_bounds_checker @@ -126,7 +124,7 @@ Context (weight : nat -> Z) (fun z1' => add_cps weight z1' z2 (fun z1 => scmul_cps weight s z1 (fun sz1 => add_cps weight sum_z sz1 f)))))))))). - + Definition goldilocks_mul_cps s (xs ys : T2) {R} (f:T2->R) := split_cps (m1:=n) (m2:=n) weight s xs (fun x0_x1 => split_cps weight s ys @@ -190,11 +188,11 @@ Context (weight : nat -> Z) Admitted. Local Infix "**" := Associational.mul (at level 40). - + Local Definition multerm terms := Associational.multerm (fst terms) (snd terms). - - Lemma mul_power_equiv (p q : list limb) : + + Lemma mul_power_equiv (p q : list limb) : Permutation.permutation (p ** q) (List.map multerm (list_prod p q)). @@ -254,7 +252,7 @@ Context (weight : nat -> Z) Lemma subtraction_id N p q : from N ((p ++ Associational.negate_snd p) ++ q) = from N q. Admitted. - + Lemma goldilocks_mul_equiv' x0 x1 y0 y1 : let X0 := to (from n x0) in let X1 := to (from n x1) in @@ -281,10 +279,10 @@ Context (weight : nat -> Z) | |- _ = from ?n (?a ++ ?b ++ ?c ++ ?d ++ Associational.negate_snd ?a) => transitivity (from n ((a ++ Associational.negate_snd a) ++ b ++ c ++ d)); [|remember a as A; remember b as B; remember c as C; remember d as D; remember (Associational.negate_snd A) as negA] - + end. Focus 2. - { rewrite app_assoc_reverse. + { rewrite app_assoc_reverse. apply permutation_from_associational. replace (A ++ B ++ C ++ D ++ negA) with (A ++ (B ++ C ++ D) ++ negA). auto using app_assoc, app_assoc_reverse. @@ -298,7 +296,7 @@ Context (weight : nat -> Z) end. reflexivity. Qed. - + Lemma goldilocks_mul_equiv s xs ys {R} f: @goldilocks_mul_cps s xs ys R f = @goldilocks_mul_cps_for_bounds_checker s xs ys R f. @@ -323,19 +321,19 @@ Context (weight : nat -> Z) Qed. Definition goldilocks_mul s xs ys := - id_with_alt_bounds_and_proof (pf := goldilocks_mul_equiv _ _ _ _) + id_with_alt_bounds_and_proof + (pf := goldilocks_mul_equiv _ _ _ _) (@goldilocks_mul_cps s xs ys _ id) (@goldilocks_mul_cps_for_bounds_checker s xs ys _ id). Lemma goldilocks_mul_id s xs ys {R} f : @goldilocks_mul_cps s xs ys R f = f (goldilocks_mul s xs ys). Proof. - cbv [id_with_alt_bounds_and_proof - id_with_alt_bounds goldilocks_mul goldilocks_mul_cps]. + cbv [goldilocks_mul goldilocks_mul_cps]; rewrite !unfold_id_tuple_with_alt_proof. repeat autounfold. autorewrite with cancel_pair push_id uncps. reflexivity. Qed. - + Local Existing Instances Z.equiv_modulo_Reflexive RelationClasses.eq_Reflexive Z.equiv_modulo_Symmetric Z.equiv_modulo_Transitive Z.mul_mod_Proper Z.add_mod_Proper @@ -344,8 +342,7 @@ Context (weight : nat -> Z) Lemma goldilocks_mul_correct (p : Z) (p_nonzero : p <> 0) s (s_nonzero : s <> 0) (s2_modp : (s^2) mod p = (s+1) mod p) xs ys : (eval weight (goldilocks_mul s xs ys)) mod p = (eval weight xs * eval weight ys) mod p. Proof. - cbv [id_with_alt_bounds_and_proof - id_with_alt_bounds goldilocks_mul goldilocks_mul_cps]. + cbv [goldilocks_mul goldilocks_mul_cps]; rewrite !unfold_id_tuple_with_alt_proof. Zmod_to_equiv_modulo. repeat autounfold; autorewrite with push_id cancel_pair uncps push_basesystem_eval. repeat match goal with diff --git a/src/Specific/IntegrationTestKaratsubaMul.v b/src/Specific/IntegrationTestKaratsubaMul.v new file mode 100644 index 000000000..7f48c0c9f --- /dev/null +++ b/src/Specific/IntegrationTestKaratsubaMul.v @@ -0,0 +1,70 @@ +Require Import Coq.ZArith.ZArith. +Require Import Coq.Lists.List. +Local Open Scope Z_scope. + +Require Import Crypto.Arithmetic.Core. +Require Import Crypto.Util.FixedWordSizes. +Require Import Crypto.Specific.Karatsuba. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. +Require Import Crypto.Util.Tuple Crypto.Util.Sigma Crypto.Util.Sigma.MapProjections Crypto.Util.Sigma.Lift Crypto.Util.Notations Crypto.Util.ZRange Crypto.Util.BoundedWord. +Require Import Crypto.Util.Tactics.Head. +Require Import Crypto.Util.Tactics.MoveLetIn. +Import ListNotations. + +Require Import Crypto.Specific.IntegrationTestTemporaryMiscCommon. + +Require Import Crypto.Compilers.Z.Bounds.Pipeline. + +Section BoundedField25p5. + Local Coercion Z.of_nat : nat >-> Z. + + Let limb_widths := Eval vm_compute in (List.map (fun i => Z.log2 (wt (S i) / wt i)) (seq 0 sz)). + Let length_lw := Eval compute in List.length limb_widths. + + Local Notation b_of exp := {| lower := 0 ; upper := 2^exp + 2^(exp-3) |}%Z (only parsing). (* max is [(0, 2^(exp+2) + 2^exp + 2^(exp-1) + 2^(exp-3) + 2^(exp-4) + 2^(exp-5) + 2^(exp-6) + 2^(exp-10) + 2^(exp-12) + 2^(exp-13) + 2^(exp-14) + 2^(exp-15) + 2^(exp-17) + 2^(exp-23) + 2^(exp-24))%Z] *) + (* The definition [bounds_exp] is a tuple-version of the + limb-widths, which are the [exp] argument in [b_of] above, i.e., + the approximate base-2 exponent of the bounds on the limb in that + position. *) + Let bounds_exp : Tuple.tuple Z length_lw + := Eval compute in + Tuple.from_list length_lw limb_widths eq_refl. + Let bounds : Tuple.tuple zrange length_lw + := Eval compute in + Tuple.map (fun e => b_of e) bounds_exp. + + Let lgbitwidth := Eval compute in (Z.to_nat (Z.log2_up (List.fold_right Z.max 0 limb_widths))). + Let bitwidth := Eval compute in (2^lgbitwidth)%nat. + Let feZ : Type := tuple Z sz. + Let feW : Type := tuple (wordT lgbitwidth) sz. + Let feBW : Type := BoundedWord sz bitwidth bounds. + Let phi : feBW -> F m := + fun x => B.Positional.Fdecode wt (BoundedWordToZ _ _ _ x). + + (* TODO : change this to field once field isomorphism happens *) + Definition mul : + { mul : feBW -> feBW -> feBW + | forall a b, phi (mul a b) = F.mul (phi a) (phi b) }. + Proof. + lazymatch goal with + | [ |- { f | forall a b, ?phi (f a b) = @?rhs a b } ] + => apply lift2_sig with (P:=fun a b f => phi f = rhs a b) + end. + intros a b. + eexists_sig_etransitivity. all:cbv [phi]. + rewrite <- (proj2_sig mul_sig). + symmetry; rewrite <- (proj2_sig carry_sig); symmetry. + set (carry_mulZ := fun a b => proj1_sig carry_sig (proj1_sig mul_sig a b)). + change (proj1_sig carry_sig (proj1_sig mul_sig ?a ?b)) with (carry_mulZ a b). + context_to_dlet_in_rhs carry_mulZ. + cbv beta iota delta [carry_mulZ proj1_sig mul_sig carry_sig fst snd runtime_add runtime_and runtime_mul runtime_opp runtime_shr sz]. + reflexivity. + sig_dlet_in_rhs_to_context. + apply (fun f => proj2_sig_map (fun THIS_NAME_MUST_NOT_BE_UNDERSCORE_TO_WORK_AROUND_CONSTR_MATCHING_ANAOMLIES___BUT_NOTE_THAT_IF_THIS_NAME_IS_LOWERCASE_A___THEN_REIFICATION_STACK_OVERFLOWS___AND_I_HAVE_NO_IDEA_WHATS_GOING_ON p => f_equal f p)). + (* jgross start here! *) + (*Set Ltac Profiling.*) + Time admit; refine_reflectively. + (*Show Ltac Profile.*) + Time Admitted. (* Finished transaction in 10.167 secs (10.123u,0.023s) (successful) *) + +End BoundedField25p5. diff --git a/src/Specific/Karatsuba.v b/src/Specific/Karatsuba.v index dc2f42d44..dd85203b7 100644 --- a/src/Specific/Karatsuba.v +++ b/src/Specific/Karatsuba.v @@ -6,6 +6,7 @@ Require Import (*Crypto.Util.Tactics*) Crypto.Util.Decidable. Require Import Crypto.Util.LetIn Crypto.Util.ZUtil Crypto.Util.Tactics. Require Import Crypto.Arithmetic.Karatsuba. Require Crypto.Util.Tuple. +Require Import Crypto.Util.IdfunWithAlt. Local Notation tuple := Tuple.tuple. Local Open Scope list_scope. Local Open Scope Z_scope. @@ -207,9 +208,22 @@ Section Ops51. cbv [mod_eq]; apply f_equal2; [ | reflexivity ]; apply f_equal. cbv [goldilocks_mul]. - transitivity (id_with_alt_bounds_and_proof (pf := goldilocks_mul_sig_equiv a b) ((proj1_sig goldilocks_mul_sig) a b) ((proj1_sig goldilocks_mul_for_bounds_checker_sig) a b)). - { cbv [proj1_sig goldilocks_mul_for_bounds_checker_sig goldilocks_mul_sig]. reflexivity. } - { cbv [id_with_alt_bounds_and_proof id_with_alt_bounds]. + transitivity + (Tuple.eta_tuple + (fun a + => Tuple.eta_tuple + (fun b + => id_tuple_with_alt_proof + (pf := goldilocks_mul_sig_equiv a b) + ((proj1_sig goldilocks_mul_sig) a b) + ((proj1_sig goldilocks_mul_for_bounds_checker_sig) a b)) + b) + a). + { cbv [proj1_sig goldilocks_mul_for_bounds_checker_sig goldilocks_mul_sig Tuple.eta_tuple Tuple.eta_tuple_dep sz Tuple.eta_tuple'_dep id_tuple_with_alt_proof id_tuple'_with_alt_proof]; + cbn [fst snd]. + cbv [id_with_alt_proof]. + reflexivity. } + { rewrite !Tuple.strip_eta_tuple, !unfold_id_tuple_with_alt_proof. rewrite (proj2_sig goldilocks_mul_sig). reflexivity. } Defined. |