aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar jadep <jadep@mit.edu>2019-03-12 12:46:00 -0400
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2019-03-25 06:13:45 -0400
commit684d356bcb81ca36314cd7864c62a1d97af4ea99 (patch)
tree26df2a46df43c15d512ced251b647b2d01115916 /src
parent3e4edb9a9b8cc15bdc02b9005e0b94561645b77b (diff)
finish proofs
Diffstat (limited to 'src')
-rw-r--r--src/Arithmetic.v456
-rw-r--r--src/PushButtonSynthesis/BarrettReduction.v2
2 files changed, 272 insertions, 186 deletions
diff --git a/src/Arithmetic.v b/src/Arithmetic.v
index 65eff76a4..75e93deaf 100644
--- a/src/Arithmetic.v
+++ b/src/Arithmetic.v
@@ -1889,6 +1889,14 @@ Module Partition.
reflexivity.
Qed.
+ Lemma partition_0 n : partition n 0 = Positional.zeros n.
+ Proof.
+ cbv [partition].
+ erewrite Positional.zeros_ext_map with (p:=seq 0 n) by distr_length.
+ apply map_ext; intros.
+ autorewrite with zsimplify; reflexivity.
+ Qed.
+
End Partition.
Hint Rewrite length_partition length_recursive_partition : distr_length.
Hint Rewrite eval_partition using (solve [auto; distr_length]) : push_eval.
@@ -3035,9 +3043,6 @@ Module BaseConversion.
Hint Resolve Z.gt_lt Z.positive_is_nonzero Nat2Z.is_nonneg.
-
- (* TODO: Unspecialize from 2-limb *)
- (* multi-limb version -- multiply two numbers that each have m limbs with (n*k) bits each, by converting them to numbers with (m*n) limbs of k bits each, multiplying, then converting back *)
Definition widemul a b := mul_converted sw dw m m mn mn nout (aligned_carries n nout) a b.
Lemma widemul_correct a b :
@@ -3556,6 +3561,19 @@ Module UniformWeight.
auto using uweight_recursive_partition_change_start with omega.
Qed.
+ Lemma uweight_firstn_partition lgr (Hr : 0 < lgr) n x m (Hm : (m <= n)%nat) :
+ firstn m (Partition.partition (uweight lgr) n x) = Partition.partition (uweight lgr) m x.
+ Proof.
+ cbv [Partition.partition];
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress autorewrite with push_firstn natsimplify zsimplify_fast
+ | _ => rewrite Nat.min_l by lia
+ | _ => rewrite weight_0 by auto using uwprops
+ | _ => reflexivity
+ end.
+ Qed.
+
Lemma uweight_skipn_partition lgr (Hr : 0 < lgr) n x m :
skipn m (Partition.partition (uweight lgr) n x) = Partition.partition (uweight lgr) (n - m) (x / uweight lgr m).
Proof.
@@ -3623,6 +3641,21 @@ Module UniformWeight.
m = (n + length y)%nat ->
Positional.eval (uweight lgr) m (x ++ y) = Positional.eval (uweight lgr) n x + (uweight lgr n) * Positional.eval (uweight lgr) (length y) y.
Proof using Type. intros. subst m. apply uweight_eval_app'; lia. Qed.
+
+ Lemma uweight_partition_app lgr (Hr : 0 < lgr) n m a b :
+ Partition.partition (uweight lgr) n a ++ Partition.partition (uweight lgr) m b
+ = Partition.partition (uweight lgr) (n+m) (a mod uweight lgr n + b * uweight lgr n).
+ Proof.
+ assert (0 < uweight lgr n) by auto using uwprops.
+ match goal with |- _ = ?rhs => rewrite <-(firstn_skipn n rhs) end.
+ rewrite uweight_firstn_partition, uweight_skipn_partition by lia.
+ rewrite Z.div_add by lia.
+ rewrite (Z.div_small (_ mod _)) by auto with zarith.
+ f_equal.
+ { apply Partition.partition_eq_mod; [ auto using uwprops | ].
+ push_Zmod. autorewrite with zsimplify. reflexivity. }
+ { f_equal; lia. }
+ Qed.
End UniformWeight.
Module WordByWordMontgomery.
@@ -4885,8 +4918,10 @@ Module BarrettReduction.
Context (b k M mu width : Z) (n : nat)
(b_ok : 1 < b)
(k_pos : 0 < k)
+ (bk_eq : b^k = 2^(width * Z.of_nat n))
(M_range : b ^ (k - 1) < M < b ^ k)
(mu_eq : mu = b ^ (2 * k) / M)
+ (width_pos : 0 < width)
(strong_bound : b ^ 1 * (b ^ (2 * k) mod M) <= b ^ (k + 1) - mu).
Local Notation weight := (UniformWeight.uweight width).
Local Notation partition := (Partition.partition weight).
@@ -4903,23 +4938,19 @@ Module BarrettReduction.
q3 (partition (n*2) x) (partition (n+1) q1) = partition (n+1) ((mu * q1) / b ^ (k + 1)))
(r : list Z -> list Z -> list Z)
(r_correct :
- forall x q1 q3,
- 0 <= x < b ^ (2 * k) ->
- q1 = x / b ^ (k - 1) ->
- q3 = (mu * q1) / b ^ (k + 1) ->
- r (partition (n*2) x) (partition (n+1) q3) = partition (n*2) (x - q3 * M))
- (final_reduce : list Z -> list Z)
- (final_reduce_correct :
- forall r,
- 0 <= r < 2 * M ->
- final_reduce (partition (n*2) r) = partition n (if dec (r < M) then r else r - M)).
+ forall x q3,
+ 0 <= x < M * b ^ k ->
+ 0 <= q3 ->
+ (exists b : bool, q3 = x / M + (if b then -1 else 0)) ->
+ r (partition (n*2) x) (partition (n+1) q3) = partition n (x mod M)).
Context (x : Z) (x_range : 0 <= x < M * b ^ k)
(xt : list Z) (xt_correct : xt = partition (n*2) x).
- (* TODO: make barrett Z figure out some of this itself *)
Local Lemma M_pos : 0 < M.
Proof. assert (0 <= b ^ (k - 1)) by Z.zero_bounds. lia. Qed.
+ Local Lemma M_upper : M < weight n.
+ Proof. rewrite UniformWeight.uweight_eq_alt'. lia. Qed.
Local Lemma x_upper : x < b ^ (2 * k).
Proof.
assert (0 < b ^ k) by Z.zero_bounds.
@@ -4934,22 +4965,34 @@ Module BarrettReduction.
Definition reduce :=
dlet_nd q1t := q1 xt in
- dlet_nd q3t := q3 xt q1t in
- dlet_nd rt := r xt q3t in
- final_reduce rt.
+ dlet_nd q3t := q3 xt q1t in
+ r xt q3t.
+
+ Lemma q1_range : 0 <= x / b^(k-1) < b^(k+1).
+ Proof.
+ split; [ solve [Z.zero_bounds] | ].
+ assert (0 < b ^ (k - 1)) by Z.zero_bounds.
+ assert (0 < b ^ k) by Z.zero_bounds.
+ apply Z.div_lt_upper_bound; [ solve [Z.zero_bounds] | ].
+ eapply Z.lt_le_trans with (m:=b^k * b^k);
+ [ nia | autorewrite with pull_Zpow; apply Z.pow_le_mono; lia ].
+ Qed.
+
+ Lemma q3_range : 0 <= mu * (x / b ^ (k - 1)) / b ^ (k + 1).
+ Proof.
+ assert (0 < b ^ (k - 1)) by Z.zero_bounds.
+ subst mu; Z.zero_bounds.
+ Qed.
Lemma reduce_correct : reduce = partition n (x mod M).
Proof.
- cbv [reduce Let_In].
+ cbv [reduce Let_In]. pose proof q3_range.
rewrite xt_correct, q1_correct, q3_correct by auto with lia.
- erewrite r_correct by eauto with lia. rewrite final_reduce_correct; [ | ].
- { rewrite barrett_reduction_small_strong with (b:=b) (k:=k) (m:=mu) (offset:=1) by
- auto using Z.lt_gt with zarith.
- break_innermost_match; Z.ltb_to_lt; (lia || reflexivity). }
- { split; [ Z.zero_bounds; apply qn_small | apply r_small_strong];
- auto using Z.lt_gt with zarith. }
+ assert (exists cond : bool, ((mu * (x / b^(k-1))) / b^(k+1)) = x / M + (if cond then -1 else 0)) as Hq3.
+ { destruct q_nice_strong with (b:=b) (k:=k) (m:=mu) (offset:=1) (a:=x) (n:=M) as [cond Hcond];
+ eauto using Z.lt_gt with zarith. }
+ eauto using r_correct with lia.
Qed.
-
End Generic.
(* Non-standard implementation -- uses specialized instructions and b=2 *)
@@ -4963,15 +5006,32 @@ Module BarrettReduction.
(* sz = 1, width = k = 256 *)
Local Notation w := (UniformWeight.uweight width). Local Notation eval := (Positional.eval w).
Context (mut Mt : list Z) (mut_correct : mut = partition w (sz+1) mu) (Mt_correct : Mt = partition w sz M).
- Context (mu_eq : mu = 2 ^ (2 * k) / M) (M_range : 2^(k-1) < M < 2^k).
+ Context (mu_eq : mu = 2 ^ (2 * k) / M) (muHigh_one : mu / w sz = 1) (M_range : 2^(k-1) < M < 2^k).
Local Lemma wprops : @weight_properties w. Proof. apply UniformWeight.uwprops; auto with lia. Qed.
Local Hint Resolve wprops.
Hint Rewrite mut_correct Mt_correct : pull_partition.
- Lemma w_eq_2k : w sz = 2^k. Admitted. (* TODO *)
+ Lemma w_eq_2k : w sz = 2^k. Proof. rewrite UniformWeight.uweight_eq_alt' by auto. congruence. Qed.
Lemma mu_range : 2^k <= mu < 2^(k+1).
- Admitted. (* TODO *)
+ Proof.
+ rewrite mu_eq. assert (0 < 2^(k-1)) by Z.zero_bounds.
+ assert (2^k < M * 2).
+ { replace (2^k) with (2^(k-1+1)) by (f_equal; lia).
+ rewrite Z.pow_add_r, Z.pow_1_r by lia.
+ lia. }
+ replace (2 ^ (2 * k)) with (2^(k+k)) by (f_equal; lia).
+ rewrite !Z.pow_add_r, Z.pow_1_r by lia. split.
+ { apply Z.div_le_lower_bound; nia. }
+ { apply Z.div_lt_upper_bound; nia. }
+ Qed.
+ Lemma mu_range' : 0 <= mu < 2 * w sz.
+ Proof.
+ pose proof mu_range. assert (0 < 2^k) by auto with zarith.
+ assert (2^(k+1) = 2 * w sz); [ | lia].
+ rewrite k_eq, UniformWeight.uweight_eq_alt'.
+ rewrite Z.pow_add_r, Z.pow_1_r by lia. lia.
+ Qed.
Lemma M_range' : 0 <= M < w sz. (* more convenient form, especially for mod_small *)
Proof. assert (0 <= 2 ^ (k-1)) by Z.zero_bounds. pose proof w_eq_2k; lia. Qed.
@@ -5000,13 +5060,13 @@ Module BarrettReduction.
(* select based on the most significant bit of xHigh *)
Definition muSelect xt :=
let xHigh := nth_default 0 xt (sz*2 - 1) in
- Positional.select (Z.cc_m (2 ^ width) xHigh) (repeat 0 sz) (low mut).
+ Positional.select (Z.cc_m (2 ^ width) xHigh) (Positional.zeros sz) (low mut).
Definition cond_sub (a y : list Z) : list Z :=
let cond := Z.cc_l (nth_default 0 (high a) 0) in (* a[k] = least significant bit of (high a) *)
- dlet_nd maybe_y := Positional.select cond (repeat 0 sz) y in
- dlet_nd diff := Rows.sub w sz (low a) maybe_y in (* (a mod (w sz) - y) mod (w sz)) = (a - y) mod (w sz); since we know a - y is < w sz this is okay by mod_small *)
- fst diff.
+ dlet_nd maybe_y := Positional.select cond (Positional.zeros sz) y in
+ dlet_nd diff := Rows.sub w sz (low a) maybe_y in (* (a mod (w sz) - y) mod (w sz)) = (a - y) mod (w sz); since we know a - y is < w sz this is okay by mod_small *)
+ fst diff.
Definition cond_subM x :=
if Nat.eq_dec sz 1
@@ -5017,16 +5077,14 @@ Module BarrettReduction.
Definition q3 (xt q1t : list Z) :=
dlet_nd muSelect := muSelect xt in (* make sure muSelect is not inlined in the output *)
- dlet_nd twoq := mul_high (fill (sz*2) mut) (fill (sz*2) q1t) (fill (sz*2) muSelect) in
- shiftr (sz+1) twoq 1.
+ dlet_nd twoq := mul_high (fill (sz*2) mut) (fill (sz*2) q1t) (fill (sz*2) muSelect) in
+ shiftr (sz+1) twoq 1.
Definition r (xt q3t : list Z) :=
dlet_nd r2 := widemul (low q3t) Mt in
- widesub xt r2.
-
- Definition final_reduce (rt : list Z) :=
+ dlet_nd rt := widesub xt r2 in
dlet_nd rt := cond_sub rt Mt in
- cond_subM rt.
+ cond_subM rt.
Section Proofs.
@@ -5120,20 +5178,20 @@ Module BarrettReduction.
Qed.
Lemma low_correct n a : (sz <= n)%nat -> low (partition w n a) = partition w sz a.
- Admitted.
+ Proof. cbv [low]; auto using UniformWeight.uweight_firstn_partition with lia. Qed.
Lemma high_correct a : high (partition w (sz*2) a) = partition w sz (a / w sz).
- Admitted.
+ Proof. cbv [high]. rewrite UniformWeight.uweight_skipn_partition by lia. f_equal; lia. Qed.
Lemma fill_correct n m a :
(n <= m)%nat ->
fill m (partition w n a) = partition w m (a mod w n).
- Admitted.
+ Proof.
+ cbv [fill]; intros. distr_length.
+ rewrite <-partition_0 with (weight:=w).
+ rewrite UniformWeight.uweight_partition_app by lia.
+ f_equal; lia.
+ Qed.
Hint Rewrite low_correct high_correct fill_correct using lia : pull_partition.
- (* TODO: move *)
- Lemma partition_app n m a b :
- partition w n a ++ partition w m b = partition w (n+m) (a mod w n + b * w n).
- Admitted.
-
Lemma wideadd_correct a b :
wideadd (partition w (sz*2) a) (partition w (sz*2) b) = partition w (sz*2) (a + b).
Proof.
@@ -5192,7 +5250,7 @@ Module BarrettReduction.
repeat match goal with
| _ => progress autorewrite with pull_partition
| _ => progress rewrite ?Ha, ?Ha0b1
- | _ => rewrite partition_app by auto;
+ | _ => rewrite UniformWeight.uweight_partition_app by lia;
replace (sz+sz)%nat with (sz*2)%nat by lia
| _ => rewrite Z.mod_pull_div by auto with zarith
| _ => progress Z.rewrite_mod_small
@@ -5200,12 +5258,6 @@ Module BarrettReduction.
end.
Qed.
- Lemma muSelect_correct x :
- 0 <= x < w (sz * 2) ->
- muSelect (partition w (sz*2) x) = partition w (sz+1) (mu mod (w sz) * (x / 2 ^ (k - 1) / (w sz))).
- Admitted.
- Hint Rewrite muSelect_correct using lia : pull_partition.
-
(* TODO: move *)
Lemma pow_pos_le a b : 0 < a -> 0 < b -> a <= a ^ b.
Proof.
@@ -5215,47 +5267,102 @@ Module BarrettReduction.
Qed.
Hint Resolve pow_pos_le : zarith.
+ (* TODO: move *)
+ Lemma pow_pos_lt a b : 1 < a -> 1 < b -> a < a ^ b.
+ Proof.
+ intros; eapply Z.le_lt_trans with (m:=a ^ 1).
+ { rewrite Z.pow_1_r; reflexivity. }
+ { apply Z.pow_lt_mono_r; auto with zarith. }
+ Qed.
+ Hint Resolve pow_pos_lt : zarith.
+
+ (* TODO: move *)
+ Lemma pow_div_base a b : a <> 0 -> 0 < b -> a ^ b / a = a ^ (b - 1).
+ Proof. intros; rewrite Z.pow_sub_r, Z.pow_1_r; lia. Qed.
+ Hint Rewrite pow_div_base using zutil_arith : pull_Zpow.
+
+ (* TODO: move *)
+ Lemma pow_mul_base a b : 0 <= b -> a * a ^ b = a ^ (b + 1).
+ Proof. intros; rewrite <-Z.pow_succ_r, <-Z.add_1_r by lia; reflexivity. Qed.
+ Hint Rewrite pow_mul_base using zutil_arith : pull_Zpow.
+
+
+ (* improved! *)
+ Ltac zero_bounds' :=
+ repeat match goal with
+ | |- ?a <> 0 => apply Z.positive_is_nonzero
+ | |- ?a > 0 => apply Z.lt_gt
+ | |- ?a >= 0 => apply Z.le_ge
+ end;
+ try match goal with
+ | |- 0 < ?a => Z.zero_bounds
+ | |- 0 <= ?a => Z.zero_bounds
+ end.
+
+ Ltac zutil_arith ::= solve [ omega | Psatz.lia | auto with nocore | solve [zero_bounds'] ].
+
+ Hint Rewrite UniformWeight.uweight_S UniformWeight.uweight_eq_alt' using lia : weight_to_pow.
+ Hint Rewrite <-UniformWeight.uweight_S UniformWeight.uweight_eq_alt' using lia : pow_to_weight.
+
+ Lemma q1_range x :
+ 0 <= x < w (sz * 2) ->
+ 0 <= x / 2 ^ (k-1) < 2 * w sz.
+ Proof.
+ intros; split; [ solve [Z.zero_bounds] | ].
+ apply Z.div_lt_upper_bound; [ solve [Z.zero_bounds] | ].
+ assert (w (sz * 2) <= 2 ^ (k-1) * (2 * w sz)); [ | lia ].
+ autorewrite with weight_to_pow pull_Zpow.
+ apply Z.pow_le_mono_r; lia.
+ Qed.
+
+ Lemma muSelect_correct x :
+ 0 <= x < w (sz * 2) ->
+ muSelect (partition w (sz*2) x) = partition w sz (mu mod (w sz) * (x / 2 ^ (k - 1) / (w sz))).
+ Proof.
+ cbv [muSelect]; intros;
+ repeat match goal with
+ | _ => progress autorewrite with pull_partition natsimplify
+ | _ => progress rewrite ?Z.cc_m_eq by auto with zarith
+ | _ => erewrite Positional.select_eq by (distr_length; eauto)
+ | _ => rewrite nth_default_partition by lia
+ | _ => progress replace (S (sz * 2 - 1)) with (sz * 2)%nat by lia
+ | H : 0 <= ?x < ?m |- context [?x mod ?m] => rewrite (Z.mod_small x m) by auto with zarith
+ end.
+ replace (x / (w (sz * 2 - 1)) / (2 ^ width / 2)) with (x / (2 ^ (k - 1)) / w sz) by
+ (autorewrite with weight_to_pow pull_Zpow pull_Zdiv; do 2 f_equal; nia).
+
+ rewrite Z.div_between_0_if with (a:=x / 2^(k-1)) by (zero_bounds'; auto using q1_range).
+ break_innermost_match; try lia; autorewrite with zsimplify_fast; [ | ].
+ { apply partition_eq_mod; auto with zarith. }
+ { rewrite partition_0; reflexivity. }
+ Qed.
+ Hint Rewrite muSelect_correct using lia : pull_partition.
+
+ Lemma mu_q1_range x (Hx : 0 <= x < w (sz * 2)) : mu * (x / 2^(k-1)) < w sz * w (sz * 2).
+ Proof.
+ pose proof mu_range'. pose proof q1_range x ltac:(lia).
+ replace (w (sz * 2)) with (w sz * w sz) by
+ (autorewrite with weight_to_pow pull_Zpow; f_equal; lia).
+ apply Z.lt_le_trans with (m:= 2 * w sz * (2 * w sz)); [ nia | ].
+ assert (4 <= w sz); [ | nia ]. change 4 with (Z.pow 2 2).
+ autorewrite with weight_to_pow. apply Z.pow_le_mono_r; nia.
+ Qed.
+
Lemma q3_correct x (Hx : 0 <= x < w (sz * 2)) q1 (Hq1 : q1 = x / 2 ^ (k - 1)) :
q3 (partition w (sz*2) x) (partition w (sz+1) q1) = partition w (sz+1) ((mu*q1) / 2 ^ (k + 1)).
Proof.
- cbv [q3 Let_In]. intros.
- assert (0 <= q1 < w (sz+1)) by admit.
- autorewrite with pull_partition.
- rewrite ?fill_correct, ?muSelect_correct by lia.
- rewrite <-Hq1.
- rewrite mul_high_correct; [ | | ].
- 2: { (* mu mod w (sz + 1) / w sz = 1 *)
- (* highest bit of mu is 1 -- do this by rewriting mod small using bounds on mu, then take axiom for highest bit *)
- admit. }
- 2: {
- rewrite mod_mod_weight by lia.
- push_Zmod. rewrite Z.mod_pull_div by auto with zarith.
- assert (0 < w sz) by auto with zarith.
- assert (w (sz+1) <= w (sz+1) * w sz) by (apply Z.le_mul_diag_r; auto with zarith).
- Z.rewrite_mod_small.
- autorewrite with natsimplify in *.
- rewrite UniformWeight.uweight_S in * by auto with zarith.
- assert (0 <= q1 / w sz < 2^width).
- { split; [ solve [Z.zero_bounds] | ].
- apply Z.div_lt_upper_bound; auto with nia. }
- pose proof (Z.mod_pos_bound mu (w sz) ltac:(auto)).
- apply Z.mod_small. nia. }
- rewrite shiftr_correct; auto with zarith.
- 2: admit. (* (Z.to_nat (1 / width) <= sz * 2)%nat *)
- 2: admit. (* (sz + 1 <= sz * 2 - Z.to_nat (1 / width))%nat *)
- 2: admit. (* 0 <= mu mod w (sz + 1) * (q1 mod w (sz + 1)) / w sz < w (sz * 2) *)
- (* TODO: see if things are easier using div_mod_to_quot_rem *)
- Z.rewrite_mod_small.
- pose proof mu_range.
- replace (2 ^ (k+1)) with (w sz * 2) in * by admit.
- assert (0 < 2 ^ k) by (Z.zero_bounds; nia).
- assert (2 * w sz <= w (sz + 1)).
- { autorewrite with natsimplify. rewrite UniformWeight.uweight_S by auto with zarith.
- apply Z.mul_le_mono_nonneg_r; auto with zarith. }
- Z.rewrite_mod_small.
- rewrite Z.div_div by auto with zarith.
- repeat (f_equal; try ring).
- Admitted.
+ cbv [q3 Let_In]. intros. pose proof mu_q1_range x ltac:(lia).
+ pose proof mu_range'. pose proof q1_range x ltac:(lia).
+ autorewrite with pull_partition pull_Zmod.
+ assert (2 * w sz < w (sz + 1)) by (autorewrite with weight_to_pow pull_Zpow; auto with zarith lia).
+ Z.rewrite_mod_small. rewrite <-Hq1 in *.
+ rewrite mul_high_correct by
+ (try lia; rewrite Z.div_between_0_if with (a:=q1) by lia;
+ break_innermost_match; autorewrite with zsimplify; reflexivity).
+ rewrite shiftr_correct by (rewrite ?Z.div_small, ?Z2Nat.inj_0 by lia; auto with zarith lia).
+ autorewrite with weight_to_pow pull_Zpow pull_Zdiv.
+ congruence.
+ Qed.
Lemma cond_sub_correct a b :
cond_sub (partition w (sz*2) a) (partition w sz b)
@@ -5263,7 +5370,7 @@ Module BarrettReduction.
then a
else a - b).
Proof.
- cbv [cond_sub Let_In Z.cc_l]. autorewrite with pull_partition.
+ intros; cbv [cond_sub Let_In Z.cc_l]. autorewrite with pull_partition.
rewrite nth_default_partition by lia.
rewrite weight_0 by auto. autorewrite with zsimplify_fast.
rewrite UniformWeight.uweight_eq_alt' with (n:=1%nat). autorewrite with push_Zof_nat zsimplify.
@@ -5280,16 +5387,26 @@ Module BarrettReduction.
then a
else a - M).
Proof.
- Admitted.
+ cbv [cond_subM]. autorewrite with pull_partition. pose proof M_range'.
+ rewrite Rows.conditional_sub_partitions by
+ (distr_length; auto; autorewrite with push_eval; try apply partition_eq_mod; auto with zarith).
+ rewrite nth_default_partition, weight_0, Z.add_modulo_correct by auto with lia.
+ autorewrite with zsimplify_fast push_eval. Z.rewrite_mod_small.
+ pose proof Z.mod_pos_bound a (w 1) ltac:(auto).
+ break_innermost_match; Z.ltb_to_lt;
+ repeat match goal with
+ | _ => lia
+ | _ => reflexivity
+ | _ => apply partition_eq_mod; solve [auto with zarith]
+ | _ => rewrite partition_step, weight_0 by auto
+ | _ => progress autorewrite with zsimplify_fast
+ | _ => progress Z.rewrite_mod_small
+ | _ => rewrite Z.sub_mod_l with (a:=a)
+ end.
+ Qed.
Hint Rewrite cond_subM_correct : pull_partition.
- (* TODO: move this with barrett Z lemmas & prove *)
- Lemma q3_range x :
- 0 <= x < 2 ^ (2 * k) ->
- 0 <= (mu * (x / 2 ^ (k-1))) / 2 ^ (k+1) < 2^k.
- Proof using Type.
- Admitted.
-
+ (* TODO: unused?*)
Lemma w_eq_22k : w (sz * 2) = 2 ^ (2 * k).
Proof.
replace (sz * 2)%nat with (sz + sz)%nat by lia.
@@ -5297,48 +5414,66 @@ Module BarrettReduction.
f_equal; lia.
Qed.
- Lemma r_correct x q1 q3 :
- 0 <= x < w (sz * 2) ->
- q1 = x / 2 ^ (k - 1) ->
- q3 = (mu*q1) / 2 ^ (k + 1) ->
- r (partition w (sz*2) x) (partition w (sz+1) q3) = partition w (sz*2) (x - q3 * M).
+ (* TODO: unused?*)
+ Lemma q3M_eq x q3 (b:bool) :
+ 0 <= x < M * 2 ^ k ->
+ q3 = x / M + (if b then -1 else 0) ->
+ q3 * M = x - x mod M - (if b then M else 0).
Proof.
- cbv [r Let_In]; intros. pose proof w_eq_22k.
- autorewrite with pull_partition.
- assert (0 <= q3 < w sz).
- { subst q3 q1. rewrite w_eq_2k. apply q3_range; lia. }
- pose proof M_range'. Z.rewrite_mod_small.
- autorewrite with zsimplify. reflexivity.
+ intros. assert (0 < 2^(k-1)) by Z.zero_bounds. rewrite Z.mod_eq by lia.
+ subst q3. break_innermost_match; ring.
Qed.
- Lemma final_reduce_correct r:
- 0 <= r < 2 * M ->
- final_reduce (partition w (sz*2) r) = partition w sz (if dec (r < M) then r else r - M).
+ (* TODO: move *)
+ Hint Resolve Z.positive_is_nonzero Z.lt_gt : zarith.
+
+ Lemma r_idea x q3 (b:bool) :
+ 0 <= x < M * 2 ^ k ->
+ 0 <= q3 ->
+ q3 = x / M + (if b then -1 else 0) ->
+ x - q3 mod w sz * M = x mod M + (if b then M else 0).
Proof.
- intros; pose proof M_range'.
- cbv [final_reduce Let_In]; intros; autorewrite with pull_partition.
- apply partition_eq_mod; auto with zarith.
- replace (2 ^ (k + 1)) with (2 * w sz) in *
- by (rewrite Z.pow_add_r by auto with zarith; autorewrite with zsimplify; rewrite <-w_eq_2k; lia).
- pose proof (Z.mod_pos_bound r (w sz) ltac:(auto with zarith)).
- remember ((r / w sz) mod 2) as rk.
- assert (rk = 1 -> r > M) by (subst rk; Z.rewrite_mod_small; intros;
- rewrite (Z.div_mod r (w sz)) by auto with zarith; nia).
- replace r with (w sz * rk + r mod w sz) by
- (subst rk; Z.rewrite_mod_small; rewrite <-Z.div_mod; auto with zarith).
- subst rk; rewrite Zmod_odd in *.
+ intros. assert (0 < 2^(k-1)) by Z.zero_bounds.
+ assert (q3 < w sz).
+ { apply Z.le_lt_trans with (m:=x/M); [ subst q3; break_innermost_match; lia | ].
+ autorewrite with weight_to_pow. rewrite <-k_eq. auto with zarith. }
+ Z.rewrite_mod_small.
repeat match goal with
- | _ => lia
- | H : _ |- _ => rewrite Z.mod_small in H by auto with zarith lia
- | _ => progress (push_Zmod; pull_Zmod); autorewrite with zsimplify_fast
- | _ => break_innermost_match_step
+ | _ => progress autorewrite with push_Zmul
+ | H : q3 = ?e |- _ => progress replace (q3 * M) with (e * M) by (rewrite H; reflexivity)
+ | _ => rewrite (Z.mul_div_eq' x M) by lia
end.
+ break_innermost_match; Z.ltb_to_lt; lia.
+ Qed.
+
+ Lemma r_correct x q3 :
+ 0 <= x < M * 2 ^ k ->
+ 0 <= q3 ->
+ (exists b : bool, q3 = x / M + (if b then -1 else 0)) ->
+ r (partition w (sz*2) x) (partition w (sz+1) q3) = partition w sz (x mod M).
+ Proof.
+ intros; cbv [r Let_In]. pose proof M_range'. assert (0 < 2^(k-1)) by Z.zero_bounds.
+ autorewrite with pull_partition. Z.rewrite_mod_small.
+ match goal with H : exists _, q3 = _ |- _ => destruct H end.
+ erewrite r_idea by eassumption.
+ pose proof (Z.mod_pos_bound x M ltac:(lia)).
+ rewrite Z.div_between_0_if with (b:=w sz) by (break_innermost_match; auto with zarith).
+ rewrite Z.mod_small with (b:=2) by (break_innermost_match; lia).
+ break_innermost_match; Z.ltb_to_lt; try lia; autorewrite with zsimplify_fast;
+ repeat match goal with
+ | |- exists e, _ /\ _ /\ ?f ?x = ?f e => exists x; split; [ | split ]
+ | _ => rewrite Z.mod_small in * by lia
+ | _ => progress Z.rewrite_mod_small
+ | _ => progress (push_Zmod; pull_Zmod); autorewrite with zsimplify_fast
+ | _ => lia
+ | _ => reflexivity
+ end.
Qed.
End Proofs.
Section Def.
Context (sz_eq_1 : sz = 1%nat). (* this is needed to get rid of branches in the templates; a different definition would be needed for sizes other than 1, but would be able to use the same proofs. *)
- Local Hint Resolve q1_correct q3_correct r_correct final_reduce_correct.
+ Local Hint Resolve q1_correct q3_correct r_correct.
(* muselect relies on an initially-set flag, so pull it out of q3 *)
Definition fancy_reduce_muSelect_first xt :=
@@ -5346,8 +5481,7 @@ Module BarrettReduction.
dlet_nd q1t := q1 xt in
dlet_nd twoq := mul_high (fill (sz * 2) mut) (fill (sz * 2) q1t) (fill (sz * 2) muSelect) in
dlet_nd q3t := shiftr (sz+1) twoq 1 in
- dlet_nd rt := r xt q3t in
- final_reduce rt.
+ r xt q3t.
Lemma fancy_reduce_muSelect_first_correct x :
0 <= x < M * 2^k ->
@@ -5356,38 +5490,9 @@ Module BarrettReduction.
Proof.
intros. pose proof w_eq_22k.
erewrite <-reduce_correct with (b:=2) (k:=k) (mu:=mu) by
- (eauto with nia; intros; try rewrite q3'_correct; eauto with nia ).
+ (eauto with nia; intros; try rewrite q3'_correct; try rewrite <-k_eq; eauto with nia ).
reflexivity.
Qed.
-
-
- (*
- Definition fancy_reduce xLow xHigh :=
- dlet_nd muSelect : list Z := Positional.select (Z.cc_m (2 ^ width) xHigh) (repeat 0 sz) (low mut) in
- dlet_nd q1t := [Z.rshi (2 ^ width) xHigh xLow (k - 1); Z.rshi (2 ^ width) 0 xHigh (k - 1)] in
- dlet_nd a0b0 : list Z := widemul (low mut) (low q1t) in
- dlet_nd ab : list Z := wideadd [high a0b0; high q1t] [low q1t; 0] in
- dlet_nd a0 := wideadd ab [muSelect;0] in
- dlet_nd q3t := [Z.rshi (2 ^ width) (nth_default 0 a0 1) (nth_default 0 a0 0) 1; Z.rshi (2 ^ width) (nth_default 0 a0 2) (nth_default 0 a0 1) 1] in
- dlet_nd r2 : list Z := widemul (low q3t) Mt in
- dlet_nd rt : list Z := widesub [xLow; xHigh] r2 in
- dlet_nd maybe_M : list Z := Positional.select (Z.cc_l (nth_default 0 (high rt) 0)) (repeat 0 sz) Mt in
- dlet_nd diff : list Z * Z := Rows.sub w sz (low rt) maybe_M in
- dlet_nd rt0 :=fst diff in
- Z.add_modulo (nth_default 0 rt0 0) 0 M.
-*)
- (* This simplifies the shiftr inside q3, specializing it to the # limbs so we don't end up with maps in the final expression *)
- Derive q3'
- SuchThat (forall xt q1t, q3' xt q1t = q3 xt q1t)
- As q3'_correct.
- Proof.
- intros. cbv [q3]. rewrite sz_eq_1. autorewrite with natsimplify.
- cbv [shiftr shiftr' seq]. break_match; try lia; [ ].
- rewrite Proper_Let_In_nd_changebody; [ | reflexivity | repeat intro ].
- 2 : apply Proper_Let_In_nd_changebody;
- [ reflexivity | repeat intro; autorewrite with push_map push_nth_default; reflexivity ].
- subst q3'; reflexivity.
- Qed.
Derive fancy_reduce'
SuchThat (
@@ -5399,7 +5504,7 @@ Module BarrettReduction.
Proof.
intros. assert (k = width) as width_eq_k by nia.
erewrite <-fancy_reduce_muSelect_first_correct by nia.
- cbv [fancy_reduce_muSelect_first q1 q3 shiftr final_reduce cond_subM].
+ cbv [fancy_reduce_muSelect_first q1 q3 shiftr r cond_subM].
break_match; try solve [exfalso; lia].
match goal with |- ?g ?x = ?rhs =>
let f := (match (eval pattern x in rhs) with ?f _ => f end) in
@@ -5407,7 +5512,6 @@ Module BarrettReduction.
end.
Qed.
- (* TODO: Can't deal with input by shifting and stuff! *)
Derive fancy_reduce
SuchThat (
forall xLow xHigh,
@@ -5433,31 +5537,11 @@ Module BarrettReduction.
autorewrite with zsimplify.
rewrite <-Z.mod_pull_div by Z.zero_bounds.
autorewrite with zsimplify. reflexivity. }
- cbv [fancy_reduce' reduce q1 q3 shiftr final_reduce cond_subM].
+ cbv [fancy_reduce' reduce q1 q3 shiftr r cond_subM].
autorewrite with natsimplify. (* TODO: maybe need to get rid of hd 0? *)
cbv [shiftr' seq]. autorewrite with push_map push_nth_default.
subst fancy_reduce; reflexivity.
Qed.
- (* TODO: decide which version to use
- Derive fancy_reduce
- SuchThat (
- forall x,
- 0 <= x < M * 2 ^ k ->
- 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu ->
- fancy_reduce (partition w (sz*2) x) = partition w sz (x mod M))
- As fancy_reduce_correct.
- Proof.
- intros. pose proof w_eq_22k.
- erewrite <-reduce_correct with (b:=2) (k:=k) (mu:=mu) by eauto with lia.
- assert (width = k) by nia.
- cbv [reduce q1 q3 shiftr final_reduce cond_subM].
- break_match; try solve [exfalso; lia].
- match goal with |- ?g ?x = ?rhs =>
- let f := (match (eval pattern x in rhs) with ?f _ => f end) in
- assert (f = g); subst fancy_reduce; reflexivity
- end.
- Qed.
- *)
End Def.
End Fancy.
End Fancy.
diff --git a/src/PushButtonSynthesis/BarrettReduction.v b/src/PushButtonSynthesis/BarrettReduction.v
index c0078b117..504a26a0c 100644
--- a/src/PushButtonSynthesis/BarrettReduction.v
+++ b/src/PushButtonSynthesis/BarrettReduction.v
@@ -177,6 +177,7 @@ Section rbarrett_red.
| progress cbv [weight]
| rewrite mut_correct
| rewrite Mt_correct
+ | rewrite UniformWeight.uweight_eq_alt'
| rewrite Z.pow_mul_r by lia ].
Local Strategy -100 [barrett_red]. (* needed for making Qed not take forever *)
@@ -185,6 +186,7 @@ Section rbarrett_red.
Proof using M curve_good.
cbv [barrett_red_correct]; intros.
assert (1 < machine_wordsize) by apply use_curve_good.
+ pose proof (Z.mod_pos_bound mu (2^machine_wordsize) ltac:(lia)).
rewrite <-Fancy.fancy_reduce_correct with (mu := muLow + 2^machine_wordsize) (width:=machine_wordsize) (sz:=1%nat) (mut:=[muLow;1]) (Mt:=[M]) by solve_barrett_red_preconditions.
prove_correctness' ltac:(fun _ => idtac) use_curve_good.
{ cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. rewrite Bool.andb_true_iff, !Z.leb_le. lia. }