diff options
-rw-r--r-- | src/PushButtonSynthesis/BarrettReduction.v | 152 | ||||
-rw-r--r-- | src/PushButtonSynthesis/MontgomeryReduction.v | 11 |
2 files changed, 101 insertions, 62 deletions
diff --git a/src/PushButtonSynthesis/BarrettReduction.v b/src/PushButtonSynthesis/BarrettReduction.v index 224584c4a..265958c09 100644 --- a/src/PushButtonSynthesis/BarrettReduction.v +++ b/src/PushButtonSynthesis/BarrettReduction.v @@ -3,6 +3,7 @@ Require Import Coq.Strings.String. Require Import Coq.ZArith.ZArith. Require Import Coq.Lists.List. Require Import Coq.derive.Derive. +Require Import Coq.micromega.Lia. Require Import Crypto.Util.ErrorT. Require Import Crypto.Util.ListUtil. Require Import Crypto.Util.ZRange. @@ -11,6 +12,7 @@ Require Import Crypto.Language. Require Import Crypto.CStringification. Require Import Crypto.Arithmetic. Require Import Crypto.BoundsPipeline. +Require Import Crypto.Fancy.Compiler. Require Import Crypto.COperationSpecifications. Require Import Crypto.PushButtonSynthesis.ReificationCache. Require Import Crypto.PushButtonSynthesis.Primitives. @@ -26,6 +28,7 @@ Import Import Compilers.defaults. Import COperationSpecifications.Primitives. +Import COperationSpecifications.BarrettReduction. Import Associational Positional Arithmetic.BarrettReduction. @@ -34,7 +37,7 @@ Local Set Keyed Unification. (* needed for making [autorewrite] fast, c.f. COQBU Local Opaque reified_barrett_red_gen. (* needed for making [autorewrite] not take a very long time *) Section rbarrett_red. - Context (M : Z) + Context (k M : Z) (n nout : nat) (machine_wordsize : Z). Let value_range := r[0 ~> (2^machine_wordsize - 1)%Z]%zrange. @@ -48,19 +51,6 @@ Section rbarrett_red. := [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize]%Z. Let possible_values := possible_values_of_machine_wordsize. - Definition check_args {T} (res : Pipeline.ErrorT T) - : Pipeline.ErrorT T - := fold_right - (fun '(b, e) k => if b:bool then Error e else k) - res - [((mu / (2 ^ machine_wordsize) =? 0), Pipeline.Values_not_provably_distinctZ "mu / 2 ^ k ≠ 0" (mu / 2 ^ machine_wordsize) 0); - ((machine_wordsize <? 2), Pipeline.Value_not_leZ "~ (2 <=k)" 2 machine_wordsize); - (negb (Z.log2 M + 1 =? machine_wordsize), Pipeline.Values_not_provably_equalZ "log2(M)+1 != k" (Z.log2 M + 1) machine_wordsize); - ((2 ^ (machine_wordsize + 1) - mu <? 2 * (2 ^ (2 * machine_wordsize) mod M)), - Pipeline.Value_not_leZ "~ (2 * (2 ^ (2*k) mod M) <= 2^(k + 1) - mu)" - (2 * (2 ^ (2*machine_wordsize) mod M)) - (2^(machine_wordsize + 1) - mu))]. - Let fancy_args := (Some {| Pipeline.invert_low log2wordsize := invert_low log2wordsize consts_list; Pipeline.invert_high log2wordsize := invert_high log2wordsize consts_list; @@ -78,6 +68,79 @@ Section rbarrett_red. cbv [fancy_args invert_low invert_high constant_to_scalar constant_to_scalar_single consts_list fold_right]; split; intros; break_innermost_match_hyps; Z.ltb_to_lt; subst; congruence. Qed. + Local Hint Extern 1 => apply fancy_args_good: typeclass_instances. (* This is a kludge *) + + (** Note: If you change the name or type signature of this + function, you will need to update the code in CLI.v *) + Definition check_args {T} (res : Pipeline.ErrorT T) + : Pipeline.ErrorT T + := fold_right + (fun '(b, e) k => if b:bool then Error e else k) + res + [ + ((negb (2 <=? k))%Z, Pipeline.Value_not_ltZ "k < 2" 2 k); + ((n =? 0)%nat, Pipeline.Values_not_provably_distinctZ "n = 0" (Z.of_nat n) 0); + ((negb (0 <? M))%Z, Pipeline.Value_not_ltZ "M ≤ 0" 0 M); + (negb (muLow + 2 ^ k =? 2 ^ (2 * k) / M)%Z, Pipeline.Values_not_provably_equalZ "muLow + 2^k ≠ 2 ^ (2 * k) / M" (muLow + 2^k) (2 ^ (2 * k) / M)); + ((negb (0 <=? muLow))%Z, Pipeline.Value_not_leZ "muLow < 0" 0 muLow); + ((negb (muLow <? 2 ^ k))%Z, Pipeline.Value_not_ltZ "2 ^ k ≤ muLow" muLow (2^k)); + ((negb (2 ^ (k-1) <? M))%Z, Pipeline.Value_not_ltZ "M ≤ 2^(k-1)" (2^(k-1)) M); + ((negb (M <? 2 ^ k))%Z, Pipeline.Value_not_ltZ "2 ^ k ≤ M" M (2^k)); + (negb ((2 * (2 ^ (2 * k) mod M) <=? 2 ^ (k + 1) - (muLow + 2 ^ k)))%Z, Pipeline.Value_not_leZ ("(2 * (2 ^ (2 * k) mod M) 2 ^ (k + 1) - (muLow + 2 ^ k)") (2 * (2 ^ (2 * k) mod M)) (2 ^ (k + 1) - (muLow + 2 ^ k))); + (negb (Z.of_nat n <=? k)%Z, Pipeline.Value_not_leZ "k < n" (Z.of_nat n) k); + (negb (k =? machine_wordsize)%Z, Pipeline.Values_not_provably_equalZ "k ≠ machine_wordsize" k machine_wordsize); + (negb (n =? 2)%nat, Pipeline.Values_not_provably_equalZ "n ≠ 2" (Z.of_nat n) 2); + (negb (nout =? 2)%nat, Pipeline.Values_not_provably_equalZ "nout ≠ 2" (Z.of_nat nout) 2)]. + + Local Arguments Z.mul !_ !_. + Local Ltac use_curve_good_t := + repeat first [ assumption + | progress cbv [EquivModulo.Z.equiv_modulo] + | progress rewrite ?Z.mul_0_r, ?Pos.mul_1_r, ?Z.mul_1_r in * + | reflexivity + | lia + | progress cbn in * + | progress intros + | solve [ auto with zarith ] + | rewrite Z.log2_pow2 by use_curve_good_t ]. + + Context (curve_good : check_args (Success tt) = Success tt). + + Lemma use_curve_good + : 2 <= k + /\ 0 < M + /\ muLow + 2 ^ k = 2 ^ (2 * k) / M + /\ 0 <= muLow < 2 ^ k + /\ 2 ^ (k - 1) < M < 2 ^ k + /\ 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2 ^ k) + /\ n <> 0%nat + /\ Z.of_nat n <= k + /\ k = machine_wordsize + /\ n = 2%nat + /\ nout = 2%nat. + Proof using curve_good. + clear -curve_good. + cbv [check_args fold_right] in curve_good. + break_innermost_match_hyps; try discriminate. + rewrite Bool.negb_false_iff in *. + Z.ltb_to_lt. + rewrite NPeano.Nat.eqb_neq in *. + intros. + repeat apply conj. + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + Qed. Definition barrett_red := Pipeline.BoundsPipeline @@ -96,42 +159,27 @@ Section rbarrett_red. prefix "barrett_red" barrett_red (fun _ _ _ => @nil string). - Local Strategy -100 [barrett_red]. (* Probably needed to make Qed not take forever *) - (* TODO: Replace the following lemmas with a new-glue-style correctness lemma, like -<< -Lemma barrett_red_correct res - (Hres : barrett_red = Success res) - : barrett_red_correct (weight (Qnum limbwidth) (QDen limbwidth)) n m tight_bounds loose_bounds (Interp res). - Proof using curve_good. prove_correctness (). Qed. ->> *) - - Notation BoundsPipeline_correct in_bounds out_bounds op - := (fun rv (rop : Expr (reify_type_of op)) Hrop - => @Pipeline.BoundsPipeline_correct_trans - false (* subst01 *) - fancy_args - fancy_args_good - possible_values - _ - rop - in_bounds - out_bounds - _ - op - Hrop rv) - (only parsing). - - Definition rbarrett_red_correct - := BoundsPipeline_correct - (bound, (bound, tt)) - bound - (barrett_reduce machine_wordsize M muLow 2 2). - - Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _). - Definition rbarrett_red_correctT rv : Prop - := type_of_strip_3arrow (@rbarrett_red_correct rv). + Local Ltac solve_barrett_red_preconditions := + repeat first [ lia + | assumption + | apply use_curve_good + | progress autorewrite with zsimplify + | progress intros + | progress cbv [weight] + | rewrite Z.pow_mul_r by lia ]. + + Local Strategy -100 [barrett_red]. (* needed for making Qed not take forever *) + Lemma barrett_red_correct res (Hres : barrett_red = Success res) + : barrett_red_correct k M (expr.Interp (@ident.gen_interp cast_oor) res). + Proof using k M curve_good. + cbv [barrett_red_correct]; intros. + assert (2 <= k) by apply use_curve_good. + rewrite <-barrett_reduce_correct with (muLow := muLow) (n:=n) (nout:=nout) by solve_barrett_red_preconditions. + prove_correctness' ltac:(fun _ => idtac) use_curve_good. + { congruence. } + { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. + subst k. rewrite Bool.andb_true_iff, !Z.leb_le. lia. } + { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. + subst k. rewrite Bool.andb_true_iff, !Z.leb_le. lia. } + Qed. End rbarrett_red. - -(* TODO: After moving to new-glue-style, remove these tactics *) -Ltac solve_rbarrett_red := solve_rop rbarrett_red_correct. -Ltac solve_rbarrett_red_nocache := solve_rop_nocache rbarrett_red_correct. diff --git a/src/PushButtonSynthesis/MontgomeryReduction.v b/src/PushButtonSynthesis/MontgomeryReduction.v index d4399a743..a682aa227 100644 --- a/src/PushButtonSynthesis/MontgomeryReduction.v +++ b/src/PushButtonSynthesis/MontgomeryReduction.v @@ -168,13 +168,6 @@ Section rmontred. prefix "montred" montred (fun _ _ _ => @nil string). - (* TODO: Replace the following lemmas with a new-glue-style correctness lemma, like -<< -Lemma montred_correct res - (Hres : montred = Success res) - : montred_correct (weight (Qnum limbwidth) (QDen limbwidth)) n m tight_bounds loose_bounds (Interp res). - Proof using curve_good. prove_correctness (). Qed. ->> *) Local Ltac solve_montred_preconditions := repeat first [ lia | apply use_curve_good @@ -197,6 +190,4 @@ Lemma montred_correct res { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. rewrite Bool.andb_true_iff, !Z.leb_le. lia. } Qed. -End rmontred. - -(* TODO: get Barrett to this same point, and then use these lemmas in the specific files *)
\ No newline at end of file +End rmontred.
\ No newline at end of file |