aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic
diff options
context:
space:
mode:
authorGravatar David Benjamin <davidben@google.com>2018-02-24 18:20:12 -0500
committerGravatar Andres Erbsen <andreser@mit.edu>2018-03-09 18:02:01 -0500
commit0a556929568e0fc3255cc160fa4b35a75eb14f60 (patch)
tree0654f54711c6ca1d0ff0c5d2e2e319f3915a5c8f /src/Arithmetic
parent22f92f15e8b42cdb9db06a421986a36f4a76d05a (diff)
Review comments.
Major change is porting everything to Z and using Z.div_mod_to_quot_rem which is a handy sledgehammer. Z is also a nice simplification. Dealing with subtraction is tidier, though I do have 0 <= x goals everywhere as a result.
Diffstat (limited to 'src/Arithmetic')
-rw-r--r--src/Arithmetic/BarrettReduction/RidiculousFish.v526
1 files changed, 230 insertions, 296 deletions
diff --git a/src/Arithmetic/BarrettReduction/RidiculousFish.v b/src/Arithmetic/BarrettReduction/RidiculousFish.v
index b9efa1615..e3f00d7a9 100644
--- a/src/Arithmetic/BarrettReduction/RidiculousFish.v
+++ b/src/Arithmetic/BarrettReduction/RidiculousFish.v
@@ -1,8 +1,10 @@
Require Import Crypto.Util.Notations.
-Require Import Coq.NArith.BinNat.
+Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem.
+Require Import Coq.ZArith.ZArith.
Require Import Coq.micromega.Lia.
-Open Scope N_scope.
+Open Scope Z_scope.
(** This file proves an implementation of the variant of Barrett reduction
@@ -10,74 +12,70 @@ Open Scope N_scope.
http://www.ridiculousfish.com/blog/posts/labor-of-division-episode-i.html *)
-(* To simulate C behavior with fixed-width words, we introduce a word data type
- and associated operations that automatically take a modulus. We'll then prove
- the moduli are all no-ops (save one which is a subtraction).
+(** To simulate C behavior with fixed-width words, we introduce a word data type
+ and associated operations that automatically take a modulus. We'll then
+ prove the moduli are all no-ops (save one which is a subtraction).
- TODO(davidben): This does not perfectly match C's behavior. In C, operations
- on types smaller than [int], notably [uint16_t], promote both types to [int]
- first. Assuming a standard 32-bit [int], multiplying two [uint32_t]s is
- always defined, but multiplying two [uint16_t]s may not be! However, we never
- multiply [uint16_t]s, and we would otherwise need an unreasonably large
- expression to overflow [int] by way of promoted [uint16_t] additions. *)
+ TODO(davidben): This does not perfectly match C's behavior. In C, operations
+ on types smaller than [int], notably [uint16_t], promote both types to [int]
+ first. Assuming a standard 32-bit [int], multiplying two [uint32_t]s is
+ always defined, but multiplying two [uint16_t]s may not be! However, we
+ never multiply [uint16_t]s, and we would otherwise need an unreasonably
+ large expression to overflow [int] by way of promoted [uint16_t]
+ additions. *)
-(* First, the C data types will ultimately compile down to a bunch of extra
- modulus operations, which we must remove. Define three primitive wrapping
- operations which we'll instruct [cbn] to leave alone and match on in the
- proof script.
+(** First, the C data types will ultimately compile down to a bunch of extra
+ modulus operations, which we must remove. Define a primitive [wrap]
+ operation operations which we'll instruct [cbn] to leave alone and match on
+ in the proof script. *)
+Definition wrap (bits val : Z) : Z := val mod 2^bits.
- [wrap] is used when we expect [val] to already be reduced. [wrap'] is when we
- expect one subtraction by [2^bits]. [sub_wrapped] can be implemented in terms
- of [wrap'], but we make it a primitive so the proof script can treat it more
- efficiently. *)
-Definition wrap (bits val : N) : N := val mod 2^bits.
-Definition wrap' := wrap.
-Definition sub_wrapped (bits a b : N) : N := wrap' bits (a + 2^bits - b).
+(** A [word] is a C data type with a specified bit size. *)
+Inductive word (bits : Z) :=
+| Word : Z -> word bits.
-(* A [word] is a C data type with a specified bit size. *)
-Inductive word (bits : N) :=
-| Word : N -> word bits.
-
-Definition to_N {bits : N} (a : word bits) : N :=
+Definition to_Z {bits : Z} (a : word bits) : Z :=
match a with Word v => v end.
-Definition to_word {bits : N} (val : N) :=
+Definition to_word {bits : Z} (val : Z) :=
Word bits (wrap bits val).
Definition word_binop
- (aw bw : N)
- (op : N -> N -> N)
- (a : word aw) (b : word bw) : word (N.max aw bw) :=
- to_word (op (to_N a) (to_N b)).
-
-Definition word_add {aw bw : N} := word_binop aw bw N.add.
-Definition word_sub {aw bw : N} := word_binop aw bw (sub_wrapped (N.max aw bw)).
-Definition word_mul {aw bw : N} := word_binop aw bw N.mul.
-Definition word_div {aw bw : N} := word_binop aw bw N.div.
-
-Definition word_shiftl {aw bw : N} (a : word aw) (b : word bw) : word aw :=
- to_word (N.shiftl (to_N a) (to_N b)).
-Definition word_shiftr {aw bw : N} (a : word aw) (b : word bw) : word aw :=
- to_word (N.shiftr (to_N a) (to_N b)).
-
-Definition word_cast (from to : N) (a : word from) : word to :=
- to_word (to_N a).
+ (aw bw : Z)
+ (op : Z -> Z -> Z)
+ (a : word aw) (b : word bw) : word (Z.max aw bw) :=
+ to_word (op (to_Z a) (to_Z b)).
+
+Definition word_add {aw bw : Z} := word_binop aw bw Z.add.
+Definition word_sub {aw bw : Z} := word_binop aw bw Z.sub.
+Definition word_mul {aw bw : Z} := word_binop aw bw Z.mul.
+Definition word_div {aw bw : Z} := word_binop aw bw Z.div.
+
+(** The type of a shift expression in C is the type of the first argument, not
+ the larger of the two type. *)
+Definition word_shiftl {aw bw : Z} (a : word aw) (b : word bw) : word aw :=
+ to_word (Z.shiftl (to_Z a) (to_Z b)).
+Definition word_shiftr {aw bw : Z} (a : word aw) (b : word bw) : word aw :=
+ to_word (Z.shiftr (to_Z a) (to_Z b)).
+
+Definition word_cast (from to : Z) (a : word from) : word to :=
+ to_word (to_Z a).
Definition u16 := word 16.
Definition u32 := word 32.
Definition u64 := word 64.
-Definition to_u16 {bits : N} := word_cast bits 16.
-Definition to_u32 {bits : N} := word_cast bits 32.
-Definition to_u64 {bits : N} := word_cast bits 64.
+Definition to_u16 {bits : Z} := word_cast bits 16.
+Definition to_u32 {bits : Z} := word_cast bits 32.
+Definition to_u64 {bits : Z} := word_cast bits 64.
-(* [to_u32'] is identical to [to_u32], but it is expected to be a
- subtraction of 2^32, rather than a no-op. *)
-Definition to_u32' {bits : N} (a : word bits) : u32 :=
- Word 32 (wrap' 32 (to_N a)).
+Definition Z_to_u16 := @to_word 16.
+Definition Z_to_u32 := @to_word 32.
+Definition Z_to_u64 := @to_word 64.
-Definition N_to_u16 := @to_word 16.
-Definition N_to_u32 := @to_word 32.
-Definition N_to_u64 := @to_word 64.
+(** [to_u32'] is identical to [to_u32], but it is expected to be a subtraction
+ of 2^32, rather than a no-op. *)
+Definition to_u32' {bits : Z} (a : word bits) : u32 :=
+ Z_to_u32 ((to_Z a) - 2^32).
Delimit Scope word_scope with word.
Bind Scope word_scope with word.
@@ -87,14 +85,22 @@ Infix "*" := word_mul : word_scope.
Infix "/" := word_div : word_scope.
Infix "<<" := word_shiftl : word_scope.
Infix ">>" := word_shiftr : word_scope.
-Notation "1" := (N_to_u32 1) : word_scope.
-Notation "32" := (N_to_u32 32) : word_scope.
+Notation "1" := (Z_to_u32 1) : word_scope.
+Notation "32" := (Z_to_u32 32) : word_scope.
+
+Definition num_bits (a : Z) : Z :=
+ match a with
+ | Z0 => 0
+ | Zpos p => Zpos (Pos.size p)
+ | Zneg _ => 0 (* A [u32] is always non-negative, so this does not come up. *)
+ end.
Definition BN_num_bits_word (a : u32) : u32 :=
- N_to_u32 (N.size (to_N a)).
+ Z_to_u32 (num_bits (to_Z a)).
-(* See [mod_u16] in https://boringssl-review.googlesource.com/#/c/boringssl/+/25887. *)
+(** See [mod_u16] in
+ https://boringssl-review.googlesource.com/#/c/boringssl/+/25887. *)
Definition mod_u16 (n : u32) (d : u16) : u16 :=
let p := to_u32 (BN_num_bits_word (to_u32 (d - 1))) in
let m := to_u32' ((((to_u64 1) << (32 + p)) + d - 1) / d) in
@@ -105,347 +111,275 @@ Definition mod_u16 (n : u32) (d : u16) : u16 :=
to_u16 n.
-(* Proofs *)
+(** Proofs *)
-Lemma wrap_bound (a b : N) :
- wrap a b < 2^a.
+Lemma num_bits_log2_up (a : Z) :
+ num_bits (a - 1) = Z.log2_up a.
+Proof.
+ unfold Z.log2_up, Z.log2, num_bits.
+ destruct (1 ?= a) eqn:H.
+ { (* [1 = a] *)
+ rewrite Z.compare_eq_iff in H.
+ rewrite <- H.
+ reflexivity. }
+ { (* [1 < a] *)
+ rewrite Z.compare_lt_iff in H.
+ assert (0 < a - 1) by lia.
+ rewrite Z.sub_1_r in *.
+ destruct (Z.pred a); try lia.
+ destruct p; cbn; lia. }
+ { (* [1 > a] *)
+ rewrite Z.compare_gt_iff in H.
+ assert (a - 1 < 0) by lia.
+ destruct (a - 1); lia. }
+Qed.
+
+Lemma wrap_bound (a b : Z) (H : 0 < a) :
+ 0 <= wrap a b < 2^a.
Proof.
unfold wrap.
- apply N.mod_upper_bound.
- apply N.pow_nonzero.
+ Z.div_mod_to_quot_rem.
lia.
Qed.
-Lemma remove_outer_wrap (a a' b : N) (H : a' <= a) :
+Lemma remove_outer_wrap (a a' b : Z) (H : 0 < a' <= a) :
wrap a (wrap a' b) = wrap a' b.
Proof.
unfold wrap.
- apply N.mod_small.
- assert (b mod 2^a' < 2^a') by (apply N.mod_upper_bound; apply N.pow_nonzero; lia).
- assert (2^a' <= 2^a) by (apply N.pow_le_mono_r; lia).
+ apply Z.mod_small.
+ Z.div_mod_to_quot_rem.
+ assert (2^a' <= 2^a) by (apply Z.pow_le_mono_r; lia).
lia.
Qed.
-Lemma remove_inner_wrap (a a' b : N) (H : a' >= a) :
+Lemma divide_pow (a b c : Z) (Ha : 0 < a) (Hb : 0 <= b <= c) :
+ (a^b | a^c).
+Proof.
+ exists (a^(c - b)).
+ rewrite <- Z.pow_add_r; try f_equal; lia.
+Qed.
+
+Lemma remove_inner_wrap (a a' b : Z) (H : 0 < a <= a') :
wrap a (wrap a' b) = wrap a b.
Proof.
unfold wrap.
- rewrite <- (N.sub_add a a') by lia.
- rewrite N.pow_add_r.
- rewrite N.mul_comm.
- rewrite N.mod_mul_r by (apply N.pow_nonzero; lia).
- rewrite N.mul_comm.
- rewrite N.mod_add by (apply N.pow_nonzero; lia).
- apply N.mod_mod.
- apply N.pow_nonzero.
- lia.
+ rewrite (Znumtheory.Zmod_div_mod (2^a) (2^a') b);
+ try apply Z.pow_pos_nonneg;
+ try apply divide_pow;
+ lia.
Qed.
-Lemma double_wrap (a a' b : N) :
- wrap a (wrap a' b) = wrap (N.min a a') b.
+Lemma double_wrap (a a' b : Z) (Ha : 0 < a) (Ha' : 0 < a') :
+ wrap a (wrap a' b) = wrap (Z.min a a') b.
Proof.
destruct (a <? a') eqn:H.
- { apply N.ltb_lt in H.
- rewrite N.min_l by lia.
+ { apply Z.ltb_lt in H.
+ rewrite Z.min_l by lia.
apply remove_inner_wrap.
lia. }
- { apply N.ltb_ge in H.
- rewrite N.min_r by lia.
+ { apply Z.ltb_ge in H.
+ rewrite Z.min_r by lia.
apply remove_outer_wrap.
lia. }
Qed.
-(* [lia] and [nia] need some help with exponents, and [cbn] alone
- often unfolds too much. *)
-Ltac unfold_exponents :=
- cbn [N.pow BinPos.Pos.pow BinPos.Pos.iter BinPos.Pos.mul] in *.
-
-Lemma sub_wrapped_noop (bits a b : N) (H : b <= a < 2^bits) :
- sub_wrapped bits a b = a - b.
-Proof.
- unfold sub_wrapped, wrap', wrap.
- rewrite N.add_comm.
- rewrite <- N.add_sub_assoc by lia.
- rewrite N.add_mod by lia.
- rewrite N.mod_same by lia.
- cbn.
- rewrite N.mod_mod by lia.
- apply N.mod_small.
- lia.
-Qed.
-
-Lemma a_minus_b_div_2_plus_b (a b : N) (H : a >= b) :
+Lemma a_minus_b_div_2_plus_b (a b : Z) (H : a >= b) :
(a - b) / 2 + b = (a + b) / 2.
-Proof.
- rewrite <- N.div_add by lia.
- f_equal.
- lia.
-Qed.
-
-Lemma p_bound (d bits : N) (H : 1 < d < 2^bits) :
- 1 <= N.size (d - 1) <= bits.
-Proof.
- rewrite N.size_log2 by lia.
- assert (N.log2 (d - 1) < bits) by (apply N.log2_lt_pow2; lia).
- lia.
-Qed.
+Proof. Z.div_mod_to_quot_rem; lia. Qed.
-Lemma ceil_log2_lt (d : N) (H : 1 < d) :
- 2^(N.size (d - 1)) < d + d.
+Lemma p_bound (d bits : Z) (Hbits : 0 < bits) (H : 1 < d < 2^bits) :
+ 0 < Z.log2_up d <= bits.
Proof.
- pose proof (N.size_le (d - 1)).
- rewrite N.succ_double_spec in *.
- lia.
+ split.
+ - apply Z.log2_up_pos; lia.
+ - apply Z.log2_up_le_pow2; lia.
Qed.
-Lemma ceil_log2_ge (d : N) (H : 1 < d) :
- 2^(N.size (d - 1)) >= d.
+Lemma pow2_log2_up_bounds (d : Z) (H : 1 < d) :
+ d <= 2^(Z.log2_up d) < d + d.
Proof.
- pose proof (N.size_gt (d - 1)).
+ pose proof (Z.log2_up_spec d H).
+ assert (2 ^ Z.log2_up d = 2 * (2 ^ Z.pred (Z.log2_up d))).
+ { rewrite <- Z.pow_succ_r by (apply Z.lt_le_pred; apply Z.log2_up_pos; lia).
+ f_equal.
+ lia. }
lia.
Qed.
-Lemma m_bound (d bits : N) (H : 1 < d < 2^bits) :
- 2^bits <= (2^(bits + N.size (d - 1)) + d - 1) / d < 2 * 2^bits.
+Lemma m_bound (d bits : Z) (Hbits : 0 < bits) (H : 1 < d < 2^bits) :
+ 2^bits <= (2^(bits + Z.log2_up d) + d - 1) / d < 2 * 2^bits.
Proof.
- rewrite !N.pow_add_r.
+ rewrite !Z.pow_add_r by (auto with zarith).
+ pose proof (pow2_log2_up_bounds d (proj1 H)).
split.
- { (* 2^e <= m *)
- apply N.div_le_lower_bound; try lia.
- assert (2^(N.size (d - 1)) >= d).
- { apply ceil_log2_ge; lia. }
- nia. }
- { (* m < 2 * 2^e *)
- apply N.div_lt_upper_bound; try lia.
- assert (2^(N.size (d - 1)) < d + d).
- { apply ceil_log2_lt; lia. }
- (* Subtraction is saturating, so nia needs a bit of help to group
- [d - 1] together. *)
- rewrite <- N.add_sub_assoc by lia.
- nia. }
-Qed.
-
-Lemma mod_sub_once (a b : N) (H : b <= a < 2*b) :
- a mod b = a - b.
-Proof.
- assert (1 = a / b).
- { apply (N.div_unique a b 1 (a - b)); lia. }
- rewrite N.mod_eq; nia.
+ - apply Z.div_le_lower_bound; nia.
+ - apply Z.div_lt_upper_bound; nia.
Qed.
-Lemma division_shrinks (a b c : N) (Hnz : b <> 0) (H : a < c) :
+Lemma division_shrinks (a b c : Z) (Hb : 0 < b) (H : 0 <= a < c) :
a / b < c.
-Proof.
- apply N.div_lt_upper_bound; try lia.
- nia.
-Qed.
+Proof. Z.div_mod_to_quot_rem; nia. Qed.
-Lemma mul_div_ceil_ge (a b : N) (H : b <> 0) :
+Lemma mul_div_ceil_ge (a b : Z) (Hb : 0 < b) :
a <= b * ((a + b - 1) / b).
-Proof.
- rewrite (N.div_mod a b) by lia.
- set (q := a / b).
- set (r := a mod b).
- assert (r < b) by (apply N.mod_upper_bound; lia).
- destruct (r =? 0) eqn:Hr.
- { (* r = 0 *)
- apply N.eqb_eq in Hr.
- replace r with 0.
- rewrite !N.add_0_r.
- replace ((b * q + b - 1) / b) with q by
- (rewrite <- N.add_sub_assoc by lia;
- apply (N.div_unique _ b q (b - 1));
- lia).
- lia. }
- { (* r != 0 *)
- apply N.eqb_neq in Hr.
- replace ((b * q + r + b - 1) / b) with (q + 1) by
- (replace (b * q + r + b - 1) with (b * (q + 1) + (r - 1)) by lia;
- apply (N.div_unique _ b (q + 1) (r - 1));
- lia).
- lia. }
-Qed.
+Proof. Z.div_mod_to_quot_rem; nia. Qed.
-(* This is the main theorem this algorithm relies on. *)
-Lemma division_algorithm (bits n d k : N)
- (Hn : n < 2^bits)
+(** This is the main theorem this algorithm relies on. *)
+Lemma division_algorithm (bits n d k : Z)
+ (Hbits : 0 < bits)
+ (Hn : 0 <= n < 2^bits)
(Hd : 1 < d)
- (Hk : bits + N.size (d - 1) <= k) :
+ (Hk : bits + Z.log2_up d <= k) :
n * ((2^k + d - 1) / d) / 2^k = n / d.
Proof.
+ (* Help [lia] and [nia] out. *)
+ pose proof (Z.log2_up_nonneg d).
+ pose proof (pow2_log2_up_bounds d Hd).
+ assert (0 < 2^k) by (apply Z.pow_pos_nonneg; lia).
+
assert (n/d <= n * ((2^k + d - 1) / d) / 2^k).
{ (* This direction is true independent of [Hk]. *)
- apply N.div_le_lower_bound; try apply N.pow_nonzero; try lia.
- rewrite N.div_mul_le by lia.
- apply N.div_le_upper_bound; try apply N.pow_nonzero; try lia.
- rewrite (N.mul_comm n _).
- rewrite N.mul_assoc.
assert (2^k <= d * ((2 ^ k + d - 1) / d)) by (apply mul_div_ceil_ge; lia).
+ apply Z.div_le_lower_bound; try lia.
+ rewrite Z.div_mul_le by (try auto with zarith; lia).
+ apply Z.div_le_upper_bound; try lia.
nia. }
assert (n * ((2^k + d - 1) / d) / 2^k <= n / d).
- { assert (div_mul_le_inner : n * ((2 ^ k + d - 1) / d) / 2^k <=
- (n * (2 ^ k + d - 1) / d) / 2 ^ k).
- { apply N.div_le_mono; try (apply N.pow_nonzero; lia).
- apply N.div_mul_le.
- lia. }
+ { assert (div_mul_le_inner : n * ((2^k + d - 1) / d) / 2^k <=
+ (n * (2^k + d - 1) / d) / 2^k).
+ { apply Z.div_le_mono; try apply Z.div_mul_le; lia. }
rewrite div_mul_le_inner.
replace (n * (2^k + d - 1)) with (n * 2^k + n * (d - 1)) by nia.
(* Swap [d] and [2^k] in the denominator. *)
- rewrite N.div_div by (try apply N.pow_nonzero; lia).
- rewrite (N.mul_comm d (2^k)).
- rewrite <- N.div_div by (try apply N.pow_nonzero; lia).
+ rewrite Z.div_div by lia.
+ rewrite (Z.mul_comm d (2^k)).
+ rewrite <- Z.div_div by lia.
(* We now can extract the [n * 2^k] term from the inner division. *)
- rewrite N.div_add_l by (try apply N.pow_nonzero; lia).
+ rewrite Z.div_add_l by lia.
(* We show the error term is zero, which allows the inequality to hold. *)
assert (zero_error : n * (d - 1) / 2 ^ k = 0).
- { apply N.div_small.
- assert (2^bits * 2^(N.size (d - 1)) <= 2^k) by
- (rewrite <- N.pow_add_r; apply N.pow_le_mono_r; lia).
- assert (d <= 2^(N.size (d - 1))) by (apply N.ge_le; apply ceil_log2_ge; lia).
+ { apply Z.div_small.
+ assert (2^bits * 2^(Z.log2_up d) <= 2^k) by
+ (rewrite <- Z.pow_add_r; try apply Z.pow_le_mono_r; lia).
nia. }
rewrite zero_error.
- rewrite N.add_0_r.
- lia. }
+ auto with zarith. }
lia.
Qed.
-(* Peel off one innermost [wrap], and replace it with the expected
- bounds. Before doing so, introduce a lemma into the context that
- the value is bounded. This is allows later goals to make use of
- earlier bounds. *)
+(** Peel off one innermost [wrap], and replace it with the expected
+ bounds. Before doing so, introduce a lemma into the context that the value
+ is bounded. This is allows later goals to make use of earlier bounds.
+
+ TODO(davidben): Most of the slowness comes from these intermediate
+ hypotheses. The proofs need them, but only those one or two layers down. Can
+ we be smarter about retaining them? *)
Local Ltac replace_innermost_wrap expr :=
match expr with
| context[wrap ?x ?y] =>
first [ replace_innermost_wrap y |
- assert ((wrap x y) < 2^x) by (apply wrap_bound);
+ assert (0 <= (wrap x y) < 2^x) by (apply wrap_bound; lia);
replace (wrap x y) with y in * ]
- | context[wrap' ?x ?y] =>
- first [ replace_innermost_wrap y |
- assert ((wrap' x y) < 2^x) by (unfold wrap'; apply wrap_bound);
- replace (wrap' x y) with (y - 2^x) in * ]
- | context[sub_wrapped ?x ?y ?z] =>
- first [ replace_innermost_wrap y |
- replace_innermost_wrap z |
- assert ((sub_wrapped x y z) < 2^x) by (unfold sub_wrapped, wrap';
- apply wrap_bound);
- replace (sub_wrapped x y z) with (y - z) in * ]
end.
-Theorem mod_u16_spec (n d : N) (Hn : n < 2^32) (Hd : 1 < d < 2^16) :
- to_N (mod_u16 (N_to_u32 n) (N_to_u16 d)) = n mod d.
+(** Finally, the correctness proof. *)
+Theorem mod_u16_spec (n d : Z) (Hn : 0 <= n < 2^32) (Hd : 1 < d < 2^16) :
+ to_Z (mod_u16 (Z_to_u32 n) (Z_to_u16 d)) = n mod d.
Proof.
unfold mod_u16.
(* Evaluate away everything except for exponentations, which are
- still useful, and our [wrap], etc., metadata. *)
- cbn -[N.pow wrap wrap' sub_wrapped].
+ still useful, and our [wrap] metadata. *)
+ cbn -[Z.pow num_bits wrap].
(* The solver needs some help getting these in the context. *)
- assert (1 <= N.size (d - 1) <= 16) by (apply p_bound; lia).
- assert (2^(N.size (d - 1)) < d + d) by (apply ceil_log2_lt; lia).
- assert (2^(N.size (d - 1)) < 2 ^ 17) by (unfold_exponents; lia).
- assert (2^32 <= (2^(32 + N.size (d - 1)) + d - 1) / d < 2 * 2^32) by
- (apply m_bound; unfold_exponents; lia).
+ assert (Hlog : 0 < Z.log2_up d <= 16) by (apply p_bound; lia).
+ assert (d <= 2^(Z.log2_up d) < d + d) by (apply pow2_log2_up_bounds; lia).
+ assert (2^(Z.log2_up d) < 2 ^ 17) by lia.
+ assert (2^32 <= (2^32 * 2^(Z.log2_up d) + d - 1) / d < 2 * 2^32) by
+ (rewrite <- Z.pow_add_r by lia; apply m_bound; lia).
(* Convert shifts to multiplication and divison. *)
- rewrite !N.shiftl_mul_pow2.
- rewrite !N.shiftr_div_pow2.
- rewrite !N.mul_1_l.
- rewrite !N.pow_1_r.
+ rewrite !Z.shiftl_mul_pow2 by (apply wrap_bound; lia).
+ rewrite !Z.shiftr_div_pow2 by (apply wrap_bound; lia).
+ rewrite !Z.mul_1_l.
+ rewrite !Z.pow_1_r.
(* Replace all the [to_word] calls with their expected bounds. *)
repeat match goal with
| [ |- ?x = n mod d ] =>
match goal with
- (* Evaluate away the [N.min]s. *)
- | _ => progress cbn -[N.pow N.add N.mul wrap wrap']
- (* Remove obviously pointless casts, before the solver expands too
- much. *)
- | [ |- context[wrap ?x (wrap' ?x ?y)] ] =>
- replace (wrap x (wrap' x y)) with (wrap' x y) by
- (unfold wrap'; symmetry; apply remove_outer_wrap; lia)
- | [ |- context[wrap ?x (sub_wrapped ?x ?y ?z)] ] =>
- replace (wrap x (sub_wrapped x y z)) with (sub_wrapped x y z) by
- (unfold sub_wrapped, wrap';
- symmetry;
- apply remove_outer_wrap;
- lia)
+ (* Evaluate away the [Z.min]s. *)
+ | _ => progress cbn -[Z.pow Z.add Z.mul wrap]
+ (* Remove obviously pointless casts. The solver could handle them,
+ but this cuts down on the number of generated hypotheses. *)
| [ |- context[wrap ?x (wrap ?x' ?y)] ] =>
- replace (wrap x (wrap x' y)) with (wrap (N.min x x') y) by
+ replace (wrap x (wrap x' y)) with (wrap (Z.min x x') y) by
(symmetry; apply double_wrap; lia)
- (* Perform [a_minus_b_div_2_plus_b] earlier, so the
- generated hypotheses match too. *)
+ | [ |- context[wrap ?x n] ] =>
+ replace (wrap x n) with n by (unfold wrap; rewrite Z.mod_small; lia)
+ | [ |- context[wrap ?x d] ] =>
+ replace (wrap x d) with d by (unfold wrap; rewrite Z.mod_small; lia)
+ | [ |- context[wrap ?x (Zpos ?p)] ] =>
+ replace (wrap x (Zpos p)) with (Zpos p) by reflexivity
+ (* Perform rewrites as soon as possible, so the generated hypotheses
+ match too. *)
| _ => rewrite a_minus_b_div_2_plus_b
+ | _ => rewrite num_bits_log2_up
+ | _ => rewrite Z.pow_add_r by lia
| _ => replace_innermost_wrap x
end
- | _ => rewrite sub_wrapped_noop
end;
- repeat unfold sub_wrapped, wrap', wrap in *;
+ repeat unfold wrap in *;
repeat match goal with
(* Normalize a few things. *)
- | _ => apply N.le_ge
- | _ => apply N.lt_gt
+ | _ => apply Z.le_ge
+ | _ => apply Z.lt_gt
| [ |- _ = _ mod _ ] => symmetry
(* Apply these transforms everywhere. *)
- | _ => apply N.mod_small
+ | _ => apply Z.mod_small
| _ => split
+ (* Everything is positive, so these tend to be easy. *)
+ | [ |- 0 <= _ * _ ] => nia
(* Don't touch the goals that involve the final division
expression. We'll dispatch them separately. *)
| [ |- context[d * (_ / 2^_)] ] => fail 1
- (* The solver needs some help for this one. *)
- | [ _ : ?x < ?y |- ?x / _ < ?y ] => apply division_shrinks
- (* Smash the bounds subgoals away. *)
- | _ => apply mod_sub_once
- | _ => rewrite N.pow_add_r in *
- | _ => apply N.div_le_upper_bound
- | _ => apply N.div_lt_upper_bound
- | _ => apply N.pow_nonzero
- | _ => apply N.mul_le_mono_r
- | _ => progress unfold_exponents
- (* Erase irrelevant hypotheses before we start running the
- slow solvers. *)
- | [ H : context[n] |- _ ] => match goal with
- | [ |- context[n] ] => fail 1
- | _ => clear H
- end
- | [ H : context[d] |- _ ] => match goal with
- | [ |- context[d] ] => fail 1
- | _ => clear H
- end
+ (* The solver needs some help for this one, otherwise the [Z.div_*]
+ theorems below break apart the numerator term. *)
+ | [ _ : 0 <= ?x < ?y |- ?x / _ < ?y ] => apply division_shrinks
+ (* Simplify the various inequalities. *)
+ | _ => apply Z.div_le_upper_bound
+ | _ => apply Z.div_lt_upper_bound
+ | _ => apply Z.div_le_lower_bound
+ | _ => apply Z.pow_nonneg
+ | _ => apply Z.pow_pos_nonneg
+ (* Only apply in one direction. *)
+ | [ |- 0 <= _ - _ ] => apply Z.le_0_sub
+ | _ => apply Z.mul_le_mono_nonneg_r
(* [nia] is rather slow, and only one case needs it. *)
- | [ _ : ?x < _, _ : ?y < _ |- ?x * ?y < ?z ] => nia
+ | [ _ : 0 <= ?x < _, _ : 0 <= ?y < _ |- ?x * ?y < ?z ] => nia
| _ => lia
end;
+ (* Clear the now unnecessary intermediate hypotheses. *)
+ clear - Hn Hd Hlog;
+
(* Invert the various transforms to get to the main theorem. *)
repeat match goal with
- | [ |- _ <= N.size _ ] => rewrite N.size_log2; lia
- | [ |- context[?x * ?y + ?z * ?x]] => rewrite (N.mul_comm z x)
- | [ |- context[?x + ?y - ?x]] => rewrite (N.add_comm x y)
- | _ => rewrite N.div_div
- | _ => apply N.pow_nonzero
- | _ => rewrite <- N.pow_succ_r'
- | _ => rewrite <- N.add_1_r
- | _ => rewrite <- N.pow_add_r
- | _ => rewrite <- N.mul_add_distr_l
- | _ => rewrite N.sub_add
- | _ => rewrite N.add_sub
- | _ => rewrite N.add_sub_assoc
- | _ => rewrite <- N.div_add_l
- | _ => lia
+ | _ => rewrite Z.div_div by (try apply Z.pow_pos_nonneg; lia)
+ | _ => rewrite <- Z.pow_succ_r by lia
+ | _ => rewrite <- Z.add_1_r
+ | _ => rewrite Z.sub_add
+ | _ => rewrite Zplus_minus
+ | _ => rewrite <- Z.pow_add_r by lia
+ | _ => rewrite <- Z.mul_add_distr_l
+ | _ => rewrite <- Z.div_add_l by lia
+ | [ |- context[?x * ?y + ?z * ?x] ] => rewrite (Z.mul_comm z x)
end;
(* And finish with the main theorem. *)
rewrite (division_algorithm 32) by lia;
- try rewrite <- N.mod_eq;
- try apply N.mul_div_le;
- (* [lia] needs some help on these two. *)
- try match goal with
- | [ |- ?x mod ?y < 2^_ ] =>
- assert (x mod y < y) by (apply N.mod_upper_bound; lia)
- | [ |- ?x * (?y / ?x) < 2^_ ] =>
- assert (x * (y / x) <= y) by (apply N.mul_div_le; lia)
- end;
- lia.
+ Z.div_mod_to_quot_rem;
+ nia.
Qed.