aboutsummaryrefslogtreecommitdiff
path: root/src/PushButtonSynthesis
diff options
context:
space:
mode:
authorGravatar jadep <jadep@mit.edu>2019-02-15 14:41:47 -0500
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2019-02-21 11:10:12 -0500
commita1f9d9ee2c662790b43bf56df58121a390efbe7c (patch)
tree7398e0136b4d0c5b314d9d9943ed9d5d404b8e85 /src/PushButtonSynthesis
parent47c0533d8640af625d6f403a0784edaa6cc26fac (diff)
start adapting Montgomery to new glue code
Diffstat (limited to 'src/PushButtonSynthesis')
-rw-r--r--src/PushButtonSynthesis/MontgomeryReduction.v144
-rw-r--r--src/PushButtonSynthesis/Primitives.v6
2 files changed, 113 insertions, 37 deletions
diff --git a/src/PushButtonSynthesis/MontgomeryReduction.v b/src/PushButtonSynthesis/MontgomeryReduction.v
index a452a047c..1387cad36 100644
--- a/src/PushButtonSynthesis/MontgomeryReduction.v
+++ b/src/PushButtonSynthesis/MontgomeryReduction.v
@@ -1,17 +1,23 @@
(** * Push-Button Synthesis of Montgomery Reduction *)
Require Import Coq.Strings.String.
Require Import Coq.ZArith.ZArith.
+Require Import Coq.micromega.Lia.
Require Import Coq.Lists.List.
Require Import Coq.derive.Derive.
Require Import Crypto.Util.ErrorT.
Require Import Crypto.Util.ListUtil.
Require Import Crypto.Util.ZRange.
+Require Import Crypto.Util.ZUtil.Div.
+Require Import Crypto.Util.ZUtil.ModInv.
Require Import Crypto.Util.ZUtil.Tactics.LtbToLt.
+Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo.
+Require Import Crypto.Util.ZUtil.Tactics.RewriteModSmall.
Require Import Crypto.Language.
Require Import Crypto.CStringification.
Require Import Crypto.Arithmetic.
Require Import Crypto.BoundsPipeline.
Require Import Crypto.COperationSpecifications.
+Require Import Crypto.Fancy.Compiler.
Require Import Crypto.PushButtonSynthesis.ReificationCache.
Require Import Crypto.PushButtonSynthesis.Primitives.
Require Import Crypto.PushButtonSynthesis.MontgomeryReductionReificationCache.
@@ -27,6 +33,8 @@ Import Compilers.defaults.
Import COperationSpecifications.Primitives.
+Import COperationSpecifications.MontgomeryReduction.
+
Import Associational Positional Arithmetic.MontgomeryReduction.
Local Set Keyed Unification. (* needed for making [autorewrite] fast, c.f. COQBUG(https://github.com/coq/coq/issues/9283) *)
@@ -34,13 +42,17 @@ Local Set Keyed Unification. (* needed for making [autorewrite] fast, c.f. COQBU
Local Opaque reified_montred_gen. (* needed for making [autorewrite] not take a very long time *)
Section rmontred.
- Context (N R N' : Z)
+ Context (N R N' : Z) (n nout : nat)
(machine_wordsize : Z).
Let value_range := r[0 ~> (2^machine_wordsize - 1)%Z]%zrange.
Let flag_range := r[0 ~> 1]%zrange.
Let bound := Some value_range.
Let consts_list := [N; N'].
+ Let R' := match Z.modinv R N with
+ | Some R' => R'
+ | None => 0
+ end.
Definition possible_values_of_machine_wordsize
:= [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize]%Z.
@@ -48,10 +60,6 @@ Section rmontred.
Let possible_values := possible_values_of_machine_wordsize.
- Definition check_args {T} (res : Pipeline.ErrorT T)
- : Pipeline.ErrorT T
- := res. (* TODO: this should actually check stuff that corresponds with preconditions of montred'_correct *)
-
Let fancy_args
:= (Some {| Pipeline.invert_low log2wordsize := invert_low log2wordsize consts_list;
Pipeline.invert_high log2wordsize := invert_high log2wordsize consts_list;
@@ -69,6 +77,79 @@ Section rmontred.
cbv [fancy_args invert_low invert_high constant_to_scalar constant_to_scalar_single consts_list fold_right];
split; intros; break_innermost_match_hyps; Z.ltb_to_lt; subst; congruence.
Qed.
+ Local Hint Extern 1 => apply fancy_args_good: typeclass_instances. (* This is a kludge *)
+
+ (** Note: If you change the name or type signature of this
+ function, you will need to update the code in CLI.v *)
+ Definition check_args {T} (res : Pipeline.ErrorT T)
+ : Pipeline.ErrorT T
+ := fold_right
+ (fun '(b, e) k => if b:bool then Error e else k)
+ res
+ [
+ ((negb (1 <? R))%Z, Pipeline.Value_not_ltZ "R ≤ 1" 1 R);
+ ((n =? 0)%nat, Pipeline.Values_not_provably_distinctZ "n = 0" (Z.of_nat n) 0);
+ ((R' =? 0)%Z, Pipeline.No_modular_inverse "R⁻¹ mod N" R N);
+ (negb ((R * R') mod N =? 1 mod N)%Z, Pipeline.Values_not_provably_equalZ "(R * R') mod N ≠ 1 mod N" ((R * R') mod N) (1 mod N));
+ (negb ((N * N') mod R =? (-1) mod R)%Z, Pipeline.Values_not_provably_equalZ "(N * N') mod R ≠ (-1) mod R" ((N * N') mod R) ((-1) mod R));
+ (negb (nout =? 2)%nat, Pipeline.Values_not_provably_equalZ "nout ≠ 2" (Z.of_nat nout) 2);
+ (negb (n =? 2)%nat, Pipeline.Values_not_provably_equalZ "n ≠ 2" (Z.of_nat n) 2);
+ (negb (2 ^ machine_wordsize =? R)%Z, Pipeline.Values_not_provably_equalZ "2^machine_wordsize ≠ R" (2^machine_wordsize) R);
+ ((negb (0 <? N))%Z, Pipeline.Value_not_ltZ "N ≤ 0" 0 N);
+ ((negb (N <? R))%Z, Pipeline.Value_not_ltZ "R ≤ N" R N);
+ ((negb (0 <=? N'))%Z, Pipeline.Value_not_leZ "N' < 0" 0 N');
+ ((negb (N' <? R))%Z, Pipeline.Value_not_ltZ "R ≤ N'" R N');
+ ((negb (Z.of_nat n <=? machine_wordsize))%Z, Pipeline.Value_not_leZ "machine_wordsize < n" (Z.of_nat n) machine_wordsize)].
+
+ Local Arguments Z.mul !_ !_.
+ Local Ltac use_curve_good_t :=
+ repeat first [ assumption
+ | progress cbv [EquivModulo.Z.equiv_modulo]
+ | progress rewrite ?Z.mul_0_r, ?Pos.mul_1_r, ?Z.mul_1_r in *
+ | reflexivity
+ | lia
+ | progress cbn in *
+ | progress intros
+ | solve [ auto with zarith ]
+ | rewrite Z.log2_pow2 by use_curve_good_t ].
+
+ Context (curve_good : check_args (Success tt) = Success tt).
+
+ Lemma use_curve_good
+ : 0 <= N < R
+ /\ 0 <= N' < R
+ /\ N <> 0
+ /\ R > 1
+ /\ EquivModulo.Z.equiv_modulo R (N * N') (-1)
+ /\ EquivModulo.Z.equiv_modulo N (R * R') 1
+ /\ n <> 0%nat
+ /\ Z.of_nat n <= machine_wordsize
+ /\ 2 ^ machine_wordsize = R
+ /\ n = 2%nat
+ /\ nout = 2%nat.
+ Proof using curve_good.
+ clear -curve_good.
+ cbv [check_args fold_right] in curve_good.
+ break_innermost_match_hyps; try discriminate.
+ rewrite Bool.negb_false_iff in *.
+ Z.ltb_to_lt.
+ rewrite NPeano.Nat.eqb_neq in *.
+ intros.
+ repeat apply conj.
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ Qed.
Definition montred
:= Pipeline.BoundsPipeline
@@ -76,7 +157,7 @@ Section rmontred.
fancy_args (* fancy *)
possible_values
(reified_montred_gen
- @ GallinaReify.Reify N @ GallinaReify.Reify R @ GallinaReify.Reify N' @ GallinaReify.Reify (Z.log2 R) @ GallinaReify.Reify 2%nat @ GallinaReify.Reify 2%nat)
+ @ GallinaReify.Reify N @ GallinaReify.Reify R @ GallinaReify.Reify N' @ GallinaReify.Reify machine_wordsize @ GallinaReify.Reify 2%nat @ GallinaReify.Reify 2%nat)
(bound, (bound, tt))
bound.
@@ -94,34 +175,27 @@ Lemma montred_correct res
: montred_correct (weight (Qnum limbwidth) (QDen limbwidth)) n m tight_bounds loose_bounds (Interp res).
Proof using curve_good. prove_correctness (). Qed.
>> *)
-
- Notation BoundsPipeline_correct in_bounds out_bounds op
- := (fun rv (rop : Expr (reify_type_of op)) Hrop
- => @Pipeline.BoundsPipeline_correct_trans
- false (* subst01 *)
- fancy_args
- fancy_args_good
- possible_values
- _
- rop
- in_bounds
- out_bounds
- _
- op
- Hrop rv)
- (only parsing).
-
- Definition rmontred_correct
- := BoundsPipeline_correct
- (bound, (bound, tt))
- bound
- (montred' N R N' (Z.log2 R) 2 2).
-
- Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _).
- Definition rmontred_correctT rv : Prop
- := type_of_strip_3arrow (@rmontred_correct rv).
+ Local Ltac solve_montred_preconditions :=
+ repeat first [ lia
+ | apply use_curve_good
+ | progress (push_Zmod; pull_Zmod)
+ | progress autorewrite with zsimplify_fast
+ | rewrite Z.div_add' by lia
+ | rewrite Z.div_small by lia
+ | progress Z.rewrite_mod_small ].
+
+ Lemma montred_correct res (Hres : montred = Success res)
+ : montred_correct N R R' (expr.Interp (@ident.gen_interp cast_oor) res).
+ Proof using n nout curve_good.
+ cbv [montred_correct]; intros.
+ rewrite <- MontgomeryReduction.montred'_correct with (R:=R) (N':=N') (Zlog2R:=machine_wordsize) (n:=n) (nout:=nout) (lo:=lo) (hi:=hi) by solve_montred_preconditions.
+ prove_correctness' ltac:(fun _ => idtac) use_curve_good.
+ { congruence. }
+ { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower].
+ rewrite Bool.andb_true_iff, !Z.leb_le. lia. }
+ { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower].
+ rewrite Bool.andb_true_iff, !Z.leb_le. lia. }
+ Admitted.
End rmontred.
-(* TODO: After moving to new-glue-style, remove these tactics *)
-Ltac solve_rmontred := solve_rop rmontred_correct.
-Ltac solve_rmontred_nocache := solve_rop_nocache rmontred_correct.
+(* TODO: get Barrett to this same point, even if Qed slow, and then use these lemmas in the specific files *) \ No newline at end of file
diff --git a/src/PushButtonSynthesis/Primitives.v b/src/PushButtonSynthesis/Primitives.v
index c72b5a50d..24a5592ed 100644
--- a/src/PushButtonSynthesis/Primitives.v
+++ b/src/PushButtonSynthesis/Primitives.v
@@ -135,14 +135,14 @@ Local Notation out_bounds_of_pipeline result
Notation FromPipelineToString prefix name result
:= (Pipeline.FromPipelineToString prefix name result).
-Ltac prove_correctness use_curve_good :=
+Ltac prove_correctness' should_not_clear use_curve_good :=
let Hres := match goal with H : _ = Success _ |- _ => H end in
let H := fresh in
pose proof use_curve_good as H;
(* I want to just use [clear -H Hres], but then I can't use any lemmas in the section because of COQBUG(https://github.com/coq/coq/issues/8153) *)
repeat match goal with
| [ H' : _ |- _ ]
- => tryif first [ has_body H' | constr_eq H' H | constr_eq H' Hres ]
+ => tryif first [ has_body H' | constr_eq H' H | constr_eq H' Hres | should_not_clear H' ]
then fail
else clear H'
end;
@@ -163,6 +163,8 @@ Ltac prove_correctness use_curve_good :=
| progress autorewrite with distr_length in * ]
| .. ].
+Ltac prove_correctness use_curve_good := prove_correctness' ltac:(fun _ => fail) use_curve_good.
+
Module CorrectnessStringification.
Module dyn_context.
Inductive list :=