aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar jadep <jadep@mit.edu>2019-02-19 10:40:18 -0500
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2019-02-21 11:10:12 -0500
commit582dd629eab4c1a051c6f3426c3953dbe45e8efc (patch)
tree8af130f5e478cbb83cb567c81ea67daac2df8da8 /src
parent9025014cf7b64082d0bfed93dd76ba99ab9bb72b (diff)
adapt barrett to new glue code
Diffstat (limited to 'src')
-rw-r--r--src/PushButtonSynthesis/BarrettReduction.v152
-rw-r--r--src/PushButtonSynthesis/MontgomeryReduction.v11
2 files changed, 101 insertions, 62 deletions
diff --git a/src/PushButtonSynthesis/BarrettReduction.v b/src/PushButtonSynthesis/BarrettReduction.v
index 224584c4a..265958c09 100644
--- a/src/PushButtonSynthesis/BarrettReduction.v
+++ b/src/PushButtonSynthesis/BarrettReduction.v
@@ -3,6 +3,7 @@ Require Import Coq.Strings.String.
Require Import Coq.ZArith.ZArith.
Require Import Coq.Lists.List.
Require Import Coq.derive.Derive.
+Require Import Coq.micromega.Lia.
Require Import Crypto.Util.ErrorT.
Require Import Crypto.Util.ListUtil.
Require Import Crypto.Util.ZRange.
@@ -11,6 +12,7 @@ Require Import Crypto.Language.
Require Import Crypto.CStringification.
Require Import Crypto.Arithmetic.
Require Import Crypto.BoundsPipeline.
+Require Import Crypto.Fancy.Compiler.
Require Import Crypto.COperationSpecifications.
Require Import Crypto.PushButtonSynthesis.ReificationCache.
Require Import Crypto.PushButtonSynthesis.Primitives.
@@ -26,6 +28,7 @@ Import
Import Compilers.defaults.
Import COperationSpecifications.Primitives.
+Import COperationSpecifications.BarrettReduction.
Import Associational Positional Arithmetic.BarrettReduction.
@@ -34,7 +37,7 @@ Local Set Keyed Unification. (* needed for making [autorewrite] fast, c.f. COQBU
Local Opaque reified_barrett_red_gen. (* needed for making [autorewrite] not take a very long time *)
Section rbarrett_red.
- Context (M : Z)
+ Context (k M : Z) (n nout : nat)
(machine_wordsize : Z).
Let value_range := r[0 ~> (2^machine_wordsize - 1)%Z]%zrange.
@@ -48,19 +51,6 @@ Section rbarrett_red.
:= [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize]%Z.
Let possible_values := possible_values_of_machine_wordsize.
- Definition check_args {T} (res : Pipeline.ErrorT T)
- : Pipeline.ErrorT T
- := fold_right
- (fun '(b, e) k => if b:bool then Error e else k)
- res
- [((mu / (2 ^ machine_wordsize) =? 0), Pipeline.Values_not_provably_distinctZ "mu / 2 ^ k ≠ 0" (mu / 2 ^ machine_wordsize) 0);
- ((machine_wordsize <? 2), Pipeline.Value_not_leZ "~ (2 <=k)" 2 machine_wordsize);
- (negb (Z.log2 M + 1 =? machine_wordsize), Pipeline.Values_not_provably_equalZ "log2(M)+1 != k" (Z.log2 M + 1) machine_wordsize);
- ((2 ^ (machine_wordsize + 1) - mu <? 2 * (2 ^ (2 * machine_wordsize) mod M)),
- Pipeline.Value_not_leZ "~ (2 * (2 ^ (2*k) mod M) <= 2^(k + 1) - mu)"
- (2 * (2 ^ (2*machine_wordsize) mod M))
- (2^(machine_wordsize + 1) - mu))].
-
Let fancy_args
:= (Some {| Pipeline.invert_low log2wordsize := invert_low log2wordsize consts_list;
Pipeline.invert_high log2wordsize := invert_high log2wordsize consts_list;
@@ -78,6 +68,79 @@ Section rbarrett_red.
cbv [fancy_args invert_low invert_high constant_to_scalar constant_to_scalar_single consts_list fold_right];
split; intros; break_innermost_match_hyps; Z.ltb_to_lt; subst; congruence.
Qed.
+ Local Hint Extern 1 => apply fancy_args_good: typeclass_instances. (* This is a kludge *)
+
+ (** Note: If you change the name or type signature of this
+ function, you will need to update the code in CLI.v *)
+ Definition check_args {T} (res : Pipeline.ErrorT T)
+ : Pipeline.ErrorT T
+ := fold_right
+ (fun '(b, e) k => if b:bool then Error e else k)
+ res
+ [
+ ((negb (2 <=? k))%Z, Pipeline.Value_not_ltZ "k < 2" 2 k);
+ ((n =? 0)%nat, Pipeline.Values_not_provably_distinctZ "n = 0" (Z.of_nat n) 0);
+ ((negb (0 <? M))%Z, Pipeline.Value_not_ltZ "M ≤ 0" 0 M);
+ (negb (muLow + 2 ^ k =? 2 ^ (2 * k) / M)%Z, Pipeline.Values_not_provably_equalZ "muLow + 2^k ≠ 2 ^ (2 * k) / M" (muLow + 2^k) (2 ^ (2 * k) / M));
+ ((negb (0 <=? muLow))%Z, Pipeline.Value_not_leZ "muLow < 0" 0 muLow);
+ ((negb (muLow <? 2 ^ k))%Z, Pipeline.Value_not_ltZ "2 ^ k ≤ muLow" muLow (2^k));
+ ((negb (2 ^ (k-1) <? M))%Z, Pipeline.Value_not_ltZ "M ≤ 2^(k-1)" (2^(k-1)) M);
+ ((negb (M <? 2 ^ k))%Z, Pipeline.Value_not_ltZ "2 ^ k ≤ M" M (2^k));
+ (negb ((2 * (2 ^ (2 * k) mod M) <=? 2 ^ (k + 1) - (muLow + 2 ^ k)))%Z, Pipeline.Value_not_leZ ("(2 * (2 ^ (2 * k) mod M) 2 ^ (k + 1) - (muLow + 2 ^ k)") (2 * (2 ^ (2 * k) mod M)) (2 ^ (k + 1) - (muLow + 2 ^ k)));
+ (negb (Z.of_nat n <=? k)%Z, Pipeline.Value_not_leZ "k < n" (Z.of_nat n) k);
+ (negb (k =? machine_wordsize)%Z, Pipeline.Values_not_provably_equalZ "k ≠ machine_wordsize" k machine_wordsize);
+ (negb (n =? 2)%nat, Pipeline.Values_not_provably_equalZ "n ≠ 2" (Z.of_nat n) 2);
+ (negb (nout =? 2)%nat, Pipeline.Values_not_provably_equalZ "nout ≠ 2" (Z.of_nat nout) 2)].
+
+ Local Arguments Z.mul !_ !_.
+ Local Ltac use_curve_good_t :=
+ repeat first [ assumption
+ | progress cbv [EquivModulo.Z.equiv_modulo]
+ | progress rewrite ?Z.mul_0_r, ?Pos.mul_1_r, ?Z.mul_1_r in *
+ | reflexivity
+ | lia
+ | progress cbn in *
+ | progress intros
+ | solve [ auto with zarith ]
+ | rewrite Z.log2_pow2 by use_curve_good_t ].
+
+ Context (curve_good : check_args (Success tt) = Success tt).
+
+ Lemma use_curve_good
+ : 2 <= k
+ /\ 0 < M
+ /\ muLow + 2 ^ k = 2 ^ (2 * k) / M
+ /\ 0 <= muLow < 2 ^ k
+ /\ 2 ^ (k - 1) < M < 2 ^ k
+ /\ 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2 ^ k)
+ /\ n <> 0%nat
+ /\ Z.of_nat n <= k
+ /\ k = machine_wordsize
+ /\ n = 2%nat
+ /\ nout = 2%nat.
+ Proof using curve_good.
+ clear -curve_good.
+ cbv [check_args fold_right] in curve_good.
+ break_innermost_match_hyps; try discriminate.
+ rewrite Bool.negb_false_iff in *.
+ Z.ltb_to_lt.
+ rewrite NPeano.Nat.eqb_neq in *.
+ intros.
+ repeat apply conj.
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ { use_curve_good_t. }
+ Qed.
Definition barrett_red
:= Pipeline.BoundsPipeline
@@ -96,42 +159,27 @@ Section rbarrett_red.
prefix "barrett_red" barrett_red
(fun _ _ _ => @nil string).
- Local Strategy -100 [barrett_red]. (* Probably needed to make Qed not take forever *)
- (* TODO: Replace the following lemmas with a new-glue-style correctness lemma, like
-<<
-Lemma barrett_red_correct res
- (Hres : barrett_red = Success res)
- : barrett_red_correct (weight (Qnum limbwidth) (QDen limbwidth)) n m tight_bounds loose_bounds (Interp res).
- Proof using curve_good. prove_correctness (). Qed.
->> *)
-
- Notation BoundsPipeline_correct in_bounds out_bounds op
- := (fun rv (rop : Expr (reify_type_of op)) Hrop
- => @Pipeline.BoundsPipeline_correct_trans
- false (* subst01 *)
- fancy_args
- fancy_args_good
- possible_values
- _
- rop
- in_bounds
- out_bounds
- _
- op
- Hrop rv)
- (only parsing).
-
- Definition rbarrett_red_correct
- := BoundsPipeline_correct
- (bound, (bound, tt))
- bound
- (barrett_reduce machine_wordsize M muLow 2 2).
-
- Notation type_of_strip_3arrow := ((fun (d : Prop) (_ : forall A B C, d) => d) _).
- Definition rbarrett_red_correctT rv : Prop
- := type_of_strip_3arrow (@rbarrett_red_correct rv).
+ Local Ltac solve_barrett_red_preconditions :=
+ repeat first [ lia
+ | assumption
+ | apply use_curve_good
+ | progress autorewrite with zsimplify
+ | progress intros
+ | progress cbv [weight]
+ | rewrite Z.pow_mul_r by lia ].
+
+ Local Strategy -100 [barrett_red]. (* needed for making Qed not take forever *)
+ Lemma barrett_red_correct res (Hres : barrett_red = Success res)
+ : barrett_red_correct k M (expr.Interp (@ident.gen_interp cast_oor) res).
+ Proof using k M curve_good.
+ cbv [barrett_red_correct]; intros.
+ assert (2 <= k) by apply use_curve_good.
+ rewrite <-barrett_reduce_correct with (muLow := muLow) (n:=n) (nout:=nout) by solve_barrett_red_preconditions.
+ prove_correctness' ltac:(fun _ => idtac) use_curve_good.
+ { congruence. }
+ { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower].
+ subst k. rewrite Bool.andb_true_iff, !Z.leb_le. lia. }
+ { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower].
+ subst k. rewrite Bool.andb_true_iff, !Z.leb_le. lia. }
+ Qed.
End rbarrett_red.
-
-(* TODO: After moving to new-glue-style, remove these tactics *)
-Ltac solve_rbarrett_red := solve_rop rbarrett_red_correct.
-Ltac solve_rbarrett_red_nocache := solve_rop_nocache rbarrett_red_correct.
diff --git a/src/PushButtonSynthesis/MontgomeryReduction.v b/src/PushButtonSynthesis/MontgomeryReduction.v
index d4399a743..a682aa227 100644
--- a/src/PushButtonSynthesis/MontgomeryReduction.v
+++ b/src/PushButtonSynthesis/MontgomeryReduction.v
@@ -168,13 +168,6 @@ Section rmontred.
prefix "montred" montred
(fun _ _ _ => @nil string).
- (* TODO: Replace the following lemmas with a new-glue-style correctness lemma, like
-<<
-Lemma montred_correct res
- (Hres : montred = Success res)
- : montred_correct (weight (Qnum limbwidth) (QDen limbwidth)) n m tight_bounds loose_bounds (Interp res).
- Proof using curve_good. prove_correctness (). Qed.
->> *)
Local Ltac solve_montred_preconditions :=
repeat first [ lia
| apply use_curve_good
@@ -197,6 +190,4 @@ Lemma montred_correct res
{ cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower].
rewrite Bool.andb_true_iff, !Z.leb_le. lia. }
Qed.
-End rmontred.
-
-(* TODO: get Barrett to this same point, and then use these lemmas in the specific files *) \ No newline at end of file
+End rmontred. \ No newline at end of file