aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-10-19 14:44:48 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-10-19 15:40:23 -0400
commit7e939cd63236d0a6a492ddff5015daf3f706a3bc (patch)
treefa19e772dc624eb7899017b55e527de184e7bf8f /src/Arithmetic
parent79b586e4589f56d081301de92b305569c1077ed2 (diff)
Switch arithmetic to cps for Z * Z under the hood
This is in preparation for writing a ~compiler for the arithmetic things to expression trees. I'm not sure what's up with femul in the table below; I ran it again and got: After: src/Specific/NISTP256/AMD64/femul (real: 115.70, user: 115.25, sys: 0.44, mem: 3571448 ko) Before: src/Specific/NISTP256/AMD64/femul (real: 118.49, user: 117.99, sys: 0.43, mem: 3581612 ko) After | File Name | Before || Change --------------------------------------------------------------------------------------------- 17m02.82s | Total | 16m36.20s || +0m26.61s --------------------------------------------------------------------------------------------- 2m27.04s | Specific/NISTP256/AMD64/femul | 2m04.60s || +0m22.43s 1m38.55s | Specific/X2448/Karatsuba/C64/femul | 1m41.44s || -0m02.89s 0m12.46s | Arithmetic/Saturated/AddSub | 0m09.77s || +0m02.69s 3m22.38s | Specific/X25519/C64/ladderstep | 3m23.49s || -0m01.11s 0m54.40s | Specific/X25519/C32/fesquare | 0m52.68s || +0m01.71s 0m28.70s | Arithmetic/Karatsuba | 0m27.59s || +0m01.10s 0m10.00s | Arithmetic/Saturated/MontgomeryAPI | 0m08.95s || +0m01.05s 0m08.15s | Specific/X2448/Karatsuba/C64/Synthesis | 0m09.47s || -0m01.32s 0m05.62s | Arithmetic/Saturated/MulSplit | 0m04.28s || +0m01.33s 1m29.44s | Specific/X25519/C32/femul | 1m28.55s || +0m00.89s 0m39.38s | Specific/X25519/C32/freeze | 0m38.62s || +0m00.76s 0m31.54s | Specific/NISTP256/AMD128/femul | 0m31.60s || -0m00.06s 0m24.80s | Specific/X25519/C64/femul | 0m24.10s || +0m00.69s 0m23.82s | Specific/NISTP256/AMD64/fesub | 0m23.52s || +0m00.30s 0m21.81s | Specific/NISTP256/AMD64/feadd | 0m21.90s || -0m00.08s 0m20.30s | Specific/X25519/C64/freeze | 0m20.26s || +0m00.03s 0m20.12s | Specific/X25519/C32/Synthesis | 0m20.77s || -0m00.64s 0m19.12s | Specific/X25519/C64/fesquare | 0m19.02s || +0m00.10s 0m17.28s | Specific/NISTP256/AMD64/feopp | 0m17.68s || -0m00.39s 0m15.99s | Specific/NISTP256/AMD128/fesub | 0m16.03s || -0m00.04s 0m15.88s | Specific/NISTP256/AMD128/feadd | 0m16.56s || -0m00.67s 0m15.03s | Specific/NISTP256/AMD64/fenz | 0m15.00s || +0m00.02s 0m14.18s | Specific/NISTP256/AMD128/fenz | 0m14.12s || +0m00.06s 0m13.46s | Specific/NISTP256/AMD128/feopp | 0m12.88s || +0m00.58s 0m12.15s | Arithmetic/Core | 0m12.03s || +0m00.12s 0m07.82s | Arithmetic/Saturated/Core | 0m07.05s || +0m00.77s 0m07.13s | Specific/NISTP256/AMD64/Synthesis | 0m08.05s || -0m00.92s 0m05.48s | Specific/X25519/C64/Synthesis | 0m05.68s || -0m00.19s 0m04.02s | Specific/Framework/ArithmeticSynthesis/Montgomery | 0m03.89s || +0m00.12s 0m03.52s | Arithmetic/MontgomeryReduction/WordByWord/Proofs | 0m03.34s || +0m00.18s 0m03.32s | Specific/NISTP256/AMD128/Synthesis | 0m03.46s || -0m00.14s 0m02.30s | Specific/Framework/ArithmeticSynthesis/Defaults | 0m02.31s || -0m00.01s 0m02.08s | Arithmetic/Saturated/Freeze | 0m01.94s || +0m00.14s 0m01.66s | Specific/Framework/OutputType | 0m01.66s || +0m00.00s 0m01.54s | Arithmetic/CoreUnfolder | 0m01.43s || +0m00.11s 0m01.35s | Specific/Framework/ArithmeticSynthesis/Karatsuba | 0m01.28s || +0m00.07s 0m01.13s | Arithmetic/Saturated/CoreUnfolder | 0m01.16s || -0m00.03s 0m01.06s | Arithmetic/Saturated/WrappersUnfolder | 0m01.04s || +0m00.02s 0m01.04s | Arithmetic/Saturated/UniformWeight | 0m00.95s || +0m00.09s 0m01.03s | Specific/Framework/ArithmeticSynthesis/Base | 0m01.14s || -0m00.10s 0m01.02s | Specific/Framework/SynthesisFramework | 0m01.04s || -0m00.02s 0m00.97s | Specific/Framework/ArithmeticSynthesis/HelperTactics | 0m01.01s || -0m00.04s 0m00.92s | Specific/Framework/ReificationTypes | 0m00.90s || +0m00.02s 0m00.92s | Specific/Framework/ArithmeticSynthesis/Freeze | 0m00.93s || -0m00.01s 0m00.90s | Arithmetic/Saturated/MulSplitUnfolder | 0m00.83s || +0m00.07s 0m00.83s | Specific/Framework/ReificationTypesPackage | 0m00.79s || +0m00.03s 0m00.83s | Arithmetic/Saturated/FreezeUnfolder | 0m00.86s || -0m00.03s 0m00.82s | Specific/Framework/ArithmeticSynthesis/BasePackage | 0m00.77s || +0m00.04s 0m00.81s | Specific/Framework/ArithmeticSynthesis/SquareFromMul | 0m00.72s || +0m00.09s 0m00.81s | Specific/Framework/ArithmeticSynthesis/LadderstepPackage | 0m00.82s || -0m00.00s 0m00.80s | Specific/Framework/MontgomeryReificationTypesPackage | 0m00.82s || -0m00.01s 0m00.78s | Specific/Framework/ArithmeticSynthesis/MontgomeryPackage | 0m00.79s || -0m00.01s 0m00.78s | Arithmetic/Saturated/Wrappers | 0m00.78s || +0m00.00s 0m00.76s | Specific/Framework/ArithmeticSynthesis/FreezePackage | 0m00.80s || -0m00.04s 0m00.76s | Specific/Framework/ArithmeticSynthesis/DefaultsPackage | 0m00.75s || +0m00.01s 0m00.75s | Specific/Framework/MontgomeryReificationTypes | 0m00.78s || -0m00.03s 0m00.73s | Specific/Framework/ArithmeticSynthesis/Ladderstep | 0m00.77s || -0m00.04s 0m00.73s | Arithmetic/MontgomeryReduction/WordByWord/Definition | 0m00.80s || -0m00.07s 0m00.72s | Arithmetic/Saturated/UniformWeightInstances | 0m00.78s || -0m00.06s 0m00.68s | Specific/Framework/ArithmeticSynthesis/KaratsubaPackage | 0m00.76s || -0m00.07s 0m00.43s | Util/ZUtil/CPS | 0m00.42s || +0m00.01s
Diffstat (limited to 'src/Arithmetic')
-rw-r--r--src/Arithmetic/Core.v73
-rw-r--r--src/Arithmetic/CoreUnfolder.v15
-rw-r--r--src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v6
-rw-r--r--src/Arithmetic/Saturated/AddSub.v70
-rw-r--r--src/Arithmetic/Saturated/Core.v43
-rw-r--r--src/Arithmetic/Saturated/Freeze.v2
-rw-r--r--src/Arithmetic/Saturated/MontgomeryAPI.v22
-rw-r--r--src/Arithmetic/Saturated/MulSplit.v48
-rw-r--r--src/Arithmetic/Saturated/Wrappers.v9
9 files changed, 189 insertions, 99 deletions
diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v
index 430e1c19a..f2d9ee00b 100644
--- a/src/Arithmetic/Core.v
+++ b/src/Arithmetic/Core.v
@@ -585,9 +585,10 @@ Module B.
Context {T : Type}.
Fixpoint place_cps (t:limb) (i:nat) (f:nat * Z->T) :=
- if dec (fst t mod weight i = 0)
+ Z.eqb_cps (fst t mod weight i) 0 (fun eqb =>
+ if eqb
then f (i, let c := fst t / weight i in (c * snd t)%RT)
- else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end.
+ else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end).
End place_cps.
Definition place t i := place_cps t i id.
@@ -599,12 +600,13 @@ Module B.
Lemma place_cps_in_range (t:limb) (n:nat)
: (fst (place_cps t n id) < S n)%nat.
- Proof using Type. induction n; simpl; break_match; simpl; omega. Qed.
+ Proof using Type. induction n; simpl; cbv [Z.eqb_cps]; break_match; simpl; omega. Qed.
Lemma weight_place_cps t i
: weight (fst (place_cps t i id)) * snd (place_cps t i id)
= fst t * snd t.
Proof using Type*.
- induction i; cbv [id]; simpl place_cps; break_match;
+ induction i; cbv [id]; simpl place_cps; cbv [Z.eqb_cps]; break_match;
+ Z.ltb_to_lt;
autorewrite with cancel_pair;
try match goal with [H:_|-_] => apply Z_div_exact_full_2 in H end;
nsatz || auto.
@@ -962,38 +964,59 @@ End B.
(* Modulo and div that do shifts if possible, otherwise normal mod/div *)
Section DivMod.
- Definition modulo (a b : Z) : Z :=
- if dec (2 ^ (Z.log2 b) = b)
- then let x := (Z.ones (Z.log2 b)) in (a &' x)%RT
- else Z.modulo a b.
-
- Definition div (a b : Z) : Z :=
- if dec (2 ^ (Z.log2 b) = b)
- then let x := Z.log2 b in (a >> x)%RT
- else Z.div a b.
-
- Lemma div_correct a b : div a b = Z.div a b.
+ Definition modulo_cps {T} (a b : Z) (f : Z -> T) : T :=
+ Z.eqb_cps (2 ^ (Z.log2 b)) b (fun eqb =>
+ if eqb
+ then let x := (Z.ones (Z.log2 b)) in f (a &' x)%RT
+ else f (Z.modulo a b)).
+
+ Definition div_cps {T} (a b : Z) (f : Z -> T) : T :=
+ Z.eqb_cps (2 ^ (Z.log2 b)) b (fun eqb =>
+ if eqb
+ then let x := Z.log2 b in f ((a >> x)%RT)
+ else f (Z.div a b)).
+
+ Definition modulo (a b : Z) : Z := modulo_cps a b id.
+ Definition div (a b : Z) : Z := div_cps a b id.
+
+ Lemma modulo_id {T} a b f
+ : @modulo_cps T a b f = f (modulo a b).
+ Proof. cbv [modulo_cps modulo]; autorewrite with uncps; break_match; reflexivity. Qed.
+ Hint Opaque modulo : uncps.
+ Hint Rewrite @modulo_id : uncps.
+
+ Lemma div_id {T} a b f
+ : @div_cps T a b f = f (div a b).
+ Proof. cbv [div_cps div]; autorewrite with uncps; break_match; reflexivity. Qed.
+ Hint Opaque div : uncps.
+ Hint Rewrite @div_id : uncps.
+
+ Lemma div_cps_correct {T} a b f : @div_cps T a b f = f (Z.div a b).
Proof.
- cbv [div]; intros. break_match; try reflexivity.
+ cbv [div_cps Z.eqb_cps]; intros. break_match; try reflexivity.
rewrite Z.shiftr_div_pow2 by apply Z.log2_nonneg.
- congruence.
+ Z.ltb_to_lt; congruence.
Qed.
- Lemma modulo_correct a b : modulo a b = Z.modulo a b.
+ Lemma modulo_cps_correct {T} a b f : @modulo_cps T a b f = f (Z.modulo a b).
Proof.
- cbv [modulo]; intros. break_match; try reflexivity.
+ cbv [modulo_cps Z.eqb_cps]; intros. break_match; try reflexivity.
rewrite Z.land_ones by apply Z.log2_nonneg.
- congruence.
+ Z.ltb_to_lt; congruence.
Qed.
+ Definition div_correct a b : div a b = Z.div a b := div_cps_correct a b id.
+ Definition modulo_correct a b : modulo a b = Z.modulo a b := modulo_cps_correct a b id.
+
Lemma div_mod a b (H:b <> 0) : a = b * div a b + modulo a b.
Proof.
- cbv [div modulo]; intros. break_match; auto using Z.div_mod.
- rewrite Z.land_ones, Z.shiftr_div_pow2 by apply Z.log2_nonneg.
- pose proof (Z.div_mod a b H). congruence.
+ rewrite div_correct, modulo_correct; auto using Z.div_mod.
Qed.
End DivMod.
+Hint Opaque div modulo : uncps.
+Hint Rewrite @div_id @modulo_id : uncps.
+
Import B.
Create HintDb basesystem_partial_evaluation_unfolder.
@@ -1045,7 +1068,7 @@ Hint Unfold
Positional.eval_from
Positional.select_cps
Positional.select
- modulo div
+ modulo div modulo_cps div_cps
id_tuple_with_alt id_tuple'_with_alt
Z.add_get_carry_full Z.add_get_carry_full_cps
: basesystem_partial_evaluation_unfolder.
@@ -1055,7 +1078,7 @@ Hint Unfold
CPSUtil.Tuple.mapi_with_cps CPSUtil.Tuple.mapi_with'_cps CPSUtil.flat_map_cps CPSUtil.on_tuple_cps CPSUtil.fold_right_cps2
Decidable.dec Decidable.dec_eq_Z
id_tuple_with_alt id_tuple'_with_alt
- Z.add_get_carry_full Z.add_get_carry_full_cps Z.mul_split Z.mul_split_cps
+ Z.add_get_carry_full Z.add_get_carry_full_cps Z.mul_split Z.mul_split_cps Z.mul_split_cps'
: basesystem_partial_evaluation_unfolder.
diff --git a/src/Arithmetic/CoreUnfolder.v b/src/Arithmetic/CoreUnfolder.v
index df005e630..cb943e80a 100644
--- a/src/Arithmetic/CoreUnfolder.v
+++ b/src/Arithmetic/CoreUnfolder.v
@@ -12,7 +12,7 @@ Ltac make_parameterized_sig t :=
Decidable.dec Decidable.dec_eq_Z
id_tuple_with_alt id_tuple'_with_alt
Z.add_get_carry_full Z.mul_split
- Z.add_get_carry_full_cps Z.mul_split_cps
+ Z.add_get_carry_full_cps Z.mul_split_cps Z.mul_split_cps'
Z.add_get_carry_cps];
repeat autorewrite with pattern_runtime;
reflexivity.
@@ -64,7 +64,7 @@ done
echo " End Positional."
echo "End B."
echo ""
-for i in modulo div; do
+for i in modulo_cps div_cps modulo div; do
echo "Definition ${i}_sig := parameterize_sig (@Core.${i}).";
echo "Definition ${i} := parameterize_from_sig ${i}_sig.";
echo "Definition ${i}_eq := parameterize_eq ${i} ${i}_sig.";
@@ -140,6 +140,7 @@ done
Definition carry := parameterize_from_sig carry_sig.
Definition carry_eq := parameterize_eq carry carry_sig.
Hint Rewrite <- carry_eq : pattern_runtime.
+
End Associational.
Module Positional.
Definition to_associational_cps_sig := parameterize_sig (@Core.B.Positional.to_associational_cps).
@@ -300,6 +301,16 @@ done
End Positional.
End B.
+Definition modulo_cps_sig := parameterize_sig (@Core.modulo_cps).
+Definition modulo_cps := parameterize_from_sig modulo_cps_sig.
+Definition modulo_cps_eq := parameterize_eq modulo_cps modulo_cps_sig.
+Hint Rewrite <- modulo_cps_eq : pattern_runtime.
+
+Definition div_cps_sig := parameterize_sig (@Core.div_cps).
+Definition div_cps := parameterize_from_sig div_cps_sig.
+Definition div_cps_eq := parameterize_eq div_cps div_cps_sig.
+Hint Rewrite <- div_cps_eq : pattern_runtime.
+
Definition modulo_sig := parameterize_sig (@Core.modulo).
Definition modulo := parameterize_from_sig modulo_sig.
Definition modulo_eq := parameterize_eq modulo modulo_sig.
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v
index 9affa82fa..fd4869f23 100644
--- a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v
+++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v
@@ -10,6 +10,7 @@ Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Depende
Require Import Crypto.Util.Notations.
Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.ZUtil.Definitions.
+Require Import Crypto.Util.ZUtil.CPS.
Local Open Scope Z_scope.
@@ -47,10 +48,11 @@ Section WordByWordMontgomery.
:= divmod_cps A (fun '(A, a) =>
@scmul_cps r _ a B _ (fun aB => @add_cps r _ S' aB _ (fun S1 =>
divmod_cps S1 (fun '(_, s) =>
- dlet q := fst (Z.mul_split r s k) in
+ Z.mul_split_cps' r s k (fun mul_split_r_s_k =>
+ dlet q := fst mul_split_r_s_k in
@scmul_cps r _ q N _ (fun qN => @add_S1_cps r _ S1 qN _ (fun S2 =>
divmod_cps S2 (fun '(S3, _) =>
- @drop_high_cps (S R_numlimbs) S3 _ (fun S4 => rest (A, S4))))))))).
+ @drop_high_cps (S R_numlimbs) S3 _ (fun S4 => rest (A, S4)))))))))).
Section loop.
Context {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) {cpsT : Type}.
diff --git a/src/Arithmetic/Saturated/AddSub.v b/src/Arithmetic/Saturated/AddSub.v
index e34230904..e886c36de 100644
--- a/src/Arithmetic/Saturated/AddSub.v
+++ b/src/Arithmetic/Saturated/AddSub.v
@@ -6,8 +6,10 @@ Require Import Crypto.Arithmetic.Core.
Require Import Crypto.Arithmetic.Saturated.Core.
Require Import Crypto.Arithmetic.Saturated.UniformWeight.
Require Import Crypto.Util.ZUtil.Definitions.
+Require Import Crypto.Util.ZUtil.CPS.
Require Import Crypto.Util.ZUtil.AddGetCarry.
Require Import Crypto.Util.Tuple Crypto.Util.LetIn.
+Require Import Crypto.Util.Tactics.BreakMatch.
Local Notation "A ^ n" := (tuple A n) : type_scope.
Module B.
@@ -17,8 +19,15 @@ Module B.
Let small {n} := @small s n.
Section GenericOp.
Context {op : Z -> Z -> Z}
- {op_get_carry : Z -> Z -> Z * Z} (* no carry in, carry out *)
- {op_with_carry : Z -> Z -> Z -> Z * Z}. (* carry in, carry out *)
+ {op_get_carry_cps : forall {T}, Z -> Z -> (Z * Z -> T) -> T} (* no carry in, carry out *)
+ {op_with_carry_cps : forall {T}, Z -> Z -> Z -> (Z * Z -> T) -> T}. (* carry in, carry out *)
+ Let op_get_carry x y := op_get_carry_cps _ x y id.
+ Let op_with_carry x y z := op_with_carry_cps _ x y z id.
+ Context {op_get_carry_id : forall {T} x y f,
+ @op_get_carry_cps T x y f = f (op_get_carry x y)}
+ {op_with_carry_id : forall {T} x y z f,
+ @op_with_carry_cps T x y z f = f (op_with_carry x y z)}.
+ Hint Rewrite @op_get_carry_id @op_with_carry_id : uncps.
Section chain_op'_cps.
Context (T : Type).
@@ -32,14 +41,15 @@ Module B.
| S n' =>
fun c p q f =>
(* for the first call, use op_get_carry, then op_with_carry *)
- let op' := match c with
- | None => op_get_carry
- | Some x => op_with_carry x end in
- dlet carry_result := op' (hd p) (hd q) in
+ let op'_cps := match c with
+ | None => op_get_carry_cps _
+ | Some x => op_with_carry_cps _ x end in
+ op'_cps (hd p) (hd q) (fun carry_result =>
+ dlet carry_result := carry_result in
chain_op'_cps (Some (snd carry_result)) (tl p) (tl q)
(fun carry_pq =>
f (fst carry_pq,
- append (fst carry_result) (snd carry_pq)))
+ append (fst carry_result) (snd carry_pq))))
end c p q.
End chain_op'_cps.
Definition chain_op' {n} c p q := @chain_op'_cps _ n c p q id.
@@ -50,7 +60,8 @@ Module B.
@chain_op'_cps T n c p q f = f (chain_op' c p q).
Proof.
cbv [chain_op']; induction n; intros; destruct c;
- simpl chain_op'_cps; cbv [Let_In]; try reflexivity.
+ simpl chain_op'_cps; cbv [Let_In]; try reflexivity;
+ autorewrite with uncps.
{ etransitivity; rewrite IHn; reflexivity. }
{ etransitivity; rewrite IHn; reflexivity. }
Qed.
@@ -60,7 +71,7 @@ Module B.
Proof. apply (@chain_op'_id n None). Qed.
End GenericOp.
Hint Opaque chain_op chain_op' : uncps.
- Hint Rewrite @chain_op_id @chain_op'_id : uncps.
+ Hint Rewrite @chain_op_id @chain_op'_id using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps.
Section AddSub.
Create HintDb divmod discriminated.
@@ -77,14 +88,14 @@ Module B.
Let eval {n} := B.Positional.eval (n:=n) (uweight s).
Definition sat_add_cps {n} p q T (f:Z*Z^n->T) :=
- chain_op_cps (op_get_carry := Z.add_get_carry_full s)
- (op_with_carry := Z.add_with_get_carry_full s)
+ chain_op_cps (op_get_carry_cps := fun T => Z.add_get_carry_full_cps s)
+ (op_with_carry_cps := fun T => Z.add_with_get_carry_full_cps s)
p q f.
Definition sat_add {n} p q := @sat_add_cps n p q _ id.
Lemma sat_add_id n p q T f :
@sat_add_cps n p q T f = f (sat_add p q).
- Proof. cbv [sat_add sat_add_cps]. rewrite !chain_op_id. reflexivity. Qed.
+ Proof. cbv [sat_add sat_add_cps]. autorewrite with uncps. reflexivity. Qed.
Lemma sat_add_mod_step n c d :
c mod s + s * ((d + c / s) mod (uweight s n))
@@ -156,23 +167,25 @@ Module B.
simpl In in H
| H : _ \/ _ |- _ => destruct H
| _ => contradiction
- end.
- { subst x.
- destruct c; rewrite ?Z.add_with_get_carry_full_mod,
- ?Z.add_get_carry_full_mod;
- apply Z.mod_pos_bound; omega. }
- { apply IHn in H. assumption. }
+ | _ => break_innermost_match_hyps_step
+ | _ => progress subst
+ | [ H : In _ (to_list _ (snd _)) |- _ ]
+ => apply IHn in H; assumption
+ end;
+ try solve [ rewrite ?Z.add_with_get_carry_full_mod,
+ ?Z.add_get_carry_full_mod;
+ apply Z.mod_pos_bound; omega ].
Qed.
Definition sat_sub_cps {n} p q T (f:Z*Z^n->T) :=
- chain_op_cps (op_get_carry := Z.sub_get_borrow_full s)
- (op_with_carry := Z.sub_with_get_borrow_full s)
+ chain_op_cps (op_get_carry_cps := fun T => Z.sub_get_borrow_full_cps s)
+ (op_with_carry_cps := fun T => Z.sub_with_get_borrow_full_cps s)
p q f.
Definition sat_sub {n} p q := @sat_sub_cps n p q _ id.
Lemma sat_sub_id n p q T f :
@sat_sub_cps n p q T f = f (sat_sub p q).
- Proof. cbv [sat_sub sat_sub_cps]. rewrite !chain_op_id. reflexivity. Qed.
+ Proof. cbv [sat_sub sat_sub_cps]. autorewrite with uncps. reflexivity. Qed.
Lemma sat_sub_divmod n p q :
eval (snd (@sat_sub n p q)) = (eval p - eval q) mod (uweight s n)
/\ fst (@sat_sub n p q) = - ((eval p - eval q) / (uweight s n)).
@@ -224,19 +237,22 @@ Module B.
simpl In in H
| H : _ \/ _ |- _ => destruct H
| _ => contradiction
- end.
- { subst x.
- destruct c; rewrite ?Z.sub_with_get_borrow_full_mod,
+ | _ => break_innermost_match_hyps_step
+ | _ => progress subst
+ | [ H : In _ (to_list _ (snd _)) |- _ ]
+ => apply IHn in H; assumption
+ end;
+ try solve [ rewrite ?Z.sub_with_get_borrow_full_mod,
?Z.sub_get_borrow_full_mod;
- apply Z.mod_pos_bound; omega. }
- { apply IHn in H. assumption. }
+ apply Z.mod_pos_bound; omega ].
Qed.
End AddSub.
End Positional.
End Positional.
End B.
Hint Opaque B.Positional.sat_sub B.Positional.sat_add B.Positional.chain_op B.Positional.chain_op' : uncps.
-Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id @B.Positional.chain_op_id @B.Positional.chain_op' : uncps.
+Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id : uncps.
+Hint Rewrite @B.Positional.chain_op_id @B.Positional.chain_op' using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps.
Hint Rewrite @B.Positional.sat_sub_mod @B.Positional.sat_sub_div @B.Positional.sat_add_mod @B.Positional.sat_add_div using (omega || assumption) : push_basesystem_eval.
Hint Unfold
diff --git a/src/Arithmetic/Saturated/Core.v b/src/Arithmetic/Saturated/Core.v
index e34a37501..8af1cce6f 100644
--- a/src/Arithmetic/Saturated/Core.v
+++ b/src/Arithmetic/Saturated/Core.v
@@ -112,16 +112,20 @@ Module Columns.
{weight_multiples : forall i, weight (S i) mod weight i = 0}
{weight_divides : forall i : nat, weight (S i) / weight i > 0}
(* add_get_carry takes in a number at which to split output *)
- {add_get_carry: Z ->Z -> Z -> (Z * Z)}
+ {add_get_carry_cps: forall {T}, Z ->Z -> Z -> (Z * Z -> T) -> T}
+ {add_get_carry_cps_id : forall {T} s x y f,
+ @add_get_carry_cps T s x y f = f (@add_get_carry_cps _ s x y id)}
{add_get_carry_mod : forall s x y,
- fst (add_get_carry s x y) = (x + y) mod s}
+ fst (add_get_carry_cps s x y id) = (x + y) mod s}
{add_get_carry_div : forall s x y,
- snd (add_get_carry s x y) = (x + y) / s}
+ snd (add_get_carry_cps s x y id) = (x + y) / s}
{div modulo : Z -> Z -> Z}
{div_correct : forall a b, div a b = a / b}
{modulo_correct : forall a b, modulo a b = a mod b}
.
Hint Rewrite div_correct modulo_correct add_get_carry_mod add_get_carry_div : div_mod.
+ Let add_get_carry s x y := add_get_carry_cps _ s x y id.
+ Hint Rewrite (add_get_carry_cps_id : forall T s x y f, _ = f (@add_get_carry s x y)) : uncps.
Definition eval {n} (x : (list Z)^n) : Z :=
B.Positional.eval weight (Tuple.map sum x).
@@ -164,25 +168,27 @@ Module Columns.
| nil => f (0, 0)
| x :: nil => f (div x (weight (S n) / weight n), modulo x (weight (S n) / weight n))
| x :: y :: nil =>
- dlet sum_carry := add_get_carry (weight (S n) / weight n) x y in
+ add_get_carry_cps _ (weight (S n) / weight n) x y (fun sum_carry =>
+ dlet sum_carry := sum_carry in
dlet carry := snd sum_carry in
- f (carry, fst sum_carry)
+ f (carry, fst sum_carry))
| x :: tl =>
compact_digit_cps tl
(fun rec =>
- dlet sum_carry := add_get_carry (weight (S n) / weight n) x (snd rec) in
+ add_get_carry_cps _ (weight (S n) / weight n) x (snd rec) (fun sum_carry =>
+ dlet sum_carry := sum_carry in
dlet carry' := (fst rec + snd sum_carry)%RT in
- f (carry', fst sum_carry))
+ f (carry', fst sum_carry)))
end.
End compact_digit_cps.
Definition compact_digit n digit := compact_digit_cps n digit id.
Lemma compact_digit_id n digit: forall {T} f,
@compact_digit_cps n T digit f = f (compact_digit n digit).
- Proof using Type.
- induction digit; intros; cbv [compact_digit]; [reflexivity|];
- simpl compact_digit_cps; break_match; rewrite ?IHdigit;
- reflexivity.
+ Proof using add_get_carry_cps_id.
+ induction digit; intros; cbv [compact_digit]; [reflexivity|].
+ simpl compact_digit_cps; break_match; rewrite ?IHdigit; clear IHdigit;
+ cbv [Let_In]; autorewrite with uncps; reflexivity.
Qed.
Hint Opaque compact_digit : uncps.
Hint Rewrite compact_digit_id : uncps.
@@ -194,7 +200,7 @@ Module Columns.
Definition compact_step i c d := compact_step_cps i c d id.
Lemma compact_step_id i c d T f :
@compact_step_cps i c d T f = f (compact_step i c d).
- Proof using Type. cbv [compact_step_cps compact_step]; autorewrite with uncps; reflexivity. Qed.
+ Proof using add_get_carry_cps_id. cbv [compact_step_cps compact_step]; autorewrite with uncps; reflexivity. Qed.
Hint Opaque compact_step : uncps.
Hint Rewrite compact_step_id : uncps.
@@ -203,15 +209,15 @@ Module Columns.
Definition compact {n} xs := @compact_cps n xs _ id.
Lemma compact_id {n} xs {T} f : @compact_cps n xs T f = f (compact xs).
- Proof using Type. cbv [compact_cps compact]; autorewrite with uncps; reflexivity. Qed.
+ Proof using add_get_carry_cps_id. cbv [compact_cps compact]; autorewrite with uncps; reflexivity. Qed.
Lemma compact_digit_mod i (xs : list Z) :
snd (compact_digit i xs) = sum xs mod (weight (S i) / weight i).
- Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct.
+ Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct add_get_carry_cps_id.
induction xs; cbv [compact_digit]; simpl compact_digit_cps;
cbv [Let_In];
repeat match goal with
- | _ => progress autorewrite with div_mod
+ | _ => cbv [add_get_carry]; progress autorewrite with div_mod
| _ => rewrite IHxs, <-Z.add_mod_r
| _ => progress (rewrite ?sum_cons, ?sum_nil in * )
| _ => progress (autorewrite with uncps push_id cancel_pair in * )
@@ -223,11 +229,11 @@ Module Columns.
Lemma compact_digit_div i (xs : list Z) :
fst (compact_digit i xs) = sum xs / (weight (S i) / weight i).
- Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct weight_0 weight_divides.
+ Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct weight_0 weight_divides add_get_carry_cps_id.
induction xs; cbv [compact_digit]; simpl compact_digit_cps;
cbv [Let_In];
repeat match goal with
- | _ => progress autorewrite with div_mod
+ | _ => cbv [add_get_carry]; progress autorewrite with div_mod
| _ => rewrite IHxs
| _ => progress (rewrite ?sum_cons, ?sum_nil in * )
| _ => progress (autorewrite with uncps push_id cancel_pair in * )
@@ -421,6 +427,9 @@ Hint Rewrite
@Columns.compact_digit_id
@Columns.compact_step_id
@Columns.compact_id
+ using (assumption || (intros; autorewrite with uncps; reflexivity))
+ : uncps.
+Hint Rewrite
@Columns.cons_to_nth_id
@Columns.from_associational_id
: uncps.
diff --git a/src/Arithmetic/Saturated/Freeze.v b/src/Arithmetic/Saturated/Freeze.v
index 65b8ee55d..b56a69a3d 100644
--- a/src/Arithmetic/Saturated/Freeze.v
+++ b/src/Arithmetic/Saturated/Freeze.v
@@ -7,6 +7,7 @@ Require Import Crypto.Arithmetic.Saturated.Core.
Require Import Crypto.Arithmetic.Saturated.Wrappers.
Require Import Crypto.Util.ZUtil.AddGetCarry.
Require Import Crypto.Util.ZUtil.Definitions.
+Require Import Crypto.Util.ZUtil.CPS.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Decidable Crypto.Util.ZUtil.
Require Import Crypto.Util.Tuple Crypto.Util.LetIn.
@@ -101,6 +102,7 @@ Section Freeze.
pose proof Z.add_get_carry_full_mod.
pose proof Z.add_get_carry_full_div.
pose proof div_correct. pose proof modulo_correct.
+ pose proof @Z.add_get_carry_full_cps_correct.
autorewrite with uncps push_id push_basesystem_eval.
pose proof (weight_nonzero n).
diff --git a/src/Arithmetic/Saturated/MontgomeryAPI.v b/src/Arithmetic/Saturated/MontgomeryAPI.v
index 7ff86258c..fb896749b 100644
--- a/src/Arithmetic/Saturated/MontgomeryAPI.v
+++ b/src/Arithmetic/Saturated/MontgomeryAPI.v
@@ -13,6 +13,7 @@ Require Import Crypto.Util.Tuple Crypto.Util.LetIn.
Require Import Crypto.Util.Decidable.
Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil.
Require Import Crypto.Util.ZUtil.Definitions.
+Require Import Crypto.Util.ZUtil.CPS.
Require Import Crypto.Util.ZUtil.Zselect.
Require Import Crypto.Util.ZUtil.AddGetCarry.
Require Import Crypto.Util.ZUtil.MulSplit.
@@ -106,7 +107,9 @@ Section API.
Section CPSProofs.
Local Ltac prove_id :=
- repeat autounfold; autorewrite with uncps; reflexivity.
+ repeat autounfold;
+ repeat (intros; autorewrite with uncps push_id);
+ reflexivity.
Lemma nonzero_id n p {cpsT} f : @nonzero_cps n p cpsT f = f (@nonzero n p).
Proof. cbv [nonzero nonzero_cps]. prove_id. Qed.
@@ -281,7 +284,10 @@ Section API.
pose proof Z.add_get_carry_full_div;
pose proof Z.add_get_carry_full_mod;
pose proof Z.mul_split_div; pose proof Z.mul_split_mod;
- pose proof div_correct; pose proof modulo_correct.
+ pose proof div_correct; pose proof modulo_correct;
+ pose proof @Z.add_get_carry_full_cps_correct;
+ pose proof @Z.mul_split_cps_correct;
+ pose proof @Z.mul_split_cps'_correct.
Lemma eval_add n p q :
eval (@add n p q) = eval p + eval q.
@@ -308,8 +314,8 @@ Section API.
Qed.
Hint Rewrite eval_add_same eval_add_S1 eval_add_S2 using (omega || assumption): push_basesystem_eval.
- Local Definition compact {n} := Columns.compact (n:=n) (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound).
- Local Definition compact_digit := Columns.compact_digit (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound).
+ Local Definition compact {n} := Columns.compact (n:=n) (add_get_carry_cps:=@Z.add_get_carry_full_cps) (div:=div) (modulo:=modulo) (uweight bound).
+ Local Definition compact_digit := Columns.compact_digit (add_get_carry_cps:=@Z.add_get_carry_full_cps) (div:=div) (modulo:=modulo) (uweight bound).
Lemma small_compact {n} (p:(list Z)^n) : small (snd (compact p)).
Proof.
pose_all.
@@ -329,7 +335,8 @@ Section API.
match goal with H : _ /\ _ |- _ => destruct H end.
destruct n0; subst f.
{ cbv [compact_digit uweight to_list to_list' In].
- rewrite Columns.compact_digit_mod by assumption.
+ rewrite Columns.compact_digit_mod
+ by (assumption || (intros; autorewrite with uncps push_id; auto)).
rewrite Z.pow_0_r, Z.pow_1_r, Z.div_1_r. intros x ?.
match goal with
H : _ \/ False |- _ => destruct H; [|exfalso; assumption] end.
@@ -340,7 +347,8 @@ Section API.
[solve[auto]| cbv [In] in H; destruct H;
[|exfalso; assumption] ].
subst x. cbv [compact_digit].
- rewrite Columns.compact_digit_mod by assumption.
+ rewrite Columns.compact_digit_mod
+ by (assumption || (intros; autorewrite with uncps push_id; auto)).
rewrite !uweight_succ, Z.div_mul by
(apply Z.neq_mul_0; split; auto; omega).
apply Z.mod_pos_bound, Z.gt_lt, bound_pos. }
@@ -554,6 +562,8 @@ Section API.
Proof.
intro Hsmall. pose_all. apply eval_small in Hsmall.
intros. cbv [scmul scmul_cps eval] in *. repeat autounfold.
+ autorewrite with uncps.
+ autorewrite with push_basesystem_eval.
autorewrite with uncps push_id push_basesystem_eval.
rewrite uweight_0, Z.mul_1_l. apply Z.mod_small.
split; [solve[Z.zero_bounds]|]. cbv [uweight] in *.
diff --git a/src/Arithmetic/Saturated/MulSplit.v b/src/Arithmetic/Saturated/MulSplit.v
index 98d8d0e0c..4947f422a 100644
--- a/src/Arithmetic/Saturated/MulSplit.v
+++ b/src/Arithmetic/Saturated/MulSplit.v
@@ -9,21 +9,34 @@ Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil.
Module B.
Module Associational.
Section Associational.
- Context {mul_split : Z -> Z -> Z -> Z * Z} (* first argument is where to split output; [mul_split s x y] gives ((x * y) mod s, (x * y) / s) *)
+ Context {mul_split_cps : forall {T}, Z -> Z -> Z -> (Z * Z -> T) -> T} (* first argument is where to split output; [mul_split s x y] gives ((x * y) mod s, (x * y) / s) *)
+ {mul_split_cps_id : forall {T} s x y f,
+ @mul_split_cps T s x y f = f (@mul_split_cps _ s x y id)}
{mul_split_mod : forall s x y,
- fst (mul_split s x y) = (x * y) mod s}
+ fst (mul_split_cps s x y id) = (x * y) mod s}
{mul_split_div : forall s x y,
- snd (mul_split s x y) = (x * y) / s}
- .
+ snd (mul_split_cps s x y id) = (x * y) / s}
+ .
+
+ Local Lemma mul_split_cps_correct {T} s x y f
+ : @mul_split_cps T s x y f = f ((x * y) mod s, (x * y) / s).
+ Proof.
+ now rewrite mul_split_cps_id, <- mul_split_mod, <- mul_split_div, <- surjective_pairing.
+ Qed.
+ Hint Rewrite @mul_split_cps_correct : uncps.
Definition sat_multerm_cps s (t t' : B.limb) {T} (f:list B.limb ->T) :=
- dlet xy := mul_split s (snd t) (snd t') in
- f ((fst t * fst t', fst xy) :: (fst t * fst t' * s, snd xy) :: nil).
+ mul_split_cps _ s (snd t) (snd t') (fun xy =>
+ dlet xy := xy in
+ f ((fst t * fst t', fst xy) :: (fst t * fst t' * s, snd xy) :: nil)).
Definition sat_multerm s t t' := sat_multerm_cps s t t' id.
Lemma sat_multerm_id s t t' T f :
@sat_multerm_cps s t t' T f = f (sat_multerm s t t').
- Proof. reflexivity. Qed.
+ Proof.
+ unfold sat_multerm, sat_multerm_cps;
+ etransitivity; rewrite mul_split_cps_id; reflexivity.
+ Qed.
Hint Opaque sat_multerm : uncps.
Hint Rewrite sat_multerm_id : uncps.
@@ -39,14 +52,17 @@ Module B.
Lemma eval_map_sat_multerm s a q (s_nonzero:s<>0):
B.Associational.eval (flat_map (sat_multerm s a) q) = fst a * snd a * B.Associational.eval q.
Proof.
- cbv [sat_multerm sat_multerm_cps Let_In]; induction q;
- repeat match goal with
- | _ => progress (autorewrite with uncps push_id cancel_pair push_basesystem_eval in * )
- | _ => progress simpl flat_map
- | _ => progress rewrite ?IHq, ?mul_split_mod, ?mul_split_div
- | _ => rewrite Z.mod_eq by assumption
- | _ => ring_simplify; omega
- end.
+ cbv [sat_multerm sat_multerm_cps Let_In]; induction q;
+ repeat match goal with
+ | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval in *
+ | _ => progress simpl flat_map
+ | _ => progress unfold id in *
+ | _ => progress rewrite ?IHq, ?mul_split_mod, ?mul_split_div
+ | _ => rewrite Z.mod_eq by assumption
+ | _ => rewrite B.Associational.eval_nil
+ | _ => progress change (Z * Z)%type with B.limb
+ | _ => ring_simplify; omega
+ end.
Qed.
Hint Rewrite eval_map_sat_multerm using (omega || assumption)
: push_basesystem_eval.
@@ -68,7 +84,7 @@ Module B.
End Associational.
End B.
Hint Opaque B.Associational.sat_mul B.Associational.sat_multerm : uncps.
-Hint Rewrite @B.Associational.sat_mul_id @B.Associational.sat_multerm_id : uncps.
+Hint Rewrite @B.Associational.sat_mul_id @B.Associational.sat_multerm_id using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps.
Hint Rewrite @B.Associational.eval_sat_mul @B.Associational.eval_map_sat_multerm using (omega || assumption) : push_basesystem_eval.
Hint Unfold
diff --git a/src/Arithmetic/Saturated/Wrappers.v b/src/Arithmetic/Saturated/Wrappers.v
index 6fe466967..6bb3893d5 100644
--- a/src/Arithmetic/Saturated/Wrappers.v
+++ b/src/Arithmetic/Saturated/Wrappers.v
@@ -7,6 +7,7 @@ Require Import Crypto.Arithmetic.Saturated.Core.
Require Import Crypto.Arithmetic.Saturated.MulSplit.
Require Import Crypto.Util.ZUtil.Definitions.
Require Import Crypto.Util.ZUtil.MulSplit.
+Require Import Crypto.Util.ZUtil.CPS.
Require Import Crypto.Util.Tuple.
Local Notation "A ^ n" := (tuple A n) : type_scope.
@@ -21,7 +22,7 @@ Module Columns.
B.Positional.to_associational_cps weight p
(fun P => B.Positional.to_associational_cps weight q
(fun Q => Columns.from_associational_cps weight n3 (P++Q)
- (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f))).
+ (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f))).
Definition unbalanced_sub_cps {n1 n2 n3} (p : Z^n1) (q:Z^n2)
{T} (f : (Z*Z^n3)->T) :=
@@ -29,15 +30,15 @@ Module Columns.
(fun P => B.Positional.negate_snd_cps weight q
(fun nq => B.Positional.to_associational_cps weight nq
(fun Q => Columns.from_associational_cps weight n3 (P++Q)
- (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))).
+ (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f)))).
Definition mul_cps {n1 n2 n3} s (p : Z^n1) (q : Z^n2)
{T} (f : (Z*Z^n3)->T) :=
B.Positional.to_associational_cps weight p
(fun P => B.Positional.to_associational_cps weight q
- (fun Q => B.Associational.sat_mul_cps (mul_split := Z.mul_split) s P Q
+ (fun Q => B.Associational.sat_mul_cps (mul_split_cps := @Z.mul_split_cps') s P Q
(fun PQ => Columns.from_associational_cps weight n3 PQ
- (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))).
+ (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f)))).
Definition conditional_add_cps {n1 n2 n3} mask cond (p:Z^n1) (q:Z^n2)
{T} (f:_->T) :=