diff options
Diffstat (limited to 'src/Util/ZUtil/Shift.v')
-rw-r--r-- | src/Util/ZUtil/Shift.v | 393 |
1 files changed, 393 insertions, 0 deletions
diff --git a/src/Util/ZUtil/Shift.v b/src/Util/ZUtil/Shift.v new file mode 100644 index 000000000..b5fb79c13 --- /dev/null +++ b/src/Util/ZUtil/Shift.v @@ -0,0 +1,393 @@ +Require Import Coq.ZArith.ZArith. +Require Import Coq.micromega.Lia. +Require Import Crypto.Util.ZUtil.Hints.Core. +Require Import Crypto.Util.ZUtil.Ones. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.Testbit. +Require Import Crypto.Util.ZUtil.Pow2Mod. +Require Import Crypto.Util.ZUtil.Le. +Require Import Crypto.Util.ZUtil.Div. +Require Import Crypto.Util.ZUtil.Tactics.ZeroBounds. +Require Import Crypto.Util.ZUtil.Notations. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.SpecializeBy. +Local Open Scope Z_scope. + +Module Z. + Lemma shiftr_add_shiftl_high : forall n m a b, 0 <= n <= m -> 0 <= a < 2 ^ n -> + Z.shiftr (a + (Z.shiftl b n)) m = Z.shiftr b (m - n). + Proof. + intros n m a b H H0. + rewrite !Z.shiftr_div_pow2, Z.shiftl_mul_pow2 by omega. + replace (2 ^ m) with (2 ^ n * 2 ^ (m - n)) by + (rewrite <-Z.pow_add_r by omega; f_equal; ring). + rewrite <-Z.div_div, Z.div_add, (Z.div_small a) ; try solve + [assumption || apply Z.pow_nonzero || apply Z.pow_pos_nonneg; omega]. + f_equal; ring. + Qed. + Hint Rewrite Z.shiftr_add_shiftl_high using zutil_arith : pull_Zshift. + Hint Rewrite <- Z.shiftr_add_shiftl_high using zutil_arith : push_Zshift. + + Lemma shiftr_add_shiftl_low : forall n m a b, 0 <= m <= n -> 0 <= a < 2 ^ n -> + Z.shiftr (a + (Z.shiftl b n)) m = Z.shiftr a m + Z.shiftr b (m - n). + Proof. + intros n m a b H H0. + rewrite !Z.shiftr_div_pow2, Z.shiftl_mul_pow2, Z.shiftr_mul_pow2 by omega. + replace (2 ^ n) with (2 ^ (n - m) * 2 ^ m) by + (rewrite <-Z.pow_add_r by omega; f_equal; ring). + rewrite Z.mul_assoc, Z.div_add by (apply Z.pow_nonzero; omega). + repeat f_equal; ring. + Qed. + Hint Rewrite Z.shiftr_add_shiftl_low using zutil_arith : pull_Zshift. + Hint Rewrite <- Z.shiftr_add_shiftl_low using zutil_arith : push_Zshift. + + Lemma testbit_add_shiftl_high : forall i, (0 <= i) -> forall a b n, (0 <= n <= i) -> + 0 <= a < 2 ^ n -> + Z.testbit (a + Z.shiftl b n) i = Z.testbit b (i - n). + Proof. + intros i ?. + apply natlike_ind with (x := i); [ intros a b n | intros x H0 H1 a b n | ]; intros; try assumption; + (destruct (Z.eq_dec 0 n); [ subst; rewrite Z.pow_0_r in *; + replace a with 0 by omega; f_equal; ring | ]); try omega. + rewrite <-Z.add_1_r at 1. rewrite <-Z.shiftr_spec by assumption. + replace (Z.succ x - n) with (x - (n - 1)) by ring. + rewrite shiftr_add_shiftl_low, <-Z.shiftl_opp_r with (a := b) by omega. + rewrite <-H1 with (a := Z.shiftr a 1); try omega; [ repeat f_equal; ring | ]. + rewrite Z.shiftr_div_pow2 by omega. + split; apply Z.div_pos || apply Z.div_lt_upper_bound; + try solve [rewrite ?Z.pow_1_r; omega]. + rewrite <-Z.pow_add_r by omega. + replace (1 + (n - 1)) with n by ring; omega. + Qed. + Hint Rewrite testbit_add_shiftl_high using zutil_arith : Ztestbit. + + Lemma shiftr_succ : forall n x, + Z.shiftr n (Z.succ x) = Z.shiftr (Z.shiftr n x) 1. + Proof. + intros. + rewrite Z.shiftr_shiftr by omega. + reflexivity. + Qed. + Hint Rewrite Z.shiftr_succ using zutil_arith : push_Zshift. + Hint Rewrite <- Z.shiftr_succ using zutil_arith : pull_Zshift. + + Lemma shiftr_1_r_le : forall a b, a <= b -> + Z.shiftr a 1 <= Z.shiftr b 1. + Proof. + intros. + rewrite !Z.shiftr_div_pow2, Z.pow_1_r by omega. + apply Z.div_le_mono; omega. + Qed. + Hint Resolve shiftr_1_r_le : zarith. + + Lemma shiftr_le : forall a b i : Z, 0 <= i -> a <= b -> a >> i <= b >> i. + Proof. + intros a b i ?; revert a b. apply natlike_ind with (x := i); intros; auto. + rewrite !shiftr_succ, shiftr_1_r_le; eauto. reflexivity. + Qed. + Hint Resolve shiftr_le : zarith. + + Lemma shiftr_ones' : forall a n, 0 <= a < 2 ^ n -> forall i, (0 <= i) -> + Z.shiftr a i <= Z.ones (n - i) \/ n <= i. + Proof. + intros a n H. + apply natlike_ind. + + unfold Z.ones. + rewrite Z.shiftr_0_r, Z.shiftl_1_l, Z.sub_0_r. + omega. + + intros x H0 H1. + destruct (Z_lt_le_dec x n); try omega. + intuition auto with zarith lia. + left. + rewrite shiftr_succ. + replace (n - Z.succ x) with (Z.pred (n - x)) by omega. + rewrite Z.ones_pred by omega. + apply Z.shiftr_1_r_le. + assumption. + Qed. + + Lemma shiftr_ones : forall a n i, 0 <= a < 2 ^ n -> (0 <= i) -> (i <= n) -> + Z.shiftr a i <= Z.ones (n - i) . + Proof. + intros a n i G G0 G1. + destruct (Z_le_lt_eq_dec i n G1). + + destruct (Z.shiftr_ones' a n G i G0); omega. + + subst; rewrite Z.sub_diag. + destruct (Z.eq_dec a 0). + - subst; rewrite Z.shiftr_0_l; reflexivity. + - rewrite Z.shiftr_eq_0; try omega; try reflexivity. + apply Z.log2_lt_pow2; omega. + Qed. + Hint Resolve shiftr_ones : zarith. + + Lemma shiftr_upper_bound : forall a n, 0 <= n -> 0 <= a <= 2 ^ n -> Z.shiftr a n <= 1. + Proof. + intros a ? ? [a_nonneg a_upper_bound]. + apply Z_le_lt_eq_dec in a_upper_bound. + destruct a_upper_bound. + + destruct (Z.eq_dec 0 a). + - subst; rewrite Z.shiftr_0_l; omega. + - rewrite Z.shiftr_eq_0; auto; try omega. + apply Z.log2_lt_pow2; auto; omega. + + subst. + rewrite Z.shiftr_div_pow2 by assumption. + rewrite Z.div_same; try omega. + assert (0 < 2 ^ n) by (apply Z.pow_pos_nonneg; omega). + omega. + Qed. + Hint Resolve shiftr_upper_bound : zarith. + + Lemma lor_shiftl : forall a b n, 0 <= n -> 0 <= a < 2 ^ n -> + Z.lor a (Z.shiftl b n) = a + (Z.shiftl b n). + Proof. + intros a b n H H0. + apply Z.bits_inj'; intros t ?. + rewrite Z.lor_spec, Z.shiftl_spec by assumption. + destruct (Z_lt_dec t n). + + rewrite Z.testbit_add_shiftl_low by omega. + rewrite Z.testbit_neg_r with (n := t - n) by omega. + apply Bool.orb_false_r. + + rewrite testbit_add_shiftl_high by omega. + replace (Z.testbit a t) with false; [ apply Bool.orb_false_l | ]. + symmetry. + apply Z.testbit_false; try omega. + rewrite Z.div_small; try reflexivity. + split; try eapply Z.lt_le_trans with (m := 2 ^ n); try omega. + apply Z.pow_le_mono_r; omega. + Qed. + Hint Rewrite <- Z.lor_shiftl using zutil_arith : convert_to_Ztestbit. + + Lemma lor_shiftl' : forall a b n, 0 <= n -> 0 <= a < 2 ^ n -> + Z.lor (Z.shiftl b n) a = (Z.shiftl b n) + a. + Proof. + intros; rewrite Z.lor_comm, Z.add_comm; apply lor_shiftl; assumption. + Qed. + Hint Rewrite <- Z.lor_shiftl' using zutil_arith : convert_to_Ztestbit. + + Lemma shiftl_spec_full a n m + : Z.testbit (a << n) m = if Z_lt_dec m n + then false + else if Z_le_dec 0 m + then Z.testbit a (m - n) + else false. + Proof. + repeat break_match; auto using Z.shiftl_spec_low, Z.shiftl_spec, Z.testbit_neg_r with omega. + Qed. + Hint Rewrite shiftl_spec_full : Ztestbit_full. + + Lemma shiftr_spec_full a n m + : Z.testbit (a >> n) m = if Z_lt_dec m (-n) + then false + else if Z_le_dec 0 m + then Z.testbit a (m + n) + else false. + Proof. + rewrite <- Z.shiftl_opp_r, shiftl_spec_full, Z.sub_opp_r; reflexivity. + Qed. + Hint Rewrite shiftr_spec_full : Ztestbit_full. + + Lemma testbit_add_shiftl_full i (Hi : 0 <= i) a b n (Ha : 0 <= a < 2^n) + : Z.testbit (a + b << n) i + = if (i <? n) then Z.testbit a i else Z.testbit b (i - n). + Proof. + assert (0 < 2^n) by omega. + assert (0 <= n) by eauto 2 with zarith. + pose proof (Zlt_cases i n); break_match; autorewrite with Ztestbit; reflexivity. + Qed. + Hint Rewrite testbit_add_shiftl_full using zutil_arith : Ztestbit. + + Lemma land_add_land : forall n m a b, (m <= n)%nat -> + Z.land ((Z.land a (Z.ones (Z.of_nat n))) + (Z.shiftl b (Z.of_nat n))) (Z.ones (Z.of_nat m)) = Z.land a (Z.ones (Z.of_nat m)). + Proof. + intros n m a b H. + rewrite !Z.land_ones by apply Nat2Z.is_nonneg. + rewrite Z.shiftl_mul_pow2 by apply Nat2Z.is_nonneg. + replace (b * 2 ^ Z.of_nat n) with + ((b * 2 ^ Z.of_nat (n - m)) * 2 ^ Z.of_nat m) by + (rewrite (le_plus_minus m n) at 2; try assumption; + rewrite Nat2Z.inj_add, Z.pow_add_r by apply Nat2Z.is_nonneg; ring). + rewrite Z.mod_add by (pose proof (Z.pow_pos_nonneg 2 (Z.of_nat m)); omega). + symmetry. apply Znumtheory.Zmod_div_mod; try (apply Z.pow_pos_nonneg; omega). + rewrite (le_plus_minus m n) by assumption. + rewrite Nat2Z.inj_add, Z.pow_add_r by apply Nat2Z.is_nonneg. + apply Z.divide_factor_l. + Qed. + + Lemma shiftl_add x y z : 0 <= z -> (x + y) << z = (x << z) + (y << z). + Proof. intros; autorewrite with Zshift_to_pow; lia. Qed. + Hint Rewrite shiftl_add using zutil_arith : push_Zshift. + Hint Rewrite <- shiftl_add using zutil_arith : pull_Zshift. + + Lemma shiftr_add x y z : z <= 0 -> (x + y) >> z = (x >> z) + (y >> z). + Proof. intros; autorewrite with Zshift_to_pow; lia. Qed. + Hint Rewrite shiftr_add using zutil_arith : push_Zshift. + Hint Rewrite <- shiftr_add using zutil_arith : pull_Zshift. + + Lemma shiftl_sub x y z : 0 <= z -> (x - y) << z = (x << z) - (y << z). + Proof. intros; autorewrite with Zshift_to_pow; lia. Qed. + Hint Rewrite shiftl_sub using zutil_arith : push_Zshift. + Hint Rewrite <- shiftl_sub using zutil_arith : pull_Zshift. + + Lemma shiftr_sub x y z : z <= 0 -> (x - y) >> z = (x >> z) - (y >> z). + Proof. intros; autorewrite with Zshift_to_pow; lia. Qed. + Hint Rewrite shiftr_sub using zutil_arith : push_Zshift. + Hint Rewrite <- shiftr_sub using zutil_arith : pull_Zshift. + + Lemma compare_add_shiftl : forall x1 y1 x2 y2 n, 0 <= n -> + Z.pow2_mod x1 n = x1 -> Z.pow2_mod x2 n = x2 -> + x1 + (y1 << n) ?= x2 + (y2 << n) = + if Z.eq_dec y1 y2 + then x1 ?= x2 + else y1 ?= y2. + Proof. + repeat match goal with + | |- _ => progress intros + | |- _ => progress subst y1 + | |- _ => rewrite Z.shiftl_mul_pow2 by omega + | |- _ => rewrite Z.add_compare_mono_r + | |- _ => rewrite <-Z.mul_sub_distr_r + | |- _ => break_innermost_match_step + | H : Z.pow2_mod _ _ = _ |- _ => rewrite Z.pow2_mod_id_iff in H by omega + | H : ?a <> ?b |- _ = (?a ?= ?b) => + case_eq (a ?= b); rewrite ?Z.compare_eq_iff, ?Z.compare_gt_iff, ?Z.compare_lt_iff + | |- _ + (_ * _) > _ + (_ * _) => cbv [Z.gt] + | |- _ + (_ * ?x) < _ + (_ * ?x) => + apply Z.lt_sub_lt_add; apply Z.lt_le_trans with (m := 1 * x); [omega|] + | |- _ => apply Z.mul_le_mono_nonneg_r; omega + | |- _ => reflexivity + | |- _ => congruence + end. + Qed. + + Lemma shiftl_opp_l a n + : Z.shiftl (-a) n = - Z.shiftl a n - (if Z_zerop (a mod 2 ^ (- n)) then 0 else 1). + Proof. + destruct (Z_dec 0 n) as [ [?|?] | ? ]; + subst; + rewrite ?Z.pow_neg_r by omega; + autorewrite with zsimplify_const; + [ | | simpl; omega ]. + { rewrite !Z.shiftl_mul_pow2 by omega. + nia. } + { rewrite !Z.shiftl_div_pow2 by omega. + rewrite Z.div_opp_l_complete by auto with zarith. + reflexivity. } + Qed. + Hint Rewrite shiftl_opp_l : push_Zshift. + Hint Rewrite <- shiftl_opp_l : pull_Zshift. + + Lemma shiftr_opp_l a n + : Z.shiftr (-a) n = - Z.shiftr a n - (if Z_zerop (a mod 2 ^ n) then 0 else 1). + Proof. + unfold Z.shiftr; rewrite shiftl_opp_l at 1; rewrite Z.opp_involutive. + reflexivity. + Qed. + Hint Rewrite shiftr_opp_l : push_Zshift. + Hint Rewrite <- shiftr_opp_l : pull_Zshift. + + Lemma shl_shr_lt x y n m (Hx : 0 <= x < 2^n) (Hy : 0 <= y < 2^n) (Hm : 0 <= m <= n) + : 0 <= (x >> (n - m)) + ((y << m) mod 2^n) < 2^n. + Proof. + cut (0 <= (x >> (n - m)) + ((y << m) mod 2^n) <= 2^n - 1); [ omega | ]. + assert (0 <= x <= 2^n - 1) by omega. + assert (0 <= y <= 2^n - 1) by omega. + assert (0 < 2 ^ (n - m)) by auto with zarith. + assert (0 <= y mod 2 ^ (n - m) < 2^(n-m)) by auto with zarith. + assert (0 <= y mod 2 ^ (n - m) <= 2 ^ (n - m) - 1) by omega. + assert (0 <= (y mod 2 ^ (n - m)) * 2^m <= (2^(n-m) - 1)*2^m) by auto with zarith. + assert (0 <= x / 2^(n-m) < 2^n / 2^(n-m)). + { split; Z.zero_bounds. + apply Z.div_lt_upper_bound; autorewrite with pull_Zpow zsimplify; nia. } + autorewrite with Zshift_to_pow. + split; Z.zero_bounds. + replace (2^n) with (2^(n-m) * 2^m) by (autorewrite with pull_Zpow; f_equal; omega). + rewrite Zmult_mod_distr_r. + autorewrite with pull_Zpow zsimplify push_Zmul in * |- . + nia. + Qed. + + Lemma add_shift_mod x y n m + (Hx : 0 <= x < 2^n) (Hy : 0 <= y) + (Hn : 0 <= n) (Hm : 0 < m) + : (x + y << n) mod (m * 2^n) = x + (y mod m) << n. + Proof. + pose proof (Z.mod_bound_pos y m). + specialize_by omega. + assert (0 < 2^n) by auto with zarith. + autorewrite with Zshift_to_pow. + rewrite Zplus_mod, !Zmult_mod_distr_r. + rewrite Zplus_mod, !Zmod_mod, <- Zplus_mod. + rewrite !(Zmod_eq (_ + _)) by nia. + etransitivity; [ | apply Z.add_0_r ]. + rewrite <- !Z.add_opp_r, <- !Z.add_assoc. + repeat apply f_equal. + ring_simplify. + cut (((x + y mod m * 2 ^ n) / (m * 2 ^ n)) = 0); [ nia | ]. + apply Z.div_small; split; nia. + Qed. + + Lemma add_mul_mod x y n m + (Hx : 0 <= x < 2^n) (Hy : 0 <= y) + (Hn : 0 <= n) (Hm : 0 < m) + : (x + y * 2^n) mod (m * 2^n) = x + (y mod m) * 2^n. + Proof. + generalize (add_shift_mod x y n m). + autorewrite with Zshift_to_pow; auto. + Qed. + + Lemma lt_pow_2_shiftr : forall a n, 0 <= a < 2 ^ n -> a >> n = 0. + Proof. + intros a n H. + destruct (Z_le_dec 0 n). + + rewrite Z.shiftr_div_pow2 by assumption. + auto using Z.div_small. + + assert (2 ^ n = 0) by (apply Z.pow_neg_r; omega). + omega. + Qed. + + Hint Rewrite Z.pow2_bits_eqb using zutil_arith : Ztestbit. + Lemma pow_2_shiftr : forall n, 0 <= n -> (2 ^ n) >> n = 1. + Proof. + intros; apply Z.bits_inj'; intros. + replace 1 with (2 ^ 0) by ring. + repeat match goal with + | |- _ => progress intros + | |- _ => progress rewrite ?Z.eqb_eq, ?Z.eqb_neq in * + | |- _ => progress autorewrite with Ztestbit + | |- context[Z.eqb ?a ?b] => case_eq (Z.eqb a b) + | |- _ => reflexivity || omega + end. + Qed. + + Lemma lt_mul_2_pow_2_shiftr : forall a n, 0 <= a < 2 * 2 ^ n -> + a >> n = if Z_lt_dec a (2 ^ n) then 0 else 1. + Proof. + intros a n H; break_match; [ apply lt_pow_2_shiftr; omega | ]. + destruct (Z_le_dec 0 n). + + replace (2 * 2 ^ n) with (2 ^ (n + 1)) in * + by (rewrite Z.pow_add_r; try omega; ring). + pose proof (Z.shiftr_ones a (n + 1) n H). + pose proof (Z.shiftr_le (2 ^ n) a n). + specialize_by omega. + replace (n + 1 - n) with 1 in * by ring. + replace (Z.ones 1) with 1 in * by reflexivity. + rewrite pow_2_shiftr in * by omega. + omega. + + assert (2 ^ n = 0) by (apply Z.pow_neg_r; omega). + omega. + Qed. + + Lemma shiftr_nonneg_le : forall a n, 0 <= a -> 0 <= n -> a >> n <= a. + Proof. + intros. + repeat match goal with + | [ H : _ <= _ |- _ ] + => rewrite Z.lt_eq_cases in H + | [ H : _ \/ _ |- _ ] => destruct H + | _ => progress subst + | _ => progress autorewrite with zsimplify Zshift_to_pow + | _ => solve [ auto with zarith omega ] + end. + Qed. + Hint Resolve shiftr_nonneg_le : zarith. +End Z. |