From 0a556929568e0fc3255cc160fa4b35a75eb14f60 Mon Sep 17 00:00:00 2001 From: David Benjamin Date: Sat, 24 Feb 2018 18:20:12 -0500 Subject: Review comments. Major change is porting everything to Z and using Z.div_mod_to_quot_rem which is a handy sledgehammer. Z is also a nice simplification. Dealing with subtraction is tidier, though I do have 0 <= x goals everywhere as a result. --- src/Arithmetic/BarrettReduction/RidiculousFish.v | 526 ++++++++++------------- 1 file changed, 230 insertions(+), 296 deletions(-) (limited to 'src/Arithmetic') diff --git a/src/Arithmetic/BarrettReduction/RidiculousFish.v b/src/Arithmetic/BarrettReduction/RidiculousFish.v index b9efa1615..e3f00d7a9 100644 --- a/src/Arithmetic/BarrettReduction/RidiculousFish.v +++ b/src/Arithmetic/BarrettReduction/RidiculousFish.v @@ -1,8 +1,10 @@ Require Import Crypto.Util.Notations. -Require Import Coq.NArith.BinNat. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem. +Require Import Coq.ZArith.ZArith. Require Import Coq.micromega.Lia. -Open Scope N_scope. +Open Scope Z_scope. (** This file proves an implementation of the variant of Barrett reduction @@ -10,74 +12,70 @@ Open Scope N_scope. http://www.ridiculousfish.com/blog/posts/labor-of-division-episode-i.html *) -(* To simulate C behavior with fixed-width words, we introduce a word data type - and associated operations that automatically take a modulus. We'll then prove - the moduli are all no-ops (save one which is a subtraction). +(** To simulate C behavior with fixed-width words, we introduce a word data type + and associated operations that automatically take a modulus. We'll then + prove the moduli are all no-ops (save one which is a subtraction). - TODO(davidben): This does not perfectly match C's behavior. In C, operations - on types smaller than [int], notably [uint16_t], promote both types to [int] - first. Assuming a standard 32-bit [int], multiplying two [uint32_t]s is - always defined, but multiplying two [uint16_t]s may not be! However, we never - multiply [uint16_t]s, and we would otherwise need an unreasonably large - expression to overflow [int] by way of promoted [uint16_t] additions. *) + TODO(davidben): This does not perfectly match C's behavior. In C, operations + on types smaller than [int], notably [uint16_t], promote both types to [int] + first. Assuming a standard 32-bit [int], multiplying two [uint32_t]s is + always defined, but multiplying two [uint16_t]s may not be! However, we + never multiply [uint16_t]s, and we would otherwise need an unreasonably + large expression to overflow [int] by way of promoted [uint16_t] + additions. *) -(* First, the C data types will ultimately compile down to a bunch of extra - modulus operations, which we must remove. Define three primitive wrapping - operations which we'll instruct [cbn] to leave alone and match on in the - proof script. +(** First, the C data types will ultimately compile down to a bunch of extra + modulus operations, which we must remove. Define a primitive [wrap] + operation operations which we'll instruct [cbn] to leave alone and match on + in the proof script. *) +Definition wrap (bits val : Z) : Z := val mod 2^bits. - [wrap] is used when we expect [val] to already be reduced. [wrap'] is when we - expect one subtraction by [2^bits]. [sub_wrapped] can be implemented in terms - of [wrap'], but we make it a primitive so the proof script can treat it more - efficiently. *) -Definition wrap (bits val : N) : N := val mod 2^bits. -Definition wrap' := wrap. -Definition sub_wrapped (bits a b : N) : N := wrap' bits (a + 2^bits - b). +(** A [word] is a C data type with a specified bit size. *) +Inductive word (bits : Z) := +| Word : Z -> word bits. -(* A [word] is a C data type with a specified bit size. *) -Inductive word (bits : N) := -| Word : N -> word bits. - -Definition to_N {bits : N} (a : word bits) : N := +Definition to_Z {bits : Z} (a : word bits) : Z := match a with Word v => v end. -Definition to_word {bits : N} (val : N) := +Definition to_word {bits : Z} (val : Z) := Word bits (wrap bits val). Definition word_binop - (aw bw : N) - (op : N -> N -> N) - (a : word aw) (b : word bw) : word (N.max aw bw) := - to_word (op (to_N a) (to_N b)). - -Definition word_add {aw bw : N} := word_binop aw bw N.add. -Definition word_sub {aw bw : N} := word_binop aw bw (sub_wrapped (N.max aw bw)). -Definition word_mul {aw bw : N} := word_binop aw bw N.mul. -Definition word_div {aw bw : N} := word_binop aw bw N.div. - -Definition word_shiftl {aw bw : N} (a : word aw) (b : word bw) : word aw := - to_word (N.shiftl (to_N a) (to_N b)). -Definition word_shiftr {aw bw : N} (a : word aw) (b : word bw) : word aw := - to_word (N.shiftr (to_N a) (to_N b)). - -Definition word_cast (from to : N) (a : word from) : word to := - to_word (to_N a). + (aw bw : Z) + (op : Z -> Z -> Z) + (a : word aw) (b : word bw) : word (Z.max aw bw) := + to_word (op (to_Z a) (to_Z b)). + +Definition word_add {aw bw : Z} := word_binop aw bw Z.add. +Definition word_sub {aw bw : Z} := word_binop aw bw Z.sub. +Definition word_mul {aw bw : Z} := word_binop aw bw Z.mul. +Definition word_div {aw bw : Z} := word_binop aw bw Z.div. + +(** The type of a shift expression in C is the type of the first argument, not + the larger of the two type. *) +Definition word_shiftl {aw bw : Z} (a : word aw) (b : word bw) : word aw := + to_word (Z.shiftl (to_Z a) (to_Z b)). +Definition word_shiftr {aw bw : Z} (a : word aw) (b : word bw) : word aw := + to_word (Z.shiftr (to_Z a) (to_Z b)). + +Definition word_cast (from to : Z) (a : word from) : word to := + to_word (to_Z a). Definition u16 := word 16. Definition u32 := word 32. Definition u64 := word 64. -Definition to_u16 {bits : N} := word_cast bits 16. -Definition to_u32 {bits : N} := word_cast bits 32. -Definition to_u64 {bits : N} := word_cast bits 64. +Definition to_u16 {bits : Z} := word_cast bits 16. +Definition to_u32 {bits : Z} := word_cast bits 32. +Definition to_u64 {bits : Z} := word_cast bits 64. -(* [to_u32'] is identical to [to_u32], but it is expected to be a - subtraction of 2^32, rather than a no-op. *) -Definition to_u32' {bits : N} (a : word bits) : u32 := - Word 32 (wrap' 32 (to_N a)). +Definition Z_to_u16 := @to_word 16. +Definition Z_to_u32 := @to_word 32. +Definition Z_to_u64 := @to_word 64. -Definition N_to_u16 := @to_word 16. -Definition N_to_u32 := @to_word 32. -Definition N_to_u64 := @to_word 64. +(** [to_u32'] is identical to [to_u32], but it is expected to be a subtraction + of 2^32, rather than a no-op. *) +Definition to_u32' {bits : Z} (a : word bits) : u32 := + Z_to_u32 ((to_Z a) - 2^32). Delimit Scope word_scope with word. Bind Scope word_scope with word. @@ -87,14 +85,22 @@ Infix "*" := word_mul : word_scope. Infix "/" := word_div : word_scope. Infix "<<" := word_shiftl : word_scope. Infix ">>" := word_shiftr : word_scope. -Notation "1" := (N_to_u32 1) : word_scope. -Notation "32" := (N_to_u32 32) : word_scope. +Notation "1" := (Z_to_u32 1) : word_scope. +Notation "32" := (Z_to_u32 32) : word_scope. + +Definition num_bits (a : Z) : Z := + match a with + | Z0 => 0 + | Zpos p => Zpos (Pos.size p) + | Zneg _ => 0 (* A [u32] is always non-negative, so this does not come up. *) + end. Definition BN_num_bits_word (a : u32) : u32 := - N_to_u32 (N.size (to_N a)). + Z_to_u32 (num_bits (to_Z a)). -(* See [mod_u16] in https://boringssl-review.googlesource.com/#/c/boringssl/+/25887. *) +(** See [mod_u16] in + https://boringssl-review.googlesource.com/#/c/boringssl/+/25887. *) Definition mod_u16 (n : u32) (d : u16) : u16 := let p := to_u32 (BN_num_bits_word (to_u32 (d - 1))) in let m := to_u32' ((((to_u64 1) << (32 + p)) + d - 1) / d) in @@ -105,347 +111,275 @@ Definition mod_u16 (n : u32) (d : u16) : u16 := to_u16 n. -(* Proofs *) +(** Proofs *) -Lemma wrap_bound (a b : N) : - wrap a b < 2^a. +Lemma num_bits_log2_up (a : Z) : + num_bits (a - 1) = Z.log2_up a. +Proof. + unfold Z.log2_up, Z.log2, num_bits. + destruct (1 ?= a) eqn:H. + { (* [1 = a] *) + rewrite Z.compare_eq_iff in H. + rewrite <- H. + reflexivity. } + { (* [1 < a] *) + rewrite Z.compare_lt_iff in H. + assert (0 < a - 1) by lia. + rewrite Z.sub_1_r in *. + destruct (Z.pred a); try lia. + destruct p; cbn; lia. } + { (* [1 > a] *) + rewrite Z.compare_gt_iff in H. + assert (a - 1 < 0) by lia. + destruct (a - 1); lia. } +Qed. + +Lemma wrap_bound (a b : Z) (H : 0 < a) : + 0 <= wrap a b < 2^a. Proof. unfold wrap. - apply N.mod_upper_bound. - apply N.pow_nonzero. + Z.div_mod_to_quot_rem. lia. Qed. -Lemma remove_outer_wrap (a a' b : N) (H : a' <= a) : +Lemma remove_outer_wrap (a a' b : Z) (H : 0 < a' <= a) : wrap a (wrap a' b) = wrap a' b. Proof. unfold wrap. - apply N.mod_small. - assert (b mod 2^a' < 2^a') by (apply N.mod_upper_bound; apply N.pow_nonzero; lia). - assert (2^a' <= 2^a) by (apply N.pow_le_mono_r; lia). + apply Z.mod_small. + Z.div_mod_to_quot_rem. + assert (2^a' <= 2^a) by (apply Z.pow_le_mono_r; lia). lia. Qed. -Lemma remove_inner_wrap (a a' b : N) (H : a' >= a) : +Lemma divide_pow (a b c : Z) (Ha : 0 < a) (Hb : 0 <= b <= c) : + (a^b | a^c). +Proof. + exists (a^(c - b)). + rewrite <- Z.pow_add_r; try f_equal; lia. +Qed. + +Lemma remove_inner_wrap (a a' b : Z) (H : 0 < a <= a') : wrap a (wrap a' b) = wrap a b. Proof. unfold wrap. - rewrite <- (N.sub_add a a') by lia. - rewrite N.pow_add_r. - rewrite N.mul_comm. - rewrite N.mod_mul_r by (apply N.pow_nonzero; lia). - rewrite N.mul_comm. - rewrite N.mod_add by (apply N.pow_nonzero; lia). - apply N.mod_mod. - apply N.pow_nonzero. - lia. + rewrite (Znumtheory.Zmod_div_mod (2^a) (2^a') b); + try apply Z.pow_pos_nonneg; + try apply divide_pow; + lia. Qed. -Lemma double_wrap (a a' b : N) : - wrap a (wrap a' b) = wrap (N.min a a') b. +Lemma double_wrap (a a' b : Z) (Ha : 0 < a) (Ha' : 0 < a') : + wrap a (wrap a' b) = wrap (Z.min a a') b. Proof. destruct (a = b) : +Lemma a_minus_b_div_2_plus_b (a b : Z) (H : a >= b) : (a - b) / 2 + b = (a + b) / 2. -Proof. - rewrite <- N.div_add by lia. - f_equal. - lia. -Qed. - -Lemma p_bound (d bits : N) (H : 1 < d < 2^bits) : - 1 <= N.size (d - 1) <= bits. -Proof. - rewrite N.size_log2 by lia. - assert (N.log2 (d - 1) < bits) by (apply N.log2_lt_pow2; lia). - lia. -Qed. +Proof. Z.div_mod_to_quot_rem; lia. Qed. -Lemma ceil_log2_lt (d : N) (H : 1 < d) : - 2^(N.size (d - 1)) < d + d. +Lemma p_bound (d bits : Z) (Hbits : 0 < bits) (H : 1 < d < 2^bits) : + 0 < Z.log2_up d <= bits. Proof. - pose proof (N.size_le (d - 1)). - rewrite N.succ_double_spec in *. - lia. + split. + - apply Z.log2_up_pos; lia. + - apply Z.log2_up_le_pow2; lia. Qed. -Lemma ceil_log2_ge (d : N) (H : 1 < d) : - 2^(N.size (d - 1)) >= d. +Lemma pow2_log2_up_bounds (d : Z) (H : 1 < d) : + d <= 2^(Z.log2_up d) < d + d. Proof. - pose proof (N.size_gt (d - 1)). + pose proof (Z.log2_up_spec d H). + assert (2 ^ Z.log2_up d = 2 * (2 ^ Z.pred (Z.log2_up d))). + { rewrite <- Z.pow_succ_r by (apply Z.lt_le_pred; apply Z.log2_up_pos; lia). + f_equal. + lia. } lia. Qed. -Lemma m_bound (d bits : N) (H : 1 < d < 2^bits) : - 2^bits <= (2^(bits + N.size (d - 1)) + d - 1) / d < 2 * 2^bits. +Lemma m_bound (d bits : Z) (Hbits : 0 < bits) (H : 1 < d < 2^bits) : + 2^bits <= (2^(bits + Z.log2_up d) + d - 1) / d < 2 * 2^bits. Proof. - rewrite !N.pow_add_r. + rewrite !Z.pow_add_r by (auto with zarith). + pose proof (pow2_log2_up_bounds d (proj1 H)). split. - { (* 2^e <= m *) - apply N.div_le_lower_bound; try lia. - assert (2^(N.size (d - 1)) >= d). - { apply ceil_log2_ge; lia. } - nia. } - { (* m < 2 * 2^e *) - apply N.div_lt_upper_bound; try lia. - assert (2^(N.size (d - 1)) < d + d). - { apply ceil_log2_lt; lia. } - (* Subtraction is saturating, so nia needs a bit of help to group - [d - 1] together. *) - rewrite <- N.add_sub_assoc by lia. - nia. } -Qed. - -Lemma mod_sub_once (a b : N) (H : b <= a < 2*b) : - a mod b = a - b. -Proof. - assert (1 = a / b). - { apply (N.div_unique a b 1 (a - b)); lia. } - rewrite N.mod_eq; nia. + - apply Z.div_le_lower_bound; nia. + - apply Z.div_lt_upper_bound; nia. Qed. -Lemma division_shrinks (a b c : N) (Hnz : b <> 0) (H : a < c) : +Lemma division_shrinks (a b c : Z) (Hb : 0 < b) (H : 0 <= a < c) : a / b < c. -Proof. - apply N.div_lt_upper_bound; try lia. - nia. -Qed. +Proof. Z.div_mod_to_quot_rem; nia. Qed. -Lemma mul_div_ceil_ge (a b : N) (H : b <> 0) : +Lemma mul_div_ceil_ge (a b : Z) (Hb : 0 < b) : a <= b * ((a + b - 1) / b). -Proof. - rewrite (N.div_mod a b) by lia. - set (q := a / b). - set (r := a mod b). - assert (r < b) by (apply N.mod_upper_bound; lia). - destruct (r =? 0) eqn:Hr. - { (* r = 0 *) - apply N.eqb_eq in Hr. - replace r with 0. - rewrite !N.add_0_r. - replace ((b * q + b - 1) / b) with q by - (rewrite <- N.add_sub_assoc by lia; - apply (N.div_unique _ b q (b - 1)); - lia). - lia. } - { (* r != 0 *) - apply N.eqb_neq in Hr. - replace ((b * q + r + b - 1) / b) with (q + 1) by - (replace (b * q + r + b - 1) with (b * (q + 1) + (r - 1)) by lia; - apply (N.div_unique _ b (q + 1) (r - 1)); - lia). - lia. } -Qed. +Proof. Z.div_mod_to_quot_rem; nia. Qed. -(* This is the main theorem this algorithm relies on. *) -Lemma division_algorithm (bits n d k : N) - (Hn : n < 2^bits) +(** This is the main theorem this algorithm relies on. *) +Lemma division_algorithm (bits n d k : Z) + (Hbits : 0 < bits) + (Hn : 0 <= n < 2^bits) (Hd : 1 < d) - (Hk : bits + N.size (d - 1) <= k) : + (Hk : bits + Z.log2_up d <= k) : n * ((2^k + d - 1) / d) / 2^k = n / d. Proof. + (* Help [lia] and [nia] out. *) + pose proof (Z.log2_up_nonneg d). + pose proof (pow2_log2_up_bounds d Hd). + assert (0 < 2^k) by (apply Z.pow_pos_nonneg; lia). + assert (n/d <= n * ((2^k + d - 1) / d) / 2^k). { (* This direction is true independent of [Hk]. *) - apply N.div_le_lower_bound; try apply N.pow_nonzero; try lia. - rewrite N.div_mul_le by lia. - apply N.div_le_upper_bound; try apply N.pow_nonzero; try lia. - rewrite (N.mul_comm n _). - rewrite N.mul_assoc. assert (2^k <= d * ((2 ^ k + d - 1) / d)) by (apply mul_div_ceil_ge; lia). + apply Z.div_le_lower_bound; try lia. + rewrite Z.div_mul_le by (try auto with zarith; lia). + apply Z.div_le_upper_bound; try lia. nia. } assert (n * ((2^k + d - 1) / d) / 2^k <= n / d). - { assert (div_mul_le_inner : n * ((2 ^ k + d - 1) / d) / 2^k <= - (n * (2 ^ k + d - 1) / d) / 2 ^ k). - { apply N.div_le_mono; try (apply N.pow_nonzero; lia). - apply N.div_mul_le. - lia. } + { assert (div_mul_le_inner : n * ((2^k + d - 1) / d) / 2^k <= + (n * (2^k + d - 1) / d) / 2^k). + { apply Z.div_le_mono; try apply Z.div_mul_le; lia. } rewrite div_mul_le_inner. replace (n * (2^k + d - 1)) with (n * 2^k + n * (d - 1)) by nia. (* Swap [d] and [2^k] in the denominator. *) - rewrite N.div_div by (try apply N.pow_nonzero; lia). - rewrite (N.mul_comm d (2^k)). - rewrite <- N.div_div by (try apply N.pow_nonzero; lia). + rewrite Z.div_div by lia. + rewrite (Z.mul_comm d (2^k)). + rewrite <- Z.div_div by lia. (* We now can extract the [n * 2^k] term from the inner division. *) - rewrite N.div_add_l by (try apply N.pow_nonzero; lia). + rewrite Z.div_add_l by lia. (* We show the error term is zero, which allows the inequality to hold. *) assert (zero_error : n * (d - 1) / 2 ^ k = 0). - { apply N.div_small. - assert (2^bits * 2^(N.size (d - 1)) <= 2^k) by - (rewrite <- N.pow_add_r; apply N.pow_le_mono_r; lia). - assert (d <= 2^(N.size (d - 1))) by (apply N.ge_le; apply ceil_log2_ge; lia). + { apply Z.div_small. + assert (2^bits * 2^(Z.log2_up d) <= 2^k) by + (rewrite <- Z.pow_add_r; try apply Z.pow_le_mono_r; lia). nia. } rewrite zero_error. - rewrite N.add_0_r. - lia. } + auto with zarith. } lia. Qed. -(* Peel off one innermost [wrap], and replace it with the expected - bounds. Before doing so, introduce a lemma into the context that - the value is bounded. This is allows later goals to make use of - earlier bounds. *) +(** Peel off one innermost [wrap], and replace it with the expected + bounds. Before doing so, introduce a lemma into the context that the value + is bounded. This is allows later goals to make use of earlier bounds. + + TODO(davidben): Most of the slowness comes from these intermediate + hypotheses. The proofs need them, but only those one or two layers down. Can + we be smarter about retaining them? *) Local Ltac replace_innermost_wrap expr := match expr with | context[wrap ?x ?y] => first [ replace_innermost_wrap y | - assert ((wrap x y) < 2^x) by (apply wrap_bound); + assert (0 <= (wrap x y) < 2^x) by (apply wrap_bound; lia); replace (wrap x y) with y in * ] - | context[wrap' ?x ?y] => - first [ replace_innermost_wrap y | - assert ((wrap' x y) < 2^x) by (unfold wrap'; apply wrap_bound); - replace (wrap' x y) with (y - 2^x) in * ] - | context[sub_wrapped ?x ?y ?z] => - first [ replace_innermost_wrap y | - replace_innermost_wrap z | - assert ((sub_wrapped x y z) < 2^x) by (unfold sub_wrapped, wrap'; - apply wrap_bound); - replace (sub_wrapped x y z) with (y - z) in * ] end. -Theorem mod_u16_spec (n d : N) (Hn : n < 2^32) (Hd : 1 < d < 2^16) : - to_N (mod_u16 (N_to_u32 n) (N_to_u16 d)) = n mod d. +(** Finally, the correctness proof. *) +Theorem mod_u16_spec (n d : Z) (Hn : 0 <= n < 2^32) (Hd : 1 < d < 2^16) : + to_Z (mod_u16 (Z_to_u32 n) (Z_to_u16 d)) = n mod d. Proof. unfold mod_u16. (* Evaluate away everything except for exponentations, which are - still useful, and our [wrap], etc., metadata. *) - cbn -[N.pow wrap wrap' sub_wrapped]. + still useful, and our [wrap] metadata. *) + cbn -[Z.pow num_bits wrap]. (* The solver needs some help getting these in the context. *) - assert (1 <= N.size (d - 1) <= 16) by (apply p_bound; lia). - assert (2^(N.size (d - 1)) < d + d) by (apply ceil_log2_lt; lia). - assert (2^(N.size (d - 1)) < 2 ^ 17) by (unfold_exponents; lia). - assert (2^32 <= (2^(32 + N.size (d - 1)) + d - 1) / d < 2 * 2^32) by - (apply m_bound; unfold_exponents; lia). + assert (Hlog : 0 < Z.log2_up d <= 16) by (apply p_bound; lia). + assert (d <= 2^(Z.log2_up d) < d + d) by (apply pow2_log2_up_bounds; lia). + assert (2^(Z.log2_up d) < 2 ^ 17) by lia. + assert (2^32 <= (2^32 * 2^(Z.log2_up d) + d - 1) / d < 2 * 2^32) by + (rewrite <- Z.pow_add_r by lia; apply m_bound; lia). (* Convert shifts to multiplication and divison. *) - rewrite !N.shiftl_mul_pow2. - rewrite !N.shiftr_div_pow2. - rewrite !N.mul_1_l. - rewrite !N.pow_1_r. + rewrite !Z.shiftl_mul_pow2 by (apply wrap_bound; lia). + rewrite !Z.shiftr_div_pow2 by (apply wrap_bound; lia). + rewrite !Z.mul_1_l. + rewrite !Z.pow_1_r. (* Replace all the [to_word] calls with their expected bounds. *) repeat match goal with | [ |- ?x = n mod d ] => match goal with - (* Evaluate away the [N.min]s. *) - | _ => progress cbn -[N.pow N.add N.mul wrap wrap'] - (* Remove obviously pointless casts, before the solver expands too - much. *) - | [ |- context[wrap ?x (wrap' ?x ?y)] ] => - replace (wrap x (wrap' x y)) with (wrap' x y) by - (unfold wrap'; symmetry; apply remove_outer_wrap; lia) - | [ |- context[wrap ?x (sub_wrapped ?x ?y ?z)] ] => - replace (wrap x (sub_wrapped x y z)) with (sub_wrapped x y z) by - (unfold sub_wrapped, wrap'; - symmetry; - apply remove_outer_wrap; - lia) + (* Evaluate away the [Z.min]s. *) + | _ => progress cbn -[Z.pow Z.add Z.mul wrap] + (* Remove obviously pointless casts. The solver could handle them, + but this cuts down on the number of generated hypotheses. *) | [ |- context[wrap ?x (wrap ?x' ?y)] ] => - replace (wrap x (wrap x' y)) with (wrap (N.min x x') y) by + replace (wrap x (wrap x' y)) with (wrap (Z.min x x') y) by (symmetry; apply double_wrap; lia) - (* Perform [a_minus_b_div_2_plus_b] earlier, so the - generated hypotheses match too. *) + | [ |- context[wrap ?x n] ] => + replace (wrap x n) with n by (unfold wrap; rewrite Z.mod_small; lia) + | [ |- context[wrap ?x d] ] => + replace (wrap x d) with d by (unfold wrap; rewrite Z.mod_small; lia) + | [ |- context[wrap ?x (Zpos ?p)] ] => + replace (wrap x (Zpos p)) with (Zpos p) by reflexivity + (* Perform rewrites as soon as possible, so the generated hypotheses + match too. *) | _ => rewrite a_minus_b_div_2_plus_b + | _ => rewrite num_bits_log2_up + | _ => rewrite Z.pow_add_r by lia | _ => replace_innermost_wrap x end - | _ => rewrite sub_wrapped_noop end; - repeat unfold sub_wrapped, wrap', wrap in *; + repeat unfold wrap in *; repeat match goal with (* Normalize a few things. *) - | _ => apply N.le_ge - | _ => apply N.lt_gt + | _ => apply Z.le_ge + | _ => apply Z.lt_gt | [ |- _ = _ mod _ ] => symmetry (* Apply these transforms everywhere. *) - | _ => apply N.mod_small + | _ => apply Z.mod_small | _ => split + (* Everything is positive, so these tend to be easy. *) + | [ |- 0 <= _ * _ ] => nia (* Don't touch the goals that involve the final division expression. We'll dispatch them separately. *) | [ |- context[d * (_ / 2^_)] ] => fail 1 - (* The solver needs some help for this one. *) - | [ _ : ?x < ?y |- ?x / _ < ?y ] => apply division_shrinks - (* Smash the bounds subgoals away. *) - | _ => apply mod_sub_once - | _ => rewrite N.pow_add_r in * - | _ => apply N.div_le_upper_bound - | _ => apply N.div_lt_upper_bound - | _ => apply N.pow_nonzero - | _ => apply N.mul_le_mono_r - | _ => progress unfold_exponents - (* Erase irrelevant hypotheses before we start running the - slow solvers. *) - | [ H : context[n] |- _ ] => match goal with - | [ |- context[n] ] => fail 1 - | _ => clear H - end - | [ H : context[d] |- _ ] => match goal with - | [ |- context[d] ] => fail 1 - | _ => clear H - end + (* The solver needs some help for this one, otherwise the [Z.div_*] + theorems below break apart the numerator term. *) + | [ _ : 0 <= ?x < ?y |- ?x / _ < ?y ] => apply division_shrinks + (* Simplify the various inequalities. *) + | _ => apply Z.div_le_upper_bound + | _ => apply Z.div_lt_upper_bound + | _ => apply Z.div_le_lower_bound + | _ => apply Z.pow_nonneg + | _ => apply Z.pow_pos_nonneg + (* Only apply in one direction. *) + | [ |- 0 <= _ - _ ] => apply Z.le_0_sub + | _ => apply Z.mul_le_mono_nonneg_r (* [nia] is rather slow, and only one case needs it. *) - | [ _ : ?x < _, _ : ?y < _ |- ?x * ?y < ?z ] => nia + | [ _ : 0 <= ?x < _, _ : 0 <= ?y < _ |- ?x * ?y < ?z ] => nia | _ => lia end; + (* Clear the now unnecessary intermediate hypotheses. *) + clear - Hn Hd Hlog; + (* Invert the various transforms to get to the main theorem. *) repeat match goal with - | [ |- _ <= N.size _ ] => rewrite N.size_log2; lia - | [ |- context[?x * ?y + ?z * ?x]] => rewrite (N.mul_comm z x) - | [ |- context[?x + ?y - ?x]] => rewrite (N.add_comm x y) - | _ => rewrite N.div_div - | _ => apply N.pow_nonzero - | _ => rewrite <- N.pow_succ_r' - | _ => rewrite <- N.add_1_r - | _ => rewrite <- N.pow_add_r - | _ => rewrite <- N.mul_add_distr_l - | _ => rewrite N.sub_add - | _ => rewrite N.add_sub - | _ => rewrite N.add_sub_assoc - | _ => rewrite <- N.div_add_l - | _ => lia + | _ => rewrite Z.div_div by (try apply Z.pow_pos_nonneg; lia) + | _ => rewrite <- Z.pow_succ_r by lia + | _ => rewrite <- Z.add_1_r + | _ => rewrite Z.sub_add + | _ => rewrite Zplus_minus + | _ => rewrite <- Z.pow_add_r by lia + | _ => rewrite <- Z.mul_add_distr_l + | _ => rewrite <- Z.div_add_l by lia + | [ |- context[?x * ?y + ?z * ?x] ] => rewrite (Z.mul_comm z x) end; (* And finish with the main theorem. *) rewrite (division_algorithm 32) by lia; - try rewrite <- N.mod_eq; - try apply N.mul_div_le; - (* [lia] needs some help on these two. *) - try match goal with - | [ |- ?x mod ?y < 2^_ ] => - assert (x mod y < y) by (apply N.mod_upper_bound; lia) - | [ |- ?x * (?y / ?x) < 2^_ ] => - assert (x * (y / x) <= y) by (apply N.mul_div_le; lia) - end; - lia. + Z.div_mod_to_quot_rem; + nia. Qed. -- cgit v1.2.3