diff options
author | Jason Gross <jgross@mit.edu> | 2017-10-19 14:44:48 -0400 |
---|---|---|
committer | Jason Gross <jgross@mit.edu> | 2017-10-19 15:40:23 -0400 |
commit | 7e939cd63236d0a6a492ddff5015daf3f706a3bc (patch) | |
tree | fa19e772dc624eb7899017b55e527de184e7bf8f /src/Arithmetic | |
parent | 79b586e4589f56d081301de92b305569c1077ed2 (diff) |
Switch arithmetic to cps for Z * Z under the hood
This is in preparation for writing a ~compiler for the arithmetic things
to expression trees.
I'm not sure what's up with femul in the table below; I ran it again and
got:
After:
src/Specific/NISTP256/AMD64/femul (real: 115.70, user: 115.25, sys: 0.44, mem: 3571448 ko)
Before:
src/Specific/NISTP256/AMD64/femul (real: 118.49, user: 117.99, sys: 0.43, mem: 3581612 ko)
After | File Name | Before || Change
---------------------------------------------------------------------------------------------
17m02.82s | Total | 16m36.20s || +0m26.61s
---------------------------------------------------------------------------------------------
2m27.04s | Specific/NISTP256/AMD64/femul | 2m04.60s || +0m22.43s
1m38.55s | Specific/X2448/Karatsuba/C64/femul | 1m41.44s || -0m02.89s
0m12.46s | Arithmetic/Saturated/AddSub | 0m09.77s || +0m02.69s
3m22.38s | Specific/X25519/C64/ladderstep | 3m23.49s || -0m01.11s
0m54.40s | Specific/X25519/C32/fesquare | 0m52.68s || +0m01.71s
0m28.70s | Arithmetic/Karatsuba | 0m27.59s || +0m01.10s
0m10.00s | Arithmetic/Saturated/MontgomeryAPI | 0m08.95s || +0m01.05s
0m08.15s | Specific/X2448/Karatsuba/C64/Synthesis | 0m09.47s || -0m01.32s
0m05.62s | Arithmetic/Saturated/MulSplit | 0m04.28s || +0m01.33s
1m29.44s | Specific/X25519/C32/femul | 1m28.55s || +0m00.89s
0m39.38s | Specific/X25519/C32/freeze | 0m38.62s || +0m00.76s
0m31.54s | Specific/NISTP256/AMD128/femul | 0m31.60s || -0m00.06s
0m24.80s | Specific/X25519/C64/femul | 0m24.10s || +0m00.69s
0m23.82s | Specific/NISTP256/AMD64/fesub | 0m23.52s || +0m00.30s
0m21.81s | Specific/NISTP256/AMD64/feadd | 0m21.90s || -0m00.08s
0m20.30s | Specific/X25519/C64/freeze | 0m20.26s || +0m00.03s
0m20.12s | Specific/X25519/C32/Synthesis | 0m20.77s || -0m00.64s
0m19.12s | Specific/X25519/C64/fesquare | 0m19.02s || +0m00.10s
0m17.28s | Specific/NISTP256/AMD64/feopp | 0m17.68s || -0m00.39s
0m15.99s | Specific/NISTP256/AMD128/fesub | 0m16.03s || -0m00.04s
0m15.88s | Specific/NISTP256/AMD128/feadd | 0m16.56s || -0m00.67s
0m15.03s | Specific/NISTP256/AMD64/fenz | 0m15.00s || +0m00.02s
0m14.18s | Specific/NISTP256/AMD128/fenz | 0m14.12s || +0m00.06s
0m13.46s | Specific/NISTP256/AMD128/feopp | 0m12.88s || +0m00.58s
0m12.15s | Arithmetic/Core | 0m12.03s || +0m00.12s
0m07.82s | Arithmetic/Saturated/Core | 0m07.05s || +0m00.77s
0m07.13s | Specific/NISTP256/AMD64/Synthesis | 0m08.05s || -0m00.92s
0m05.48s | Specific/X25519/C64/Synthesis | 0m05.68s || -0m00.19s
0m04.02s | Specific/Framework/ArithmeticSynthesis/Montgomery | 0m03.89s || +0m00.12s
0m03.52s | Arithmetic/MontgomeryReduction/WordByWord/Proofs | 0m03.34s || +0m00.18s
0m03.32s | Specific/NISTP256/AMD128/Synthesis | 0m03.46s || -0m00.14s
0m02.30s | Specific/Framework/ArithmeticSynthesis/Defaults | 0m02.31s || -0m00.01s
0m02.08s | Arithmetic/Saturated/Freeze | 0m01.94s || +0m00.14s
0m01.66s | Specific/Framework/OutputType | 0m01.66s || +0m00.00s
0m01.54s | Arithmetic/CoreUnfolder | 0m01.43s || +0m00.11s
0m01.35s | Specific/Framework/ArithmeticSynthesis/Karatsuba | 0m01.28s || +0m00.07s
0m01.13s | Arithmetic/Saturated/CoreUnfolder | 0m01.16s || -0m00.03s
0m01.06s | Arithmetic/Saturated/WrappersUnfolder | 0m01.04s || +0m00.02s
0m01.04s | Arithmetic/Saturated/UniformWeight | 0m00.95s || +0m00.09s
0m01.03s | Specific/Framework/ArithmeticSynthesis/Base | 0m01.14s || -0m00.10s
0m01.02s | Specific/Framework/SynthesisFramework | 0m01.04s || -0m00.02s
0m00.97s | Specific/Framework/ArithmeticSynthesis/HelperTactics | 0m01.01s || -0m00.04s
0m00.92s | Specific/Framework/ReificationTypes | 0m00.90s || +0m00.02s
0m00.92s | Specific/Framework/ArithmeticSynthesis/Freeze | 0m00.93s || -0m00.01s
0m00.90s | Arithmetic/Saturated/MulSplitUnfolder | 0m00.83s || +0m00.07s
0m00.83s | Specific/Framework/ReificationTypesPackage | 0m00.79s || +0m00.03s
0m00.83s | Arithmetic/Saturated/FreezeUnfolder | 0m00.86s || -0m00.03s
0m00.82s | Specific/Framework/ArithmeticSynthesis/BasePackage | 0m00.77s || +0m00.04s
0m00.81s | Specific/Framework/ArithmeticSynthesis/SquareFromMul | 0m00.72s || +0m00.09s
0m00.81s | Specific/Framework/ArithmeticSynthesis/LadderstepPackage | 0m00.82s || -0m00.00s
0m00.80s | Specific/Framework/MontgomeryReificationTypesPackage | 0m00.82s || -0m00.01s
0m00.78s | Specific/Framework/ArithmeticSynthesis/MontgomeryPackage | 0m00.79s || -0m00.01s
0m00.78s | Arithmetic/Saturated/Wrappers | 0m00.78s || +0m00.00s
0m00.76s | Specific/Framework/ArithmeticSynthesis/FreezePackage | 0m00.80s || -0m00.04s
0m00.76s | Specific/Framework/ArithmeticSynthesis/DefaultsPackage | 0m00.75s || +0m00.01s
0m00.75s | Specific/Framework/MontgomeryReificationTypes | 0m00.78s || -0m00.03s
0m00.73s | Specific/Framework/ArithmeticSynthesis/Ladderstep | 0m00.77s || -0m00.04s
0m00.73s | Arithmetic/MontgomeryReduction/WordByWord/Definition | 0m00.80s || -0m00.07s
0m00.72s | Arithmetic/Saturated/UniformWeightInstances | 0m00.78s || -0m00.06s
0m00.68s | Specific/Framework/ArithmeticSynthesis/KaratsubaPackage | 0m00.76s || -0m00.07s
0m00.43s | Util/ZUtil/CPS | 0m00.42s || +0m00.01s
Diffstat (limited to 'src/Arithmetic')
-rw-r--r-- | src/Arithmetic/Core.v | 73 | ||||
-rw-r--r-- | src/Arithmetic/CoreUnfolder.v | 15 | ||||
-rw-r--r-- | src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v | 6 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/AddSub.v | 70 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/Core.v | 43 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/Freeze.v | 2 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/MontgomeryAPI.v | 22 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/MulSplit.v | 48 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/Wrappers.v | 9 |
9 files changed, 189 insertions, 99 deletions
diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v index 430e1c19a..f2d9ee00b 100644 --- a/src/Arithmetic/Core.v +++ b/src/Arithmetic/Core.v @@ -585,9 +585,10 @@ Module B. Context {T : Type}. Fixpoint place_cps (t:limb) (i:nat) (f:nat * Z->T) := - if dec (fst t mod weight i = 0) + Z.eqb_cps (fst t mod weight i) 0 (fun eqb => + if eqb then f (i, let c := fst t / weight i in (c * snd t)%RT) - else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end. + else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end). End place_cps. Definition place t i := place_cps t i id. @@ -599,12 +600,13 @@ Module B. Lemma place_cps_in_range (t:limb) (n:nat) : (fst (place_cps t n id) < S n)%nat. - Proof using Type. induction n; simpl; break_match; simpl; omega. Qed. + Proof using Type. induction n; simpl; cbv [Z.eqb_cps]; break_match; simpl; omega. Qed. Lemma weight_place_cps t i : weight (fst (place_cps t i id)) * snd (place_cps t i id) = fst t * snd t. Proof using Type*. - induction i; cbv [id]; simpl place_cps; break_match; + induction i; cbv [id]; simpl place_cps; cbv [Z.eqb_cps]; break_match; + Z.ltb_to_lt; autorewrite with cancel_pair; try match goal with [H:_|-_] => apply Z_div_exact_full_2 in H end; nsatz || auto. @@ -962,38 +964,59 @@ End B. (* Modulo and div that do shifts if possible, otherwise normal mod/div *) Section DivMod. - Definition modulo (a b : Z) : Z := - if dec (2 ^ (Z.log2 b) = b) - then let x := (Z.ones (Z.log2 b)) in (a &' x)%RT - else Z.modulo a b. - - Definition div (a b : Z) : Z := - if dec (2 ^ (Z.log2 b) = b) - then let x := Z.log2 b in (a >> x)%RT - else Z.div a b. - - Lemma div_correct a b : div a b = Z.div a b. + Definition modulo_cps {T} (a b : Z) (f : Z -> T) : T := + Z.eqb_cps (2 ^ (Z.log2 b)) b (fun eqb => + if eqb + then let x := (Z.ones (Z.log2 b)) in f (a &' x)%RT + else f (Z.modulo a b)). + + Definition div_cps {T} (a b : Z) (f : Z -> T) : T := + Z.eqb_cps (2 ^ (Z.log2 b)) b (fun eqb => + if eqb + then let x := Z.log2 b in f ((a >> x)%RT) + else f (Z.div a b)). + + Definition modulo (a b : Z) : Z := modulo_cps a b id. + Definition div (a b : Z) : Z := div_cps a b id. + + Lemma modulo_id {T} a b f + : @modulo_cps T a b f = f (modulo a b). + Proof. cbv [modulo_cps modulo]; autorewrite with uncps; break_match; reflexivity. Qed. + Hint Opaque modulo : uncps. + Hint Rewrite @modulo_id : uncps. + + Lemma div_id {T} a b f + : @div_cps T a b f = f (div a b). + Proof. cbv [div_cps div]; autorewrite with uncps; break_match; reflexivity. Qed. + Hint Opaque div : uncps. + Hint Rewrite @div_id : uncps. + + Lemma div_cps_correct {T} a b f : @div_cps T a b f = f (Z.div a b). Proof. - cbv [div]; intros. break_match; try reflexivity. + cbv [div_cps Z.eqb_cps]; intros. break_match; try reflexivity. rewrite Z.shiftr_div_pow2 by apply Z.log2_nonneg. - congruence. + Z.ltb_to_lt; congruence. Qed. - Lemma modulo_correct a b : modulo a b = Z.modulo a b. + Lemma modulo_cps_correct {T} a b f : @modulo_cps T a b f = f (Z.modulo a b). Proof. - cbv [modulo]; intros. break_match; try reflexivity. + cbv [modulo_cps Z.eqb_cps]; intros. break_match; try reflexivity. rewrite Z.land_ones by apply Z.log2_nonneg. - congruence. + Z.ltb_to_lt; congruence. Qed. + Definition div_correct a b : div a b = Z.div a b := div_cps_correct a b id. + Definition modulo_correct a b : modulo a b = Z.modulo a b := modulo_cps_correct a b id. + Lemma div_mod a b (H:b <> 0) : a = b * div a b + modulo a b. Proof. - cbv [div modulo]; intros. break_match; auto using Z.div_mod. - rewrite Z.land_ones, Z.shiftr_div_pow2 by apply Z.log2_nonneg. - pose proof (Z.div_mod a b H). congruence. + rewrite div_correct, modulo_correct; auto using Z.div_mod. Qed. End DivMod. +Hint Opaque div modulo : uncps. +Hint Rewrite @div_id @modulo_id : uncps. + Import B. Create HintDb basesystem_partial_evaluation_unfolder. @@ -1045,7 +1068,7 @@ Hint Unfold Positional.eval_from Positional.select_cps Positional.select - modulo div + modulo div modulo_cps div_cps id_tuple_with_alt id_tuple'_with_alt Z.add_get_carry_full Z.add_get_carry_full_cps : basesystem_partial_evaluation_unfolder. @@ -1055,7 +1078,7 @@ Hint Unfold CPSUtil.Tuple.mapi_with_cps CPSUtil.Tuple.mapi_with'_cps CPSUtil.flat_map_cps CPSUtil.on_tuple_cps CPSUtil.fold_right_cps2 Decidable.dec Decidable.dec_eq_Z id_tuple_with_alt id_tuple'_with_alt - Z.add_get_carry_full Z.add_get_carry_full_cps Z.mul_split Z.mul_split_cps + Z.add_get_carry_full Z.add_get_carry_full_cps Z.mul_split Z.mul_split_cps Z.mul_split_cps' : basesystem_partial_evaluation_unfolder. diff --git a/src/Arithmetic/CoreUnfolder.v b/src/Arithmetic/CoreUnfolder.v index df005e630..cb943e80a 100644 --- a/src/Arithmetic/CoreUnfolder.v +++ b/src/Arithmetic/CoreUnfolder.v @@ -12,7 +12,7 @@ Ltac make_parameterized_sig t := Decidable.dec Decidable.dec_eq_Z id_tuple_with_alt id_tuple'_with_alt Z.add_get_carry_full Z.mul_split - Z.add_get_carry_full_cps Z.mul_split_cps + Z.add_get_carry_full_cps Z.mul_split_cps Z.mul_split_cps' Z.add_get_carry_cps]; repeat autorewrite with pattern_runtime; reflexivity. @@ -64,7 +64,7 @@ done echo " End Positional." echo "End B." echo "" -for i in modulo div; do +for i in modulo_cps div_cps modulo div; do echo "Definition ${i}_sig := parameterize_sig (@Core.${i})."; echo "Definition ${i} := parameterize_from_sig ${i}_sig."; echo "Definition ${i}_eq := parameterize_eq ${i} ${i}_sig."; @@ -140,6 +140,7 @@ done Definition carry := parameterize_from_sig carry_sig. Definition carry_eq := parameterize_eq carry carry_sig. Hint Rewrite <- carry_eq : pattern_runtime. + End Associational. Module Positional. Definition to_associational_cps_sig := parameterize_sig (@Core.B.Positional.to_associational_cps). @@ -300,6 +301,16 @@ done End Positional. End B. +Definition modulo_cps_sig := parameterize_sig (@Core.modulo_cps). +Definition modulo_cps := parameterize_from_sig modulo_cps_sig. +Definition modulo_cps_eq := parameterize_eq modulo_cps modulo_cps_sig. +Hint Rewrite <- modulo_cps_eq : pattern_runtime. + +Definition div_cps_sig := parameterize_sig (@Core.div_cps). +Definition div_cps := parameterize_from_sig div_cps_sig. +Definition div_cps_eq := parameterize_eq div_cps div_cps_sig. +Hint Rewrite <- div_cps_eq : pattern_runtime. + Definition modulo_sig := parameterize_sig (@Core.modulo). Definition modulo := parameterize_from_sig modulo_sig. Definition modulo_eq := parameterize_eq modulo modulo_sig. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v index 9affa82fa..fd4869f23 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v @@ -10,6 +10,7 @@ Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Depende Require Import Crypto.Util.Notations. Require Import Crypto.Util.LetIn. Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.CPS. Local Open Scope Z_scope. @@ -47,10 +48,11 @@ Section WordByWordMontgomery. := divmod_cps A (fun '(A, a) => @scmul_cps r _ a B _ (fun aB => @add_cps r _ S' aB _ (fun S1 => divmod_cps S1 (fun '(_, s) => - dlet q := fst (Z.mul_split r s k) in + Z.mul_split_cps' r s k (fun mul_split_r_s_k => + dlet q := fst mul_split_r_s_k in @scmul_cps r _ q N _ (fun qN => @add_S1_cps r _ S1 qN _ (fun S2 => divmod_cps S2 (fun '(S3, _) => - @drop_high_cps (S R_numlimbs) S3 _ (fun S4 => rest (A, S4))))))))). + @drop_high_cps (S R_numlimbs) S3 _ (fun S4 => rest (A, S4)))))))))). Section loop. Context {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) {cpsT : Type}. diff --git a/src/Arithmetic/Saturated/AddSub.v b/src/Arithmetic/Saturated/AddSub.v index e34230904..e886c36de 100644 --- a/src/Arithmetic/Saturated/AddSub.v +++ b/src/Arithmetic/Saturated/AddSub.v @@ -6,8 +6,10 @@ Require Import Crypto.Arithmetic.Core. Require Import Crypto.Arithmetic.Saturated.Core. Require Import Crypto.Arithmetic.Saturated.UniformWeight. Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.CPS. Require Import Crypto.Util.ZUtil.AddGetCarry. Require Import Crypto.Util.Tuple Crypto.Util.LetIn. +Require Import Crypto.Util.Tactics.BreakMatch. Local Notation "A ^ n" := (tuple A n) : type_scope. Module B. @@ -17,8 +19,15 @@ Module B. Let small {n} := @small s n. Section GenericOp. Context {op : Z -> Z -> Z} - {op_get_carry : Z -> Z -> Z * Z} (* no carry in, carry out *) - {op_with_carry : Z -> Z -> Z -> Z * Z}. (* carry in, carry out *) + {op_get_carry_cps : forall {T}, Z -> Z -> (Z * Z -> T) -> T} (* no carry in, carry out *) + {op_with_carry_cps : forall {T}, Z -> Z -> Z -> (Z * Z -> T) -> T}. (* carry in, carry out *) + Let op_get_carry x y := op_get_carry_cps _ x y id. + Let op_with_carry x y z := op_with_carry_cps _ x y z id. + Context {op_get_carry_id : forall {T} x y f, + @op_get_carry_cps T x y f = f (op_get_carry x y)} + {op_with_carry_id : forall {T} x y z f, + @op_with_carry_cps T x y z f = f (op_with_carry x y z)}. + Hint Rewrite @op_get_carry_id @op_with_carry_id : uncps. Section chain_op'_cps. Context (T : Type). @@ -32,14 +41,15 @@ Module B. | S n' => fun c p q f => (* for the first call, use op_get_carry, then op_with_carry *) - let op' := match c with - | None => op_get_carry - | Some x => op_with_carry x end in - dlet carry_result := op' (hd p) (hd q) in + let op'_cps := match c with + | None => op_get_carry_cps _ + | Some x => op_with_carry_cps _ x end in + op'_cps (hd p) (hd q) (fun carry_result => + dlet carry_result := carry_result in chain_op'_cps (Some (snd carry_result)) (tl p) (tl q) (fun carry_pq => f (fst carry_pq, - append (fst carry_result) (snd carry_pq))) + append (fst carry_result) (snd carry_pq)))) end c p q. End chain_op'_cps. Definition chain_op' {n} c p q := @chain_op'_cps _ n c p q id. @@ -50,7 +60,8 @@ Module B. @chain_op'_cps T n c p q f = f (chain_op' c p q). Proof. cbv [chain_op']; induction n; intros; destruct c; - simpl chain_op'_cps; cbv [Let_In]; try reflexivity. + simpl chain_op'_cps; cbv [Let_In]; try reflexivity; + autorewrite with uncps. { etransitivity; rewrite IHn; reflexivity. } { etransitivity; rewrite IHn; reflexivity. } Qed. @@ -60,7 +71,7 @@ Module B. Proof. apply (@chain_op'_id n None). Qed. End GenericOp. Hint Opaque chain_op chain_op' : uncps. - Hint Rewrite @chain_op_id @chain_op'_id : uncps. + Hint Rewrite @chain_op_id @chain_op'_id using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps. Section AddSub. Create HintDb divmod discriminated. @@ -77,14 +88,14 @@ Module B. Let eval {n} := B.Positional.eval (n:=n) (uweight s). Definition sat_add_cps {n} p q T (f:Z*Z^n->T) := - chain_op_cps (op_get_carry := Z.add_get_carry_full s) - (op_with_carry := Z.add_with_get_carry_full s) + chain_op_cps (op_get_carry_cps := fun T => Z.add_get_carry_full_cps s) + (op_with_carry_cps := fun T => Z.add_with_get_carry_full_cps s) p q f. Definition sat_add {n} p q := @sat_add_cps n p q _ id. Lemma sat_add_id n p q T f : @sat_add_cps n p q T f = f (sat_add p q). - Proof. cbv [sat_add sat_add_cps]. rewrite !chain_op_id. reflexivity. Qed. + Proof. cbv [sat_add sat_add_cps]. autorewrite with uncps. reflexivity. Qed. Lemma sat_add_mod_step n c d : c mod s + s * ((d + c / s) mod (uweight s n)) @@ -156,23 +167,25 @@ Module B. simpl In in H | H : _ \/ _ |- _ => destruct H | _ => contradiction - end. - { subst x. - destruct c; rewrite ?Z.add_with_get_carry_full_mod, - ?Z.add_get_carry_full_mod; - apply Z.mod_pos_bound; omega. } - { apply IHn in H. assumption. } + | _ => break_innermost_match_hyps_step + | _ => progress subst + | [ H : In _ (to_list _ (snd _)) |- _ ] + => apply IHn in H; assumption + end; + try solve [ rewrite ?Z.add_with_get_carry_full_mod, + ?Z.add_get_carry_full_mod; + apply Z.mod_pos_bound; omega ]. Qed. Definition sat_sub_cps {n} p q T (f:Z*Z^n->T) := - chain_op_cps (op_get_carry := Z.sub_get_borrow_full s) - (op_with_carry := Z.sub_with_get_borrow_full s) + chain_op_cps (op_get_carry_cps := fun T => Z.sub_get_borrow_full_cps s) + (op_with_carry_cps := fun T => Z.sub_with_get_borrow_full_cps s) p q f. Definition sat_sub {n} p q := @sat_sub_cps n p q _ id. Lemma sat_sub_id n p q T f : @sat_sub_cps n p q T f = f (sat_sub p q). - Proof. cbv [sat_sub sat_sub_cps]. rewrite !chain_op_id. reflexivity. Qed. + Proof. cbv [sat_sub sat_sub_cps]. autorewrite with uncps. reflexivity. Qed. Lemma sat_sub_divmod n p q : eval (snd (@sat_sub n p q)) = (eval p - eval q) mod (uweight s n) /\ fst (@sat_sub n p q) = - ((eval p - eval q) / (uweight s n)). @@ -224,19 +237,22 @@ Module B. simpl In in H | H : _ \/ _ |- _ => destruct H | _ => contradiction - end. - { subst x. - destruct c; rewrite ?Z.sub_with_get_borrow_full_mod, + | _ => break_innermost_match_hyps_step + | _ => progress subst + | [ H : In _ (to_list _ (snd _)) |- _ ] + => apply IHn in H; assumption + end; + try solve [ rewrite ?Z.sub_with_get_borrow_full_mod, ?Z.sub_get_borrow_full_mod; - apply Z.mod_pos_bound; omega. } - { apply IHn in H. assumption. } + apply Z.mod_pos_bound; omega ]. Qed. End AddSub. End Positional. End Positional. End B. Hint Opaque B.Positional.sat_sub B.Positional.sat_add B.Positional.chain_op B.Positional.chain_op' : uncps. -Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id @B.Positional.chain_op_id @B.Positional.chain_op' : uncps. +Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id : uncps. +Hint Rewrite @B.Positional.chain_op_id @B.Positional.chain_op' using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps. Hint Rewrite @B.Positional.sat_sub_mod @B.Positional.sat_sub_div @B.Positional.sat_add_mod @B.Positional.sat_add_div using (omega || assumption) : push_basesystem_eval. Hint Unfold diff --git a/src/Arithmetic/Saturated/Core.v b/src/Arithmetic/Saturated/Core.v index e34a37501..8af1cce6f 100644 --- a/src/Arithmetic/Saturated/Core.v +++ b/src/Arithmetic/Saturated/Core.v @@ -112,16 +112,20 @@ Module Columns. {weight_multiples : forall i, weight (S i) mod weight i = 0} {weight_divides : forall i : nat, weight (S i) / weight i > 0} (* add_get_carry takes in a number at which to split output *) - {add_get_carry: Z ->Z -> Z -> (Z * Z)} + {add_get_carry_cps: forall {T}, Z ->Z -> Z -> (Z * Z -> T) -> T} + {add_get_carry_cps_id : forall {T} s x y f, + @add_get_carry_cps T s x y f = f (@add_get_carry_cps _ s x y id)} {add_get_carry_mod : forall s x y, - fst (add_get_carry s x y) = (x + y) mod s} + fst (add_get_carry_cps s x y id) = (x + y) mod s} {add_get_carry_div : forall s x y, - snd (add_get_carry s x y) = (x + y) / s} + snd (add_get_carry_cps s x y id) = (x + y) / s} {div modulo : Z -> Z -> Z} {div_correct : forall a b, div a b = a / b} {modulo_correct : forall a b, modulo a b = a mod b} . Hint Rewrite div_correct modulo_correct add_get_carry_mod add_get_carry_div : div_mod. + Let add_get_carry s x y := add_get_carry_cps _ s x y id. + Hint Rewrite (add_get_carry_cps_id : forall T s x y f, _ = f (@add_get_carry s x y)) : uncps. Definition eval {n} (x : (list Z)^n) : Z := B.Positional.eval weight (Tuple.map sum x). @@ -164,25 +168,27 @@ Module Columns. | nil => f (0, 0) | x :: nil => f (div x (weight (S n) / weight n), modulo x (weight (S n) / weight n)) | x :: y :: nil => - dlet sum_carry := add_get_carry (weight (S n) / weight n) x y in + add_get_carry_cps _ (weight (S n) / weight n) x y (fun sum_carry => + dlet sum_carry := sum_carry in dlet carry := snd sum_carry in - f (carry, fst sum_carry) + f (carry, fst sum_carry)) | x :: tl => compact_digit_cps tl (fun rec => - dlet sum_carry := add_get_carry (weight (S n) / weight n) x (snd rec) in + add_get_carry_cps _ (weight (S n) / weight n) x (snd rec) (fun sum_carry => + dlet sum_carry := sum_carry in dlet carry' := (fst rec + snd sum_carry)%RT in - f (carry', fst sum_carry)) + f (carry', fst sum_carry))) end. End compact_digit_cps. Definition compact_digit n digit := compact_digit_cps n digit id. Lemma compact_digit_id n digit: forall {T} f, @compact_digit_cps n T digit f = f (compact_digit n digit). - Proof using Type. - induction digit; intros; cbv [compact_digit]; [reflexivity|]; - simpl compact_digit_cps; break_match; rewrite ?IHdigit; - reflexivity. + Proof using add_get_carry_cps_id. + induction digit; intros; cbv [compact_digit]; [reflexivity|]. + simpl compact_digit_cps; break_match; rewrite ?IHdigit; clear IHdigit; + cbv [Let_In]; autorewrite with uncps; reflexivity. Qed. Hint Opaque compact_digit : uncps. Hint Rewrite compact_digit_id : uncps. @@ -194,7 +200,7 @@ Module Columns. Definition compact_step i c d := compact_step_cps i c d id. Lemma compact_step_id i c d T f : @compact_step_cps i c d T f = f (compact_step i c d). - Proof using Type. cbv [compact_step_cps compact_step]; autorewrite with uncps; reflexivity. Qed. + Proof using add_get_carry_cps_id. cbv [compact_step_cps compact_step]; autorewrite with uncps; reflexivity. Qed. Hint Opaque compact_step : uncps. Hint Rewrite compact_step_id : uncps. @@ -203,15 +209,15 @@ Module Columns. Definition compact {n} xs := @compact_cps n xs _ id. Lemma compact_id {n} xs {T} f : @compact_cps n xs T f = f (compact xs). - Proof using Type. cbv [compact_cps compact]; autorewrite with uncps; reflexivity. Qed. + Proof using add_get_carry_cps_id. cbv [compact_cps compact]; autorewrite with uncps; reflexivity. Qed. Lemma compact_digit_mod i (xs : list Z) : snd (compact_digit i xs) = sum xs mod (weight (S i) / weight i). - Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct. + Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct add_get_carry_cps_id. induction xs; cbv [compact_digit]; simpl compact_digit_cps; cbv [Let_In]; repeat match goal with - | _ => progress autorewrite with div_mod + | _ => cbv [add_get_carry]; progress autorewrite with div_mod | _ => rewrite IHxs, <-Z.add_mod_r | _ => progress (rewrite ?sum_cons, ?sum_nil in * ) | _ => progress (autorewrite with uncps push_id cancel_pair in * ) @@ -223,11 +229,11 @@ Module Columns. Lemma compact_digit_div i (xs : list Z) : fst (compact_digit i xs) = sum xs / (weight (S i) / weight i). - Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct weight_0 weight_divides. + Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct weight_0 weight_divides add_get_carry_cps_id. induction xs; cbv [compact_digit]; simpl compact_digit_cps; cbv [Let_In]; repeat match goal with - | _ => progress autorewrite with div_mod + | _ => cbv [add_get_carry]; progress autorewrite with div_mod | _ => rewrite IHxs | _ => progress (rewrite ?sum_cons, ?sum_nil in * ) | _ => progress (autorewrite with uncps push_id cancel_pair in * ) @@ -421,6 +427,9 @@ Hint Rewrite @Columns.compact_digit_id @Columns.compact_step_id @Columns.compact_id + using (assumption || (intros; autorewrite with uncps; reflexivity)) + : uncps. +Hint Rewrite @Columns.cons_to_nth_id @Columns.from_associational_id : uncps. diff --git a/src/Arithmetic/Saturated/Freeze.v b/src/Arithmetic/Saturated/Freeze.v index 65b8ee55d..b56a69a3d 100644 --- a/src/Arithmetic/Saturated/Freeze.v +++ b/src/Arithmetic/Saturated/Freeze.v @@ -7,6 +7,7 @@ Require Import Crypto.Arithmetic.Saturated.Core. Require Import Crypto.Arithmetic.Saturated.Wrappers. Require Import Crypto.Util.ZUtil.AddGetCarry. Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.CPS. Require Import Crypto.Util.Tactics.BreakMatch. Require Import Crypto.Util.Decidable Crypto.Util.ZUtil. Require Import Crypto.Util.Tuple Crypto.Util.LetIn. @@ -101,6 +102,7 @@ Section Freeze. pose proof Z.add_get_carry_full_mod. pose proof Z.add_get_carry_full_div. pose proof div_correct. pose proof modulo_correct. + pose proof @Z.add_get_carry_full_cps_correct. autorewrite with uncps push_id push_basesystem_eval. pose proof (weight_nonzero n). diff --git a/src/Arithmetic/Saturated/MontgomeryAPI.v b/src/Arithmetic/Saturated/MontgomeryAPI.v index 7ff86258c..fb896749b 100644 --- a/src/Arithmetic/Saturated/MontgomeryAPI.v +++ b/src/Arithmetic/Saturated/MontgomeryAPI.v @@ -13,6 +13,7 @@ Require Import Crypto.Util.Tuple Crypto.Util.LetIn. Require Import Crypto.Util.Decidable. Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil. Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.CPS. Require Import Crypto.Util.ZUtil.Zselect. Require Import Crypto.Util.ZUtil.AddGetCarry. Require Import Crypto.Util.ZUtil.MulSplit. @@ -106,7 +107,9 @@ Section API. Section CPSProofs. Local Ltac prove_id := - repeat autounfold; autorewrite with uncps; reflexivity. + repeat autounfold; + repeat (intros; autorewrite with uncps push_id); + reflexivity. Lemma nonzero_id n p {cpsT} f : @nonzero_cps n p cpsT f = f (@nonzero n p). Proof. cbv [nonzero nonzero_cps]. prove_id. Qed. @@ -281,7 +284,10 @@ Section API. pose proof Z.add_get_carry_full_div; pose proof Z.add_get_carry_full_mod; pose proof Z.mul_split_div; pose proof Z.mul_split_mod; - pose proof div_correct; pose proof modulo_correct. + pose proof div_correct; pose proof modulo_correct; + pose proof @Z.add_get_carry_full_cps_correct; + pose proof @Z.mul_split_cps_correct; + pose proof @Z.mul_split_cps'_correct. Lemma eval_add n p q : eval (@add n p q) = eval p + eval q. @@ -308,8 +314,8 @@ Section API. Qed. Hint Rewrite eval_add_same eval_add_S1 eval_add_S2 using (omega || assumption): push_basesystem_eval. - Local Definition compact {n} := Columns.compact (n:=n) (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound). - Local Definition compact_digit := Columns.compact_digit (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound). + Local Definition compact {n} := Columns.compact (n:=n) (add_get_carry_cps:=@Z.add_get_carry_full_cps) (div:=div) (modulo:=modulo) (uweight bound). + Local Definition compact_digit := Columns.compact_digit (add_get_carry_cps:=@Z.add_get_carry_full_cps) (div:=div) (modulo:=modulo) (uweight bound). Lemma small_compact {n} (p:(list Z)^n) : small (snd (compact p)). Proof. pose_all. @@ -329,7 +335,8 @@ Section API. match goal with H : _ /\ _ |- _ => destruct H end. destruct n0; subst f. { cbv [compact_digit uweight to_list to_list' In]. - rewrite Columns.compact_digit_mod by assumption. + rewrite Columns.compact_digit_mod + by (assumption || (intros; autorewrite with uncps push_id; auto)). rewrite Z.pow_0_r, Z.pow_1_r, Z.div_1_r. intros x ?. match goal with H : _ \/ False |- _ => destruct H; [|exfalso; assumption] end. @@ -340,7 +347,8 @@ Section API. [solve[auto]| cbv [In] in H; destruct H; [|exfalso; assumption] ]. subst x. cbv [compact_digit]. - rewrite Columns.compact_digit_mod by assumption. + rewrite Columns.compact_digit_mod + by (assumption || (intros; autorewrite with uncps push_id; auto)). rewrite !uweight_succ, Z.div_mul by (apply Z.neq_mul_0; split; auto; omega). apply Z.mod_pos_bound, Z.gt_lt, bound_pos. } @@ -554,6 +562,8 @@ Section API. Proof. intro Hsmall. pose_all. apply eval_small in Hsmall. intros. cbv [scmul scmul_cps eval] in *. repeat autounfold. + autorewrite with uncps. + autorewrite with push_basesystem_eval. autorewrite with uncps push_id push_basesystem_eval. rewrite uweight_0, Z.mul_1_l. apply Z.mod_small. split; [solve[Z.zero_bounds]|]. cbv [uweight] in *. diff --git a/src/Arithmetic/Saturated/MulSplit.v b/src/Arithmetic/Saturated/MulSplit.v index 98d8d0e0c..4947f422a 100644 --- a/src/Arithmetic/Saturated/MulSplit.v +++ b/src/Arithmetic/Saturated/MulSplit.v @@ -9,21 +9,34 @@ Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil. Module B. Module Associational. Section Associational. - Context {mul_split : Z -> Z -> Z -> Z * Z} (* first argument is where to split output; [mul_split s x y] gives ((x * y) mod s, (x * y) / s) *) + Context {mul_split_cps : forall {T}, Z -> Z -> Z -> (Z * Z -> T) -> T} (* first argument is where to split output; [mul_split s x y] gives ((x * y) mod s, (x * y) / s) *) + {mul_split_cps_id : forall {T} s x y f, + @mul_split_cps T s x y f = f (@mul_split_cps _ s x y id)} {mul_split_mod : forall s x y, - fst (mul_split s x y) = (x * y) mod s} + fst (mul_split_cps s x y id) = (x * y) mod s} {mul_split_div : forall s x y, - snd (mul_split s x y) = (x * y) / s} - . + snd (mul_split_cps s x y id) = (x * y) / s} + . + + Local Lemma mul_split_cps_correct {T} s x y f + : @mul_split_cps T s x y f = f ((x * y) mod s, (x * y) / s). + Proof. + now rewrite mul_split_cps_id, <- mul_split_mod, <- mul_split_div, <- surjective_pairing. + Qed. + Hint Rewrite @mul_split_cps_correct : uncps. Definition sat_multerm_cps s (t t' : B.limb) {T} (f:list B.limb ->T) := - dlet xy := mul_split s (snd t) (snd t') in - f ((fst t * fst t', fst xy) :: (fst t * fst t' * s, snd xy) :: nil). + mul_split_cps _ s (snd t) (snd t') (fun xy => + dlet xy := xy in + f ((fst t * fst t', fst xy) :: (fst t * fst t' * s, snd xy) :: nil)). Definition sat_multerm s t t' := sat_multerm_cps s t t' id. Lemma sat_multerm_id s t t' T f : @sat_multerm_cps s t t' T f = f (sat_multerm s t t'). - Proof. reflexivity. Qed. + Proof. + unfold sat_multerm, sat_multerm_cps; + etransitivity; rewrite mul_split_cps_id; reflexivity. + Qed. Hint Opaque sat_multerm : uncps. Hint Rewrite sat_multerm_id : uncps. @@ -39,14 +52,17 @@ Module B. Lemma eval_map_sat_multerm s a q (s_nonzero:s<>0): B.Associational.eval (flat_map (sat_multerm s a) q) = fst a * snd a * B.Associational.eval q. Proof. - cbv [sat_multerm sat_multerm_cps Let_In]; induction q; - repeat match goal with - | _ => progress (autorewrite with uncps push_id cancel_pair push_basesystem_eval in * ) - | _ => progress simpl flat_map - | _ => progress rewrite ?IHq, ?mul_split_mod, ?mul_split_div - | _ => rewrite Z.mod_eq by assumption - | _ => ring_simplify; omega - end. + cbv [sat_multerm sat_multerm_cps Let_In]; induction q; + repeat match goal with + | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval in * + | _ => progress simpl flat_map + | _ => progress unfold id in * + | _ => progress rewrite ?IHq, ?mul_split_mod, ?mul_split_div + | _ => rewrite Z.mod_eq by assumption + | _ => rewrite B.Associational.eval_nil + | _ => progress change (Z * Z)%type with B.limb + | _ => ring_simplify; omega + end. Qed. Hint Rewrite eval_map_sat_multerm using (omega || assumption) : push_basesystem_eval. @@ -68,7 +84,7 @@ Module B. End Associational. End B. Hint Opaque B.Associational.sat_mul B.Associational.sat_multerm : uncps. -Hint Rewrite @B.Associational.sat_mul_id @B.Associational.sat_multerm_id : uncps. +Hint Rewrite @B.Associational.sat_mul_id @B.Associational.sat_multerm_id using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps. Hint Rewrite @B.Associational.eval_sat_mul @B.Associational.eval_map_sat_multerm using (omega || assumption) : push_basesystem_eval. Hint Unfold diff --git a/src/Arithmetic/Saturated/Wrappers.v b/src/Arithmetic/Saturated/Wrappers.v index 6fe466967..6bb3893d5 100644 --- a/src/Arithmetic/Saturated/Wrappers.v +++ b/src/Arithmetic/Saturated/Wrappers.v @@ -7,6 +7,7 @@ Require Import Crypto.Arithmetic.Saturated.Core. Require Import Crypto.Arithmetic.Saturated.MulSplit. Require Import Crypto.Util.ZUtil.Definitions. Require Import Crypto.Util.ZUtil.MulSplit. +Require Import Crypto.Util.ZUtil.CPS. Require Import Crypto.Util.Tuple. Local Notation "A ^ n" := (tuple A n) : type_scope. @@ -21,7 +22,7 @@ Module Columns. B.Positional.to_associational_cps weight p (fun P => B.Positional.to_associational_cps weight q (fun Q => Columns.from_associational_cps weight n3 (P++Q) - (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f))). + (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f))). Definition unbalanced_sub_cps {n1 n2 n3} (p : Z^n1) (q:Z^n2) {T} (f : (Z*Z^n3)->T) := @@ -29,15 +30,15 @@ Module Columns. (fun P => B.Positional.negate_snd_cps weight q (fun nq => B.Positional.to_associational_cps weight nq (fun Q => Columns.from_associational_cps weight n3 (P++Q) - (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))). + (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f)))). Definition mul_cps {n1 n2 n3} s (p : Z^n1) (q : Z^n2) {T} (f : (Z*Z^n3)->T) := B.Positional.to_associational_cps weight p (fun P => B.Positional.to_associational_cps weight q - (fun Q => B.Associational.sat_mul_cps (mul_split := Z.mul_split) s P Q + (fun Q => B.Associational.sat_mul_cps (mul_split_cps := @Z.mul_split_cps') s P Q (fun PQ => Columns.from_associational_cps weight n3 PQ - (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))). + (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f)))). Definition conditional_add_cps {n1 n2 n3} mask cond (p:Z^n1) (q:Z^n2) {T} (f:_->T) := |