aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-04-11 17:49:06 +0200
committerGravatar Jade Philipoom <jadep@google.com>2018-04-11 17:49:06 +0200
commit5bcaffd6bcb26d1643484e2551e10cfbd78f2d22 (patch)
treee1fcad007ff990f90da54729499445058b40224c /src
parent0bdfa57b5c24a34f6fafe8a97c1ce6453ce2cd83 (diff)
barrett reduction definition and proof
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v463
1 files changed, 463 insertions, 0 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 71fca4645..4f4eb5683 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -7747,6 +7747,469 @@ fun var : type -> Type =>
End X25519_32.
*)
+
+Require Import Crypto.Arithmetic.BarrettReduction.Generalized.
+Require Import Crypto.Util.ZUtil.Zselect Crypto.Util.ZUtil.AddModulo.
+Require Import Crypto.Util.ZUtil.CC Crypto.Util.ZUtil.Rshi.
+
+Module BarrettReduction.
+ (* TODO : generalize to multi-word and operate on (list Z) instead of T; maybe stop taking ops as context variables *)
+ Section Generic.
+ Context {T} (rep : T -> Z -> Prop)
+ (k : Z) (k_pos : 0 < k)
+ (low : T -> Z)
+ (low_correct : forall a x, rep a x -> low a = x mod 2 ^ k)
+ (shiftr : T -> Z -> T)
+ (shiftr_correct : forall a x n,
+ rep a x ->
+ 0 <= n <= k ->
+ rep (shiftr a n) (x / 2 ^ n))
+ (mul_high : T -> T -> Z -> T)
+ (mul_high_correct : forall a b x y x0y1,
+ rep a x ->
+ rep b y ->
+ 2 ^ k <= x < 2^(k+1) ->
+ 0 <= y < 2^(k+1) ->
+ x0y1 = x mod 2 ^ k * (y / 2 ^ k) ->
+ rep (mul_high a b x0y1) (x * y / 2 ^ k))
+ (mul : Z -> Z -> T)
+ (mul_correct : forall x y,
+ 0 <= x < 2^k ->
+ 0 <= y < 2^k ->
+ rep (mul x y) (x * y))
+ (sub : T -> T -> T)
+ (sub_correct : forall a b x y,
+ rep a x ->
+ rep b y ->
+ 0 <= x - y < 2^k * 2^k ->
+ rep (sub a b) (x - y))
+ (cond_sub1 : T -> Z -> Z)
+ (cond_sub1_correct : forall a x y,
+ rep a x ->
+ 0 <= x < 2 * y ->
+ 0 <= y < 2 ^ k ->
+ cond_sub1 a y = if (x <? 2 ^ k) then x else x - y)
+ (cond_sub2 : Z -> Z -> Z)
+ (cond_sub2_correct : forall x y, cond_sub2 x y = if (x <? y) then x else x - y).
+ Context (xt mut : T) (M muSelect: Z).
+
+ Let mu := 2 ^ (2 * k) / M.
+ Context x (mu_rep : rep mut mu) (x_rep : rep xt x).
+ Context (M_nz : 0 < M)
+ (x_range : 0 <= x < M * 2 ^ k)
+ (M_range : 2 ^ (k - 1) < M < 2 ^ k)
+ (M_good : 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu)
+ (muSelect_correct: muSelect = mu mod 2 ^ k * (x / 2 ^ (k - 1) / 2 ^ k)).
+
+ Definition qt :=
+ let q1 := shiftr xt (k - 1) in
+ let twoq := mul_high mut q1 muSelect in
+ shiftr twoq 1.
+ Definition reduce :=
+ let r2 := mul (low qt) M in
+ let r := sub xt r2 in
+ let q3 := cond_sub1 r M in
+ cond_sub2 q3 M.
+
+ Lemma looser_bound : M * 2 ^ k < 2 ^ (2*k).
+ Proof. clear -M_range M_nz x_range k_pos; rewrite <-Z.add_diag, Z.pow_add_r; nia. Qed.
+
+ Lemma pow_2k_eq : 2 ^ (2*k) = 2 ^ (k - 1) * 2 ^ (k + 1).
+ Proof. clear -k_pos; rewrite <-Z.pow_add_r by omega. f_equal; ring. Qed.
+
+ Lemma mu_bounds : 2 ^ k <= mu < 2^(k+1).
+ Proof.
+ pose proof looser_bound.
+ subst mu. split.
+ { apply Z.div_le_lower_bound; omega. }
+ { apply Z.div_lt_upper_bound; try omega.
+ rewrite pow_2k_eq; apply Z.mul_lt_mono_pos_r; auto with zarith. }
+ Qed.
+
+ Lemma shiftr_x_bounds : 0 <= x / 2 ^ (k - 1) < 2^(k+1).
+ Proof.
+ pose proof looser_bound.
+ split; [ solve [Z.zero_bounds] | ].
+ apply Z.div_lt_upper_bound; auto with zarith.
+ rewrite <-pow_2k_eq. omega.
+ Qed.
+ Hint Resolve shiftr_x_bounds.
+
+ Ltac solve_rep := eauto using shiftr_correct, mul_high_correct, mul_correct, sub_correct with omega.
+
+ Let q := mu * (x / 2 ^ (k - 1)) / 2 ^ (k + 1).
+
+ Lemma q_correct : rep qt q .
+ Proof.
+ pose proof mu_bounds. cbv [qt]; subst q.
+ rewrite Z.pow_add_r, <-Z.div_div by Z.zero_bounds.
+ solve_rep.
+ Qed.
+ Hint Resolve q_correct.
+
+ Lemma x_mod_small : x mod 2 ^ (k - 1) <= M.
+ Proof. transitivity (2 ^ (k - 1)); auto with zarith. Qed.
+ Hint Resolve x_mod_small.
+
+ Lemma q_bounds : 0 <= q < 2 ^ k.
+ Proof.
+ pose proof looser_bound. pose proof x_mod_small. pose proof mu_bounds.
+ split; subst q; [ solve [Z.zero_bounds] | ].
+ edestruct q_nice_strong with (n:=M) as [? Hqnice];
+ try rewrite Hqnice; auto; try omega; [ ].
+ apply Z.le_lt_trans with (m:= x / M).
+ { break_match; omega. }
+ { apply Z.div_lt_upper_bound; omega. }
+ Qed.
+
+ Lemma two_conditional_subtracts :
+ forall a x,
+ rep a x ->
+ 0 <= x < 2 * M ->
+ cond_sub2 (cond_sub1 a M) M = cond_sub2 (cond_sub2 x M) M.
+ Proof.
+ intros.
+ erewrite !cond_sub2_correct, !cond_sub1_correct by (eassumption || omega).
+ break_match; Z.ltb_to_lt; try lia; discriminate.
+ Qed.
+
+ Lemma r_bounds : 0 <= x - q * M < 2 * M.
+ Proof.
+ pose proof looser_bound. pose proof q_bounds. pose proof x_mod_small.
+ subst q mu; split.
+ { Z.zero_bounds. apply qn_small; omega. }
+ { apply r_small_strong; rewrite ?Z.pow_1_r; auto; omega. }
+ Qed.
+
+ Lemma reduce_correct : reduce = x mod M.
+ Proof.
+ pose proof looser_bound. pose proof r_bounds. pose proof q_bounds.
+ assert (2 * M < 2^k * 2^k) by nia.
+ rewrite barrett_reduction_small with (k:=k) (m:=mu) (offset:=1) (b:=2) by (auto; omega).
+ cbv [reduce].
+ erewrite low_correct by eauto. Z.rewrite_mod_small.
+ erewrite two_conditional_subtracts by solve_rep.
+ rewrite !cond_sub2_correct.
+ subst q; reflexivity.
+ Qed.
+ End Generic.
+
+ Section BarrettReduction.
+ Context (k : Z) (Hk_positive : 0 < k).
+ Context (M muLow : Z).
+ Context (M_pos : 0 < M)
+ (muLow_eq : muLow + 2^k = 2^(2*k) / M)
+ (muLow_bounds : 0 <= muLow < 2^k)
+ (M_bound1 : 2 ^ (k - 1) < M < 2^k)
+ (M_bound2: 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2^k)).
+ Context (pow2_k_bound : 4 <= 2^k).
+
+ Context (n:nat) (Hn_nz: n <> 0%nat) (n_le_k : Z.of_nat n <= k).
+ Context (nout : nat) (Hnout : nout = 2%nat).
+ Let w := weight k 1.
+ Local Lemma k_range : 0 < 1 <= k. Proof. omega. Qed.
+ Let props : @weight_properties w := wprops k 1 k_range.
+
+ Hint Rewrite Positional.eval_nil Positional.eval_snoc : push_eval.
+
+ Let T : Type := Z * Z.
+
+ Definition represents (t : T) (x : Z) :=
+ fst t = x mod 2^k /\ snd t = x / 2^k /\ 0 <= x < 2^k * 2^k.
+
+ Lemma represents_fst t x :
+ represents t x -> fst t = x mod 2^k.
+ Proof. cbv [represents]; tauto. Qed.
+
+ Lemma represents_snd t x :
+ represents t x -> snd t = x / 2^k.
+ Proof. cbv [represents]; tauto. Qed.
+
+ Lemma represents_fst_range t x :
+ represents t x -> 0 <= x mod 2^k < 2^k.
+ Proof. auto with zarith. Qed.
+ Hint Resolve represents_fst_range.
+ Lemma represents_snd_range t x :
+ represents t x -> 0 <= x / 2^k < 2^k.
+ Proof.
+ destruct 1 as [? [? ?] ]; intros.
+ auto using Z.div_lt_upper_bound with zarith.
+ Qed.
+ Hint Resolve represents_snd_range.
+
+ Lemma represents_range t x :
+ represents t x -> 0 <= x < 2^k*2^k.
+ Proof. cbv [represents]; tauto. Qed.
+
+ Lemma represents_id x :
+ 0 <= x < 2^k * 2^k ->
+ represents (x mod 2^k, x / 2^k) x.
+ Proof.
+ intros; cbv [represents]; autorewrite with cancel_pair.
+ Z.rewrite_mod_small; tauto.
+ Qed.
+
+ Local Ltac push_rep :=
+ repeat match goal with
+ | H : represents ?t ?x |- _ => unique pose proof (represents_fst_range _ _ H)
+ | H : represents ?t ?x |- _ => unique pose proof (represents_snd_range _ _ H)
+ | H : represents ?t ?x |- _ => rewrite (represents_fst t x) in * by assumption
+ | H : represents ?t ?x |- _ => rewrite (represents_snd t x) in * by assumption
+ end.
+
+ Definition shiftr (t : T) (n : Z) : T :=
+ (Z.rshi (2^k) (snd t) (fst t) n, Z.rshi (2^k) 0 (snd t) n).
+
+ Lemma shiftr_represents a i x :
+ represents a x ->
+ 0 <= i <= k ->
+ represents (shiftr a i) (x / 2 ^ i).
+ Proof.
+ cbv [shiftr]; intros; push_rep.
+ match goal with H : _ |- _ => pose proof (represents_range _ _ H) end.
+ assert (0 < 2 ^ i) by auto with zarith.
+ assert (x < 2 ^ i * 2 ^ k * 2 ^ k) by nia.
+ assert (0 <= x / 2 ^ k / 2 ^ i < 2 ^ k) by
+ (split; Z.zero_bounds; auto using Z.div_lt_upper_bound with zarith).
+ repeat match goal with
+ | _ => rewrite Z.rshi_correct by auto with zarith
+ | _ => rewrite <-Z.div_mod''' by auto with zarith
+ | _ => progress autorewrite with zsimplify_fast
+ | _ => progress Z.rewrite_mod_small
+ | |- context [represents ((?a / ?c) mod ?b, ?a / ?b / ?c)] =>
+ rewrite (Z.div_div_comm a b c) by auto with zarith
+ | _ => solve [auto using represents_id, Z.div_lt_upper_bound with zarith lia]
+ end.
+ Qed.
+
+ Context (Hw : forall i, w i = (2 ^ k) ^ Z.of_nat i).
+ Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r.
+
+ Definition wideadd t1 t2 :=
+ let sum := fst (Rows.add w 2 [fst t1; snd t1] [fst t2; snd t2]) in
+ (nth_default 0 sum 0, nth_default 0 sum 1).
+
+ Definition widesub t1 t2 :=
+ let sum := fst (Rows.sub w 2 [fst t1; snd t1] [fst t2; snd t2]) in
+ (nth_default 0 sum 0, nth_default 0 sum 1).
+
+ Definition widemul x y :=
+ let xy := BaseConversion.widemul k n nout x y in
+ (nth_default 0 xy 0, nth_default 0 xy 1).
+
+ Lemma partition_represents x y :
+ 0 <= y < 2^k*2^k ->
+ x = Rows.partition w 2 y ->
+ represents (nth_default 0 x 0, nth_default 0 x 1) y.
+ Proof.
+ intros; subst x; cbv [represents Rows.partition].
+ cbn; change_weight. Z.rewrite_mod_small.
+ auto with zarith.
+ Qed.
+
+ Lemma eval_represents t x :
+ represents t x ->
+ eval w 2 [fst t; snd t] = x.
+ Proof.
+ intros; cbn. change_weight; push_rep.
+ autorewrite with zsimplify. reflexivity.
+ Qed.
+
+ Ltac wide_op partitions_pf :=
+ repeat match goal with
+ | _ => apply partition_represents; auto with zarith; [ ]
+ | _ => rewrite partitions_pf by auto
+ | _ => erewrite eval_represents by eauto
+ | _ => reflexivity
+ end.
+
+ Lemma wideadd_represents t1 t2 x y :
+ represents t1 x ->
+ represents t2 y ->
+ 0 <= x + y < 2^k*2^k ->
+ represents (wideadd t1 t2) (x + y).
+ Proof. intros; cbv [wideadd]. wide_op Rows.add_partitions. Qed.
+
+ Lemma widesub_represents t1 t2 x y :
+ represents t1 x ->
+ represents t2 y ->
+ 0 <= x - y < 2^k*2^k ->
+ represents (widesub t1 t2) (x - y).
+ Proof. intros; cbv [widesub]. wide_op Rows.sub_partitions. Qed.
+
+ Lemma widemul_represents x y :
+ 0 <= x < 2^k ->
+ 0 <= y < 2^k ->
+ represents (widemul x y) (x * y).
+ Proof.
+ intros; cbv [widemul].
+ rewrite BaseConversion.widemul_correct by auto with zarith.
+ autorewrite with push_nth_default.
+ auto using represents_id with zarith.
+ Qed.
+
+ Definition mul_high (a b : T) a0b1 : T :=
+ let a0b0 := widemul (fst a) (fst b) in
+ let ab := wideadd (snd a0b0, snd b) (fst b, 0) in
+ wideadd ab (a0b1, 0).
+
+ Lemma mul_high_idea s a b a0 a1 b0 b1 :
+ s <> 0 ->
+ a = s * a1 + a0 ->
+ b = s * b1 + b0 ->
+ (a * b) / s = a0 * b0 / s + s * a1 * b1 + a1 * b0 + a0 * b1.
+ Proof.
+ intros. subst a b. autorewrite with push_Zmul.
+ ring_simplify_subterms. rewrite Z.pow_2_r.
+ rewrite Z.div_add_exact by (push_Zmod; autorewrite with zsimplify; omega).
+ repeat match goal with
+ | |- context [s * ?a * ?b * ?c] =>
+ replace (s * a * b * c) with (a * b * c * s) by ring
+ | |- context [s * ?a * ?b] =>
+ replace (s * a * b) with (a * b * s) by ring
+ end.
+ rewrite !Z.div_add by omega.
+ autorewrite with zsimplify.
+ rewrite (Z.mul_comm a0 b0).
+ ring_simplify. ring.
+ Qed.
+
+ Lemma represents_trans t x y:
+ represents t y -> y = x ->
+ represents t x.
+ Proof. congruence. Qed.
+
+ Lemma represents_add a b x y :
+ a = x -> b = y ->
+ 0 <= x < 2 ^ k ->
+ 0 <= y < 2 ^ k ->
+ represents (a,b) (x + 2^k*y).
+ Proof. intros; subst a b; repeat split; autorewrite with cancel_pair zsimplify; nia. Qed.
+
+ Lemma represents_small x :
+ 0 <= x < 2^k ->
+ represents (x, 0) x.
+ Proof.
+ intros.
+ eapply represents_trans.
+ { eauto using represents_add with zarith. }
+ { ring. }
+ Qed.
+
+ Lemma mul_high_represents a b x y a0b1 :
+ represents a x ->
+ represents b y ->
+ 2^k <= x < 2^(k+1) ->
+ 0 <= y < 2^(k+1) ->
+ a0b1 = x mod 2^k * (y / 2^k) ->
+ represents (mul_high a b a0b1) ((x * y) / 2^k).
+ Proof.
+ cbv [mul_high]; rewrite Z.pow_add_r, Z.pow_1_r by omega; intros.
+ assert (0 <= x * y / 2^k < 2^k*2^k) by (Z.div_mod_to_quot_rem; nia).
+
+ rewrite mul_high_idea with (a:=x) (b:=y) (a0 := fst a) (a1 := snd a) (b0 := fst b) (b1 := snd b) in *
+ by (push_rep; Z.div_mod_to_quot_rem; lia).
+
+ push_rep. subst a0b1.
+ assert (y / 2 ^ k < 2) by (apply Z.div_lt_upper_bound; omega).
+ replace (x / 2 ^ k) with 1 in * by (rewrite Z.div_between_1; lia).
+ autorewrite with zsimplify_fast in *.
+
+ eapply represents_trans.
+ { repeat (apply wideadd_represents;
+ [ | apply represents_small; Z.div_mod_to_quot_rem; nia| ]).
+ erewrite represents_snd; [ | apply widemul_represents; solve [ auto with zarith ] ].
+ { apply represents_add; try reflexivity; solve [auto with zarith]. }
+ { match goal with H : 0 <= ?x + ?y < ?z |- 0 <= ?x < ?z =>
+ split; [ solve [Z.zero_bounds] | ];
+ eapply Z.le_lt_trans with (m:= x + y); nia
+ end. }
+ { omega. } }
+ { ring. }
+ Qed.
+
+ Definition cond_sub1 (a : T) y : Z :=
+ let maybe_y := Z.zselect (Z.cc_l (snd a)) 0 y in
+ fst (Z.sub_get_borrow_full (2^k) (fst a) maybe_y).
+
+ Lemma cc_l_only_bit x s: 0 <= x < 2 * s -> Z.cc_l (x / s) = 0 <-> x < s.
+ Proof.
+ cbv [Z.cc_l]; intros.
+ rewrite Z.land_ones, Z.pow_1_r by omega.
+ rewrite Z.div_between_0_if by omega.
+ break_match; Z.ltb_to_lt; Z.rewrite_mod_small; omega.
+ Qed.
+
+ Lemma cond_sub1_correct a x y :
+ represents a x ->
+ 0 <= x < 2 * y ->
+ 0 <= y < 2 ^ k ->
+ cond_sub1 a y = if (x <? 2 ^ k) then x else x - y.
+ Proof.
+ intros; cbv [cond_sub1]. rewrite Z.zselect_correct. push_rep.
+ break_match; Z.ltb_to_lt; rewrite cc_l_only_bit in *; try omega;
+ autorewrite with zsimplify_fast to_div_mod pull_Zmod; auto with zarith.
+ Qed.
+
+ Definition cond_sub2 x y := Z.add_modulo x 0 y.
+ Lemma cond_sub2_correct x y :
+ cond_sub2 x y = if (x <? y) then x else x - y.
+ Proof.
+ cbv [cond_sub2]. rewrite Z.add_modulo_correct.
+ autorewrite with zsimplify_fast. break_match; Z.ltb_to_lt; omega.
+ Qed.
+
+ Section Defn.
+ Context (xLow xHigh : Z) (xLow_bounds : 0 <= xLow < 2^k) (xHigh_bounds : 0 <= xHigh < M).
+ Let xt := (xLow, xHigh).
+ Let x := xLow + 2^k * xHigh.
+
+ Lemma x_rep : represents xt x.
+ Proof. cbv [represents]; subst xt x; autorewrite with cancel_pair zsimplify; repeat split; nia. Qed.
+
+ Lemma x_bounds : 0 <= x < M * 2 ^ k.
+ Proof. subst x; nia. Qed.
+
+ Definition muSelect := Z.zselect (Z.cc_m (2 ^ k) xHigh) 0 muLow.
+
+ Local Hint Resolve Z.div_nonneg Z.div_lt_upper_bound.
+
+ Lemma muSelect_correct :
+ muSelect = (2 ^ (2 * k) / M) mod 2 ^ k * ((x / 2 ^ (k - 1)) / 2 ^ k).
+ Proof.
+ (* assertions to help arith tactics *)
+ pose proof x_bounds.
+ assert (2^k * M < 2 ^ (2*k)) by (rewrite <-Z.add_diag, Z.pow_add_r; nia).
+ assert (0 <= x / (2 ^ k * (2 ^ k / 2)) < 2) by (Z.div_mod_to_quot_rem; auto with nia).
+ assert (0 < 2 ^ k / 2) by Z.zero_bounds.
+ assert (2 ^ (k - 1) <> 0) by auto with zarith.
+
+ cbv [muSelect]. rewrite <-muLow_eq.
+ rewrite Z.zselect_correct, Z.cc_m_eq by nia.
+ replace xHigh with (x / 2^k) by (subst x; autorewrite with zsimplify; lia).
+ autorewrite with pull_Zdiv push_Zpow.
+ rewrite (Z.mul_comm (2 ^ k / 2)).
+ break_match; [ ring | ].
+ match goal with H : 0 <= ?x < 2, H' : ?x <> 0 |- _ => replace x with 1 by omega end.
+ autorewrite with zsimplify; reflexivity.
+ Qed.
+
+ Definition barrett_reduce : Z :=
+ reduce k fst shiftr mul_high widemul widesub cond_sub1 cond_sub2 xt (muLow, 1) M muSelect.
+
+ Lemma barrett_reduce_correct :
+ barrett_reduce = x mod M.
+ Proof.
+ intros; cbv [barrett_reduce].
+ apply reduce_correct with (rep:=represents); try omega;
+ auto using shiftr_represents, mul_high_represents, widemul_represents, widesub_represents,
+ cond_sub1_correct, cond_sub2_correct, x_bounds, muSelect_correct, represents_fst, x_rep.
+ rewrite <-muLow_eq. cbv [represents]; repeat split; autorewrite with cancel_pair zsimplify; nia.
+ Qed.
+ End Defn.
+ End BarrettReduction.
+End BarrettReduction.
+
Require Import Crypto.Arithmetic.MontgomeryReduction.Definition.
Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs.
Require Import Crypto.Util.ZUtil.EquivModulo.