diff options
author | Jade Philipoom <jadep@google.com> | 2018-04-11 17:49:06 +0200 |
---|---|---|
committer | Jade Philipoom <jadep@google.com> | 2018-04-11 17:49:06 +0200 |
commit | 5bcaffd6bcb26d1643484e2551e10cfbd78f2d22 (patch) | |
tree | e1fcad007ff990f90da54729499445058b40224c /src | |
parent | 0bdfa57b5c24a34f6fafe8a97c1ce6453ce2cd83 (diff) |
barrett reduction definition and proof
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 463 |
1 files changed, 463 insertions, 0 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 71fca4645..4f4eb5683 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -7747,6 +7747,469 @@ fun var : type -> Type => End X25519_32. *) + +Require Import Crypto.Arithmetic.BarrettReduction.Generalized. +Require Import Crypto.Util.ZUtil.Zselect Crypto.Util.ZUtil.AddModulo. +Require Import Crypto.Util.ZUtil.CC Crypto.Util.ZUtil.Rshi. + +Module BarrettReduction. + (* TODO : generalize to multi-word and operate on (list Z) instead of T; maybe stop taking ops as context variables *) + Section Generic. + Context {T} (rep : T -> Z -> Prop) + (k : Z) (k_pos : 0 < k) + (low : T -> Z) + (low_correct : forall a x, rep a x -> low a = x mod 2 ^ k) + (shiftr : T -> Z -> T) + (shiftr_correct : forall a x n, + rep a x -> + 0 <= n <= k -> + rep (shiftr a n) (x / 2 ^ n)) + (mul_high : T -> T -> Z -> T) + (mul_high_correct : forall a b x y x0y1, + rep a x -> + rep b y -> + 2 ^ k <= x < 2^(k+1) -> + 0 <= y < 2^(k+1) -> + x0y1 = x mod 2 ^ k * (y / 2 ^ k) -> + rep (mul_high a b x0y1) (x * y / 2 ^ k)) + (mul : Z -> Z -> T) + (mul_correct : forall x y, + 0 <= x < 2^k -> + 0 <= y < 2^k -> + rep (mul x y) (x * y)) + (sub : T -> T -> T) + (sub_correct : forall a b x y, + rep a x -> + rep b y -> + 0 <= x - y < 2^k * 2^k -> + rep (sub a b) (x - y)) + (cond_sub1 : T -> Z -> Z) + (cond_sub1_correct : forall a x y, + rep a x -> + 0 <= x < 2 * y -> + 0 <= y < 2 ^ k -> + cond_sub1 a y = if (x <? 2 ^ k) then x else x - y) + (cond_sub2 : Z -> Z -> Z) + (cond_sub2_correct : forall x y, cond_sub2 x y = if (x <? y) then x else x - y). + Context (xt mut : T) (M muSelect: Z). + + Let mu := 2 ^ (2 * k) / M. + Context x (mu_rep : rep mut mu) (x_rep : rep xt x). + Context (M_nz : 0 < M) + (x_range : 0 <= x < M * 2 ^ k) + (M_range : 2 ^ (k - 1) < M < 2 ^ k) + (M_good : 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - mu) + (muSelect_correct: muSelect = mu mod 2 ^ k * (x / 2 ^ (k - 1) / 2 ^ k)). + + Definition qt := + let q1 := shiftr xt (k - 1) in + let twoq := mul_high mut q1 muSelect in + shiftr twoq 1. + Definition reduce := + let r2 := mul (low qt) M in + let r := sub xt r2 in + let q3 := cond_sub1 r M in + cond_sub2 q3 M. + + Lemma looser_bound : M * 2 ^ k < 2 ^ (2*k). + Proof. clear -M_range M_nz x_range k_pos; rewrite <-Z.add_diag, Z.pow_add_r; nia. Qed. + + Lemma pow_2k_eq : 2 ^ (2*k) = 2 ^ (k - 1) * 2 ^ (k + 1). + Proof. clear -k_pos; rewrite <-Z.pow_add_r by omega. f_equal; ring. Qed. + + Lemma mu_bounds : 2 ^ k <= mu < 2^(k+1). + Proof. + pose proof looser_bound. + subst mu. split. + { apply Z.div_le_lower_bound; omega. } + { apply Z.div_lt_upper_bound; try omega. + rewrite pow_2k_eq; apply Z.mul_lt_mono_pos_r; auto with zarith. } + Qed. + + Lemma shiftr_x_bounds : 0 <= x / 2 ^ (k - 1) < 2^(k+1). + Proof. + pose proof looser_bound. + split; [ solve [Z.zero_bounds] | ]. + apply Z.div_lt_upper_bound; auto with zarith. + rewrite <-pow_2k_eq. omega. + Qed. + Hint Resolve shiftr_x_bounds. + + Ltac solve_rep := eauto using shiftr_correct, mul_high_correct, mul_correct, sub_correct with omega. + + Let q := mu * (x / 2 ^ (k - 1)) / 2 ^ (k + 1). + + Lemma q_correct : rep qt q . + Proof. + pose proof mu_bounds. cbv [qt]; subst q. + rewrite Z.pow_add_r, <-Z.div_div by Z.zero_bounds. + solve_rep. + Qed. + Hint Resolve q_correct. + + Lemma x_mod_small : x mod 2 ^ (k - 1) <= M. + Proof. transitivity (2 ^ (k - 1)); auto with zarith. Qed. + Hint Resolve x_mod_small. + + Lemma q_bounds : 0 <= q < 2 ^ k. + Proof. + pose proof looser_bound. pose proof x_mod_small. pose proof mu_bounds. + split; subst q; [ solve [Z.zero_bounds] | ]. + edestruct q_nice_strong with (n:=M) as [? Hqnice]; + try rewrite Hqnice; auto; try omega; [ ]. + apply Z.le_lt_trans with (m:= x / M). + { break_match; omega. } + { apply Z.div_lt_upper_bound; omega. } + Qed. + + Lemma two_conditional_subtracts : + forall a x, + rep a x -> + 0 <= x < 2 * M -> + cond_sub2 (cond_sub1 a M) M = cond_sub2 (cond_sub2 x M) M. + Proof. + intros. + erewrite !cond_sub2_correct, !cond_sub1_correct by (eassumption || omega). + break_match; Z.ltb_to_lt; try lia; discriminate. + Qed. + + Lemma r_bounds : 0 <= x - q * M < 2 * M. + Proof. + pose proof looser_bound. pose proof q_bounds. pose proof x_mod_small. + subst q mu; split. + { Z.zero_bounds. apply qn_small; omega. } + { apply r_small_strong; rewrite ?Z.pow_1_r; auto; omega. } + Qed. + + Lemma reduce_correct : reduce = x mod M. + Proof. + pose proof looser_bound. pose proof r_bounds. pose proof q_bounds. + assert (2 * M < 2^k * 2^k) by nia. + rewrite barrett_reduction_small with (k:=k) (m:=mu) (offset:=1) (b:=2) by (auto; omega). + cbv [reduce]. + erewrite low_correct by eauto. Z.rewrite_mod_small. + erewrite two_conditional_subtracts by solve_rep. + rewrite !cond_sub2_correct. + subst q; reflexivity. + Qed. + End Generic. + + Section BarrettReduction. + Context (k : Z) (Hk_positive : 0 < k). + Context (M muLow : Z). + Context (M_pos : 0 < M) + (muLow_eq : muLow + 2^k = 2^(2*k) / M) + (muLow_bounds : 0 <= muLow < 2^k) + (M_bound1 : 2 ^ (k - 1) < M < 2^k) + (M_bound2: 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2^k)). + Context (pow2_k_bound : 4 <= 2^k). + + Context (n:nat) (Hn_nz: n <> 0%nat) (n_le_k : Z.of_nat n <= k). + Context (nout : nat) (Hnout : nout = 2%nat). + Let w := weight k 1. + Local Lemma k_range : 0 < 1 <= k. Proof. omega. Qed. + Let props : @weight_properties w := wprops k 1 k_range. + + Hint Rewrite Positional.eval_nil Positional.eval_snoc : push_eval. + + Let T : Type := Z * Z. + + Definition represents (t : T) (x : Z) := + fst t = x mod 2^k /\ snd t = x / 2^k /\ 0 <= x < 2^k * 2^k. + + Lemma represents_fst t x : + represents t x -> fst t = x mod 2^k. + Proof. cbv [represents]; tauto. Qed. + + Lemma represents_snd t x : + represents t x -> snd t = x / 2^k. + Proof. cbv [represents]; tauto. Qed. + + Lemma represents_fst_range t x : + represents t x -> 0 <= x mod 2^k < 2^k. + Proof. auto with zarith. Qed. + Hint Resolve represents_fst_range. + Lemma represents_snd_range t x : + represents t x -> 0 <= x / 2^k < 2^k. + Proof. + destruct 1 as [? [? ?] ]; intros. + auto using Z.div_lt_upper_bound with zarith. + Qed. + Hint Resolve represents_snd_range. + + Lemma represents_range t x : + represents t x -> 0 <= x < 2^k*2^k. + Proof. cbv [represents]; tauto. Qed. + + Lemma represents_id x : + 0 <= x < 2^k * 2^k -> + represents (x mod 2^k, x / 2^k) x. + Proof. + intros; cbv [represents]; autorewrite with cancel_pair. + Z.rewrite_mod_small; tauto. + Qed. + + Local Ltac push_rep := + repeat match goal with + | H : represents ?t ?x |- _ => unique pose proof (represents_fst_range _ _ H) + | H : represents ?t ?x |- _ => unique pose proof (represents_snd_range _ _ H) + | H : represents ?t ?x |- _ => rewrite (represents_fst t x) in * by assumption + | H : represents ?t ?x |- _ => rewrite (represents_snd t x) in * by assumption + end. + + Definition shiftr (t : T) (n : Z) : T := + (Z.rshi (2^k) (snd t) (fst t) n, Z.rshi (2^k) 0 (snd t) n). + + Lemma shiftr_represents a i x : + represents a x -> + 0 <= i <= k -> + represents (shiftr a i) (x / 2 ^ i). + Proof. + cbv [shiftr]; intros; push_rep. + match goal with H : _ |- _ => pose proof (represents_range _ _ H) end. + assert (0 < 2 ^ i) by auto with zarith. + assert (x < 2 ^ i * 2 ^ k * 2 ^ k) by nia. + assert (0 <= x / 2 ^ k / 2 ^ i < 2 ^ k) by + (split; Z.zero_bounds; auto using Z.div_lt_upper_bound with zarith). + repeat match goal with + | _ => rewrite Z.rshi_correct by auto with zarith + | _ => rewrite <-Z.div_mod''' by auto with zarith + | _ => progress autorewrite with zsimplify_fast + | _ => progress Z.rewrite_mod_small + | |- context [represents ((?a / ?c) mod ?b, ?a / ?b / ?c)] => + rewrite (Z.div_div_comm a b c) by auto with zarith + | _ => solve [auto using represents_id, Z.div_lt_upper_bound with zarith lia] + end. + Qed. + + Context (Hw : forall i, w i = (2 ^ k) ^ Z.of_nat i). + Ltac change_weight := rewrite !Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r. + + Definition wideadd t1 t2 := + let sum := fst (Rows.add w 2 [fst t1; snd t1] [fst t2; snd t2]) in + (nth_default 0 sum 0, nth_default 0 sum 1). + + Definition widesub t1 t2 := + let sum := fst (Rows.sub w 2 [fst t1; snd t1] [fst t2; snd t2]) in + (nth_default 0 sum 0, nth_default 0 sum 1). + + Definition widemul x y := + let xy := BaseConversion.widemul k n nout x y in + (nth_default 0 xy 0, nth_default 0 xy 1). + + Lemma partition_represents x y : + 0 <= y < 2^k*2^k -> + x = Rows.partition w 2 y -> + represents (nth_default 0 x 0, nth_default 0 x 1) y. + Proof. + intros; subst x; cbv [represents Rows.partition]. + cbn; change_weight. Z.rewrite_mod_small. + auto with zarith. + Qed. + + Lemma eval_represents t x : + represents t x -> + eval w 2 [fst t; snd t] = x. + Proof. + intros; cbn. change_weight; push_rep. + autorewrite with zsimplify. reflexivity. + Qed. + + Ltac wide_op partitions_pf := + repeat match goal with + | _ => apply partition_represents; auto with zarith; [ ] + | _ => rewrite partitions_pf by auto + | _ => erewrite eval_represents by eauto + | _ => reflexivity + end. + + Lemma wideadd_represents t1 t2 x y : + represents t1 x -> + represents t2 y -> + 0 <= x + y < 2^k*2^k -> + represents (wideadd t1 t2) (x + y). + Proof. intros; cbv [wideadd]. wide_op Rows.add_partitions. Qed. + + Lemma widesub_represents t1 t2 x y : + represents t1 x -> + represents t2 y -> + 0 <= x - y < 2^k*2^k -> + represents (widesub t1 t2) (x - y). + Proof. intros; cbv [widesub]. wide_op Rows.sub_partitions. Qed. + + Lemma widemul_represents x y : + 0 <= x < 2^k -> + 0 <= y < 2^k -> + represents (widemul x y) (x * y). + Proof. + intros; cbv [widemul]. + rewrite BaseConversion.widemul_correct by auto with zarith. + autorewrite with push_nth_default. + auto using represents_id with zarith. + Qed. + + Definition mul_high (a b : T) a0b1 : T := + let a0b0 := widemul (fst a) (fst b) in + let ab := wideadd (snd a0b0, snd b) (fst b, 0) in + wideadd ab (a0b1, 0). + + Lemma mul_high_idea s a b a0 a1 b0 b1 : + s <> 0 -> + a = s * a1 + a0 -> + b = s * b1 + b0 -> + (a * b) / s = a0 * b0 / s + s * a1 * b1 + a1 * b0 + a0 * b1. + Proof. + intros. subst a b. autorewrite with push_Zmul. + ring_simplify_subterms. rewrite Z.pow_2_r. + rewrite Z.div_add_exact by (push_Zmod; autorewrite with zsimplify; omega). + repeat match goal with + | |- context [s * ?a * ?b * ?c] => + replace (s * a * b * c) with (a * b * c * s) by ring + | |- context [s * ?a * ?b] => + replace (s * a * b) with (a * b * s) by ring + end. + rewrite !Z.div_add by omega. + autorewrite with zsimplify. + rewrite (Z.mul_comm a0 b0). + ring_simplify. ring. + Qed. + + Lemma represents_trans t x y: + represents t y -> y = x -> + represents t x. + Proof. congruence. Qed. + + Lemma represents_add a b x y : + a = x -> b = y -> + 0 <= x < 2 ^ k -> + 0 <= y < 2 ^ k -> + represents (a,b) (x + 2^k*y). + Proof. intros; subst a b; repeat split; autorewrite with cancel_pair zsimplify; nia. Qed. + + Lemma represents_small x : + 0 <= x < 2^k -> + represents (x, 0) x. + Proof. + intros. + eapply represents_trans. + { eauto using represents_add with zarith. } + { ring. } + Qed. + + Lemma mul_high_represents a b x y a0b1 : + represents a x -> + represents b y -> + 2^k <= x < 2^(k+1) -> + 0 <= y < 2^(k+1) -> + a0b1 = x mod 2^k * (y / 2^k) -> + represents (mul_high a b a0b1) ((x * y) / 2^k). + Proof. + cbv [mul_high]; rewrite Z.pow_add_r, Z.pow_1_r by omega; intros. + assert (0 <= x * y / 2^k < 2^k*2^k) by (Z.div_mod_to_quot_rem; nia). + + rewrite mul_high_idea with (a:=x) (b:=y) (a0 := fst a) (a1 := snd a) (b0 := fst b) (b1 := snd b) in * + by (push_rep; Z.div_mod_to_quot_rem; lia). + + push_rep. subst a0b1. + assert (y / 2 ^ k < 2) by (apply Z.div_lt_upper_bound; omega). + replace (x / 2 ^ k) with 1 in * by (rewrite Z.div_between_1; lia). + autorewrite with zsimplify_fast in *. + + eapply represents_trans. + { repeat (apply wideadd_represents; + [ | apply represents_small; Z.div_mod_to_quot_rem; nia| ]). + erewrite represents_snd; [ | apply widemul_represents; solve [ auto with zarith ] ]. + { apply represents_add; try reflexivity; solve [auto with zarith]. } + { match goal with H : 0 <= ?x + ?y < ?z |- 0 <= ?x < ?z => + split; [ solve [Z.zero_bounds] | ]; + eapply Z.le_lt_trans with (m:= x + y); nia + end. } + { omega. } } + { ring. } + Qed. + + Definition cond_sub1 (a : T) y : Z := + let maybe_y := Z.zselect (Z.cc_l (snd a)) 0 y in + fst (Z.sub_get_borrow_full (2^k) (fst a) maybe_y). + + Lemma cc_l_only_bit x s: 0 <= x < 2 * s -> Z.cc_l (x / s) = 0 <-> x < s. + Proof. + cbv [Z.cc_l]; intros. + rewrite Z.land_ones, Z.pow_1_r by omega. + rewrite Z.div_between_0_if by omega. + break_match; Z.ltb_to_lt; Z.rewrite_mod_small; omega. + Qed. + + Lemma cond_sub1_correct a x y : + represents a x -> + 0 <= x < 2 * y -> + 0 <= y < 2 ^ k -> + cond_sub1 a y = if (x <? 2 ^ k) then x else x - y. + Proof. + intros; cbv [cond_sub1]. rewrite Z.zselect_correct. push_rep. + break_match; Z.ltb_to_lt; rewrite cc_l_only_bit in *; try omega; + autorewrite with zsimplify_fast to_div_mod pull_Zmod; auto with zarith. + Qed. + + Definition cond_sub2 x y := Z.add_modulo x 0 y. + Lemma cond_sub2_correct x y : + cond_sub2 x y = if (x <? y) then x else x - y. + Proof. + cbv [cond_sub2]. rewrite Z.add_modulo_correct. + autorewrite with zsimplify_fast. break_match; Z.ltb_to_lt; omega. + Qed. + + Section Defn. + Context (xLow xHigh : Z) (xLow_bounds : 0 <= xLow < 2^k) (xHigh_bounds : 0 <= xHigh < M). + Let xt := (xLow, xHigh). + Let x := xLow + 2^k * xHigh. + + Lemma x_rep : represents xt x. + Proof. cbv [represents]; subst xt x; autorewrite with cancel_pair zsimplify; repeat split; nia. Qed. + + Lemma x_bounds : 0 <= x < M * 2 ^ k. + Proof. subst x; nia. Qed. + + Definition muSelect := Z.zselect (Z.cc_m (2 ^ k) xHigh) 0 muLow. + + Local Hint Resolve Z.div_nonneg Z.div_lt_upper_bound. + + Lemma muSelect_correct : + muSelect = (2 ^ (2 * k) / M) mod 2 ^ k * ((x / 2 ^ (k - 1)) / 2 ^ k). + Proof. + (* assertions to help arith tactics *) + pose proof x_bounds. + assert (2^k * M < 2 ^ (2*k)) by (rewrite <-Z.add_diag, Z.pow_add_r; nia). + assert (0 <= x / (2 ^ k * (2 ^ k / 2)) < 2) by (Z.div_mod_to_quot_rem; auto with nia). + assert (0 < 2 ^ k / 2) by Z.zero_bounds. + assert (2 ^ (k - 1) <> 0) by auto with zarith. + + cbv [muSelect]. rewrite <-muLow_eq. + rewrite Z.zselect_correct, Z.cc_m_eq by nia. + replace xHigh with (x / 2^k) by (subst x; autorewrite with zsimplify; lia). + autorewrite with pull_Zdiv push_Zpow. + rewrite (Z.mul_comm (2 ^ k / 2)). + break_match; [ ring | ]. + match goal with H : 0 <= ?x < 2, H' : ?x <> 0 |- _ => replace x with 1 by omega end. + autorewrite with zsimplify; reflexivity. + Qed. + + Definition barrett_reduce : Z := + reduce k fst shiftr mul_high widemul widesub cond_sub1 cond_sub2 xt (muLow, 1) M muSelect. + + Lemma barrett_reduce_correct : + barrett_reduce = x mod M. + Proof. + intros; cbv [barrett_reduce]. + apply reduce_correct with (rep:=represents); try omega; + auto using shiftr_represents, mul_high_represents, widemul_represents, widesub_represents, + cond_sub1_correct, cond_sub2_correct, x_bounds, muSelect_correct, represents_fst, x_rep. + rewrite <-muLow_eq. cbv [represents]; repeat split; autorewrite with cancel_pair zsimplify; nia. + Qed. + End Defn. + End BarrettReduction. +End BarrettReduction. + Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. Require Import Crypto.Util.ZUtil.EquivModulo. |