aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar jadep <jadep@mit.edu>2019-03-07 15:18:42 -0500
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2019-03-25 06:13:45 -0400
commit3e4edb9a9b8cc15bdc02b9005e0b94561645b77b (patch)
tree765489dd5ce02b3baf141190b1f52e3eb2ffb1fb /src
parentbbabd295594448f12161075c5d19dd369ed04a53 (diff)
Get new Barrett proofs to generate Fancy code as before
Diffstat (limited to 'src')
-rw-r--r--src/Arithmetic.v1315
-rw-r--r--src/Arithmetic/BarrettReduction/Generalized.v9
-rw-r--r--src/Fancy/Barrett256.v21
-rw-r--r--src/PushButtonSynthesis/BarrettReduction.v97
-rw-r--r--src/PushButtonSynthesis/BarrettReductionReificationCache.v6
-rw-r--r--src/Rewriter.v20
6 files changed, 762 insertions, 706 deletions
diff --git a/src/Arithmetic.v b/src/Arithmetic.v
index 1a25532f3..65eff76a4 100644
--- a/src/Arithmetic.v
+++ b/src/Arithmetic.v
@@ -56,6 +56,8 @@ Require Import Crypto.Util.Equality.
Require Import Crypto.Util.Tactics.SetEvars.
Import Coq.Lists.List ListNotations. Local Open Scope Z_scope.
+Hint Rewrite Nat.add_1_r : natsimplify. (* TODO : put in a better location *)
+
Module Associational.
Definition eval (p:list (Z*Z)) : Z :=
fold_right (fun x y => x + y) 0%Z (map (fun t => fst t * snd t) p).
@@ -1831,6 +1833,15 @@ Module Partition.
partition n x = partition n y.
Proof. apply partition_Proper. Qed.
+ Lemma nth_default_partition d n x i :
+ (i < n)%nat ->
+ nth_default d (partition n x) i = x mod weight (S i) / weight i.
+ Proof.
+ cbv [partition]; intros.
+ rewrite map_nth_default with (x:=0%nat) by distr_length.
+ autorewrite with push_nth_default natsimplify. reflexivity.
+ Qed.
+
Fixpoint recursive_partition n i x :=
match n with
| O => []
@@ -1877,6 +1888,7 @@ Module Partition.
rewrite Nat.min_l by omega.
reflexivity.
Qed.
+
End Partition.
Hint Rewrite length_partition length_recursive_partition : distr_length.
Hint Rewrite eval_partition using (solve [auto; distr_length]) : push_eval.
@@ -2498,7 +2510,7 @@ Module Rows.
Definition flatten' (start_state : list Z * Z) (inp : rows) : list Z * Z :=
fold_right (fun next_row (state : list Z * Z)=>
- let out_carry := sum_rows next_row (fst state) in
+ let out_carry := sum_rows (fst state) next_row in
(fst out_carry, snd state + snd out_carry)) start_state inp.
(* In order for the output to have the right length and bounds,
@@ -2508,10 +2520,10 @@ Module Rows.
flatten' (hd default inp, 0) (hd default (tl inp) :: tl (tl inp)).
Lemma flatten'_cons state r inp :
- flatten' state (r :: inp) = (fst (sum_rows r (fst (flatten' state inp))), snd (flatten' state inp) + snd (sum_rows r (fst (flatten' state inp)))).
+ flatten' state (r :: inp) = (fst (sum_rows (fst (flatten' state inp)) r), snd (flatten' state inp) + snd (sum_rows (fst (flatten' state inp)) r)).
Proof using Type. cbv [flatten']; autorewrite with list push_fold_right. reflexivity. Qed.
Lemma flatten'_snoc state r inp :
- flatten' state (inp ++ r :: nil) = flatten' (fst (sum_rows r (fst state)), snd state + snd (sum_rows r (fst state))) inp.
+ flatten' state (inp ++ r :: nil) = flatten' (fst (sum_rows (fst state) r), snd state + snd (sum_rows (fst state) r)) inp.
Proof using Type. cbv [flatten']; autorewrite with list push_fold_right. reflexivity. Qed.
Lemma flatten'_nil state : flatten' state [] = state. Proof using Type. reflexivity. Qed.
Hint Rewrite flatten'_cons flatten'_snoc flatten'_nil : push_flatten.
@@ -3004,52 +3016,57 @@ Module BaseConversion.
(* multiply two (n*k)-bit numbers by converting them to n k-bit limbs each, multiplying, then converting back *)
Section widemul.
Context (log2base : Z) (log2base_pos : 0 < log2base).
- Context (n : nat) (n_nz : n <> 0%nat) (n_le_log2base : Z.of_nat n <= log2base)
- (nout : nat) (nout_2 : nout = 2%nat). (* nout is always 2, but partial evaluation is overeager if it's a constant *)
+ Context (m n : nat) (m_nz : m <> 0%nat) (n_nz : n <> 0%nat) (n_le_log2base : Z.of_nat n <= log2base).
Let dw : nat -> Z := weight (log2base / Z.of_nat n) 1.
Let sw : nat -> Z := weight log2base 1.
+ Let mn := (m * n)%nat.
+ Let nout := (m * 2)%nat.
+ Local Lemma mn_nonzero : mn <> 0%nat. Proof. subst mn. apply Nat.neq_mul_0. auto. Qed.
+ Local Hint Resolve mn_nonzero.
+ Local Lemma nout_nonzero : nout <> 0%nat. Proof. subst nout. apply Nat.neq_mul_0. auto. Qed.
+ Local Hint Resolve nout_nonzero.
Local Lemma base_bounds : 0 < 1 <= log2base. Proof using log2base_pos. clear -log2base_pos; auto with zarith. Qed.
Local Lemma dbase_bounds : 0 < 1 <= log2base / Z.of_nat n. Proof using n_nz n_le_log2base. clear -n_nz n_le_log2base; auto with zarith. Qed.
Let dwprops : @weight_properties dw := wprops (log2base / Z.of_nat n) 1 dbase_bounds.
Let swprops : @weight_properties sw := wprops log2base 1 base_bounds.
+ Local Notation deval := (Positional.eval dw).
+ Local Notation seval := (Positional.eval sw).
Hint Resolve Z.gt_lt Z.positive_is_nonzero Nat2Z.is_nonneg.
- Definition widemul a b := mul_converted sw dw 1 1 n n nout (aligned_carries n nout) [a] [b].
+
+ (* TODO: Unspecialize from 2-limb *)
+ (* multi-limb version -- multiply two numbers that each have m limbs with (n*k) bits each, by converting them to numbers with (m*n) limbs of k bits each, multiplying, then converting back *)
+ Definition widemul a b := mul_converted sw dw m m mn mn nout (aligned_carries n nout) a b.
Lemma widemul_correct a b :
- 0 <= a * b < 2^log2base * 2^log2base ->
- widemul a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base].
- Proof using dwprops swprops nout_2.
- cbv [widemul]; intros.
- rewrite mul_converted_partitions by auto with zarith.
- subst nout.
- unfold sw in *; cbv [weight]; cbn.
- autorewrite with zsimplify.
- rewrite Z.pow_mul_r, Z.pow_2_r by omega.
- Z.rewrite_mod_small. reflexivity.
- Qed.
+ length a = m ->
+ length b = m ->
+ widemul a b = Partition.partition sw nout (seval m a * seval m b).
+ Proof. apply mul_converted_partitions; auto with zarith. Qed.
Derive widemul_inlined
SuchThat (forall a b,
- 0 <= a * b < 2^log2base * 2^log2base ->
- widemul_inlined a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base])
+ length a = m ->
+ length b = m ->
+ widemul_inlined a b = Partition.partition sw nout (seval m a * seval m b))
As widemul_inlined_correct.
Proof.
intros.
rewrite <-widemul_correct by auto.
cbv beta iota delta [widemul mul_converted].
- rewrite <-to_associational_inlined_correct with (p:=[a]).
- rewrite <-to_associational_inlined_correct with (p:=[b]).
+ rewrite <-to_associational_inlined_correct with (p:=a).
+ rewrite <-to_associational_inlined_correct with (p:=b).
rewrite <-from_associational_inlined_correct.
subst widemul_inlined; reflexivity.
Qed.
Derive widemul_inlined_reverse
SuchThat (forall a b,
- 0 <= a * b < 2^log2base * 2^log2base ->
- widemul_inlined_reverse a b = [(a * b) mod 2^log2base; (a * b) / 2^log2base])
+ length a = m ->
+ length b = m ->
+ widemul_inlined_reverse a b = Partition.partition sw nout (seval m a * seval m b))
As widemul_inlined_reverse_correct.
Proof.
intros.
@@ -3060,10 +3077,10 @@ Module BaseConversion.
[ | transitivity (from_associational sw dw idxs n p); [ | reflexivity ] ](* reverse to make addc chains line up *)
end.
{ subst widemul_inlined_reverse; reflexivity. }
- { rewrite from_associational_inlined_correct by (subst nout; auto).
+ { rewrite from_associational_inlined_correct by auto.
cbv [from_associational].
rewrite !Rows.flatten_correct by eauto using Rows.length_from_associational.
- rewrite !Rows.eval_from_associational by (subst nout; auto).
+ rewrite !Rows.eval_from_associational by auto.
f_equal.
rewrite !eval_carries, !Associational.bind_snd_correct, !Associational.eval_rev by auto.
reflexivity. }
@@ -3504,7 +3521,11 @@ Module UniformWeight.
{ transitivity (2 ^ 1); [ reflexivity | ].
apply Z.pow_le_mono_r; omega. }
Qed.
-
+ Lemma uweight_sum_indices lgr (Hr : 0 <= lgr) i j : uweight lgr (i + j) = uweight lgr i * uweight lgr j.
+ Proof.
+ rewrite !uweight_eq_alt by lia.
+ rewrite Nat2Z.inj_add; auto using Z.pow_add_r with zarith.
+ Qed.
Lemma uweight_1 lgr : uweight lgr 1 = 2^lgr.
Proof using Type.
cbv [uweight weight].
@@ -3535,6 +3556,20 @@ Module UniformWeight.
auto using uweight_recursive_partition_change_start with omega.
Qed.
+ Lemma uweight_skipn_partition lgr (Hr : 0 < lgr) n x m :
+ skipn m (Partition.partition (uweight lgr) n x) = Partition.partition (uweight lgr) (n - m) (x / uweight lgr m).
+ Proof.
+ cbv [Partition.partition];
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress autorewrite with push_skipn natsimplify zsimplify_fast
+ | _ => rewrite skipn_seq by auto
+ | _ => rewrite weight_0 by auto using uwprops
+ | _ => rewrite Partition.recursive_partition_equiv' by auto using uwprops
+ | _ => auto using uweight_recursive_partition_change_start with zarith
+ end.
+ Qed.
+
Lemma uweight_partition_unique lgr (Hr : 0 < lgr) n ls :
length ls = n -> (forall x, List.In x ls -> 0 <= x <= 2^lgr - 1) ->
ls = Partition.partition (uweight lgr) n (Positional.eval (uweight lgr) n ls).
@@ -3565,6 +3600,29 @@ Module UniformWeight.
| [ |- ?x :: _ = ?x :: _ ] => apply f_equal
end ].
Qed.
+
+ Lemma uweight_eval_app' lgr (Hr : 0 <= lgr) n x y :
+ n = length x ->
+ Positional.eval (uweight lgr) (n + length y) (x ++ y) = Positional.eval (uweight lgr) n x + (uweight lgr n) * Positional.eval (uweight lgr) (length y) y.
+ Proof using Type.
+ induction y using rev_ind;
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress distr_length
+ | _ => progress autorewrite with push_eval zsimplify natsimplify
+ | _ => rewrite Nat.add_succ_r
+ | H : ?x = 0%nat |- _ => subst x
+ | _ => progress rewrite ?app_nil_r, ?app_assoc
+ | _ => reflexivity
+ end.
+ rewrite IHy by auto. rewrite uweight_sum_indices; lia.
+ Qed.
+
+ Lemma uweight_eval_app lgr (Hr : 0 <= lgr) n m x y :
+ n = length x ->
+ m = (n + length y)%nat ->
+ Positional.eval (uweight lgr) m (x ++ y) = Positional.eval (uweight lgr) n x + (uweight lgr n) * Positional.eval (uweight lgr) (length y) y.
+ Proof using Type. intros. subst m. apply uweight_eval_app'; lia. Qed.
End UniformWeight.
Module WordByWordMontgomery.
@@ -4821,627 +4879,588 @@ Module WordByWordMontgomery.
End WordByWordMontgomery.
Hint Rewrite Z2Nat.inj_succ using omega : push_Zto_nat. (* TODO: put in PullPush *)
-Hint Rewrite Nat.add_1_r : natsimplify. (* TODO : put in a better location *)
Module BarrettReduction.
- Locate barrett_reduction_small.
- Check barrett_reduction_small.
- Locate Associational.
- Section Generic.
- Context (k : Z) (k_pos : 0 < k) (b : Z) (b_ok : 2 < b).
- Local Notation weight := (UniformWeight.uweight b).
- Local Notation rep := (list Z).
-
- Definition eval (x : list Z) := Positional.eval weight (length x) x.
- Definition valid (x : list Z) :=
- x = Partition.partition weight (length x) (eval x).
-
- (* TODO: make a lemma sort of like this (actually, probably doesn't have to be specialized to uniform weights) to allow splitting [app] in a positional evaluation *)
- Lemma UniformWeight_eval_app :
- forall b y n x,
- length (x ++ y) = n ->
- Positional.eval (UniformWeight.uweight b) n (x ++ y)
- = Positional.eval (UniformWeight.uweight b) (length x) x
- + UniformWeight.uweight b (length x) * Positional.eval (UniformWeight.uweight b) (n - length x) y.
- Proof.
- induction y using rev_ind;
- repeat match goal with
- | _ => progress (intros; subst)
- | _ => progress distr_length; autorewrite with natsimplify in *
- | |- context [(?a + ?b - ?a)%nat] => replace (a + b - a)%nat with b by lia
- | _ => progress rewrite ?app_nil_r, ?app_assoc, ?Nat.add_succ_r
- | _ => erewrite IHy by distr_length
- | _ => progress autorewrite with natsimplify push_eval zsimplify_fast
- | _ => reflexivity
- end.
- ring_simplify.
- (* TODO: need to say here that uweight (i + j) = uweight i * uweight j *)
- Admitted.
-
- Definition low : rep -> rep := firstn (Z.to_nat k).
- Lemma low_correct : forall x, valid x -> eval (low x) = eval x mod b ^ k.
- Proof.
- cbv [eval low]; intros.
- match goal with |- context [firstn ?n ?l] =>
- rewrite <-(firstn_skipn n l);
- replace (firstn n (firstn n l ++ skipn n l)) with (firstn n l)
- by (rewrite firstn_skipn; reflexivity)
- end.
- erewrite UniformWeight_eval_app by distr_length.
- distr_length.
- apply Nat.min_case_strong; intros.
- { (* weight k = b ^ k, so second term goes away by mod *)
- admit. }
- { autorewrite with push_skipn push_eval zsimplify.
- SearchAbout Partition.partition.
- (* annoying, how do I prove valid -> mod_small? *)
- (* might need to use recursive_paritition? *)
- (*
- if length x <= k:
- skipn k x = []
- second term goes away by eval_nil
- *)
- distr_length.
- 2:distr_length.
- repeat match goal with
- | |- context [Nat.min ?x (S (length ?p))] =>
- destruct (lt_dec (length p) x);
- [ rewrite (Nat.min_r x) by lia | rewrite (Nat.min_l x) by lia ]
- end.
-
- Print UniformWeight.
-
- apply natlike_ind with (x:=k); try omega; destruct x using rev_ind.
- { autorewrite with push_eval zsimplify. reflexivity. }
- { autorewrite with push_eval zsimplify. reflexivity. }
- { intros; autorewrite with push_eval zsimplify distr_length push_firstn. reflexivity. }
- { intros. clear IHx.
- autorewrite with push_Zto_nat distr_length natsimplify push_firstn in *.
- repeat match goal with
- | |- context [Nat.min ?x (S (length ?p))] =>
- destruct (lt_dec (length p) x);
- [ rewrite (Nat.min_r x) by lia | rewrite (Nat.min_l x) by lia ]
- | |- context [firstn (?a - ?b)%nat [_] ] =>
- destruct (le_dec a b); [ rewrite (not_le_minus_0 a b) by lia
- | rewrite (firstn_all2 (n:=(a-b)%nat)) by (distr_length; lia) ];
- try lia
- end.
- { rewrite firstn_all2 by lia.
- autorewrite with push_eval in *.
- rewrite Z.mod_small; [ reflexivity | ].
- (* since length x0 < S x1, it must be that length x0 <= x1 and therefore
- [weight (length x0) <= b ^x1] *)
- (* since length x0 < S x1, it must be that length x0 <= x1 and therefore
- [weight (length x0) <= b ^x1] *)
-
- { admit. }
- { SearchAbout (_ - _ = 0)%nat.
- SearchAbout (S _ - _)%nat.
- rewrite not_le_minus_0.
- 2:lia.
-
-
- match goal with |- context [firstn ?n ?l] =>
- rewrite <-(firstn_skipn n l);
- replace (firstn n (firstn n l ++ skipn n l)) with (firstn n l) by (rewrite firstn_skipn; reflexivity)
- end.
- distr_length.
- autorewrite with push_Zto_nat push_firstn push_skipn.
- Print Rewrite HintDb push_skipn.
- SearchAbout Nat.min S.
- repeat apply Nat.min_case_strong; intros; try lia.
- { SearchAbout Positional.eval app.
-
- match goal with
- | |- context [(?a + (?b - ?a))%nat] => replace (a + (b - a))%nat with b by lia
- end.
-
- {
- Focus 9.
- app
- distr_length.
- SearchAbout firstn nil.
-
- rewrite <-firstn
- autorewrite with distr_length.
- SearchAbout Nat.min.
- apply Nat.min_case_strong; intros.
- 2:autorewrite with push_firstn.
- SearchAbout Positional.eval.
- SearchAbout firstn.
- Admitted.
- Print Positional.
-
-
- (* TODO : generalize to multi-word and operate on (list Z) instead of T; maybe stop taking ops as context variables *)
+ Import Partition.
Section Generic.
- Context {T} (rep : T -> Z -> Prop)
- (k : Z) (k_pos : 0 < k)
- (low : T -> Z)
- (low_correct : forall a x, rep a x -> low a = x mod 2 ^ k)
- (shiftr : T -> Z -> T)
- (shiftr_correct : forall a x n,
- rep a x ->
- 0 <= n <= k ->
- rep (shiftr a n) (x / 2 ^ n))
- (mul_high : T -> T -> Z -> T)
- (mul_high_correct : forall a b x y x0y1,
- rep a x ->
- rep b y ->
- 2 ^ k <= x < 2^(k+1) ->
- 0 <= y < 2^(k+1) ->
- x0y1 = x mod 2 ^ k * (y / 2 ^ k) ->
- rep (mul_high a b x0y1) (x * y / 2 ^ k))
- (mul : Z -> Z -> T)
- (mul_correct : forall x y,
- 0 <= x < 2^k ->
- 0 <= y < 2^k ->
- rep (mul x y) (x * y))
- (sub : T -> T -> T)
- (sub_correct : forall a b x y,
- rep a x ->
- rep b y ->
- 0 <= x - y < 2^k * 2^k ->
- rep (sub a b) (x - y))
- (cond_sub1 : T -> Z -> Z)
- (cond_sub1_correct : forall a x y,
- rep a x ->
- 0 <= x < 2 * y ->
- 0 <= y < 2 ^ k ->
- cond_sub1 a y = if (x <? 2 ^ k) then x else x - y)
- (cond_sub2 : Z -> Z -> Z)
- (cond_sub2_correct : forall x y, cond_sub2 x y = if (x <? y) then x else x - y).
- Context (xt mut : T) (M muSelect: Z).
-
- Let mu := 2 ^ (2 * k) / M.
- Context x (mu_rep : rep mut mu) (x_rep : rep xt x).
- Context (M_nz : 0 < M)
- (x_range : 0 <= x < M * 2 ^ k)
- (M_range : 2 ^ (k - 1) < M < 2 ^ k)
- (M_good : 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu)
- (muSelect_correct: muSelect = mu mod 2 ^ k * (x / 2 ^ (k - 1) / 2 ^ k)).
-
- Definition qt :=
- dlet_nd muSelect := muSelect in (* makes sure muSelect is not inlined in the output *)
- dlet_nd q1 := shiftr xt (k - 1) in
- dlet_nd twoq := mul_high mut q1 muSelect in
- shiftr twoq 1.
- Definition reduce :=
- dlet_nd qt := qt in
- dlet_nd r2 := mul (low qt) M in
- dlet_nd r := sub xt r2 in
- let q3 := cond_sub1 r M in
- cond_sub2 q3 M.
-
- Lemma looser_bound : M * 2 ^ k < 2 ^ (2*k).
- Proof. clear -M_range M_nz x_range k_pos; rewrite <-Z.add_diag, Z.pow_add_r; nia. Qed.
-
- Lemma pow_2k_eq : 2 ^ (2*k) = 2 ^ (k - 1) * 2 ^ (k + 1).
- Proof. clear -k_pos; rewrite <-Z.pow_add_r by omega. f_equal; ring. Qed.
-
- Lemma mu_bounds : 2 ^ k <= mu < 2^(k+1).
+ Context (b k M mu width : Z) (n : nat)
+ (b_ok : 1 < b)
+ (k_pos : 0 < k)
+ (M_range : b ^ (k - 1) < M < b ^ k)
+ (mu_eq : mu = b ^ (2 * k) / M)
+ (strong_bound : b ^ 1 * (b ^ (2 * k) mod M) <= b ^ (k + 1) - mu).
+ Local Notation weight := (UniformWeight.uweight width).
+ Local Notation partition := (Partition.partition weight).
+ Context (q1 : list Z -> list Z)
+ (q1_correct :
+ forall x,
+ 0 <= x < b ^ (2 * k) ->
+ q1 (partition (n*2)%nat x) = partition (n+1)%nat (x / b ^ (k - 1)))
+ (q3 : list Z -> list Z -> list Z)
+ (q3_correct :
+ forall x q1,
+ 0 <= x < b ^ (2 * k) ->
+ q1 = x / b ^ (k - 1) ->
+ q3 (partition (n*2) x) (partition (n+1) q1) = partition (n+1) ((mu * q1) / b ^ (k + 1)))
+ (r : list Z -> list Z -> list Z)
+ (r_correct :
+ forall x q1 q3,
+ 0 <= x < b ^ (2 * k) ->
+ q1 = x / b ^ (k - 1) ->
+ q3 = (mu * q1) / b ^ (k + 1) ->
+ r (partition (n*2) x) (partition (n+1) q3) = partition (n*2) (x - q3 * M))
+ (final_reduce : list Z -> list Z)
+ (final_reduce_correct :
+ forall r,
+ 0 <= r < 2 * M ->
+ final_reduce (partition (n*2) r) = partition n (if dec (r < M) then r else r - M)).
+
+ Context (x : Z) (x_range : 0 <= x < M * b ^ k)
+ (xt : list Z) (xt_correct : xt = partition (n*2) x).
+
+ (* TODO: make barrett Z figure out some of this itself *)
+ Local Lemma M_pos : 0 < M.
+ Proof. assert (0 <= b ^ (k - 1)) by Z.zero_bounds. lia. Qed.
+ Local Lemma x_upper : x < b ^ (2 * k).
Proof.
- pose proof looser_bound.
- subst mu. split.
- { apply Z.div_le_lower_bound; omega. }
- { apply Z.div_lt_upper_bound; try omega.
- rewrite pow_2k_eq; apply Z.mul_lt_mono_pos_r; auto with zarith. }
+ assert (0 < b ^ k) by Z.zero_bounds.
+ apply Z.lt_le_trans with (m:= M * b^k); [ lia | ].
+ transitivity (b^k * b^k); [ nia | ].
+ rewrite <-Z.pow_2_r, <-Z.pow_mul_r by lia.
+ rewrite (Z.mul_comm k 2); reflexivity.
Qed.
+ Local Lemma xmod_lt_M : x mod b ^ (k - 1) <= M.
+ Proof. pose proof (Z.mod_pos_bound x (b ^ (k - 1)) ltac:(Z.zero_bounds)). lia. Qed.
+ Local Hint Resolve M_pos x_upper xmod_lt_M.
- Lemma shiftr_x_bounds : 0 <= x / 2 ^ (k - 1) < 2^(k+1).
- Proof.
- pose proof looser_bound.
- split; [ solve [Z.zero_bounds] | ].
- apply Z.div_lt_upper_bound; auto with zarith.
- rewrite <-pow_2k_eq. omega.
- Qed.
- Hint Resolve shiftr_x_bounds.
-
- Ltac solve_rep := eauto using shiftr_correct, mul_high_correct, mul_correct, sub_correct with omega.
-
- Let q := mu * (x / 2 ^ (k - 1)) / 2 ^ (k + 1).
+ Definition reduce :=
+ dlet_nd q1t := q1 xt in
+ dlet_nd q3t := q3 xt q1t in
+ dlet_nd rt := r xt q3t in
+ final_reduce rt.
- Lemma q_correct : rep qt q .
+ Lemma reduce_correct : reduce = partition n (x mod M).
Proof.
- pose proof mu_bounds. cbv [qt]; subst q.
- rewrite Z.pow_add_r, <-Z.div_div by Z.zero_bounds.
- solve_rep.
- Qed.
- Hint Resolve q_correct.
-
- Lemma x_mod_small : x mod 2 ^ (k - 1) <= M.
- Proof. transitivity (2 ^ (k - 1)); auto with zarith. Qed.
- Hint Resolve x_mod_small.
-
- Lemma q_bounds : 0 <= q < 2 ^ k.
- Proof.
- pose proof looser_bound. pose proof x_mod_small. pose proof mu_bounds.
- split; subst q; [ solve [Z.zero_bounds] | ].
- edestruct q_nice_strong with (n:=M) as [? Hqnice];
- try rewrite Hqnice; auto; try omega; [ ].
- apply Z.le_lt_trans with (m:= x / M).
- { break_match; omega. }
- { apply Z.div_lt_upper_bound; omega. }
- Qed.
-
- Lemma two_conditional_subtracts :
- forall a x,
- rep a x ->
- 0 <= x < 2 * M ->
- cond_sub2 (cond_sub1 a M) M = cond_sub2 (cond_sub2 x M) M.
- Proof.
- intros.
- erewrite !cond_sub2_correct, !cond_sub1_correct by (eassumption || omega).
- break_match; Z.ltb_to_lt; try lia; discriminate.
- Qed.
-
- Lemma r_bounds : 0 <= x - q * M < 2 * M.
- Proof.
- pose proof looser_bound. pose proof q_bounds. pose proof x_mod_small.
- subst q mu; split.
- { Z.zero_bounds. apply qn_small; omega. }
- { apply r_small_strong; rewrite ?Z.pow_1_r; auto; omega. }
- Qed.
-
- Lemma reduce_correct : reduce = x mod M.
- Proof.
- pose proof looser_bound. pose proof r_bounds. pose proof q_bounds.
- assert (2 * M < 2^k * 2^k) by nia.
- rewrite barrett_reduction_small with (k:=k) (m:=mu) (offset:=1) (b:=2) by (auto; omega).
cbv [reduce Let_In].
- erewrite low_correct by eauto. Z.rewrite_mod_small.
- erewrite two_conditional_subtracts by solve_rep.
- rewrite !cond_sub2_correct.
- subst q; reflexivity.
- Qed.
- End Generic.
-
- Section BarrettReduction.
- Context (k : Z) (k_bound : 2 <= k).
- Context (M muLow : Z).
- Context (M_pos : 0 < M)
- (muLow_eq : muLow + 2^k = 2^(2*k) / M)
- (muLow_bounds : 0 <= muLow < 2^k)
- (M_bound1 : 2 ^ (k - 1) < M < 2^k)
- (M_bound2: 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2^k)).
-
- Context (n:nat) (Hn_nz: n <> 0%nat) (n_le_k : Z.of_nat n <= k).
- Context (nout : nat) (Hnout : nout = 2%nat).
- Let w := weight k 1.
- Local Lemma k_range : 0 < 1 <= k. Proof. omega. Qed.
- Let props : @weight_properties w := wprops k 1 k_range.
-
- Hint Rewrite Positional.eval_nil Positional.eval_snoc : push_eval.
-
- Definition low (t : list Z) : Z := nth_default 0 t 0.
- Definition high (t : list Z) : Z := nth_default 0 t 1.
- Definition represents (t : list Z) (x : Z) :=
- t = [x mod 2^k; x / 2^k] /\ 0 <= x < 2^k * 2^k.
-
- Lemma represents_eq t x :
- represents t x -> t = [x mod 2^k; x / 2^k].
- Proof. cbv [represents]; tauto. Qed.
-
- Lemma represents_length t x : represents t x -> length t = 2%nat.
- Proof. cbv [represents]; intuition. subst t; reflexivity. Qed.
-
- Lemma represents_low t x :
- represents t x -> low t = x mod 2^k.
- Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed.
-
- Lemma represents_high t x :
- represents t x -> high t = x / 2^k.
- Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed.
-
- Lemma represents_low_range t x :
- represents t x -> 0 <= x mod 2^k < 2^k.
- Proof. auto with zarith. Qed.
-
- Lemma represents_high_range t x :
- represents t x -> 0 <= x / 2^k < 2^k.
- Proof.
- destruct 1 as [? [? ?] ]; intros.
- auto using Z.div_lt_upper_bound with zarith.
- Qed.
- Hint Resolve represents_length represents_low_range represents_high_range.
-
- Lemma represents_range t x :
- represents t x -> 0 <= x < 2^k*2^k.
- Proof. cbv [represents]; tauto. Qed.
-
- Lemma represents_id x :
- 0 <= x < 2^k * 2^k ->
- represents [x mod 2^k; x / 2^k] x.
- Proof.
- intros; cbv [represents]; autorewrite with cancel_pair.
- Z.rewrite_mod_small; tauto.
- Qed.
-
- Local Ltac push_rep :=
- repeat match goal with
- | H : represents ?t ?x |- _ => unique pose proof (represents_low_range _ _ H)
- | H : represents ?t ?x |- _ => unique pose proof (represents_high_range _ _ H)
- | H : represents ?t ?x |- _ => rewrite (represents_low t x) in * by assumption
- | H : represents ?t ?x |- _ => rewrite (represents_high t x) in * by assumption
- end.
-
- Definition shiftr (t : list Z) (n : Z) : list Z :=
- [Z.rshi (2^k) (high t) (low t) n; Z.rshi (2^k) 0 (high t) n].
-
- Lemma shiftr_represents a i x :
- represents a x ->
- 0 <= i <= k ->
- represents (shiftr a i) (x / 2 ^ i).
- Proof.
- cbv [shiftr]; intros; push_rep.
- match goal with H : _ |- _ => pose proof (represents_range _ _ H) end.
- assert (0 < 2 ^ i) by auto with zarith.
- assert (x < 2 ^ i * 2 ^ k * 2 ^ k) by nia.
- assert (0 <= x / 2 ^ k / 2 ^ i < 2 ^ k) by
- (split; Z.zero_bounds; auto using Z.div_lt_upper_bound with zarith).
- repeat match goal with
- | _ => rewrite Z.rshi_correct by auto with zarith
- | _ => rewrite <-Z.div_mod''' by auto with zarith
- | _ => progress autorewrite with zsimplify_fast
- | _ => progress Z.rewrite_mod_small
- | |- context [represents [(?a / ?c) mod ?b; ?a / ?b / ?c] ] =>
- rewrite (Z.div_div_comm a b c) by auto with zarith
- | _ => solve [auto using represents_id, Z.div_lt_upper_bound with zarith lia]
- end.
+ rewrite xt_correct, q1_correct, q3_correct by auto with lia.
+ erewrite r_correct by eauto with lia. rewrite final_reduce_correct; [ | ].
+ { rewrite barrett_reduction_small_strong with (b:=b) (k:=k) (m:=mu) (offset:=1) by
+ auto using Z.lt_gt with zarith.
+ break_innermost_match; Z.ltb_to_lt; (lia || reflexivity). }
+ { split; [ Z.zero_bounds; apply qn_small | apply r_small_strong];
+ auto using Z.lt_gt with zarith. }
Qed.
- Context (Hw : forall i, w i = (2 ^ k) ^ Z.of_nat i).
- Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r.
-
- Definition wideadd t1 t2 := fst (Rows.add w 2 t1 t2).
- (* TODO: use this definition once issue #352 is resolved *)
- (* Definition widesub t1 t2 := fst (Rows.sub w 2 t1 t2). *)
- Definition widesub (t1 t2 : list Z) :=
- let t1_0 := hd 0 t1 in
- let t1_1 := hd 0 (tl t1) in
- let t2_0 := hd 0 t2 in
- let t2_1 := hd 0 (tl t2) in
- dlet_nd x0 := Z.sub_get_borrow_full (2^k) t1_0 t2_0 in
- dlet_nd x1 := Z.sub_with_get_borrow_full (2^k) (snd x0) t1_1 t2_1 in
- [fst x0; fst x1].
- Definition widemul := BaseConversion.widemul_inlined k n nout.
-
- Lemma partition_represents x :
- 0 <= x < 2^k*2^k ->
- represents (Partition.partition w 2 x) x.
- Proof.
- intros; cbn. change_weight.
- Z.rewrite_mod_small.
- autorewrite with zsimplify_fast.
- auto using represents_id.
- Qed.
-
- Lemma eval_represents t x :
- represents t x -> Positional.eval w 2 t = x.
- Proof.
- intros; rewrite (represents_eq t x) by assumption.
- cbn. change_weight; push_rep.
- autorewrite with zsimplify. reflexivity.
- Qed.
-
- Ltac wide_op partitions_pf :=
- repeat match goal with
- | _ => rewrite partitions_pf by eauto
- | _ => rewrite partitions_pf by auto with zarith
- | _ => erewrite eval_represents by eauto
- | _ => solve [auto using partition_represents, represents_id]
- end.
-
- Lemma wideadd_represents t1 t2 x y :
- represents t1 x ->
- represents t2 y ->
- 0 <= x + y < 2^k*2^k ->
- represents (wideadd t1 t2) (x + y).
- Proof. intros; cbv [wideadd]. wide_op Rows.add_partitions. Qed.
-
- Lemma widesub_represents t1 t2 x y :
- represents t1 x ->
- represents t2 y ->
- 0 <= x - y < 2^k*2^k ->
- represents (widesub t1 t2) (x - y).
- Proof.
- intros; cbv [widesub Let_In].
- rewrite (represents_eq t1 x) by assumption.
- rewrite (represents_eq t2 y) by assumption.
- cbn [hd tl].
- autorewrite with to_div_mod.
- pull_Zmod.
- match goal with |- represents [?m; ?d] ?x =>
- replace d with (x / 2 ^ k); [solve [auto using represents_id] |] end.
- rewrite <-(Z.mod_small ((x - y) / 2^k) (2^k)) by (split; try apply Z.div_lt_upper_bound; Z.zero_bounds).
- f_equal.
- transitivity ((x mod 2^k - y mod 2^k + 2^k * (x / 2 ^ k) - 2^k * (y / 2^k)) / 2^k). {
- rewrite (Z.div_mod x (2^k)) at 1 by auto using Z.pow_nonzero with omega.
- rewrite (Z.div_mod y (2^k)) at 1 by auto using Z.pow_nonzero with omega.
- f_equal. ring. }
- autorewrite with zsimplify.
- ring.
- Qed.
- (* Works with Rows.sub-based widesub definition
- Proof. intros; cbv [widesub]. wide_op Rows.sub_partitions. Qed.
- *)
-
- (* TODO: MOVE Equivlalent Keys decl to Arithmetic? *)
- Declare Equivalent Keys BaseConversion.widemul BaseConversion.widemul_inlined.
- Lemma widemul_represents x y :
- 0 <= x < 2^k ->
- 0 <= y < 2^k ->
- represents (widemul x y) (x * y).
- Proof.
- intros; cbv [widemul].
- assert (0 <= x * y < 2^k*2^k) by auto with zarith.
- wide_op BaseConversion.widemul_correct.
- Qed.
-
- Definition mul_high (a b : list Z) a0b1 : list Z :=
- dlet_nd a0b0 := widemul (low a) (low b) in
- dlet_nd ab := wideadd [high a0b0; high b] [low b; 0] in
- wideadd ab [a0b1; 0].
-
- Lemma mul_high_idea d a b a0 a1 b0 b1 :
- d <> 0 ->
- a = d * a1 + a0 ->
- b = d * b1 + b0 ->
- (a * b) / d = a0 * b0 / d + d * a1 * b1 + a1 * b0 + a0 * b1.
- Proof.
- intros. subst a b. autorewrite with push_Zmul.
- ring_simplify_subterms. rewrite Z.pow_2_r.
- rewrite Z.div_add_exact by (push_Zmod; autorewrite with zsimplify; omega).
- repeat match goal with
- | |- context [d * ?a * ?b * ?c] =>
- replace (d * a * b * c) with (a * b * c * d) by ring
- | |- context [d * ?a * ?b] =>
- replace (d * a * b) with (a * b * d) by ring
- end.
- rewrite !Z.div_add by omega.
- autorewrite with zsimplify.
- rewrite (Z.mul_comm a0 b0).
- ring_simplify. ring.
- Qed.
+ End Generic.
- Lemma represents_trans t x y:
- represents t y -> y = x ->
- represents t x.
- Proof. congruence. Qed.
+ (* Non-standard implementation -- uses specialized instructions and b=2 *)
+ Module Fancy.
+ Section Fancy.
+ Context (M mu width k : Z)
+ (sz : nat) (sz_nz : sz <> 0%nat)
+ (width_ok : 1 < width)
+ (k_pos : 0 < k) (* this can be inferred from other arguments but is easier to put here for tactics *)
+ (k_eq : k = width * Z.of_nat sz).
+ (* sz = 1, width = k = 256 *)
+ Local Notation w := (UniformWeight.uweight width). Local Notation eval := (Positional.eval w).
+ Context (mut Mt : list Z) (mut_correct : mut = partition w (sz+1) mu) (Mt_correct : Mt = partition w sz M).
+ Context (mu_eq : mu = 2 ^ (2 * k) / M) (M_range : 2^(k-1) < M < 2^k).
+
+ Local Lemma wprops : @weight_properties w. Proof. apply UniformWeight.uwprops; auto with lia. Qed.
+ Local Hint Resolve wprops.
+ Hint Rewrite mut_correct Mt_correct : pull_partition.
+
+ Lemma w_eq_2k : w sz = 2^k. Admitted. (* TODO *)
+ Lemma mu_range : 2^k <= mu < 2^(k+1).
+ Admitted. (* TODO *)
+ Lemma M_range' : 0 <= M < w sz. (* more convenient form, especially for mod_small *)
+ Proof. assert (0 <= 2 ^ (k-1)) by Z.zero_bounds. pose proof w_eq_2k; lia. Qed.
+
+ Definition shiftr' (m : nat) (t : list Z) (n : Z) : list Z :=
+ map (fun i => Z.rshi (2^width) (nth_default 0 t (S i)) (nth_default 0 t i) n) (seq 0 m).
+
+ Definition shiftr (m : nat) (t : list Z) (n : Z) : list Z :=
+ (* if width <= n, drop limbs first *)
+ if dec (width <= n)
+ then shiftr' m (skipn (Z.to_nat (n / width)) t) (n mod width)
+ else shiftr' m t n.
+
+ Definition wideadd t1 t2 := fst (Rows.add w (sz*2) t1 t2).
+ Definition widesub t1 t2 := fst (Rows.sub w (sz*2) t1 t2).
+ Definition widemul := BaseConversion.widemul_inlined width sz 2.
+ (* widemul_inlined takes the following argument order : (width of limbs in input) (# limbs in input) (# parts to split each limb into before multiplying) *)
+
+ Definition fill (n : nat) (a : list Z) := a ++ Positional.zeros (n - length a).
+ Definition low : list Z -> list Z := firstn sz.
+ Definition high : list Z -> list Z := skipn sz.
+ Definition mul_high (a b : list Z) a0b1 : list Z :=
+ dlet_nd a0b0 := widemul (low a) (low b) in
+ dlet_nd ab := wideadd (high a0b0 ++ high b) (fill (sz*2) (low b)) in
+ wideadd ab a0b1.
+
+ (* select based on the most significant bit of xHigh *)
+ Definition muSelect xt :=
+ let xHigh := nth_default 0 xt (sz*2 - 1) in
+ Positional.select (Z.cc_m (2 ^ width) xHigh) (repeat 0 sz) (low mut).
+
+ Definition cond_sub (a y : list Z) : list Z :=
+ let cond := Z.cc_l (nth_default 0 (high a) 0) in (* a[k] = least significant bit of (high a) *)
+ dlet_nd maybe_y := Positional.select cond (repeat 0 sz) y in
+ dlet_nd diff := Rows.sub w sz (low a) maybe_y in (* (a mod (w sz) - y) mod (w sz)) = (a - y) mod (w sz); since we know a - y is < w sz this is okay by mod_small *)
+ fst diff.
+
+ Definition cond_subM x :=
+ if Nat.eq_dec sz 1
+ then [Z.add_modulo (nth_default 0 x 0) 0 M] (* use the special instruction if we can *)
+ else Rows.conditional_sub w sz x Mt.
+
+ Definition q1 (xt : list Z) := shiftr (sz+1) xt (k - 1).
+
+ Definition q3 (xt q1t : list Z) :=
+ dlet_nd muSelect := muSelect xt in (* make sure muSelect is not inlined in the output *)
+ dlet_nd twoq := mul_high (fill (sz*2) mut) (fill (sz*2) q1t) (fill (sz*2) muSelect) in
+ shiftr (sz+1) twoq 1.
+
+ Definition r (xt q3t : list Z) :=
+ dlet_nd r2 := widemul (low q3t) Mt in
+ widesub xt r2.
+
+ Definition final_reduce (rt : list Z) :=
+ dlet_nd rt := cond_sub rt Mt in
+ cond_subM rt.
+
+ Section Proofs.
+
+ (* TODO: move *)
+ Lemma divides_pow_le b n m : 0 <= n <= m -> (b ^ n | b ^ m).
+ Proof.
+ intros. replace m with (n + (m - n)) by ring.
+ rewrite Z.pow_add_r by lia.
+ apply Z.divide_factor_l.
+ Qed.
- Lemma represents_add x y :
- 0 <= x < 2 ^ k ->
- 0 <= y < 2 ^ k ->
- represents [x;y] (x + 2^k*y).
- Proof.
- intros; cbv [represents]; autorewrite with zsimplify.
- repeat split; (reflexivity || nia).
- Qed.
+ (* TODO: move to UW or Weight *)
+ Lemma mod_mod_weight a i j :
+ (i <= j)%nat -> (a mod (w j)) mod (w i) = a mod (w i).
+ Proof.
+ intros. rewrite <-Znumtheory.Zmod_div_mod; auto; [ ].
+ rewrite !UniformWeight.uweight_eq_alt'.
+ apply divides_pow_le. nia.
+ Qed.
- Lemma represents_small x :
- 0 <= x < 2^k ->
- represents [x; 0] x.
- Proof.
- intros.
- eapply represents_trans.
- { eauto using represents_add with zarith. }
- { ring. }
- Qed.
+ Lemma shiftr'_correct m n :
+ forall t tn,
+ (m <= tn)%nat -> 0 <= t < w tn -> 0 <= n < width ->
+ shiftr' m (partition w tn t) n = partition w m (t / 2 ^ n).
+ Proof.
+ cbv [shiftr']. induction m; intros; [ reflexivity | ].
+ rewrite !partition_step, seq_snoc.
+ autorewrite with distr_length natsimplify push_map push_nth_default.
+ rewrite IHm by auto with zarith.
+ rewrite Z.rshi_correct, UniformWeight.uweight_S by auto with zarith.
+ rewrite <-Z.mod_pull_div by auto with zarith.
+ destruct (Nat.eq_dec (S m) tn); [subst tn | ]; rewrite !nth_default_partition by omega.
+ { rewrite nth_default_out_of_bounds by distr_length.
+ autorewrite with zsimplify. Z.rewrite_mod_small.
+ rewrite Z.div_div_comm by auto with zarith; reflexivity. }
+ { (* oh god the worst *)
+ rewrite !UniformWeight.uweight_S by auto with zarith.
+ rewrite <-!Z.mod_pull_div by auto with zarith.
+ rewrite Z.mod_pull_div with (b:=2^n) by auto with zarith.
+ rewrite <-Z.div_div by auto with zarith.
+ rewrite (Z.div_div_comm _ (2^width) (w m)) by auto with zarith.
+ rewrite Z.mod_pull_div with (b:=2^width) by auto with zarith.
+ rewrite Z.mul_div_eq'; auto using Z.lt_gt with zarith.
+ rewrite Z.rem_mul_r with (b:=2^width) (c:=2^width) by auto with zarith.
+ autorewrite with zsimplify.
+ rewrite <-Z.rem_mul_r by auto with zarith.
+ rewrite !Z.mod_pull_div by auto with zarith.
+ rewrite <-Znumtheory.Zmod_div_mod by
+ (try solve [Z.zero_bounds];
+ rewrite <-!Z.pow_add_r by auto with zarith;
+ auto using Z.mul_divide_mono_r, divides_pow_le with zarith).
+ rewrite Z.div_div_comm by auto with zarith.
+ repeat (f_equal; try ring). }
+ Qed.
+ Lemma shiftr_correct m n :
+ forall t tn,
+ (Z.to_nat (n / width) <= tn)%nat ->
+ (m <= tn - Z.to_nat (n / width))%nat -> 0 <= t < w tn -> 0 <= n ->
+ shiftr m (partition w tn t) n = partition w m (t / 2 ^ n).
+ Proof.
+ cbv [shiftr]; intros.
+ break_innermost_match; [ | solve [auto using shiftr'_correct with zarith] ].
+ pose proof (Z.mod_pos_bound n width ltac:(omega)).
+ assert (t / 2 ^ (n - n mod width) < w (tn - Z.to_nat (n / width))).
+ { apply Z.div_lt_upper_bound; [solve [Z.zero_bounds] | ].
+ rewrite UniformWeight.uweight_eq_alt' in *.
+ rewrite <-Z.pow_add_r, Nat2Z.inj_sub, Z2Nat.id, <-Z.mul_div_eq by auto with zarith.
+ autorewrite with push_Zmul zsimplify. auto with zarith. }
+ repeat match goal with
+ | _ => progress rewrite ?UniformWeight.uweight_skipn_partition, ?UniformWeight.uweight_eq_alt' by auto with lia
+ | _ => rewrite Z2Nat.id by Z.zero_bounds
+ | _ => rewrite Z.mul_div_eq_full by auto with zarith
+ | _ => rewrite shiftr'_correct by auto with zarith
+ | _ => progress rewrite ?Z.div_div, <-?Z.pow_add_r by auto with zarith
+ end.
+ autorewrite with zsimplify. reflexivity.
+ Qed.
+ Hint Rewrite shiftr_correct using (solve [auto with lia]) : pull_partition.
+
+ (* 2 ^ (k + 1) bits fit in sz + 1 limbs because we know 2^k bits fit in sz and 1 <= width *)
+ Lemma q1_correct x :
+ 0 <= x < w (sz * 2) ->
+ q1 (partition w (sz*2)%nat x) = partition w (sz+1)%nat (x / 2 ^ (k - 1)).
+ Proof.
+ cbv [q1]; intros. assert (1 <= Z.of_nat sz) by (destruct sz; lia).
+ assert (Z.to_nat ((k-1) / width) < sz)%nat. {
+ subst k. rewrite <-Z.add_opp_r. autorewrite with zsimplify.
+ apply Nat2Z.inj_lt. rewrite Z2Nat.id by lia. lia. }
+ assert (0 <= k - 1) by nia.
+ autorewrite with pull_partition. reflexivity.
+ Qed.
- Lemma mul_high_represents a b x y a0b1 :
- represents a x ->
- represents b y ->
- 2^k <= x < 2^(k+1) ->
- 0 <= y < 2^(k+1) ->
- a0b1 = x mod 2^k * (y / 2^k) ->
- represents (mul_high a b a0b1) ((x * y) / 2^k).
- Proof.
- cbv [mul_high Let_In]; rewrite Z.pow_add_r, Z.pow_1_r by omega; intros.
- assert (4 <= 2 ^ k) by (transitivity (Z.pow 2 2); auto with zarith).
- assert (0 <= x * y / 2^k < 2^k*2^k) by (Z.div_mod_to_quot_rem_in_goal; nia).
-
- rewrite mul_high_idea with (a:=x) (b:=y) (a0 := low a) (a1 := high a) (b0 := low b) (b1 := high b) in *
- by (push_rep; Z.div_mod_to_quot_rem_in_goal; lia).
-
- push_rep. subst a0b1.
- assert (y / 2 ^ k < 2) by (apply Z.div_lt_upper_bound; omega).
- replace (x / 2 ^ k) with 1 in * by (rewrite Z.div_between_1; lia).
- autorewrite with zsimplify_fast in *.
-
- eapply represents_trans.
- { repeat (apply wideadd_represents;
- [ | apply represents_small; Z.div_mod_to_quot_rem_in_goal; nia| ]).
- erewrite represents_high; [ | apply widemul_represents; solve [ auto with zarith ] ].
- { apply represents_add; try reflexivity; solve [auto with zarith]. }
- { match goal with H : 0 <= ?x + ?y < ?z |- 0 <= ?x < ?z =>
- split; [ solve [Z.zero_bounds] | ];
- eapply Z.le_lt_trans with (m:= x + y); nia
- end. }
- { omega. } }
- { ring. }
- Qed.
+ Lemma low_correct n a : (sz <= n)%nat -> low (partition w n a) = partition w sz a.
+ Admitted.
+ Lemma high_correct a : high (partition w (sz*2) a) = partition w sz (a / w sz).
+ Admitted.
+ Lemma fill_correct n m a :
+ (n <= m)%nat ->
+ fill m (partition w n a) = partition w m (a mod w n).
+ Admitted.
+ Hint Rewrite low_correct high_correct fill_correct using lia : pull_partition.
+
+ (* TODO: move *)
+ Lemma partition_app n m a b :
+ partition w n a ++ partition w m b = partition w (n+m) (a mod w n + b * w n).
+ Admitted.
+
+ Lemma wideadd_correct a b :
+ wideadd (partition w (sz*2) a) (partition w (sz*2) b) = partition w (sz*2) (a + b).
+ Proof.
+ cbv [wideadd]. rewrite Rows.add_partitions by (distr_length; auto).
+ autorewrite with push_eval.
+ apply partition_eq_mod; auto with zarith.
+ Qed.
+ Lemma widesub_correct a b :
+ widesub (partition w (sz*2) a) (partition w (sz*2) b) = partition w (sz*2) (a - b).
+ Proof.
+ cbv [widesub]. rewrite Rows.sub_partitions by (distr_length; auto).
+ autorewrite with push_eval.
+ apply partition_eq_mod; auto with zarith.
+ Qed.
+ Lemma widemul_correct a b :
+ widemul (partition w sz a) (partition w sz b) = partition w (sz*2) ((a mod w sz) * (b mod w sz)).
+ Proof.
+ cbv [widemul]. rewrite BaseConversion.widemul_inlined_correct; (distr_length; auto).
+ autorewrite with push_eval. reflexivity.
+ Qed.
+ Hint Rewrite widemul_correct widesub_correct wideadd_correct using lia : pull_partition.
+
+ Lemma mul_high_idea d a b a0 a1 b0 b1 :
+ d <> 0 ->
+ a = d * a1 + a0 ->
+ b = d * b1 + b0 ->
+ (a * b) / d = a0 * b0 / d + d * a1 * b1 + a1 * b0 + a0 * b1.
+ Proof.
+ intros. subst a b. autorewrite with push_Zmul.
+ ring_simplify_subterms. rewrite Z.pow_2_r.
+ rewrite Z.div_add_exact by (push_Zmod; autorewrite with zsimplify; omega).
+ repeat match goal with
+ | |- context [d * ?a * ?b * ?c] =>
+ replace (d * a * b * c) with (a * b * c * d) by ring
+ | |- context [d * ?a * ?b] =>
+ replace (d * a * b) with (a * b * d) by ring
+ end.
+ rewrite !Z.div_add by omega.
+ autorewrite with zsimplify.
+ rewrite (Z.mul_comm a0 b0).
+ ring_simplify. ring.
+ Qed.
- Definition cond_sub1 (a : list Z) y : Z :=
- dlet_nd maybe_y := Z.zselect (Z.cc_l (high a)) 0 y in
- dlet_nd diff := Z.sub_get_borrow_full (2^k) (low a) maybe_y in
- fst diff.
+ (* TODO: move these? *)
+ Local Hint Resolve Z.lt_gt : zarith.
+ Hint Rewrite Z.div_add' using solve [auto with zarith] : zsimplify.
+
+ Lemma mul_high_correct a b
+ (Ha : a / w sz = 1)
+ a0b1 (Ha0b1 : a0b1 = a mod w sz * (b / w sz)) :
+ mul_high (partition w (sz*2) a) (partition w (sz*2) b) (partition w (sz*2) a0b1) =
+ partition w (sz*2) (a * b / w sz).
+ Proof.
+ cbv [mul_high Let_In].
+ erewrite mul_high_idea by auto using Z.div_mod with zarith.
+ repeat match goal with
+ | _ => progress autorewrite with pull_partition
+ | _ => progress rewrite ?Ha, ?Ha0b1
+ | _ => rewrite partition_app by auto;
+ replace (sz+sz)%nat with (sz*2)%nat by lia
+ | _ => rewrite Z.mod_pull_div by auto with zarith
+ | _ => progress Z.rewrite_mod_small
+ | _ => f_equal; ring
+ end.
+ Qed.
- Lemma cc_l_only_bit : forall x s, 0 <= x < 2 * s -> Z.cc_l (x / s) = 0 <-> x < s.
- Proof.
- cbv [Z.cc_l]; intros.
- rewrite Z.div_between_0_if by omega.
- break_match; Z.ltb_to_lt; Z.rewrite_mod_small; omega.
- Qed.
+ Lemma muSelect_correct x :
+ 0 <= x < w (sz * 2) ->
+ muSelect (partition w (sz*2) x) = partition w (sz+1) (mu mod (w sz) * (x / 2 ^ (k - 1) / (w sz))).
+ Admitted.
+ Hint Rewrite muSelect_correct using lia : pull_partition.
+
+ (* TODO: move *)
+ Lemma pow_pos_le a b : 0 < a -> 0 < b -> a <= a ^ b.
+ Proof.
+ intros; transitivity (a ^ 1).
+ { rewrite Z.pow_1_r; reflexivity. }
+ { apply Z.pow_le_mono; auto with zarith. }
+ Qed.
+ Hint Resolve pow_pos_le : zarith.
+
+ Lemma q3_correct x (Hx : 0 <= x < w (sz * 2)) q1 (Hq1 : q1 = x / 2 ^ (k - 1)) :
+ q3 (partition w (sz*2) x) (partition w (sz+1) q1) = partition w (sz+1) ((mu*q1) / 2 ^ (k + 1)).
+ Proof.
+ cbv [q3 Let_In]. intros.
+ assert (0 <= q1 < w (sz+1)) by admit.
+ autorewrite with pull_partition.
+ rewrite ?fill_correct, ?muSelect_correct by lia.
+ rewrite <-Hq1.
+ rewrite mul_high_correct; [ | | ].
+ 2: { (* mu mod w (sz + 1) / w sz = 1 *)
+ (* highest bit of mu is 1 -- do this by rewriting mod small using bounds on mu, then take axiom for highest bit *)
+ admit. }
+ 2: {
+ rewrite mod_mod_weight by lia.
+ push_Zmod. rewrite Z.mod_pull_div by auto with zarith.
+ assert (0 < w sz) by auto with zarith.
+ assert (w (sz+1) <= w (sz+1) * w sz) by (apply Z.le_mul_diag_r; auto with zarith).
+ Z.rewrite_mod_small.
+ autorewrite with natsimplify in *.
+ rewrite UniformWeight.uweight_S in * by auto with zarith.
+ assert (0 <= q1 / w sz < 2^width).
+ { split; [ solve [Z.zero_bounds] | ].
+ apply Z.div_lt_upper_bound; auto with nia. }
+ pose proof (Z.mod_pos_bound mu (w sz) ltac:(auto)).
+ apply Z.mod_small. nia. }
+ rewrite shiftr_correct; auto with zarith.
+ 2: admit. (* (Z.to_nat (1 / width) <= sz * 2)%nat *)
+ 2: admit. (* (sz + 1 <= sz * 2 - Z.to_nat (1 / width))%nat *)
+ 2: admit. (* 0 <= mu mod w (sz + 1) * (q1 mod w (sz + 1)) / w sz < w (sz * 2) *)
+ (* TODO: see if things are easier using div_mod_to_quot_rem *)
+ Z.rewrite_mod_small.
+ pose proof mu_range.
+ replace (2 ^ (k+1)) with (w sz * 2) in * by admit.
+ assert (0 < 2 ^ k) by (Z.zero_bounds; nia).
+ assert (2 * w sz <= w (sz + 1)).
+ { autorewrite with natsimplify. rewrite UniformWeight.uweight_S by auto with zarith.
+ apply Z.mul_le_mono_nonneg_r; auto with zarith. }
+ Z.rewrite_mod_small.
+ rewrite Z.div_div by auto with zarith.
+ repeat (f_equal; try ring).
+ Admitted.
+
+ Lemma cond_sub_correct a b :
+ cond_sub (partition w (sz*2) a) (partition w sz b)
+ = partition w sz (if dec ((a / w sz) mod 2 = 0)
+ then a
+ else a - b).
+ Proof.
+ cbv [cond_sub Let_In Z.cc_l]. autorewrite with pull_partition.
+ rewrite nth_default_partition by lia.
+ rewrite weight_0 by auto. autorewrite with zsimplify_fast.
+ rewrite UniformWeight.uweight_eq_alt' with (n:=1%nat). autorewrite with push_Zof_nat zsimplify.
+ rewrite <-Znumtheory.Zmod_div_mod by auto using Zpow_facts.Zpower_divide with zarith.
+ rewrite Positional.select_eq with (n:=sz) by (distr_length; apply w).
+ rewrite Rows.sub_partitions by (break_innermost_match; distr_length; auto).
+ break_innermost_match; autorewrite with push_eval zsimplify_fast;
+ apply partition_eq_mod; auto with zarith.
+ Qed.
+ Hint Rewrite cond_sub_correct : pull_partition.
+ Lemma cond_subM_correct a :
+ cond_subM (partition w sz a)
+ = partition w sz (if dec (a mod w sz < M)
+ then a
+ else a - M).
+ Proof.
+ Admitted.
+ Hint Rewrite cond_subM_correct : pull_partition.
+
+ (* TODO: move this with barrett Z lemmas & prove *)
+ Lemma q3_range x :
+ 0 <= x < 2 ^ (2 * k) ->
+ 0 <= (mu * (x / 2 ^ (k-1))) / 2 ^ (k+1) < 2^k.
+ Proof using Type.
+ Admitted.
- Lemma cond_sub1_correct a x y :
- represents a x ->
- 0 <= x < 2 * y ->
- 0 <= y < 2 ^ k ->
- cond_sub1 a y = if (x <? 2 ^ k) then x else x - y.
- Proof.
- intros; cbv [cond_sub1 Let_In]. rewrite Z.zselect_correct. push_rep.
- break_match; Z.ltb_to_lt; rewrite cc_l_only_bit in *; try omega;
- autorewrite with zsimplify_fast to_div_mod pull_Zmod; auto with zarith.
- Qed.
+ Lemma w_eq_22k : w (sz * 2) = 2 ^ (2 * k).
+ Proof.
+ replace (sz * 2)%nat with (sz + sz)%nat by lia.
+ rewrite UniformWeight.uweight_sum_indices, w_eq_2k, <-Z.pow_add_r by lia.
+ f_equal; lia.
+ Qed.
- Definition cond_sub2 x y := Z.add_modulo x 0 y.
- Lemma cond_sub2_correct x y :
- cond_sub2 x y = if (x <? y) then x else x - y.
- Proof.
- cbv [cond_sub2]. rewrite Z.add_modulo_correct.
- autorewrite with zsimplify_fast. break_match; Z.ltb_to_lt; omega.
- Qed.
+ Lemma r_correct x q1 q3 :
+ 0 <= x < w (sz * 2) ->
+ q1 = x / 2 ^ (k - 1) ->
+ q3 = (mu*q1) / 2 ^ (k + 1) ->
+ r (partition w (sz*2) x) (partition w (sz+1) q3) = partition w (sz*2) (x - q3 * M).
+ Proof.
+ cbv [r Let_In]; intros. pose proof w_eq_22k.
+ autorewrite with pull_partition.
+ assert (0 <= q3 < w sz).
+ { subst q3 q1. rewrite w_eq_2k. apply q3_range; lia. }
+ pose proof M_range'. Z.rewrite_mod_small.
+ autorewrite with zsimplify. reflexivity.
+ Qed.
- Section Defn.
- Context (xLow xHigh : Z) (xLow_bounds : 0 <= xLow < 2^k) (xHigh_bounds : 0 <= xHigh < M).
- Let xt := [xLow; xHigh].
- Let x := xLow + 2^k * xHigh.
-
- Lemma x_rep : represents xt x.
- Proof. cbv [represents]; subst xt x; autorewrite with cancel_pair zsimplify; repeat split; nia. Qed.
-
- Lemma x_bounds : 0 <= x < M * 2 ^ k.
- Proof. subst x; nia. Qed.
-
- Definition muSelect := Z.zselect (Z.cc_m (2 ^ k) xHigh) 0 muLow.
-
- Local Hint Resolve Z.div_nonneg Z.div_lt_upper_bound.
- Local Hint Resolve shiftr_represents mul_high_represents widemul_represents widesub_represents
- cond_sub1_correct cond_sub2_correct represents_low represents_add.
-
- Lemma muSelect_correct :
- muSelect = (2 ^ (2 * k) / M) mod 2 ^ k * ((x / 2 ^ (k - 1)) / 2 ^ k).
- Proof.
- (* assertions to help arith tactics *)
- pose proof x_bounds.
- assert (2^k * M < 2 ^ (2*k)) by (rewrite <-Z.add_diag, Z.pow_add_r; nia).
- assert (0 <= x / (2 ^ k * (2 ^ k / 2)) < 2) by (Z.div_mod_to_quot_rem_in_goal; auto with nia).
- assert (0 < 2 ^ k / 2) by Z.zero_bounds.
- assert (2 ^ (k - 1) <> 0) by auto with zarith.
- assert (2 < 2 ^ k) by (eapply Z.le_lt_trans with (m:=2 ^ 1); auto with zarith).
-
- cbv [muSelect]. rewrite <-muLow_eq.
- rewrite Z.zselect_correct, Z.cc_m_eq by auto with zarith.
- replace xHigh with (x / 2^k) by (subst x; autorewrite with zsimplify; lia).
- autorewrite with pull_Zdiv push_Zpow.
- rewrite (Z.mul_comm (2 ^ k / 2)).
- break_match; [ ring | ].
- match goal with H : 0 <= ?x < 2, H' : ?x <> 0 |- _ => replace x with 1 by omega end.
- autorewrite with zsimplify; reflexivity.
- Qed.
+ Lemma final_reduce_correct r:
+ 0 <= r < 2 * M ->
+ final_reduce (partition w (sz*2) r) = partition w sz (if dec (r < M) then r else r - M).
+ Proof.
+ intros; pose proof M_range'.
+ cbv [final_reduce Let_In]; intros; autorewrite with pull_partition.
+ apply partition_eq_mod; auto with zarith.
+ replace (2 ^ (k + 1)) with (2 * w sz) in *
+ by (rewrite Z.pow_add_r by auto with zarith; autorewrite with zsimplify; rewrite <-w_eq_2k; lia).
+ pose proof (Z.mod_pos_bound r (w sz) ltac:(auto with zarith)).
+ remember ((r / w sz) mod 2) as rk.
+ assert (rk = 1 -> r > M) by (subst rk; Z.rewrite_mod_small; intros;
+ rewrite (Z.div_mod r (w sz)) by auto with zarith; nia).
+ replace r with (w sz * rk + r mod w sz) by
+ (subst rk; Z.rewrite_mod_small; rewrite <-Z.div_mod; auto with zarith).
+ subst rk; rewrite Zmod_odd in *.
+ repeat match goal with
+ | _ => lia
+ | H : _ |- _ => rewrite Z.mod_small in H by auto with zarith lia
+ | _ => progress (push_Zmod; pull_Zmod); autorewrite with zsimplify_fast
+ | _ => break_innermost_match_step
+ end.
+ Qed.
+ End Proofs.
+
+ Section Def.
+ Context (sz_eq_1 : sz = 1%nat). (* this is needed to get rid of branches in the templates; a different definition would be needed for sizes other than 1, but would be able to use the same proofs. *)
+ Local Hint Resolve q1_correct q3_correct r_correct final_reduce_correct.
+
+ (* muselect relies on an initially-set flag, so pull it out of q3 *)
+ Definition fancy_reduce_muSelect_first xt :=
+ dlet_nd muSelect := muSelect xt in
+ dlet_nd q1t := q1 xt in
+ dlet_nd twoq := mul_high (fill (sz * 2) mut) (fill (sz * 2) q1t) (fill (sz * 2) muSelect) in
+ dlet_nd q3t := shiftr (sz+1) twoq 1 in
+ dlet_nd rt := r xt q3t in
+ final_reduce rt.
+
+ Lemma fancy_reduce_muSelect_first_correct x :
+ 0 <= x < M * 2^k ->
+ 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu ->
+ fancy_reduce_muSelect_first (partition w (sz*2) x) = partition w sz (x mod M).
+ Proof.
+ intros. pose proof w_eq_22k.
+ erewrite <-reduce_correct with (b:=2) (k:=k) (mu:=mu) by
+ (eauto with nia; intros; try rewrite q3'_correct; eauto with nia ).
+ reflexivity.
+ Qed.
+
+
+ (*
+ Definition fancy_reduce xLow xHigh :=
+ dlet_nd muSelect : list Z := Positional.select (Z.cc_m (2 ^ width) xHigh) (repeat 0 sz) (low mut) in
+ dlet_nd q1t := [Z.rshi (2 ^ width) xHigh xLow (k - 1); Z.rshi (2 ^ width) 0 xHigh (k - 1)] in
+ dlet_nd a0b0 : list Z := widemul (low mut) (low q1t) in
+ dlet_nd ab : list Z := wideadd [high a0b0; high q1t] [low q1t; 0] in
+ dlet_nd a0 := wideadd ab [muSelect;0] in
+ dlet_nd q3t := [Z.rshi (2 ^ width) (nth_default 0 a0 1) (nth_default 0 a0 0) 1; Z.rshi (2 ^ width) (nth_default 0 a0 2) (nth_default 0 a0 1) 1] in
+ dlet_nd r2 : list Z := widemul (low q3t) Mt in
+ dlet_nd rt : list Z := widesub [xLow; xHigh] r2 in
+ dlet_nd maybe_M : list Z := Positional.select (Z.cc_l (nth_default 0 (high rt) 0)) (repeat 0 sz) Mt in
+ dlet_nd diff : list Z * Z := Rows.sub w sz (low rt) maybe_M in
+ dlet_nd rt0 :=fst diff in
+ Z.add_modulo (nth_default 0 rt0 0) 0 M.
+*)
+ (* This simplifies the shiftr inside q3, specializing it to the # limbs so we don't end up with maps in the final expression *)
+ Derive q3'
+ SuchThat (forall xt q1t, q3' xt q1t = q3 xt q1t)
+ As q3'_correct.
+ Proof.
+ intros. cbv [q3]. rewrite sz_eq_1. autorewrite with natsimplify.
+ cbv [shiftr shiftr' seq]. break_match; try lia; [ ].
+ rewrite Proper_Let_In_nd_changebody; [ | reflexivity | repeat intro ].
+ 2 : apply Proper_Let_In_nd_changebody;
+ [ reflexivity | repeat intro; autorewrite with push_map push_nth_default; reflexivity ].
+ subst q3'; reflexivity.
+ Qed.
- Lemma mu_rep : represents [muLow; 1] (2 ^ (2 * k) / M).
- Proof. rewrite <-muLow_eq. eapply represents_trans; auto with zarith. Qed.
+ Derive fancy_reduce'
+ SuchThat (
+ forall x,
+ 0 <= x < M * 2^k ->
+ 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu ->
+ fancy_reduce' (partition w (sz*2) x) = partition w sz (x mod M))
+ As fancy_reduce'_correct.
+ Proof.
+ intros. assert (k = width) as width_eq_k by nia.
+ erewrite <-fancy_reduce_muSelect_first_correct by nia.
+ cbv [fancy_reduce_muSelect_first q1 q3 shiftr final_reduce cond_subM].
+ break_match; try solve [exfalso; lia].
+ match goal with |- ?g ?x = ?rhs =>
+ let f := (match (eval pattern x in rhs) with ?f _ => f end) in
+ assert (f = g); subst fancy_reduce'; reflexivity
+ end.
+ Qed.
- Derive barrett_reduce
- SuchThat (barrett_reduce = x mod M)
- As barrett_reduce_correct.
- Proof.
- erewrite <-reduce_correct with (rep:=represents) (muSelect:=muSelect) (k0:=k) (mut:=[muLow;1]) (xt0:=xt)
- by (auto using x_bounds, muSelect_correct, x_rep, mu_rep; omega).
- subst barrett_reduce. reflexivity.
- Qed.
- End Defn.
- End BarrettReduction.
+ (* TODO: Can't deal with input by shifting and stuff! *)
+ Derive fancy_reduce
+ SuchThat (
+ forall xLow xHigh,
+ 0 <= xLow < 2 ^ k ->
+ 0 <= xHigh < M ->
+ 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu ->
+ fancy_reduce xLow xHigh = (xLow + 2^k * xHigh) mod M)
+ As fancy_reduce_correct.
+ Proof.
+ intros. pose proof w_eq_22k. assert (k = width) as width_eq_k by nia.
+ replace ((xLow + 2^k * xHigh) mod M) with (hd 0 (partition w sz ((xLow + 2^k * xHigh) mod M))).
+ 2 : { rewrite sz_eq_1. rewrite width_eq_k in *. cbv [Partition.partition map seq hd].
+ rewrite !UniformWeight.uweight_S, !weight_0 by auto with zarith lia.
+ autorewrite with zsimplify.
+ assert (0 < 2 ^ (width - 1)) by Z.zero_bounds.
+ pose proof (Z.mod_pos_bound (xLow + 2^width * xHigh) M ltac:(lia)).
+ autorewrite with zsimplify. reflexivity. }
+ rewrite <-fancy_reduce'_correct by nia.
+ replace (partition w (sz*2) (xLow + 2^k * xHigh)) with [xLow; xHigh].
+ 2 : { replace (sz * 2)%nat with 2%nat by (subst sz; lia).
+ rewrite width_eq_k in *. cbv [Partition.partition map seq].
+ rewrite !UniformWeight.uweight_S, !weight_0 by auto with zarith lia.
+ autorewrite with zsimplify.
+ rewrite <-Z.mod_pull_div by Z.zero_bounds.
+ autorewrite with zsimplify. reflexivity. }
+ cbv [fancy_reduce' reduce q1 q3 shiftr final_reduce cond_subM].
+ autorewrite with natsimplify. (* TODO: maybe need to get rid of hd 0? *)
+ cbv [shiftr' seq]. autorewrite with push_map push_nth_default.
+ subst fancy_reduce; reflexivity.
+ Qed.
+ (* TODO: decide which version to use
+ Derive fancy_reduce
+ SuchThat (
+ forall x,
+ 0 <= x < M * 2 ^ k ->
+ 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu ->
+ fancy_reduce (partition w (sz*2) x) = partition w sz (x mod M))
+ As fancy_reduce_correct.
+ Proof.
+ intros. pose proof w_eq_22k.
+ erewrite <-reduce_correct with (b:=2) (k:=k) (mu:=mu) by eauto with lia.
+ assert (width = k) by nia.
+ cbv [reduce q1 q3 shiftr final_reduce cond_subM].
+ break_match; try solve [exfalso; lia].
+ match goal with |- ?g ?x = ?rhs =>
+ let f := (match (eval pattern x in rhs) with ?f _ => f end) in
+ assert (f = g); subst fancy_reduce; reflexivity
+ end.
+ Qed.
+ *)
+ End Def.
+ End Fancy.
+ End Fancy.
End BarrettReduction.
Module MontgomeryReduction.
@@ -5454,14 +5473,13 @@ Module MontgomeryReduction.
Context (Zlog2R : Z) .
Let w : nat -> Z := weight Zlog2R 1.
Context (n:nat) (Hn_nz: n <> 0%nat) (n_good : Zlog2R mod Z.of_nat n = 0).
- Context (R_big_enough : n <= Zlog2R)
+ Context (R_big_enough : 2 <= Zlog2R)
(R_two_pow : 2^Zlog2R = R).
Let w_mul : nat -> Z := weight (Zlog2R / n) 1.
- Context (nout : nat) (Hnout : nout = 2%nat).
Definition montred' (lo hi : Z) :=
- dlet_nd y := nth_default 0 (BaseConversion.widemul_inlined Zlog2R n nout lo N') 0 in
- dlet_nd t1_t2 := (BaseConversion.widemul_inlined_reverse Zlog2R n nout N y) in
+ dlet_nd y := nth_default 0 (BaseConversion.widemul_inlined Zlog2R 1 2 [lo] [N']) 0 in
+ dlet_nd t1_t2 := (BaseConversion.widemul_inlined_reverse Zlog2R 1 2 [N] [y]) in
dlet_nd sum_carry := Rows.add (weight Zlog2R 1) 2 [lo;hi] t1_t2 in
dlet_nd y' := Z.zselect (snd sum_carry) 0 N in
dlet_nd lo''_carry := Z.sub_get_borrow_full R (nth_default 0 (fst sum_carry) 1) y' in
@@ -5489,10 +5507,15 @@ Module MontgomeryReduction.
Local Lemma eval2 x y : Positional.eval w 2 [x;y] = x + R * y.
Proof. cbn. change_weight. ring. Qed.
+ Local Lemma eval1 x : Positional.eval w 1 [x] = x.
+ Proof. cbn. change_weight. ring. Qed.
Hint Rewrite BaseConversion.widemul_inlined_reverse_correct BaseConversion.widemul_inlined_correct
using (autorewrite with widemul push_nth_default; solve [solve_range]) : widemul.
+ (* TODO: move *)
+ Hint Rewrite Nat.mul_1_l : natsimplify.
+
Lemma montred'_eq lo hi T (HT_range: 0 <= T < R * N)
(Hlo: lo = T mod R) (Hhi: hi = T / R):
montred' lo hi = reduce_via_partial N R N' T.
@@ -5501,13 +5524,15 @@ Module MontgomeryReduction.
cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In].
rewrite Hlo, Hhi.
assert (0 <= (T mod R) * N' < w 2) by (solve_range).
-
autorewrite with widemul.
rewrite Rows.add_partitions, Rows.add_div by (distr_length; apply wprops; omega).
- rewrite R_two_pow.
- cbv [Partition.partition seq]. rewrite !eval2.
- autorewrite with push_nth_default push_map.
- autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct.
+ (* rewrite R_two_pow. *)
+ cbv [Partition.partition seq].
+ repeat match goal with
+ | _ => progress rewrite ?eval1, ?eval2
+ | _ => progress rewrite ?Z.zselect_correct, ?Z.add_modulo_correct
+ | _ => progress autorewrite with natsimplify push_nth_default push_map to_div_mod
+ end.
change_weight.
(* pull out value before last modular reduction *)
@@ -5515,16 +5540,18 @@ Module MontgomeryReduction.
let P := fresh "H" in assert (x = y) as P; [|rewrite P; reflexivity] end.
autorewrite with zsimplify.
+ Z.rewrite_mod_small.
+ autorewrite with zsimplify.
rewrite (Z.mul_comm (((T mod R) * N') mod R) N) in *.
- break_match; try reflexivity; Z.ltb_to_lt; rewrite Z.div_small_iff in * by omega;
- repeat match goal with
- | _ => progress autorewrite with zsimplify_fast
- | |- context [?x mod (R * R)] =>
- unique pose proof (Z.mod_pos_bound x (R * R));
- try rewrite (Z.mod_small x (R * R)) in * by Z.rewrite_mod_small_solver
- | _ => omega
- | _ => progress Z.rewrite_mod_small
- end.
+ match goal with
+ |- context [(?x - (if dec (?a / ?b = 0) then 0 else ?y)) mod ?m
+ = if (?b <=? ?a) then (?x - ?y) mod ?m else ?x ] =>
+ assert (a / b = 0 <-> a < b) by
+ (rewrite Z.div_between_0_if by (Z.div_mod_to_quot_rem; nia);
+ break_match; Z.ltb_to_lt; lia)
+ end.
+ break_match; Z.ltb_to_lt; try reflexivity; try lia; [ ].
+ autorewrite with zsimplify_fast. Z.rewrite_mod_small. reflexivity.
Qed.
Lemma montred'_correct lo hi T (HT_range: 0 <= T < R * N)
diff --git a/src/Arithmetic/BarrettReduction/Generalized.v b/src/Arithmetic/BarrettReduction/Generalized.v
index c2885bc77..93aa9452f 100644
--- a/src/Arithmetic/BarrettReduction/Generalized.v
+++ b/src/Arithmetic/BarrettReduction/Generalized.v
@@ -218,6 +218,15 @@ Section barrett.
autorewrite with push_Zmul zsimplify zstrip_div.
auto with lia.
Qed.
+
+ Theorem barrett_reduction_small_strong
+ : a mod n = if r <? n then r else r-n.
+ Proof using a a_good a_nonneg a_small b base_good epsilon k k_big_enough k_good m m_good n n_good n_large n_pos n_reasonable offset offset_nonneg q r.
+ pose proof r_small_strong. pose proof qn_small.
+ destruct (r <? n) eqn:Hr; try rewrite Hr; Z.ltb_to_lt; try lia.
+ { symmetry; apply (Zmod_unique a n q); subst r; lia. }
+ { symmetry; apply (Zmod_unique a n (q + 1)); subst r; lia. }
+ Qed.
End StrongerBounds.
End barrett_algorithm.
End barrett.
diff --git a/src/Fancy/Barrett256.v b/src/Fancy/Barrett256.v
index 21da40aa7..0474e6e07 100644
--- a/src/Fancy/Barrett256.v
+++ b/src/Fancy/Barrett256.v
@@ -29,27 +29,13 @@ Module Barrett256.
Lemma barrett_red256_correct :
COperationSpecifications.BarrettReduction.barrett_red_correct machine_wordsize M (expr.Interp (@ident.gen_interp cast_oor) barrett_red256).
Proof.
- apply barrett_red_correct with (n:=2%nat) (nout:=2%nat) (machine_wordsize:=machine_wordsize).
+ apply barrett_red_correct with (machine_wordsize:=machine_wordsize).
{ lazy. reflexivity. }
{ apply barrett_red256_eq. }
Qed.
Definition muLow := Eval lazy in (2 ^ (2 * machine_wordsize) / M) mod (2^machine_wordsize).
- Lemma barrett_reduce_correct_specialized :
- forall (xLow xHigh : Z),
- 0 <= xLow < 2 ^ machine_wordsize ->
- 0 <= xHigh < M ->
- BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 xLow xHigh = (xLow + 2 ^ machine_wordsize * xHigh) mod M.
- Proof.
- intros.
- apply BarrettReduction.barrett_reduce_correct; cbv [machine_wordsize M muLow] in *;
- try omega;
- try match goal with
- | |- context [weight] => intros; cbv [weight]; autorewrite with zsimplify; auto using Z.pow_mul_r with omega
- end; lazy; try split; congruence.
- Qed.
-
Definition barrett_red256_fancy' (xLow xHigh RegMuLow RegMod RegZero error : positive) :=
of_Expr 6%positive
(make_consts [(RegMuLow, muLow); (RegMod, M); (RegZero, 0)])
@@ -136,7 +122,10 @@ Module Barrett256.
| _ => econstructor
end. }
{ cbn. cbv [muLow M].
- repeat (econstructor; [ solve [valid_expr_subgoal] | intros ]).
+ repeat (match goal with
+ | _ => eapply valid_LetInZZ
+ | _ => eapply valid_LetInZ
+ end; [ solve [valid_expr_subgoal] | intros ]).
econstructor. valid_expr_subgoal. }
{ reflexivity. }
Qed.
diff --git a/src/PushButtonSynthesis/BarrettReduction.v b/src/PushButtonSynthesis/BarrettReduction.v
index 265958c09..c0078b117 100644
--- a/src/PushButtonSynthesis/BarrettReduction.v
+++ b/src/PushButtonSynthesis/BarrettReduction.v
@@ -37,8 +37,7 @@ Local Set Keyed Unification. (* needed for making [autorewrite] fast, c.f. COQBU
Local Opaque reified_barrett_red_gen. (* needed for making [autorewrite] not take a very long time *)
Section rbarrett_red.
- Context (k M : Z) (n nout : nat)
- (machine_wordsize : Z).
+ Context (M machine_wordsize : Z).
Let value_range := r[0 ~> (2^machine_wordsize - 1)%Z]%zrange.
Let flag_range := r[0 ~> 1]%zrange.
@@ -70,6 +69,31 @@ Section rbarrett_red.
Qed.
Local Hint Extern 1 => apply fancy_args_good: typeclass_instances. (* This is a kludge *)
+ Lemma mut_correct :
+ 0 < machine_wordsize ->
+ Partition.partition (UniformWeight.uweight machine_wordsize) (1 + 1) (muLow + 2 ^ machine_wordsize) = [muLow; 1].
+ Proof.
+ intros; cbn. subst muLow.
+ assert (0 < 2^machine_wordsize) by ZeroBounds.Z.zero_bounds.
+ pose proof (Z.mod_pos_bound mu (2^machine_wordsize) ltac:(lia)).
+ rewrite !UniformWeight.uweight_S, weight_0; auto using UniformWeight.uwprops with lia.
+ autorewrite with zsimplify.
+ Modulo.push_Zmod. autorewrite with zsimplify. Modulo.pull_Zmod.
+ rewrite <-Modulo.Z.mod_pull_div by lia.
+ autorewrite with zsimplify. RewriteModSmall.Z.rewrite_mod_small.
+ reflexivity.
+ Qed.
+ Lemma Mt_correct :
+ 0 < machine_wordsize ->
+ 2^(machine_wordsize - 1) < M < 2^machine_wordsize ->
+ Partition.partition (UniformWeight.uweight machine_wordsize) 1 M = [M].
+ Proof.
+ intros; cbn. assert (0 < 2^(machine_wordsize-1)) by ZeroBounds.Z.zero_bounds.
+ rewrite !UniformWeight.uweight_S, weight_0; auto using UniformWeight.uwprops with lia.
+ autorewrite with zsimplify. RewriteModSmall.Z.rewrite_mod_small.
+ reflexivity.
+ Qed.
+
(** Note: If you change the name or type signature of this
function, you will need to update the code in CLI.v *)
Definition check_args {T} (res : Pipeline.ErrorT T)
@@ -78,19 +102,11 @@ Section rbarrett_red.
(fun '(b, e) k => if b:bool then Error e else k)
res
[
- ((negb (2 <=? k))%Z, Pipeline.Value_not_ltZ "k < 2" 2 k);
- ((n =? 0)%nat, Pipeline.Values_not_provably_distinctZ "n = 0" (Z.of_nat n) 0);
- ((negb (0 <? M))%Z, Pipeline.Value_not_ltZ "M ≤ 0" 0 M);
- (negb (muLow + 2 ^ k =? 2 ^ (2 * k) / M)%Z, Pipeline.Values_not_provably_equalZ "muLow + 2^k ≠ 2 ^ (2 * k) / M" (muLow + 2^k) (2 ^ (2 * k) / M));
- ((negb (0 <=? muLow))%Z, Pipeline.Value_not_leZ "muLow < 0" 0 muLow);
- ((negb (muLow <? 2 ^ k))%Z, Pipeline.Value_not_ltZ "2 ^ k ≤ muLow" muLow (2^k));
- ((negb (2 ^ (k-1) <? M))%Z, Pipeline.Value_not_ltZ "M ≤ 2^(k-1)" (2^(k-1)) M);
- ((negb (M <? 2 ^ k))%Z, Pipeline.Value_not_ltZ "2 ^ k ≤ M" M (2^k));
- (negb ((2 * (2 ^ (2 * k) mod M) <=? 2 ^ (k + 1) - (muLow + 2 ^ k)))%Z, Pipeline.Value_not_leZ ("(2 * (2 ^ (2 * k) mod M) 2 ^ (k + 1) - (muLow + 2 ^ k)") (2 * (2 ^ (2 * k) mod M)) (2 ^ (k + 1) - (muLow + 2 ^ k)));
- (negb (Z.of_nat n <=? k)%Z, Pipeline.Value_not_leZ "k < n" (Z.of_nat n) k);
- (negb (k =? machine_wordsize)%Z, Pipeline.Values_not_provably_equalZ "k ≠ machine_wordsize" k machine_wordsize);
- (negb (n =? 2)%nat, Pipeline.Values_not_provably_equalZ "n ≠ 2" (Z.of_nat n) 2);
- (negb (nout =? 2)%nat, Pipeline.Values_not_provably_equalZ "nout ≠ 2" (Z.of_nat nout) 2)].
+ ((negb (1 <? machine_wordsize))%Z, Pipeline.Value_not_ltZ "machine_wordsize ≤ 1" 1 machine_wordsize);
+ ((negb (2 ^ (machine_wordsize-1) <? M))%Z, Pipeline.Value_not_ltZ "M ≤ 2^(machine_wordsize-1)" (2^(machine_wordsize-1)) M);
+ ((negb (M <? 2 ^ machine_wordsize))%Z, Pipeline.Value_not_ltZ "2 ^ machine_wordsize ≤ M" M (2^machine_wordsize));
+ ((negb (muLow + 2 ^ machine_wordsize =? ((2 ^ 2) ^ machine_wordsize) / M))%Z, Pipeline.Values_not_provably_equalZ "muLow + 2^machine_wordsize ≠ (2 ^ 2) ^ machine_wordsize) / M" (muLow + 2^machine_wordsize) (((2 ^ 2) ^ machine_wordsize) / M));
+ (negb ((2 * (((2 ^ 2) ^ machine_wordsize) mod M) <=? 2 ^ (machine_wordsize + 1) - (muLow + 2 ^ machine_wordsize)))%Z, Pipeline.Value_not_leZ ("(2 * ((2 ^ 2) ^ machine_wordsize) mod M) 2 ^ (machine_wordsize + 1) - (muLow + 2 ^ machine_wordsize)") (2 * (((2 ^ 2) ^ machine_wordsize) mod M)) (2 ^ (machine_wordsize + 1) - (muLow + 2 ^ machine_wordsize))) ].
Local Arguments Z.mul !_ !_.
Local Ltac use_curve_good_t :=
@@ -107,24 +123,17 @@ Section rbarrett_red.
Context (curve_good : check_args (Success tt) = Success tt).
Lemma use_curve_good
- : 2 <= k
- /\ 0 < M
- /\ muLow + 2 ^ k = 2 ^ (2 * k) / M
- /\ 0 <= muLow < 2 ^ k
- /\ 2 ^ (k - 1) < M < 2 ^ k
- /\ 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2 ^ k)
- /\ n <> 0%nat
- /\ Z.of_nat n <= k
- /\ k = machine_wordsize
- /\ n = 2%nat
- /\ nout = 2%nat.
+ : 1 < machine_wordsize
+ /\ 2 ^ (machine_wordsize - 1) <= M < 2 ^ machine_wordsize
+ /\ muLow + 2 ^ machine_wordsize = (2 ^ 2) ^ machine_wordsize / M
+ /\ 2 ^ (machine_wordsize - 1) < M < 2 ^ machine_wordsize
+ /\ 2 * ((2 ^ 2) ^ machine_wordsize mod M) <= 2 ^ (machine_wordsize + 1) - (muLow + 2 ^ machine_wordsize).
Proof using curve_good.
clear -curve_good.
cbv [check_args fold_right] in curve_good.
break_innermost_match_hyps; try discriminate.
rewrite Bool.negb_false_iff in *.
Z.ltb_to_lt.
- rewrite NPeano.Nat.eqb_neq in *.
intros.
repeat apply conj.
{ use_curve_good_t. }
@@ -134,12 +143,6 @@ Section rbarrett_red.
{ use_curve_good_t. }
{ use_curve_good_t. }
{ use_curve_good_t. }
- { use_curve_good_t. }
- { use_curve_good_t. }
- { use_curve_good_t. }
- { use_curve_good_t. }
- { use_curve_good_t. }
- { use_curve_good_t. }
Qed.
Definition barrett_red
@@ -148,7 +151,12 @@ Section rbarrett_red.
fancy_args (* fancy *)
possible_values
(reified_barrett_red_gen
- @ GallinaReify.Reify machine_wordsize @ GallinaReify.Reify M @ GallinaReify.Reify muLow @ GallinaReify.Reify 2%nat @ GallinaReify.Reify 2%nat)
+ @ GallinaReify.Reify M
+ @ GallinaReify.Reify machine_wordsize
+ @ GallinaReify.Reify machine_wordsize
+ @ GallinaReify.Reify 1%nat
+ @ GallinaReify.Reify [muLow;1]
+ @ GallinaReify.Reify [M])
(bound, (bound, tt))
bound.
@@ -162,24 +170,27 @@ Section rbarrett_red.
Local Ltac solve_barrett_red_preconditions :=
repeat first [ lia
| assumption
+ | match goal with |- ?x = ?x => reflexivity end
| apply use_curve_good
| progress autorewrite with zsimplify
| progress intros
| progress cbv [weight]
+ | rewrite mut_correct
+ | rewrite Mt_correct
| rewrite Z.pow_mul_r by lia ].
Local Strategy -100 [barrett_red]. (* needed for making Qed not take forever *)
Lemma barrett_red_correct res (Hres : barrett_red = Success res)
- : barrett_red_correct k M (expr.Interp (@ident.gen_interp cast_oor) res).
- Proof using k M curve_good.
+ : barrett_red_correct machine_wordsize M (expr.Interp (@ident.gen_interp cast_oor) res).
+ Proof using M curve_good.
cbv [barrett_red_correct]; intros.
- assert (2 <= k) by apply use_curve_good.
- rewrite <-barrett_reduce_correct with (muLow := muLow) (n:=n) (nout:=nout) by solve_barrett_red_preconditions.
+ assert (1 < machine_wordsize) by apply use_curve_good.
+ rewrite <-Fancy.fancy_reduce_correct with (mu := muLow + 2^machine_wordsize) (width:=machine_wordsize) (sz:=1%nat) (mut:=[muLow;1]) (Mt:=[M]) by solve_barrett_red_preconditions.
prove_correctness' ltac:(fun _ => idtac) use_curve_good.
- { congruence. }
- { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower].
- subst k. rewrite Bool.andb_true_iff, !Z.leb_le. lia. }
- { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower].
- subst k. rewrite Bool.andb_true_iff, !Z.leb_le. lia. }
+ { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. rewrite Bool.andb_true_iff, !Z.leb_le. lia. }
+ { cbv [ZRange.type.base.option.is_bounded_by ZRange.type.base.is_bounded_by bound is_bounded_by_bool value_range upper lower]. rewrite Bool.andb_true_iff, !Z.leb_le. lia. }
+ { cbn. econstructor. }
+ { cbn. econstructor. }
+ { cbn. econstructor. }
Qed.
End rbarrett_red.
diff --git a/src/PushButtonSynthesis/BarrettReductionReificationCache.v b/src/PushButtonSynthesis/BarrettReductionReificationCache.v
index 4c538087e..265ada2d2 100644
--- a/src/PushButtonSynthesis/BarrettReductionReificationCache.v
+++ b/src/PushButtonSynthesis/BarrettReductionReificationCache.v
@@ -14,15 +14,15 @@ Local Set Keyed Unification. (* needed for making [autorewrite] fast, c.f. COQBU
Module Export BarrettReduction.
(* all the list operations from for_reification.ident *)
Strategy 100 [length seq repeat combine map flat_map partition app rev fold_right update_nth nth_default ].
- Strategy -10 [barrett_reduce reduce].
+ Strategy -10 [Fancy.fancy_reduce reduce].
Derive reified_barrett_red_gen
- SuchThat (is_reification_of reified_barrett_red_gen barrett_reduce)
+ SuchThat (is_reification_of reified_barrett_red_gen Fancy.fancy_reduce)
As reified_barrett_red_gen_correct.
Proof. Time cache_reify (). Time Qed.
Module Export ReifyHints.
- Hint Extern 1 (_ = _) => apply_cached_reification barrett_reduce (proj1 reified_barrett_red_gen_correct) : reify_cache_gen.
+ Hint Extern 1 (_ = _) => apply_cached_reification Fancy.fancy_reduce (proj1 reified_barrett_red_gen_correct) : reify_cache_gen.
Hint Immediate (proj2 reified_barrett_red_gen_correct) : wf_gen_cache.
Hint Rewrite (proj1 reified_barrett_red_gen_correct) : interp_gen_cache.
End ReifyHints.
diff --git a/src/Rewriter.v b/src/Rewriter.v
index 1aefc35e8..01d7614e3 100644
--- a/src/Rewriter.v
+++ b/src/Rewriter.v
@@ -2437,6 +2437,26 @@ Module Compilers.
; (forall c x,
Z.abs c <= Z.abs max_const_val
-> 'c * x = x * 'c)
+
+ (* transform +- to + *)
+ ; (forall s y x,
+ Z.add_get_carry_full s x (- y)
+ = dlet vb := Z.sub_get_borrow_full s x y in (fst vb, - snd vb))
+ ; (forall s y x,
+ Z.add_get_carry_full s (- y) x
+ = dlet vb := Z.sub_get_borrow_full s x y in (fst vb, - snd vb))
+ ; (forall s y x,
+ Z.add_with_get_carry_full s 0 x (- y)
+ = dlet vb := Z.sub_get_borrow_full s x y in (fst vb, - snd vb))
+ ; (forall s y x,
+ Z.add_with_get_carry_full s 0 (- y) x
+ = dlet vb := Z.sub_get_borrow_full s x y in (fst vb, - snd vb))
+ ; (forall s c y x,
+ Z.add_with_get_carry_full s (- c) (- y) x
+ = dlet vb := Z.sub_with_get_borrow_full s c x y in (fst vb, - snd vb))
+ ; (forall s c y x,
+ Z.add_with_get_carry_full s (- c) x (- y)
+ = dlet vb := Z.sub_with_get_borrow_full s c x y in (fst vb, - snd vb))
]
; reify
[ (* [do_again], so that if one of the arguments is concrete, we automatically get the rewrite rule for [Z_cast] applying to it *)