diff options
Diffstat (limited to 'src/PushButtonSynthesis/BarrettReduction.v')
-rw-r--r-- | src/PushButtonSynthesis/BarrettReduction.v | 97 |
1 files changed, 54 insertions, 43 deletions
diff --git a/src/PushButtonSynthesis/BarrettReduction.v b/src/PushButtonSynthesis/BarrettReduction.v index 265958c09..c0078b117 100644 --- a/src/PushButtonSynthesis/BarrettReduction.v +++ b/src/PushButtonSynthesis/BarrettReduction.v @@ -37,8 +37,7 @@ Local Set Keyed Unification. (* needed for making [autorewrite] fast, c.f. COQBU Local Opaque reified_barrett_red_gen. (* needed for making [autorewrite] not take a very long time *) Section rbarrett_red. - Context (k M : Z) (n nout : nat) - (machine_wordsize : Z). + Context (M machine_wordsize : Z). Let value_range := r[0 ~> (2^machine_wordsize - 1)%Z]%zrange. Let flag_range := r[0 ~> 1]%zrange. @@ -70,6 +69,31 @@ Section rbarrett_red. Qed. Local Hint Extern 1 => apply fancy_args_good: typeclass_instances. (* This is a kludge *) + Lemma mut_correct : + 0 < machine_wordsize -> + Partition.partition (UniformWeight.uweight machine_wordsize) (1 + 1) (muLow + 2 ^ machine_wordsize) = [muLow; 1]. + Proof. + intros; cbn. subst muLow. + assert (0 < 2^machine_wordsize) by ZeroBounds.Z.zero_bounds. + pose proof (Z.mod_pos_bound mu (2^machine_wordsize) ltac:(lia)). + rewrite !UniformWeight.uweight_S, weight_0; auto using UniformWeight.uwprops with lia. + autorewrite with zsimplify. + Modulo.push_Zmod. autorewrite with zsimplify. Modulo.pull_Zmod. + rewrite <-Modulo.Z.mod_pull_div by lia. + autorewrite with zsimplify. RewriteModSmall.Z.rewrite_mod_small. + reflexivity. + Qed. + Lemma Mt_correct : + 0 < machine_wordsize -> + 2^(machine_wordsize - 1) < M < 2^machine_wordsize -> + Partition.partition (UniformWeight.uweight machine_wordsize) 1 M = [M]. + Proof. + intros; cbn. assert (0 < 2^(machine_wordsize-1)) by ZeroBounds.Z.zero_bounds. + rewrite !UniformWeight.uweight_S, weight_0; auto using UniformWeight.uwprops with lia. + autorewrite with zsimplify. RewriteModSmall.Z.rewrite_mod_small. + reflexivity. + Qed. + (** Note: If you change the name or type signature of this function, you will need to update the code in CLI.v *) Definition check_args {T} (res : Pipeline.ErrorT T) @@ -78,19 +102,11 @@ Section rbarrett_red. (fun '(b, e) k => if b:bool then Error e else k) res [ - ((negb (2 <=? k))%Z, Pipeline.Value_not_ltZ "k < 2" 2 k); - ((n =? 0)%nat, Pipeline.Values_not_provably_distinctZ "n = 0" (Z.of_nat n) 0); - ((negb (0 <? M))%Z, Pipeline.Value_not_ltZ "M ≤ 0" 0 M); - (negb (muLow + 2 ^ k =? 2 ^ (2 * k) / M)%Z, Pipeline.Values_not_provably_equalZ "muLow + 2^k ≠ 2 ^ (2 * k) / M" (muLow + 2^k) (2 ^ (2 * k) / M)); - ((negb (0 <=? muLow))%Z, Pipeline.Value_not_leZ "muLow < 0" 0 muLow); - ((negb (muLow <? 2 ^ k))%Z, Pipeline.Value_not_ltZ "2 ^ k ≤ muLow" muLow (2^k)); - ((negb (2 ^ (k-1) <? M))%Z, Pipeline.Value_not_ltZ "M ≤ 2^(k-1)" (2^(k-1)) M); - ((negb (M <? 2 ^ k))%Z, Pipeline.Value_not_ltZ "2 ^ k ≤ M" M (2^k)); - (negb ((2 * (2 ^ (2 * k) mod M) <=? 2 ^ (k + 1) - (muLow + 2 ^ k)))%Z, Pipeline.Value_not_leZ ("(2 * (2 ^ (2 * k) mod M) 2 ^ (k + 1) - (muLow + 2 ^ k)") (2 * (2 ^ (2 * k) mod M)) (2 ^ (k + 1) - (muLow + 2 ^ k))); - (negb (Z.of_nat n <=? k)%Z, Pipeline.Value_not_leZ "k < n" (Z.of_nat n) k); - (negb (k =? machine_wordsize)%Z, Pipeline.Values_not_provably_equalZ "k ≠ machine_wordsize" k machine_wordsize); - (negb (n =? 2)%nat, Pipeline.Values_not_provably_equalZ "n ≠ 2" (Z.of_nat n) 2); - (negb (nout =? 2)%nat, Pipeline.Values_not_provably_equalZ "nout ≠ 2" (Z.of_nat nout) 2)]. + ((negb (1 <? machine_wordsize))%Z, Pipeline.Value_not_ltZ "machine_wordsize ≤ 1" 1 machine_wordsize); + ((negb (2 ^ (machine_wordsize-1) <? M))%Z, Pipeline.Value_not_ltZ "M ≤ 2^(machine_wordsize-1)" (2^(machine_wordsize-1)) M); + ((negb (M <? 2 ^ machine_wordsize))%Z, Pipeline.Value_not_ltZ "2 ^ machine_wordsize ≤ M" M (2^machine_wordsize)); + ((negb (muLow + 2 ^ machine_wordsize =? ((2 ^ 2) ^ machine_wordsize) / M))%Z, Pipeline.Values_not_provably_equalZ "muLow + 2^machine_wordsize ≠ (2 ^ 2) ^ machine_wordsize) / M" (muLow + 2^machine_wordsize) (((2 ^ 2) ^ machine_wordsize) / M)); + (negb ((2 * (((2 ^ 2) ^ machine_wordsize) mod M) <=? 2 ^ (machine_wordsize + 1) - (muLow + 2 ^ machine_wordsize)))%Z, Pipeline.Value_not_leZ ("(2 * ((2 ^ 2) ^ machine_wordsize) mod M) 2 ^ (machine_wordsize + 1) - (muLow + 2 ^ machine_wordsize)") (2 * (((2 ^ 2) ^ machine_wordsize) mod M)) (2 ^ (machine_wordsize + 1) - (muLow + 2 ^ machine_wordsize))) ]. Local Arguments Z.mul !_ !_. Local Ltac use_curve_good_t := @@ -107,24 +123,17 @@ Section rbarrett_red. Context (curve_good : check_args (Success tt) = Success tt). Lemma use_curve_good - : 2 <= k - /\ 0 < M - /\ muLow + 2 ^ k = 2 ^ (2 * k) / M - /\ 0 <= muLow < 2 ^ k - /\ 2 ^ (k - 1) < M < 2 ^ k - /\ 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2 ^ k) - /\ n <> 0%nat - /\ Z.of_nat n <= k - /\ k = machine_wordsize - /\ n = 2%nat - /\ nout = 2%nat. + : 1 < machine_wordsize + /\ 2 ^ (machine_wordsize - 1) <= M < 2 ^ machine_wordsize + /\ muLow + 2 ^ machine_wordsize = (2 ^ 2) ^ machine_wordsize / M + /\ 2 ^ (machine_wordsize - 1) < M < 2 ^ machine_wordsize + /\ 2 * ((2 ^ 2) ^ machine_wordsize mod M) <= 2 ^ (machine_wordsize + 1) - (muLow + 2 ^ machine_wordsize). Proof using curve_good. clear -curve_good. cbv [check_args fold_right] in curve_good. break_innermost_match_hyps; try discriminate. rewrite Bool.negb_false_iff in *. Z.ltb_to_lt. - rewrite NPeano.Nat.eqb_neq in *. intros. repeat apply conj. { use_curve_good_t. } @@ -134,12 +143,6 @@ Section rbarrett_red. { use_curve_good_t. } { use_curve_good_t. } { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } - { use_curve_good_t. } Qed. Definition barrett_red @@ -148,7 +151,12 @@ Section rbarrett_red. fancy_args (* fancy *) possible_values (reified_barrett_red_gen - @ GallinaReify.Reify machine_wordsize @ GallinaReify.Reify M @ GallinaReify.Reify muLow @ GallinaReify.Reify 2%nat @ GallinaReify.Reify 2%nat) + @ GallinaReify.Reify M + @ GallinaReify.Reify machine_wordsize + @ GallinaReify.Reify machine_wordsize + @ GallinaReify.Reify 1%nat + @ GallinaReify.Reify [muLow;1] + @ GallinaReify.Reify [M]) (bound, (bound, tt)) bound. @@ -162,24 +170,27 @@ Section rbarrett_red. Local Ltac solve_barrett_red_preconditions := repeat first [ lia | assumption + | match goal with |- ?x = ?x => reflexivity end | apply use_curve_good | progress autorewrite with zsimplify | progress intros | progress cbv [weight] + | rewrite mut_correct + | rewrite Mt_correct | rewrite Z.pow_mul_r by lia ]. Local Strategy -100 [barrett_red]. (* needed for making Qed not take forever *) Lemma barrett_red_correct res (Hres : barrett_red = Success res) - : barrett_red_correct k M (expr.Interp (@ident.gen_interp cast_oor) res). - Proof using k M curve_good. + : barrett_red_correct machine_wordsize M (expr.Interp (@ident.gen_interp cast_oor) res). + Proof using M curve_good. cbv [barrett_red_correct]; intros. - assert (2 <= k) by apply use_curve_good. - rewrite <-barrett_reduce_correct with (muLow := muLow) (n:=n) (nout:=nout) by solve_barrett_red_preconditions. + assert (1 < machine_wordsize) by apply use_curve_good. + rewrite <-Fancy.fancy_reduce_correct with (mu := muLow + 2^machine_wordsize) (width:=machine_wordsize) (sz:=1%nat) (mut:=[muLow;1]) (Mt:=[M]) by solve_barrett_red_preconditions. prove_correctness' ltac:(fun _ => idtac) use_curve_good. - { congruence. } - { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. - subst k. rewrite Bool.andb_true_iff, !Z.leb_le. lia. } - { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. - subst k. rewrite Bool.andb_true_iff, !Z.leb_le. lia. } + { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. rewrite Bool.andb_true_iff, !Z.leb_le. lia. } + { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. rewrite Bool.andb_true_iff, !Z.leb_le. lia. } + { cbn. econstructor. } + { cbn. econstructor. } + { cbn. econstructor. } Qed. End rbarrett_red. |