diff options
author | Jason Gross <jgross@mit.edu> | 2017-10-15 22:36:36 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2017-10-18 23:01:29 -0400 |
commit | 53e79344a3bc607a634433664ec2c43337a6aad9 (patch) | |
tree | a01b788b78f7f209b844a3d9102a0271ece90482 /src/Specific/Framework | |
parent | a17f3ea44a09638f0d78428290a96ecce613ad65 (diff) |
Karatsuba in gallina
After | File Name | Before || Change
------------------------------------------------------------------------------------------
8m08.69s | Total | 8m07.12s || +0m01.57s
------------------------------------------------------------------------------------------
2m02.96s | Specific/NISTP256/AMD64/femul | 1m57.80s || +0m05.15s
3m25.28s | Specific/X25519/C64/ladderstep | 3m28.68s || -0m03.40s
0m25.02s | Specific/X25519/C64/femul | 0m25.02s || +0m00.00s
0m23.68s | Specific/NISTP256/AMD64/fesub | 0m24.02s || -0m00.33s
0m21.80s | Specific/NISTP256/AMD64/feadd | 0m22.21s || -0m00.41s
0m20.38s | Specific/X25519/C64/freeze | 0m20.25s || +0m00.12s
0m19.19s | Specific/X25519/C64/fesquare | 0m19.60s || -0m00.41s
0m17.95s | Specific/NISTP256/AMD64/feopp | 0m18.02s || -0m00.07s
0m15.15s | Specific/NISTP256/AMD64/fenz | 0m15.14s || +0m00.00s
0m08.31s | Specific/NISTP256/AMD64/Synthesis | 0m08.21s || +0m00.09s
0m05.94s | Specific/X25519/C64/Synthesis | 0m05.70s || +0m00.24s
0m01.28s | Specific/Framework/ArithmeticSynthesis/Karatsuba | 0m00.74s || +0m00.54s
0m01.05s | Specific/Framework/SynthesisFramework | 0m01.05s || +0m00.00s
0m00.71s | Specific/Framework/ArithmeticSynthesis/KaratsubaPackage | 0m00.68s || +0m00.02s
Diffstat (limited to 'src/Specific/Framework')
-rw-r--r-- | src/Specific/Framework/ArithmeticSynthesis/Karatsuba.v | 166 | ||||
-rw-r--r-- | src/Specific/Framework/ArithmeticSynthesis/KaratsubaPackage.v | 8 |
2 files changed, 65 insertions, 109 deletions
diff --git a/src/Specific/Framework/ArithmeticSynthesis/Karatsuba.v b/src/Specific/Framework/ArithmeticSynthesis/Karatsuba.v index 78cc57d68..4dc83ce1a 100644 --- a/src/Specific/Framework/ArithmeticSynthesis/Karatsuba.v +++ b/src/Specific/Framework/ArithmeticSynthesis/Karatsuba.v @@ -1,5 +1,7 @@ Require Import Coq.ZArith.ZArith Coq.ZArith.BinIntDef. +Require Import Coq.QArith.QArith_base. Require Import Coq.Lists.List. Import ListNotations. +Require Import Crypto.Arithmetic.CoreUnfolder. Require Import Crypto.Arithmetic.Core. Import B. Require Import Crypto.Arithmetic.PrimeFieldTheorems. Require Crypto.Specific.Framework.CurveParameters. @@ -14,127 +16,81 @@ Require Import Crypto.Util.QUtil. Require Import Crypto.Util.ZUtil.ModInv. Require Import Crypto.Specific.Framework.ArithmeticSynthesis.SquareFromMul. +Require Import Crypto.Specific.Framework.ArithmeticSynthesis.Base. +Require Import Crypto.Specific.Framework.ArithmeticSynthesis.HelperTactics. Require Import Crypto.Util.Tactics.PoseTermWithName. Require Import Crypto.Util.Tactics.CacheTerm. -Local Notation tuple := Tuple.tuple. -Local Open Scope list_scope. Local Open Scope Z_scope. -Local Coercion Z.of_nat : nat >-> Z. Local Infix "^" := Tuple.tuple : type_scope. -(** XXX TODO(jadep) FIXME: Is sqrt(s) the right thing to pass to goldilocks_mul_cps (the original code hard-coded 2^224 *) +(** XXX TODO(jadep) FIXME: Should we sanity-check that we have 2^2k - 2^k - 1 / the right form of prime? *) Ltac internal_pose_sqrt_s s sqrt_s := let v := (eval vm_compute in (Z.log2 s / 2)) in cache_term (2^v) sqrt_s. -Ltac basesystem_partial_evaluation_RHS := - let t0 := (match goal with - | |- _ _ ?t => t - end) in - let t := - eval - cbv - delta [Positional.to_associational_cps Positional.to_associational - Positional.eval Positional.zeros Positional.add_to_nth_cps - Positional.add_to_nth Positional.place_cps Positional.place - Positional.from_associational_cps Positional.from_associational - Positional.carry_cps Positional.carry - Positional.chained_carries_cps Positional.chained_carries - Positional.sub_cps Positional.sub Positional.split_cps - Positional.scmul_cps Positional.unbalanced_sub_cps - Positional.negate_snd_cps Positional.add_cps Positional.opp_cps - Associational.eval Associational.multerm Associational.mul_cps - Associational.mul Associational.split_cps Associational.split - Associational.reduce_cps Associational.reduce - Associational.carryterm_cps Associational.carryterm - Associational.carry_cps Associational.carry - Associational.negate_snd_cps Associational.negate_snd div modulo - id_tuple_with_alt id_tuple'_with_alt - ] - in t0 - in - let t := eval pattern @runtime_mul in t in - let t := (match t with - | ?t _ => t - end) in - let t := eval pattern @runtime_add in t in - let t := (match t with - | ?t _ => t - end) in - let t := eval pattern @runtime_opp in t in - let t := (match t with - | ?t _ => t - end) in - let t := eval pattern @runtime_shr in t in - let t := (match t with - | ?t _ => t - end) in - let t := eval pattern @runtime_and in t in - let t := (match t with - | ?t _ => t - end) in - let t := eval pattern @Let_In in t in - let t := (match t with - | ?t _ => t - end) in - let t := eval pattern @id_with_alt in t in - let t := (match t with - | ?t _ => t - end) in - let t1 := fresh "t1" in - pose (t1 := t); - transitivity - (t1 (@id_with_alt) (@Let_In) (@runtime_and) (@runtime_shr) (@runtime_opp) (@runtime_add) - (@runtime_mul)); - [ replace_with_vm_compute t1; clear t1 | reflexivity ]. +Section gen. + Context (m : positive) + (base : Q) + (sz : nat) + (s : Z) + (c : list limb) + (half_sz : nat) + (sqrt_s : Z) + (base_pos : (1 <= base)%Q) + (sz_nonzero : sz <> 0%nat) + (half_sz_nonzero : half_sz <> 0%nat) + (s_nonzero : s <> 0%Z) + (m_correct : Z.pos m = s - Associational.eval c) + (sqrt_s_nonzero : sqrt_s <> 0) + (sqrt_s_mod_m : sqrt_s ^ 2 mod Z.pos m = (sqrt_s + 1) mod Z.pos m). -Ltac internal_pose_goldilocks_mul_sig sz wt s c half_sz sqrt_s goldilocks_mul_sig := - cache_term_with_type_by - {mul : (Z^sz -> Z^sz -> Z^sz)%type | - forall a b : Z^sz, - mul a b = goldilocks_mul_cps (n:=half_sz) (n2:=sz) wt sqrt_s a b (fun ab => Positional.reduce_cps (n:=sz) wt s c ab id)} - ltac:(eexists; cbv beta zeta; intros; - cbv [goldilocks_mul_cps]; - repeat autounfold; - basesystem_partial_evaluation_RHS; - do_replace_match_with_destructuring_match_in_goal; - reflexivity) - goldilocks_mul_sig. + Local Notation wt := (wt_gen base). + Local Notation wt_divides' := (wt_gen_divides' base base_pos). + Local Notation wt_nonzero := (wt_gen_nonzero base base_pos). -Ltac internal_pose_mul_sig_from_goldilocks_mul_sig sz m wt s c half_sz sqrt_s goldilocks_mul_sig wt_nonzero mul_sig := - cache_term_with_type_by - {mul : (Z^sz -> Z^sz -> Z^sz)%type | - forall a b : Z^sz, - let eval := Positional.Fdecode (m := m) wt in - Positional.Fdecode (m := m) wt (mul a b) = (eval a * eval b)%F} - ltac:(idtac; - let a := fresh "a" in - let b := fresh "b" in - eexists; cbv beta zeta; intros a b; - pose proof wt_nonzero; - let x := constr:( - goldilocks_mul_cps (n:=half_sz) (n2:=sz) wt sqrt_s a b (fun ab => Positional.reduce_cps (n:=sz) wt s c ab id)) in - F_mod_eq; - transitivity (Positional.eval wt x); repeat autounfold; + Definition goldilocks_mul_sig' + : { mul : (Z^sz -> Z^sz -> Z^sz)%type + | forall a b : Z^sz, + mul a b = goldilocks_mul_cps (n:=half_sz) (n2:=sz) wt sqrt_s a b (fun ab => Positional.reduce_cps (n:=sz) wt s c ab id) }. + Proof. + eexists; cbv beta zeta; intros. + cbv [goldilocks_mul_cps]. + autorewrite with pattern_runtime. + reflexivity. + Defined. - [ - | autorewrite with uncps push_id push_basesystem_eval; - apply goldilocks_mul_correct; try assumption; cbv; congruence ]; - cbv [mod_eq]; apply f_equal2; - [ | reflexivity ]; - apply f_equal; - etransitivity; [|apply (proj2_sig goldilocks_mul_sig)]; - cbv [proj1_sig goldilocks_mul_sig]; - reflexivity) - mul_sig. + Definition mul_sig' + : { mul : (Z^sz -> Z^sz -> Z^sz)%type + | forall a b : Z^sz, + let eval := Positional.Fdecode (m := m) wt in + Positional.Fdecode (m := m) wt (mul a b) = (eval a * eval b)%F }. + Proof. + eexists; cbv beta zeta; intros a b. + pose proof wt_nonzero. + pose proof (wt_gen0_1 base). + let x := constr:( + goldilocks_mul_cps (n:=half_sz) (n2:=sz) wt sqrt_s a b (fun ab => Positional.reduce_cps (n:=sz) wt s c ab id)) in + presolve_op_F constr:(wt) x; + [ cbv [goldilocks_mul_cps]; + autorewrite with pattern_runtime; + reflexivity + | ]. + apply goldilocks_mul_correct; auto; try congruence. + Defined. +End gen. -Ltac pose_mul_sig sz m wt s c half_sz wt_nonzero mul_sig := +Ltac pose_mul_sig wt m base sz s c half_sz mul_sig := let sqrt_s := fresh "sqrt_s" in - let goldilocks_mul_sig := fresh "goldilocks_mul_sig" in let sqrt_s := internal_pose_sqrt_s s sqrt_s in - let goldilocks_mul_sig := internal_pose_goldilocks_mul_sig sz wt s c half_sz sqrt_s goldilocks_mul_sig in - internal_pose_mul_sig_from_goldilocks_mul_sig sz m wt s c half_sz sqrt_s goldilocks_mul_sig wt_nonzero mul_sig. + cache_sig_with_type_by_existing_sig_helper + ltac:(fun _ => cbv [mul_sig']) + { mul : (Z^sz -> Z^sz -> Z^sz)%type + | forall a b : Z^sz, + let eval := Positional.Fdecode (m := m) wt in + Positional.Fdecode (m := m) wt (mul a b) = (eval a * eval b)%F } + (mul_sig' m base sz s c half_sz sqrt_s) + mul_sig. Ltac pose_square_sig sz m wt mul_sig square_sig := SquareFromMul.pose_square_sig sz m wt mul_sig square_sig. diff --git a/src/Specific/Framework/ArithmeticSynthesis/KaratsubaPackage.v b/src/Specific/Framework/ArithmeticSynthesis/KaratsubaPackage.v index 264effeb1..2e7d98c1e 100644 --- a/src/Specific/Framework/ArithmeticSynthesis/KaratsubaPackage.v +++ b/src/Specific/Framework/ArithmeticSynthesis/KaratsubaPackage.v @@ -10,15 +10,15 @@ Require Import Crypto.Util.TagList. Ltac add_mul_sig pkg := if_goldilocks pkg - ltac:(fun _ => let sz := Tag.get pkg TAG.sz in + ltac:(fun _ => let wt := Tag.get pkg TAG.wt in let m := Tag.get pkg TAG.m in - let wt := Tag.get pkg TAG.wt in + let base := Tag.get pkg TAG.base in + let sz := Tag.get pkg TAG.sz in let s := Tag.get pkg TAG.s in let c := Tag.get pkg TAG.c in let half_sz := Tag.get pkg TAG.half_sz in - let wt_nonzero := Tag.get pkg TAG.wt_nonzero in let mul_sig := fresh "mul_sig" in - let mul_sig := pose_mul_sig sz m wt s c half_sz wt_nonzero mul_sig in + let mul_sig := pose_mul_sig wt m base sz s c half_sz mul_sig in Tag.update pkg TAG.mul_sig mul_sig) ltac:(fun _ => pkg) (). |