aboutsummaryrefslogtreecommitdiff
path: root/src/PushButtonSynthesis/BarrettReduction.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/PushButtonSynthesis/BarrettReduction.v')
-rw-r--r--src/PushButtonSynthesis/BarrettReduction.v97
1 files changed, 54 insertions, 43 deletions
diff --git a/src/PushButtonSynthesis/BarrettReduction.v b/src/PushButtonSynthesis/BarrettReduction.v
index 265958c09..c0078b117 100644
--- a/src/PushButtonSynthesis/BarrettReduction.v
+++ b/src/PushButtonSynthesis/BarrettReduction.v
@@ -37,8 +37,7 @@ Local Set Keyed Unification. (* needed for making [autorewrite] fast, c.f. COQBU
Local Opaque reified_barrett_red_gen. (* needed for making [autorewrite] not take a very long time *)
Section rbarrett_red.
- Context (k M : Z) (n nout : nat)
- (machine_wordsize : Z).
+ Context (M machine_wordsize : Z).
Let value_range := r[0 ~> (2^machine_wordsize - 1)%Z]%zrange.
Let flag_range := r[0 ~> 1]%zrange.
@@ -70,6 +69,31 @@ Section rbarrett_red.
Qed.
Local Hint Extern 1 => apply fancy_args_good: typeclass_instances. (* This is a kludge *)
+ Lemma mut_correct :
+ 0 < machine_wordsize ->
+ Partition.partition (UniformWeight.uweight machine_wordsize) (1 + 1) (muLow + 2 ^ machine_wordsize) = [muLow; 1].
+ Proof.
+ intros; cbn. subst muLow.
+ assert (0 < 2^machine_wordsize) by ZeroBounds.Z.zero_bounds.
+ pose proof (Z.mod_pos_bound mu (2^machine_wordsize) ltac:(lia)).
+ rewrite !UniformWeight.uweight_S, weight_0; auto using UniformWeight.uwprops with lia.
+ autorewrite with zsimplify.
+ Modulo.push_Zmod. autorewrite with zsimplify. Modulo.pull_Zmod.
+ rewrite <-Modulo.Z.mod_pull_div by lia.
+ autorewrite with zsimplify. RewriteModSmall.Z.rewrite_mod_small.
+ reflexivity.
+ Qed.
+ Lemma Mt_correct :
+ 0 < machine_wordsize ->
+ 2^(machine_wordsize - 1) < M < 2^machine_wordsize ->
+ Partition.partition (UniformWeight.uweight machine_wordsize) 1 M = [M].
+ Proof.
+ intros; cbn. assert (0 < 2^(machine_wordsize-1)) by ZeroBounds.Z.zero_bounds.
+ rewrite !UniformWeight.uweight_S, weight_0; auto using UniformWeight.uwprops with lia.
+ autorewrite with zsimplify. RewriteModSmall.Z.rewrite_mod_small.
+ reflexivity.
+ Qed.
+
(** Note: If you change the name or type signature of this
function, you will need to update the code in CLI.v *)
Definition check_args {T} (res : Pipeline.ErrorT T)
@@ -78,19 +102,11 @@ Section rbarrett_red.
(fun '(b, e) k => if b:bool then Error e else k)
res
[
- ((negb (2 <=? k))%Z, Pipeline.Value_not_ltZ "k < 2" 2 k);
- ((n =? 0)%nat, Pipeline.Values_not_provably_distinctZ "n = 0" (Z.of_nat n) 0);
- ((negb (0 <? M))%Z, Pipeline.Value_not_ltZ "M ≤ 0" 0 M);
- (negb (muLow + 2 ^ k =? 2 ^ (2 * k) / M)%Z, Pipeline.Values_not_provably_equalZ "muLow + 2^k ≠ 2 ^ (2 * k) / M" (muLow + 2^k) (2 ^ (2 * k) / M));
- ((negb (0 <=? muLow))%Z, Pipeline.Value_not_leZ "muLow < 0" 0 muLow);
- ((negb (muLow <? 2 ^ k))%Z, Pipeline.Value_not_ltZ "2 ^ k ≤ muLow" muLow (2^k));
- ((negb (2 ^ (k-1) <? M))%Z, Pipeline.Value_not_ltZ "M ≤ 2^(k-1)" (2^(k-1)) M);
- ((negb (M <? 2 ^ k))%Z, Pipeline.Value_not_ltZ "2 ^ k ≤ M" M (2^k));
- (negb ((2 * (2 ^ (2 * k) mod M) <=? 2 ^ (k + 1) - (muLow + 2 ^ k)))%Z, Pipeline.Value_not_leZ ("(2 * (2 ^ (2 * k) mod M) 2 ^ (k + 1) - (muLow + 2 ^ k)") (2 * (2 ^ (2 * k) mod M)) (2 ^ (k + 1) - (muLow + 2 ^ k)));
- (negb (Z.of_nat n <=? k)%Z, Pipeline.Value_not_leZ "k < n" (Z.of_nat n) k);
- (negb (k =? machine_wordsize)%Z, Pipeline.Values_not_provably_equalZ "k ≠ machine_wordsize" k machine_wordsize);
- (negb (n =? 2)%nat, Pipeline.Values_not_provably_equalZ "n ≠ 2" (Z.of_nat n) 2);
- (negb (nout =? 2)%nat, Pipeline.Values_not_provably_equalZ "nout ≠ 2" (Z.of_nat nout) 2)].
+ ((negb (1 <? machine_wordsize))%Z, Pipeline.Value_not_ltZ "machine_wordsize ≤ 1" 1 machine_wordsize);
+ ((negb (2 ^ (machine_wordsize-1) <? M))%Z, Pipeline.Value_not_ltZ "M ≤ 2^(machine_wordsize-1)" (2^(machine_wordsize-1)) M);
+ ((negb (M <? 2 ^ machine_wordsize))%Z, Pipeline.Value_not_ltZ "2 ^ machine_wordsize ≤ M" M (2^machine_wordsize));
+ ((negb (muLow + 2 ^ machine_wordsize =? ((2 ^ 2) ^ machine_wordsize) / M))%Z, Pipeline.Values_not_provably_equalZ "muLow + 2^machine_wordsize ≠ (2 ^ 2) ^ machine_wordsize) / M" (muLow + 2^machine_wordsize) (((2 ^ 2) ^ machine_wordsize) / M));
+ (negb ((2 * (((2 ^ 2) ^ machine_wordsize) mod M) <=? 2 ^ (machine_wordsize + 1) - (muLow + 2 ^ machine_wordsize)))%Z, Pipeline.Value_not_leZ ("(2 * ((2 ^ 2) ^ machine_wordsize) mod M) 2 ^ (machine_wordsize + 1) - (muLow + 2 ^ machine_wordsize)") (2 * (((2 ^ 2) ^ machine_wordsize) mod M)) (2 ^ (machine_wordsize + 1) - (muLow + 2 ^ machine_wordsize))) ].
Local Arguments Z.mul !_ !_.
Local Ltac use_curve_good_t :=
@@ -107,24 +123,17 @@ Section rbarrett_red.
Context (curve_good : check_args (Success tt) = Success tt).
Lemma use_curve_good
- : 2 <= k
- /\ 0 < M
- /\ muLow + 2 ^ k = 2 ^ (2 * k) / M
- /\ 0 <= muLow < 2 ^ k
- /\ 2 ^ (k - 1) < M < 2 ^ k
- /\ 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2 ^ k)
- /\ n <> 0%nat
- /\ Z.of_nat n <= k
- /\ k = machine_wordsize
- /\ n = 2%nat
- /\ nout = 2%nat.
+ : 1 < machine_wordsize
+ /\ 2 ^ (machine_wordsize - 1) <= M < 2 ^ machine_wordsize
+ /\ muLow + 2 ^ machine_wordsize = (2 ^ 2) ^ machine_wordsize / M
+ /\ 2 ^ (machine_wordsize - 1) < M < 2 ^ machine_wordsize
+ /\ 2 * ((2 ^ 2) ^ machine_wordsize mod M) <= 2 ^ (machine_wordsize + 1) - (muLow + 2 ^ machine_wordsize).
Proof using curve_good.
clear -curve_good.
cbv [check_args fold_right] in curve_good.
break_innermost_match_hyps; try discriminate.
rewrite Bool.negb_false_iff in *.
Z.ltb_to_lt.
- rewrite NPeano.Nat.eqb_neq in *.
intros.
repeat apply conj.
{ use_curve_good_t. }
@@ -134,12 +143,6 @@ Section rbarrett_red.
{ use_curve_good_t. }
{ use_curve_good_t. }
{ use_curve_good_t. }
- { use_curve_good_t. }
- { use_curve_good_t. }
- { use_curve_good_t. }
- { use_curve_good_t. }
- { use_curve_good_t. }
- { use_curve_good_t. }
Qed.
Definition barrett_red
@@ -148,7 +151,12 @@ Section rbarrett_red.
fancy_args (* fancy *)
possible_values
(reified_barrett_red_gen
- @ GallinaReify.Reify machine_wordsize @ GallinaReify.Reify M @ GallinaReify.Reify muLow @ GallinaReify.Reify 2%nat @ GallinaReify.Reify 2%nat)
+ @ GallinaReify.Reify M
+ @ GallinaReify.Reify machine_wordsize
+ @ GallinaReify.Reify machine_wordsize
+ @ GallinaReify.Reify 1%nat
+ @ GallinaReify.Reify [muLow;1]
+ @ GallinaReify.Reify [M])
(bound, (bound, tt))
bound.
@@ -162,24 +170,27 @@ Section rbarrett_red.
Local Ltac solve_barrett_red_preconditions :=
repeat first [ lia
| assumption
+ | match goal with |- ?x = ?x => reflexivity end
| apply use_curve_good
| progress autorewrite with zsimplify
| progress intros
| progress cbv [weight]
+ | rewrite mut_correct
+ | rewrite Mt_correct
| rewrite Z.pow_mul_r by lia ].
Local Strategy -100 [barrett_red]. (* needed for making Qed not take forever *)
Lemma barrett_red_correct res (Hres : barrett_red = Success res)
- : barrett_red_correct k M (expr.Interp (@ident.gen_interp cast_oor) res).
- Proof using k M curve_good.
+ : barrett_red_correct machine_wordsize M (expr.Interp (@ident.gen_interp cast_oor) res).
+ Proof using M curve_good.
cbv [barrett_red_correct]; intros.
- assert (2 <= k) by apply use_curve_good.
- rewrite <-barrett_reduce_correct with (muLow := muLow) (n:=n) (nout:=nout) by solve_barrett_red_preconditions.
+ assert (1 < machine_wordsize) by apply use_curve_good.
+ rewrite <-Fancy.fancy_reduce_correct with (mu := muLow + 2^machine_wordsize) (width:=machine_wordsize) (sz:=1%nat) (mut:=[muLow;1]) (Mt:=[M]) by solve_barrett_red_preconditions.
prove_correctness' ltac:(fun _ => idtac) use_curve_good.
- { congruence. }
- { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower].
- subst k. rewrite Bool.andb_true_iff, !Z.leb_le. lia. }
- { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower].
- subst k. rewrite Bool.andb_true_iff, !Z.leb_le. lia. }
+ { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. rewrite Bool.andb_true_iff, !Z.leb_le. lia. }
+ { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. rewrite Bool.andb_true_iff, !Z.leb_le. lia. }
+ { cbn. econstructor. }
+ { cbn. econstructor. }
+ { cbn. econstructor. }
Qed.
End rbarrett_red.