From a1f9d9ee2c662790b43bf56df58121a390efbe7c Mon Sep 17 00:00:00 2001 From: jadep Date: Fri, 15 Feb 2019 14:41:47 -0500 Subject: start adapting Montgomery to new glue code --- src/PushButtonSynthesis/MontgomeryReduction.v | 144 +++++++++++++++++++------- src/PushButtonSynthesis/Primitives.v | 6 +- 2 files changed, 113 insertions(+), 37 deletions(-) (limited to 'src/PushButtonSynthesis') diff --git a/src/PushButtonSynthesis/MontgomeryReduction.v b/src/PushButtonSynthesis/MontgomeryReduction.v index a452a047c..1387cad36 100644 --- a/src/PushButtonSynthesis/MontgomeryReduction.v +++ b/src/PushButtonSynthesis/MontgomeryReduction.v @@ -1,17 +1,23 @@ (** * Push-Button Synthesis of Montgomery Reduction *) Require Import Coq.Strings.String. Require Import Coq.ZArith.ZArith. +Require Import Coq.micromega.Lia. Require Import Coq.Lists.List. Require Import Coq.derive.Derive. Require Import Crypto.Util.ErrorT. Require Import Crypto.Util.ListUtil. Require Import Crypto.Util.ZRange. +Require Import Crypto.Util.ZUtil.Div. +Require Import Crypto.Util.ZUtil.ModInv. Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.ZUtil.Tactics.RewriteModSmall. Require Import Crypto.Language. Require Import Crypto.CStringification. Require Import Crypto.Arithmetic. Require Import Crypto.BoundsPipeline. Require Import Crypto.COperationSpecifications. +Require Import Crypto.Fancy.Compiler. Require Import Crypto.PushButtonSynthesis.ReificationCache. Require Import Crypto.PushButtonSynthesis.Primitives. Require Import Crypto.PushButtonSynthesis.MontgomeryReductionReificationCache. @@ -27,6 +33,8 @@ Import Compilers.defaults. Import COperationSpecifications.Primitives. +Import COperationSpecifications.MontgomeryReduction. + Import Associational Positional Arithmetic.MontgomeryReduction. Local Set Keyed Unification. (* needed for making [autorewrite] fast, c.f. COQBUG(https://github.com/coq/coq/issues/9283) *) @@ -34,13 +42,17 @@ Local Set Keyed Unification. (* needed for making [autorewrite] fast, c.f. COQBU Local Opaque reified_montred_gen. (* needed for making [autorewrite] not take a very long time *) Section rmontred. - Context (N R N' : Z) + Context (N R N' : Z) (n nout : nat) (machine_wordsize : Z). Let value_range := r[0 ~> (2^machine_wordsize - 1)%Z]%zrange. Let flag_range := r[0 ~> 1]%zrange. Let bound := Some value_range. Let consts_list := [N; N']. + Let R' := match Z.modinv R N with + | Some R' => R' + | None => 0 + end. Definition possible_values_of_machine_wordsize := [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize]%Z. @@ -48,10 +60,6 @@ Section rmontred. Let possible_values := possible_values_of_machine_wordsize. - Definition check_args {T} (res : Pipeline.ErrorT T) - : Pipeline.ErrorT T - := res. (* TODO: this should actually check stuff that corresponds with preconditions of montred'_correct *) - Let fancy_args := (Some {| Pipeline.invert_low log2wordsize := invert_low log2wordsize consts_list; Pipeline.invert_high log2wordsize := invert_high log2wordsize consts_list; @@ -69,6 +77,79 @@ Section rmontred. cbv [fancy_args invert_low invert_high constant_to_scalar constant_to_scalar_single consts_list fold_right]; split; intros; break_innermost_match_hyps; Z.ltb_to_lt; subst; congruence. Qed. + Local Hint Extern 1 => apply fancy_args_good: typeclass_instances. (* This is a kludge *) + + (** Note: If you change the name or type signature of this + function, you will need to update the code in CLI.v *) + Definition check_args {T} (res : Pipeline.ErrorT T) + : Pipeline.ErrorT T + := fold_right + (fun '(b, e) k => if b:bool then Error e else k) + res + [ + ((negb (1 0 + /\ R > 1 + /\ EquivModulo.Z.equiv_modulo R (N * N') (-1) + /\ EquivModulo.Z.equiv_modulo N (R * R') 1 + /\ n <> 0%nat + /\ Z.of_nat n <= machine_wordsize + /\ 2 ^ machine_wordsize = R + /\ n = 2%nat + /\ nout = 2%nat. + Proof using curve_good. + clear -curve_good. + cbv [check_args fold_right] in curve_good. + break_innermost_match_hyps; try discriminate. + rewrite Bool.negb_false_iff in *. + Z.ltb_to_lt. + rewrite NPeano.Nat.eqb_neq in *. + intros. + repeat apply conj. + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + { use_curve_good_t. } + Qed. Definition montred := Pipeline.BoundsPipeline @@ -76,7 +157,7 @@ Section rmontred. fancy_args (* fancy *) possible_values (reified_montred_gen - @ GallinaReify.Reify N @ GallinaReify.Reify R @ GallinaReify.Reify N' @ GallinaReify.Reify (Z.log2 R) @ GallinaReify.Reify 2%nat @ GallinaReify.Reify 2%nat) + @ GallinaReify.Reify N @ GallinaReify.Reify R @ GallinaReify.Reify N' @ GallinaReify.Reify machine_wordsize @ GallinaReify.Reify 2%nat @ GallinaReify.Reify 2%nat) (bound, (bound, tt)) bound. @@ -94,34 +175,27 @@ Lemma montred_correct res : montred_correct (weight (Qnum limbwidth) (QDen limbwidth)) n m tight_bounds loose_bounds (Interp res). Proof using curve_good. prove_correctness (). Qed. >> *) - - Notation BoundsPipeline_correct in_bounds out_bounds op - := (fun rv (rop : Expr (reify_type_of op)) Hrop - => @Pipeline.BoundsPipeline_correct_trans - false (* subst01 *) - fancy_args - fancy_args_good - possible_values - _ - rop - in_bounds - out_bounds - _ - op - Hrop rv) - (only parsing). - - Definition rmontred_correct - := BoundsPipeline_correct - (bound, (bound, tt)) - bound - (montred' N R N' (Z.log2 R) 2 2). - - Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _). - Definition rmontred_correctT rv : Prop - := type_of_strip_3arrow (@rmontred_correct rv). + Local Ltac solve_montred_preconditions := + repeat first [ lia + | apply use_curve_good + | progress (push_Zmod; pull_Zmod) + | progress autorewrite with zsimplify_fast + | rewrite Z.div_add' by lia + | rewrite Z.div_small by lia + | progress Z.rewrite_mod_small ]. + + Lemma montred_correct res (Hres : montred = Success res) + : montred_correct N R R' (expr.Interp (@ident.gen_interp cast_oor) res). + Proof using n nout curve_good. + cbv [montred_correct]; intros. + rewrite <- MontgomeryReduction.montred'_correct with (R:=R) (N':=N') (Zlog2R:=machine_wordsize) (n:=n) (nout:=nout) (lo:=lo) (hi:=hi) by solve_montred_preconditions. + prove_correctness' ltac:(fun _ => idtac) use_curve_good. + { congruence. } + { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. + rewrite Bool.andb_true_iff, !Z.leb_le. lia. } + { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. + rewrite Bool.andb_true_iff, !Z.leb_le. lia. } + Admitted. End rmontred. -(* TODO: After moving to new-glue-style, remove these tactics *) -Ltac solve_rmontred := solve_rop rmontred_correct. -Ltac solve_rmontred_nocache := solve_rop_nocache rmontred_correct. +(* TODO: get Barrett to this same point, even if Qed slow, and then use these lemmas in the specific files *) \ No newline at end of file diff --git a/src/PushButtonSynthesis/Primitives.v b/src/PushButtonSynthesis/Primitives.v index c72b5a50d..24a5592ed 100644 --- a/src/PushButtonSynthesis/Primitives.v +++ b/src/PushButtonSynthesis/Primitives.v @@ -135,14 +135,14 @@ Local Notation out_bounds_of_pipeline result Notation FromPipelineToString prefix name result := (Pipeline.FromPipelineToString prefix name result). -Ltac prove_correctness use_curve_good := +Ltac prove_correctness' should_not_clear use_curve_good := let Hres := match goal with H : _ = Success _ |- _ => H end in let H := fresh in pose proof use_curve_good as H; (* I want to just use [clear -H Hres], but then I can't use any lemmas in the section because of COQBUG(https://github.com/coq/coq/issues/8153) *) repeat match goal with | [ H' : _ |- _ ] - => tryif first [ has_body H' | constr_eq H' H | constr_eq H' Hres ] + => tryif first [ has_body H' | constr_eq H' H | constr_eq H' Hres | should_not_clear H' ] then fail else clear H' end; @@ -163,6 +163,8 @@ Ltac prove_correctness use_curve_good := | progress autorewrite with distr_length in * ] | .. ]. +Ltac prove_correctness use_curve_good := prove_correctness' ltac:(fun _ => fail) use_curve_good. + Module CorrectnessStringification. Module dyn_context. Inductive list := -- cgit v1.2.3