diff options
-rw-r--r-- | src/Arithmetic/Core.v | 73 | ||||
-rw-r--r-- | src/Arithmetic/CoreUnfolder.v | 15 | ||||
-rw-r--r-- | src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v | 6 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/AddSub.v | 70 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/Core.v | 43 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/Freeze.v | 2 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/MontgomeryAPI.v | 22 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/MulSplit.v | 48 | ||||
-rw-r--r-- | src/Arithmetic/Saturated/Wrappers.v | 9 | ||||
-rw-r--r-- | src/Util/ZUtil/CPS.v | 12 |
10 files changed, 195 insertions, 105 deletions
diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v index 430e1c19a..f2d9ee00b 100644 --- a/src/Arithmetic/Core.v +++ b/src/Arithmetic/Core.v @@ -585,9 +585,10 @@ Module B. Context {T : Type}. Fixpoint place_cps (t:limb) (i:nat) (f:nat * Z->T) := - if dec (fst t mod weight i = 0) + Z.eqb_cps (fst t mod weight i) 0 (fun eqb => + if eqb then f (i, let c := fst t / weight i in (c * snd t)%RT) - else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end. + else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end). End place_cps. Definition place t i := place_cps t i id. @@ -599,12 +600,13 @@ Module B. Lemma place_cps_in_range (t:limb) (n:nat) : (fst (place_cps t n id) < S n)%nat. - Proof using Type. induction n; simpl; break_match; simpl; omega. Qed. + Proof using Type. induction n; simpl; cbv [Z.eqb_cps]; break_match; simpl; omega. Qed. Lemma weight_place_cps t i : weight (fst (place_cps t i id)) * snd (place_cps t i id) = fst t * snd t. Proof using Type*. - induction i; cbv [id]; simpl place_cps; break_match; + induction i; cbv [id]; simpl place_cps; cbv [Z.eqb_cps]; break_match; + Z.ltb_to_lt; autorewrite with cancel_pair; try match goal with [H:_|-_] => apply Z_div_exact_full_2 in H end; nsatz || auto. @@ -962,38 +964,59 @@ End B. (* Modulo and div that do shifts if possible, otherwise normal mod/div *) Section DivMod. - Definition modulo (a b : Z) : Z := - if dec (2 ^ (Z.log2 b) = b) - then let x := (Z.ones (Z.log2 b)) in (a &' x)%RT - else Z.modulo a b. - - Definition div (a b : Z) : Z := - if dec (2 ^ (Z.log2 b) = b) - then let x := Z.log2 b in (a >> x)%RT - else Z.div a b. - - Lemma div_correct a b : div a b = Z.div a b. + Definition modulo_cps {T} (a b : Z) (f : Z -> T) : T := + Z.eqb_cps (2 ^ (Z.log2 b)) b (fun eqb => + if eqb + then let x := (Z.ones (Z.log2 b)) in f (a &' x)%RT + else f (Z.modulo a b)). + + Definition div_cps {T} (a b : Z) (f : Z -> T) : T := + Z.eqb_cps (2 ^ (Z.log2 b)) b (fun eqb => + if eqb + then let x := Z.log2 b in f ((a >> x)%RT) + else f (Z.div a b)). + + Definition modulo (a b : Z) : Z := modulo_cps a b id. + Definition div (a b : Z) : Z := div_cps a b id. + + Lemma modulo_id {T} a b f + : @modulo_cps T a b f = f (modulo a b). + Proof. cbv [modulo_cps modulo]; autorewrite with uncps; break_match; reflexivity. Qed. + Hint Opaque modulo : uncps. + Hint Rewrite @modulo_id : uncps. + + Lemma div_id {T} a b f + : @div_cps T a b f = f (div a b). + Proof. cbv [div_cps div]; autorewrite with uncps; break_match; reflexivity. Qed. + Hint Opaque div : uncps. + Hint Rewrite @div_id : uncps. + + Lemma div_cps_correct {T} a b f : @div_cps T a b f = f (Z.div a b). Proof. - cbv [div]; intros. break_match; try reflexivity. + cbv [div_cps Z.eqb_cps]; intros. break_match; try reflexivity. rewrite Z.shiftr_div_pow2 by apply Z.log2_nonneg. - congruence. + Z.ltb_to_lt; congruence. Qed. - Lemma modulo_correct a b : modulo a b = Z.modulo a b. + Lemma modulo_cps_correct {T} a b f : @modulo_cps T a b f = f (Z.modulo a b). Proof. - cbv [modulo]; intros. break_match; try reflexivity. + cbv [modulo_cps Z.eqb_cps]; intros. break_match; try reflexivity. rewrite Z.land_ones by apply Z.log2_nonneg. - congruence. + Z.ltb_to_lt; congruence. Qed. + Definition div_correct a b : div a b = Z.div a b := div_cps_correct a b id. + Definition modulo_correct a b : modulo a b = Z.modulo a b := modulo_cps_correct a b id. + Lemma div_mod a b (H:b <> 0) : a = b * div a b + modulo a b. Proof. - cbv [div modulo]; intros. break_match; auto using Z.div_mod. - rewrite Z.land_ones, Z.shiftr_div_pow2 by apply Z.log2_nonneg. - pose proof (Z.div_mod a b H). congruence. + rewrite div_correct, modulo_correct; auto using Z.div_mod. Qed. End DivMod. +Hint Opaque div modulo : uncps. +Hint Rewrite @div_id @modulo_id : uncps. + Import B. Create HintDb basesystem_partial_evaluation_unfolder. @@ -1045,7 +1068,7 @@ Hint Unfold Positional.eval_from Positional.select_cps Positional.select - modulo div + modulo div modulo_cps div_cps id_tuple_with_alt id_tuple'_with_alt Z.add_get_carry_full Z.add_get_carry_full_cps : basesystem_partial_evaluation_unfolder. @@ -1055,7 +1078,7 @@ Hint Unfold CPSUtil.Tuple.mapi_with_cps CPSUtil.Tuple.mapi_with'_cps CPSUtil.flat_map_cps CPSUtil.on_tuple_cps CPSUtil.fold_right_cps2 Decidable.dec Decidable.dec_eq_Z id_tuple_with_alt id_tuple'_with_alt - Z.add_get_carry_full Z.add_get_carry_full_cps Z.mul_split Z.mul_split_cps + Z.add_get_carry_full Z.add_get_carry_full_cps Z.mul_split Z.mul_split_cps Z.mul_split_cps' : basesystem_partial_evaluation_unfolder. diff --git a/src/Arithmetic/CoreUnfolder.v b/src/Arithmetic/CoreUnfolder.v index df005e630..cb943e80a 100644 --- a/src/Arithmetic/CoreUnfolder.v +++ b/src/Arithmetic/CoreUnfolder.v @@ -12,7 +12,7 @@ Ltac make_parameterized_sig t := Decidable.dec Decidable.dec_eq_Z id_tuple_with_alt id_tuple'_with_alt Z.add_get_carry_full Z.mul_split - Z.add_get_carry_full_cps Z.mul_split_cps + Z.add_get_carry_full_cps Z.mul_split_cps Z.mul_split_cps' Z.add_get_carry_cps]; repeat autorewrite with pattern_runtime; reflexivity. @@ -64,7 +64,7 @@ done echo " End Positional." echo "End B." echo "" -for i in modulo div; do +for i in modulo_cps div_cps modulo div; do echo "Definition ${i}_sig := parameterize_sig (@Core.${i})."; echo "Definition ${i} := parameterize_from_sig ${i}_sig."; echo "Definition ${i}_eq := parameterize_eq ${i} ${i}_sig."; @@ -140,6 +140,7 @@ done Definition carry := parameterize_from_sig carry_sig. Definition carry_eq := parameterize_eq carry carry_sig. Hint Rewrite <- carry_eq : pattern_runtime. + End Associational. Module Positional. Definition to_associational_cps_sig := parameterize_sig (@Core.B.Positional.to_associational_cps). @@ -300,6 +301,16 @@ done End Positional. End B. +Definition modulo_cps_sig := parameterize_sig (@Core.modulo_cps). +Definition modulo_cps := parameterize_from_sig modulo_cps_sig. +Definition modulo_cps_eq := parameterize_eq modulo_cps modulo_cps_sig. +Hint Rewrite <- modulo_cps_eq : pattern_runtime. + +Definition div_cps_sig := parameterize_sig (@Core.div_cps). +Definition div_cps := parameterize_from_sig div_cps_sig. +Definition div_cps_eq := parameterize_eq div_cps div_cps_sig. +Hint Rewrite <- div_cps_eq : pattern_runtime. + Definition modulo_sig := parameterize_sig (@Core.modulo). Definition modulo := parameterize_from_sig modulo_sig. Definition modulo_eq := parameterize_eq modulo modulo_sig. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v index 9affa82fa..fd4869f23 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v @@ -10,6 +10,7 @@ Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Depende Require Import Crypto.Util.Notations. Require Import Crypto.Util.LetIn. Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.CPS. Local Open Scope Z_scope. @@ -47,10 +48,11 @@ Section WordByWordMontgomery. := divmod_cps A (fun '(A, a) => @scmul_cps r _ a B _ (fun aB => @add_cps r _ S' aB _ (fun S1 => divmod_cps S1 (fun '(_, s) => - dlet q := fst (Z.mul_split r s k) in + Z.mul_split_cps' r s k (fun mul_split_r_s_k => + dlet q := fst mul_split_r_s_k in @scmul_cps r _ q N _ (fun qN => @add_S1_cps r _ S1 qN _ (fun S2 => divmod_cps S2 (fun '(S3, _) => - @drop_high_cps (S R_numlimbs) S3 _ (fun S4 => rest (A, S4))))))))). + @drop_high_cps (S R_numlimbs) S3 _ (fun S4 => rest (A, S4)))))))))). Section loop. Context {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) {cpsT : Type}. diff --git a/src/Arithmetic/Saturated/AddSub.v b/src/Arithmetic/Saturated/AddSub.v index e34230904..e886c36de 100644 --- a/src/Arithmetic/Saturated/AddSub.v +++ b/src/Arithmetic/Saturated/AddSub.v @@ -6,8 +6,10 @@ Require Import Crypto.Arithmetic.Core. Require Import Crypto.Arithmetic.Saturated.Core. Require Import Crypto.Arithmetic.Saturated.UniformWeight. Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.CPS. Require Import Crypto.Util.ZUtil.AddGetCarry. Require Import Crypto.Util.Tuple Crypto.Util.LetIn. +Require Import Crypto.Util.Tactics.BreakMatch. Local Notation "A ^ n" := (tuple A n) : type_scope. Module B. @@ -17,8 +19,15 @@ Module B. Let small {n} := @small s n. Section GenericOp. Context {op : Z -> Z -> Z} - {op_get_carry : Z -> Z -> Z * Z} (* no carry in, carry out *) - {op_with_carry : Z -> Z -> Z -> Z * Z}. (* carry in, carry out *) + {op_get_carry_cps : forall {T}, Z -> Z -> (Z * Z -> T) -> T} (* no carry in, carry out *) + {op_with_carry_cps : forall {T}, Z -> Z -> Z -> (Z * Z -> T) -> T}. (* carry in, carry out *) + Let op_get_carry x y := op_get_carry_cps _ x y id. + Let op_with_carry x y z := op_with_carry_cps _ x y z id. + Context {op_get_carry_id : forall {T} x y f, + @op_get_carry_cps T x y f = f (op_get_carry x y)} + {op_with_carry_id : forall {T} x y z f, + @op_with_carry_cps T x y z f = f (op_with_carry x y z)}. + Hint Rewrite @op_get_carry_id @op_with_carry_id : uncps. Section chain_op'_cps. Context (T : Type). @@ -32,14 +41,15 @@ Module B. | S n' => fun c p q f => (* for the first call, use op_get_carry, then op_with_carry *) - let op' := match c with - | None => op_get_carry - | Some x => op_with_carry x end in - dlet carry_result := op' (hd p) (hd q) in + let op'_cps := match c with + | None => op_get_carry_cps _ + | Some x => op_with_carry_cps _ x end in + op'_cps (hd p) (hd q) (fun carry_result => + dlet carry_result := carry_result in chain_op'_cps (Some (snd carry_result)) (tl p) (tl q) (fun carry_pq => f (fst carry_pq, - append (fst carry_result) (snd carry_pq))) + append (fst carry_result) (snd carry_pq)))) end c p q. End chain_op'_cps. Definition chain_op' {n} c p q := @chain_op'_cps _ n c p q id. @@ -50,7 +60,8 @@ Module B. @chain_op'_cps T n c p q f = f (chain_op' c p q). Proof. cbv [chain_op']; induction n; intros; destruct c; - simpl chain_op'_cps; cbv [Let_In]; try reflexivity. + simpl chain_op'_cps; cbv [Let_In]; try reflexivity; + autorewrite with uncps. { etransitivity; rewrite IHn; reflexivity. } { etransitivity; rewrite IHn; reflexivity. } Qed. @@ -60,7 +71,7 @@ Module B. Proof. apply (@chain_op'_id n None). Qed. End GenericOp. Hint Opaque chain_op chain_op' : uncps. - Hint Rewrite @chain_op_id @chain_op'_id : uncps. + Hint Rewrite @chain_op_id @chain_op'_id using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps. Section AddSub. Create HintDb divmod discriminated. @@ -77,14 +88,14 @@ Module B. Let eval {n} := B.Positional.eval (n:=n) (uweight s). Definition sat_add_cps {n} p q T (f:Z*Z^n->T) := - chain_op_cps (op_get_carry := Z.add_get_carry_full s) - (op_with_carry := Z.add_with_get_carry_full s) + chain_op_cps (op_get_carry_cps := fun T => Z.add_get_carry_full_cps s) + (op_with_carry_cps := fun T => Z.add_with_get_carry_full_cps s) p q f. Definition sat_add {n} p q := @sat_add_cps n p q _ id. Lemma sat_add_id n p q T f : @sat_add_cps n p q T f = f (sat_add p q). - Proof. cbv [sat_add sat_add_cps]. rewrite !chain_op_id. reflexivity. Qed. + Proof. cbv [sat_add sat_add_cps]. autorewrite with uncps. reflexivity. Qed. Lemma sat_add_mod_step n c d : c mod s + s * ((d + c / s) mod (uweight s n)) @@ -156,23 +167,25 @@ Module B. simpl In in H | H : _ \/ _ |- _ => destruct H | _ => contradiction - end. - { subst x. - destruct c; rewrite ?Z.add_with_get_carry_full_mod, - ?Z.add_get_carry_full_mod; - apply Z.mod_pos_bound; omega. } - { apply IHn in H. assumption. } + | _ => break_innermost_match_hyps_step + | _ => progress subst + | [ H : In _ (to_list _ (snd _)) |- _ ] + => apply IHn in H; assumption + end; + try solve [ rewrite ?Z.add_with_get_carry_full_mod, + ?Z.add_get_carry_full_mod; + apply Z.mod_pos_bound; omega ]. Qed. Definition sat_sub_cps {n} p q T (f:Z*Z^n->T) := - chain_op_cps (op_get_carry := Z.sub_get_borrow_full s) - (op_with_carry := Z.sub_with_get_borrow_full s) + chain_op_cps (op_get_carry_cps := fun T => Z.sub_get_borrow_full_cps s) + (op_with_carry_cps := fun T => Z.sub_with_get_borrow_full_cps s) p q f. Definition sat_sub {n} p q := @sat_sub_cps n p q _ id. Lemma sat_sub_id n p q T f : @sat_sub_cps n p q T f = f (sat_sub p q). - Proof. cbv [sat_sub sat_sub_cps]. rewrite !chain_op_id. reflexivity. Qed. + Proof. cbv [sat_sub sat_sub_cps]. autorewrite with uncps. reflexivity. Qed. Lemma sat_sub_divmod n p q : eval (snd (@sat_sub n p q)) = (eval p - eval q) mod (uweight s n) /\ fst (@sat_sub n p q) = - ((eval p - eval q) / (uweight s n)). @@ -224,19 +237,22 @@ Module B. simpl In in H | H : _ \/ _ |- _ => destruct H | _ => contradiction - end. - { subst x. - destruct c; rewrite ?Z.sub_with_get_borrow_full_mod, + | _ => break_innermost_match_hyps_step + | _ => progress subst + | [ H : In _ (to_list _ (snd _)) |- _ ] + => apply IHn in H; assumption + end; + try solve [ rewrite ?Z.sub_with_get_borrow_full_mod, ?Z.sub_get_borrow_full_mod; - apply Z.mod_pos_bound; omega. } - { apply IHn in H. assumption. } + apply Z.mod_pos_bound; omega ]. Qed. End AddSub. End Positional. End Positional. End B. Hint Opaque B.Positional.sat_sub B.Positional.sat_add B.Positional.chain_op B.Positional.chain_op' : uncps. -Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id @B.Positional.chain_op_id @B.Positional.chain_op' : uncps. +Hint Rewrite @B.Positional.sat_sub_id @B.Positional.sat_add_id : uncps. +Hint Rewrite @B.Positional.chain_op_id @B.Positional.chain_op' using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps. Hint Rewrite @B.Positional.sat_sub_mod @B.Positional.sat_sub_div @B.Positional.sat_add_mod @B.Positional.sat_add_div using (omega || assumption) : push_basesystem_eval. Hint Unfold diff --git a/src/Arithmetic/Saturated/Core.v b/src/Arithmetic/Saturated/Core.v index e34a37501..8af1cce6f 100644 --- a/src/Arithmetic/Saturated/Core.v +++ b/src/Arithmetic/Saturated/Core.v @@ -112,16 +112,20 @@ Module Columns. {weight_multiples : forall i, weight (S i) mod weight i = 0} {weight_divides : forall i : nat, weight (S i) / weight i > 0} (* add_get_carry takes in a number at which to split output *) - {add_get_carry: Z ->Z -> Z -> (Z * Z)} + {add_get_carry_cps: forall {T}, Z ->Z -> Z -> (Z * Z -> T) -> T} + {add_get_carry_cps_id : forall {T} s x y f, + @add_get_carry_cps T s x y f = f (@add_get_carry_cps _ s x y id)} {add_get_carry_mod : forall s x y, - fst (add_get_carry s x y) = (x + y) mod s} + fst (add_get_carry_cps s x y id) = (x + y) mod s} {add_get_carry_div : forall s x y, - snd (add_get_carry s x y) = (x + y) / s} + snd (add_get_carry_cps s x y id) = (x + y) / s} {div modulo : Z -> Z -> Z} {div_correct : forall a b, div a b = a / b} {modulo_correct : forall a b, modulo a b = a mod b} . Hint Rewrite div_correct modulo_correct add_get_carry_mod add_get_carry_div : div_mod. + Let add_get_carry s x y := add_get_carry_cps _ s x y id. + Hint Rewrite (add_get_carry_cps_id : forall T s x y f, _ = f (@add_get_carry s x y)) : uncps. Definition eval {n} (x : (list Z)^n) : Z := B.Positional.eval weight (Tuple.map sum x). @@ -164,25 +168,27 @@ Module Columns. | nil => f (0, 0) | x :: nil => f (div x (weight (S n) / weight n), modulo x (weight (S n) / weight n)) | x :: y :: nil => - dlet sum_carry := add_get_carry (weight (S n) / weight n) x y in + add_get_carry_cps _ (weight (S n) / weight n) x y (fun sum_carry => + dlet sum_carry := sum_carry in dlet carry := snd sum_carry in - f (carry, fst sum_carry) + f (carry, fst sum_carry)) | x :: tl => compact_digit_cps tl (fun rec => - dlet sum_carry := add_get_carry (weight (S n) / weight n) x (snd rec) in + add_get_carry_cps _ (weight (S n) / weight n) x (snd rec) (fun sum_carry => + dlet sum_carry := sum_carry in dlet carry' := (fst rec + snd sum_carry)%RT in - f (carry', fst sum_carry)) + f (carry', fst sum_carry))) end. End compact_digit_cps. Definition compact_digit n digit := compact_digit_cps n digit id. Lemma compact_digit_id n digit: forall {T} f, @compact_digit_cps n T digit f = f (compact_digit n digit). - Proof using Type. - induction digit; intros; cbv [compact_digit]; [reflexivity|]; - simpl compact_digit_cps; break_match; rewrite ?IHdigit; - reflexivity. + Proof using add_get_carry_cps_id. + induction digit; intros; cbv [compact_digit]; [reflexivity|]. + simpl compact_digit_cps; break_match; rewrite ?IHdigit; clear IHdigit; + cbv [Let_In]; autorewrite with uncps; reflexivity. Qed. Hint Opaque compact_digit : uncps. Hint Rewrite compact_digit_id : uncps. @@ -194,7 +200,7 @@ Module Columns. Definition compact_step i c d := compact_step_cps i c d id. Lemma compact_step_id i c d T f : @compact_step_cps i c d T f = f (compact_step i c d). - Proof using Type. cbv [compact_step_cps compact_step]; autorewrite with uncps; reflexivity. Qed. + Proof using add_get_carry_cps_id. cbv [compact_step_cps compact_step]; autorewrite with uncps; reflexivity. Qed. Hint Opaque compact_step : uncps. Hint Rewrite compact_step_id : uncps. @@ -203,15 +209,15 @@ Module Columns. Definition compact {n} xs := @compact_cps n xs _ id. Lemma compact_id {n} xs {T} f : @compact_cps n xs T f = f (compact xs). - Proof using Type. cbv [compact_cps compact]; autorewrite with uncps; reflexivity. Qed. + Proof using add_get_carry_cps_id. cbv [compact_cps compact]; autorewrite with uncps; reflexivity. Qed. Lemma compact_digit_mod i (xs : list Z) : snd (compact_digit i xs) = sum xs mod (weight (S i) / weight i). - Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct. + Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct add_get_carry_cps_id. induction xs; cbv [compact_digit]; simpl compact_digit_cps; cbv [Let_In]; repeat match goal with - | _ => progress autorewrite with div_mod + | _ => cbv [add_get_carry]; progress autorewrite with div_mod | _ => rewrite IHxs, <-Z.add_mod_r | _ => progress (rewrite ?sum_cons, ?sum_nil in * ) | _ => progress (autorewrite with uncps push_id cancel_pair in * ) @@ -223,11 +229,11 @@ Module Columns. Lemma compact_digit_div i (xs : list Z) : fst (compact_digit i xs) = sum xs / (weight (S i) / weight i). - Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct weight_0 weight_divides. + Proof using add_get_carry_div add_get_carry_mod div_correct modulo_correct weight_0 weight_divides add_get_carry_cps_id. induction xs; cbv [compact_digit]; simpl compact_digit_cps; cbv [Let_In]; repeat match goal with - | _ => progress autorewrite with div_mod + | _ => cbv [add_get_carry]; progress autorewrite with div_mod | _ => rewrite IHxs | _ => progress (rewrite ?sum_cons, ?sum_nil in * ) | _ => progress (autorewrite with uncps push_id cancel_pair in * ) @@ -421,6 +427,9 @@ Hint Rewrite @Columns.compact_digit_id @Columns.compact_step_id @Columns.compact_id + using (assumption || (intros; autorewrite with uncps; reflexivity)) + : uncps. +Hint Rewrite @Columns.cons_to_nth_id @Columns.from_associational_id : uncps. diff --git a/src/Arithmetic/Saturated/Freeze.v b/src/Arithmetic/Saturated/Freeze.v index 65b8ee55d..b56a69a3d 100644 --- a/src/Arithmetic/Saturated/Freeze.v +++ b/src/Arithmetic/Saturated/Freeze.v @@ -7,6 +7,7 @@ Require Import Crypto.Arithmetic.Saturated.Core. Require Import Crypto.Arithmetic.Saturated.Wrappers. Require Import Crypto.Util.ZUtil.AddGetCarry. Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.CPS. Require Import Crypto.Util.Tactics.BreakMatch. Require Import Crypto.Util.Decidable Crypto.Util.ZUtil. Require Import Crypto.Util.Tuple Crypto.Util.LetIn. @@ -101,6 +102,7 @@ Section Freeze. pose proof Z.add_get_carry_full_mod. pose proof Z.add_get_carry_full_div. pose proof div_correct. pose proof modulo_correct. + pose proof @Z.add_get_carry_full_cps_correct. autorewrite with uncps push_id push_basesystem_eval. pose proof (weight_nonzero n). diff --git a/src/Arithmetic/Saturated/MontgomeryAPI.v b/src/Arithmetic/Saturated/MontgomeryAPI.v index 7ff86258c..fb896749b 100644 --- a/src/Arithmetic/Saturated/MontgomeryAPI.v +++ b/src/Arithmetic/Saturated/MontgomeryAPI.v @@ -13,6 +13,7 @@ Require Import Crypto.Util.Tuple Crypto.Util.LetIn. Require Import Crypto.Util.Decidable. Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil. Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.CPS. Require Import Crypto.Util.ZUtil.Zselect. Require Import Crypto.Util.ZUtil.AddGetCarry. Require Import Crypto.Util.ZUtil.MulSplit. @@ -106,7 +107,9 @@ Section API. Section CPSProofs. Local Ltac prove_id := - repeat autounfold; autorewrite with uncps; reflexivity. + repeat autounfold; + repeat (intros; autorewrite with uncps push_id); + reflexivity. Lemma nonzero_id n p {cpsT} f : @nonzero_cps n p cpsT f = f (@nonzero n p). Proof. cbv [nonzero nonzero_cps]. prove_id. Qed. @@ -281,7 +284,10 @@ Section API. pose proof Z.add_get_carry_full_div; pose proof Z.add_get_carry_full_mod; pose proof Z.mul_split_div; pose proof Z.mul_split_mod; - pose proof div_correct; pose proof modulo_correct. + pose proof div_correct; pose proof modulo_correct; + pose proof @Z.add_get_carry_full_cps_correct; + pose proof @Z.mul_split_cps_correct; + pose proof @Z.mul_split_cps'_correct. Lemma eval_add n p q : eval (@add n p q) = eval p + eval q. @@ -308,8 +314,8 @@ Section API. Qed. Hint Rewrite eval_add_same eval_add_S1 eval_add_S2 using (omega || assumption): push_basesystem_eval. - Local Definition compact {n} := Columns.compact (n:=n) (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound). - Local Definition compact_digit := Columns.compact_digit (add_get_carry:=Z.add_get_carry_full) (div:=div) (modulo:=modulo) (uweight bound). + Local Definition compact {n} := Columns.compact (n:=n) (add_get_carry_cps:=@Z.add_get_carry_full_cps) (div:=div) (modulo:=modulo) (uweight bound). + Local Definition compact_digit := Columns.compact_digit (add_get_carry_cps:=@Z.add_get_carry_full_cps) (div:=div) (modulo:=modulo) (uweight bound). Lemma small_compact {n} (p:(list Z)^n) : small (snd (compact p)). Proof. pose_all. @@ -329,7 +335,8 @@ Section API. match goal with H : _ /\ _ |- _ => destruct H end. destruct n0; subst f. { cbv [compact_digit uweight to_list to_list' In]. - rewrite Columns.compact_digit_mod by assumption. + rewrite Columns.compact_digit_mod + by (assumption || (intros; autorewrite with uncps push_id; auto)). rewrite Z.pow_0_r, Z.pow_1_r, Z.div_1_r. intros x ?. match goal with H : _ \/ False |- _ => destruct H; [|exfalso; assumption] end. @@ -340,7 +347,8 @@ Section API. [solve[auto]| cbv [In] in H; destruct H; [|exfalso; assumption] ]. subst x. cbv [compact_digit]. - rewrite Columns.compact_digit_mod by assumption. + rewrite Columns.compact_digit_mod + by (assumption || (intros; autorewrite with uncps push_id; auto)). rewrite !uweight_succ, Z.div_mul by (apply Z.neq_mul_0; split; auto; omega). apply Z.mod_pos_bound, Z.gt_lt, bound_pos. } @@ -554,6 +562,8 @@ Section API. Proof. intro Hsmall. pose_all. apply eval_small in Hsmall. intros. cbv [scmul scmul_cps eval] in *. repeat autounfold. + autorewrite with uncps. + autorewrite with push_basesystem_eval. autorewrite with uncps push_id push_basesystem_eval. rewrite uweight_0, Z.mul_1_l. apply Z.mod_small. split; [solve[Z.zero_bounds]|]. cbv [uweight] in *. diff --git a/src/Arithmetic/Saturated/MulSplit.v b/src/Arithmetic/Saturated/MulSplit.v index 98d8d0e0c..4947f422a 100644 --- a/src/Arithmetic/Saturated/MulSplit.v +++ b/src/Arithmetic/Saturated/MulSplit.v @@ -9,21 +9,34 @@ Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil. Module B. Module Associational. Section Associational. - Context {mul_split : Z -> Z -> Z -> Z * Z} (* first argument is where to split output; [mul_split s x y] gives ((x * y) mod s, (x * y) / s) *) + Context {mul_split_cps : forall {T}, Z -> Z -> Z -> (Z * Z -> T) -> T} (* first argument is where to split output; [mul_split s x y] gives ((x * y) mod s, (x * y) / s) *) + {mul_split_cps_id : forall {T} s x y f, + @mul_split_cps T s x y f = f (@mul_split_cps _ s x y id)} {mul_split_mod : forall s x y, - fst (mul_split s x y) = (x * y) mod s} + fst (mul_split_cps s x y id) = (x * y) mod s} {mul_split_div : forall s x y, - snd (mul_split s x y) = (x * y) / s} - . + snd (mul_split_cps s x y id) = (x * y) / s} + . + + Local Lemma mul_split_cps_correct {T} s x y f + : @mul_split_cps T s x y f = f ((x * y) mod s, (x * y) / s). + Proof. + now rewrite mul_split_cps_id, <- mul_split_mod, <- mul_split_div, <- surjective_pairing. + Qed. + Hint Rewrite @mul_split_cps_correct : uncps. Definition sat_multerm_cps s (t t' : B.limb) {T} (f:list B.limb ->T) := - dlet xy := mul_split s (snd t) (snd t') in - f ((fst t * fst t', fst xy) :: (fst t * fst t' * s, snd xy) :: nil). + mul_split_cps _ s (snd t) (snd t') (fun xy => + dlet xy := xy in + f ((fst t * fst t', fst xy) :: (fst t * fst t' * s, snd xy) :: nil)). Definition sat_multerm s t t' := sat_multerm_cps s t t' id. Lemma sat_multerm_id s t t' T f : @sat_multerm_cps s t t' T f = f (sat_multerm s t t'). - Proof. reflexivity. Qed. + Proof. + unfold sat_multerm, sat_multerm_cps; + etransitivity; rewrite mul_split_cps_id; reflexivity. + Qed. Hint Opaque sat_multerm : uncps. Hint Rewrite sat_multerm_id : uncps. @@ -39,14 +52,17 @@ Module B. Lemma eval_map_sat_multerm s a q (s_nonzero:s<>0): B.Associational.eval (flat_map (sat_multerm s a) q) = fst a * snd a * B.Associational.eval q. Proof. - cbv [sat_multerm sat_multerm_cps Let_In]; induction q; - repeat match goal with - | _ => progress (autorewrite with uncps push_id cancel_pair push_basesystem_eval in * ) - | _ => progress simpl flat_map - | _ => progress rewrite ?IHq, ?mul_split_mod, ?mul_split_div - | _ => rewrite Z.mod_eq by assumption - | _ => ring_simplify; omega - end. + cbv [sat_multerm sat_multerm_cps Let_In]; induction q; + repeat match goal with + | _ => progress autorewrite with uncps push_id cancel_pair push_basesystem_eval in * + | _ => progress simpl flat_map + | _ => progress unfold id in * + | _ => progress rewrite ?IHq, ?mul_split_mod, ?mul_split_div + | _ => rewrite Z.mod_eq by assumption + | _ => rewrite B.Associational.eval_nil + | _ => progress change (Z * Z)%type with B.limb + | _ => ring_simplify; omega + end. Qed. Hint Rewrite eval_map_sat_multerm using (omega || assumption) : push_basesystem_eval. @@ -68,7 +84,7 @@ Module B. End Associational. End B. Hint Opaque B.Associational.sat_mul B.Associational.sat_multerm : uncps. -Hint Rewrite @B.Associational.sat_mul_id @B.Associational.sat_multerm_id : uncps. +Hint Rewrite @B.Associational.sat_mul_id @B.Associational.sat_multerm_id using (assumption || (intros; autorewrite with uncps; reflexivity)) : uncps. Hint Rewrite @B.Associational.eval_sat_mul @B.Associational.eval_map_sat_multerm using (omega || assumption) : push_basesystem_eval. Hint Unfold diff --git a/src/Arithmetic/Saturated/Wrappers.v b/src/Arithmetic/Saturated/Wrappers.v index 6fe466967..6bb3893d5 100644 --- a/src/Arithmetic/Saturated/Wrappers.v +++ b/src/Arithmetic/Saturated/Wrappers.v @@ -7,6 +7,7 @@ Require Import Crypto.Arithmetic.Saturated.Core. Require Import Crypto.Arithmetic.Saturated.MulSplit. Require Import Crypto.Util.ZUtil.Definitions. Require Import Crypto.Util.ZUtil.MulSplit. +Require Import Crypto.Util.ZUtil.CPS. Require Import Crypto.Util.Tuple. Local Notation "A ^ n" := (tuple A n) : type_scope. @@ -21,7 +22,7 @@ Module Columns. B.Positional.to_associational_cps weight p (fun P => B.Positional.to_associational_cps weight q (fun Q => Columns.from_associational_cps weight n3 (P++Q) - (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f))). + (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f))). Definition unbalanced_sub_cps {n1 n2 n3} (p : Z^n1) (q:Z^n2) {T} (f : (Z*Z^n3)->T) := @@ -29,15 +30,15 @@ Module Columns. (fun P => B.Positional.negate_snd_cps weight q (fun nq => B.Positional.to_associational_cps weight nq (fun Q => Columns.from_associational_cps weight n3 (P++Q) - (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))). + (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f)))). Definition mul_cps {n1 n2 n3} s (p : Z^n1) (q : Z^n2) {T} (f : (Z*Z^n3)->T) := B.Positional.to_associational_cps weight p (fun P => B.Positional.to_associational_cps weight q - (fun Q => B.Associational.sat_mul_cps (mul_split := Z.mul_split) s P Q + (fun Q => B.Associational.sat_mul_cps (mul_split_cps := @Z.mul_split_cps') s P Q (fun PQ => Columns.from_associational_cps weight n3 PQ - (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry:=Z.add_get_carry_full) weight R f)))). + (fun R => Columns.compact_cps (div:=div) (modulo:=modulo) (add_get_carry_cps:=@Z.add_get_carry_full_cps) weight R f)))). Definition conditional_add_cps {n1 n2 n3} mask cond (p:Z^n1) (q:Z^n2) {T} (f:_->T) := diff --git a/src/Util/ZUtil/CPS.v b/src/Util/ZUtil/CPS.v index 754f6cf56..3c0007c88 100644 --- a/src/Util/ZUtil/CPS.v +++ b/src/Util/ZUtil/CPS.v @@ -35,38 +35,38 @@ Module Z. | break_innermost_match_step ]. Definition get_carry_cps {T} (bitwidth : Z) (v : Z) (f : Z * Z -> T) : T - := let '(v, c) := Z.get_carry bitwidth v in f (v, c). + := f (Z.get_carry bitwidth v). Definition get_carry_cps_correct {T} bitwidth v f : @get_carry_cps T bitwidth v f = f (Z.get_carry bitwidth v) := eq_refl. Hint Rewrite @get_carry_cps_correct : uncps. Definition add_with_get_carry_cps {T} (bitwidth : Z) (c : Z) (x y : Z) (f : Z * Z -> T) : T - := let '(v, c) := Z.add_with_get_carry bitwidth c x y in f (v, c). + := f (Z.add_with_get_carry bitwidth c x y). Definition add_with_get_carry_cps_correct {T} bitwidth c x y f : @add_with_get_carry_cps T bitwidth c x y f = f (Z.add_with_get_carry bitwidth c x y) := eq_refl. Hint Rewrite @add_with_get_carry_cps_correct : uncps. Definition add_get_carry_cps {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T) : T - := let '(v, c) := Z.add_get_carry bitwidth x y in f (v, c). + := f (Z.add_get_carry bitwidth x y). Definition add_get_carry_cps_correct {T} bitwidth x y f : @add_get_carry_cps T bitwidth x y f = f (Z.add_get_carry bitwidth x y) := eq_refl. Hint Rewrite @add_get_carry_cps_correct : uncps. Definition get_borrow_cps {T} (bitwidth : Z) (v : Z) (f : Z * Z -> T) - := let '(v, c) := Z.get_borrow bitwidth v in f (v, c). + := f (Z.get_borrow bitwidth v). Definition get_borrow_cps_correct {T} bitwidth v f : @get_borrow_cps T bitwidth v f = f (Z.get_borrow bitwidth v) := eq_refl. Hint Rewrite @get_borrow_cps_correct : uncps. Definition sub_with_get_borrow_cps {T} (bitwidth : Z) (c : Z) (x y : Z) (f : Z * Z -> T) : T - := let '(v, c) := Z.sub_with_get_borrow bitwidth c x y in f (v, c). + := f (Z.sub_with_get_borrow bitwidth c x y). Definition sub_with_get_borrow_cps_correct {T} (bitwidth : Z) (c : Z) (x y : Z) (f : Z * Z -> T) : @sub_with_get_borrow_cps T bitwidth c x y f = f (Z.sub_with_get_borrow bitwidth c x y) := eq_refl. Hint Rewrite @sub_with_get_borrow_cps_correct : uncps. Definition sub_get_borrow_cps {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T) : T - := let '(v, c) := Z.sub_get_borrow bitwidth x y in f (v, c). + := f (Z.sub_get_borrow bitwidth x y). Definition sub_get_borrow_cps_correct {T} (bitwidth : Z) (x y : Z) (f : Z * Z -> T) : @sub_get_borrow_cps T bitwidth x y f = f (Z.sub_get_borrow bitwidth x y) := eq_refl. |