aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/Pow2BaseProofs.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/ModularArithmetic/Pow2BaseProofs.v')
-rw-r--r--src/ModularArithmetic/Pow2BaseProofs.v135
1 files changed, 65 insertions, 70 deletions
diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v
index a9b9fbdfc..da9bbac0d 100644
--- a/src/ModularArithmetic/Pow2BaseProofs.v
+++ b/src/ModularArithmetic/Pow2BaseProofs.v
@@ -35,7 +35,7 @@ Section Pow2BaseProofs.
Lemma two_sum_firstn_limb_widths_nonzero n : 2^sum_firstn limb_widths n <> 0.
Proof. pose proof (two_sum_firstn_limb_widths_pos n); omega. Qed.
- Lemma base_from_limb_widths_step : forall i b w, (S i < length base)%nat ->
+ Lemma base_from_limb_widths_step : forall i b w, (S i < length limb_widths)%nat ->
nth_error base i = Some b ->
nth_error limb_widths i = Some w ->
nth_error base (S i) = Some (two_p w * b).
@@ -43,7 +43,7 @@ Section Pow2BaseProofs.
induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b;
unfold base_from_limb_widths in *; fold base_from_limb_widths in *;
[rewrite (@nil_length0 Z) in *; omega | ].
- simpl in *; rewrite map_length in *.
+ simpl in *.
case_eq i; intros; subst.
+ subst; apply nth_error_first in nth_err_w.
apply nth_error_first in nth_err_b; subst.
@@ -61,15 +61,14 @@ Section Pow2BaseProofs.
Qed.
- Lemma nth_error_base : forall i, (i < length base)%nat ->
+ Lemma nth_error_base : forall i, (i < length limb_widths)%nat ->
nth_error base i = Some (two_p (sum_firstn limb_widths i)).
Proof.
induction i; intros.
+ unfold sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity.
intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega.
- + assert (i < length base)%nat as lt_i_length by omega.
+ + assert (i < length limb_widths)%nat as lt_i_length by omega.
specialize (IHi lt_i_length).
- rewrite base_from_limb_widths_length in lt_i_length.
destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w].
erewrite base_from_limb_widths_step; eauto.
f_equal.
@@ -87,19 +86,16 @@ Section Pow2BaseProofs.
eapply nth_error_value_In; eauto.
Qed.
- Lemma nth_default_base : forall d i, (i < length base)%nat ->
+ Lemma nth_default_base : forall d i, (i < length limb_widths)%nat ->
nth_default d base i = 2 ^ (sum_firstn limb_widths i).
Proof.
intros ? ? i_lt_length.
- destruct (nth_error_length_exists_value _ _ i_lt_length) as [x nth_err_x].
- unfold nth_default.
- rewrite nth_err_x.
- rewrite nth_error_base in nth_err_x by assumption.
- rewrite two_p_correct in nth_err_x.
- congruence.
+ apply nth_error_value_eq_nth_default.
+ rewrite nth_error_base, two_p_correct by assumption.
+ reflexivity.
Qed.
- Lemma base_succ : forall i, ((S i) < length base)%nat ->
+ Lemma base_succ : forall i, ((S i) < length limb_widths)%nat ->
nth_default 0 base (S i) mod nth_default 0 base i = 0.
Proof.
intros.
@@ -112,8 +108,7 @@ Section Pow2BaseProofs.
apply limb_widths_nonneg.
rewrite lw_eq.
apply in_eq.
- + assert (i < length base)%nat as i_lt_length by omega.
- rewrite base_from_limb_widths_length in *.
+ + assert (i < length limb_widths)%nat as i_lt_length by omega.
apply nth_error_length_exists_value in i_lt_length.
destruct i_lt_length as [x nth_err_x].
erewrite sum_firstn_succ; eauto.
@@ -127,6 +122,7 @@ Section Pow2BaseProofs.
Proof.
intros i b nth_err_b.
pose proof (nth_error_value_length _ _ _ _ nth_err_b).
+ rewrite base_from_limb_widths_length in *.
rewrite nth_error_base in nth_err_b by assumption.
rewrite two_p_correct in nth_err_b.
congruence.
@@ -169,19 +165,19 @@ Section Pow2BaseProofs.
Section make_base_vector.
Local Notation k := (sum_firstn limb_widths (length limb_widths)).
Context (limb_widths_match_modulus : forall i j,
- (i < length limb_widths)%nat ->
- (j < length limb_widths)%nat ->
- (i + j >= length limb_widths)%nat ->
+ (i < length base)%nat ->
+ (j < length base)%nat ->
+ (i + j >= length base)%nat ->
let w_sum := sum_firstn limb_widths in
- k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j)
+ k + w_sum (i + j - length base)%nat <= w_sum i + w_sum j)
(limb_widths_good : forall i j, (i + j < length limb_widths)%nat ->
sum_firstn limb_widths (i + j) <=
sum_firstn limb_widths i + sum_firstn limb_widths j).
Lemma base_matches_modulus: forall i j,
- (i < length limb_widths)%nat ->
- (j < length limb_widths)%nat ->
- (i+j >= length limb_widths)%nat->
+ (i < length base)%nat ->
+ (j < length base)%nat ->
+ (i+j >= length base)%nat->
let b := nth_default 0 base in
let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in
b i * b j = r * (2^k * b (i+j-length base)%nat).
@@ -189,20 +185,20 @@ Section Pow2BaseProofs.
intros.
rewrite (Z.mul_comm r).
subst r.
+ rewrite base_from_limb_widths_length in *;
assert (i + j - length limb_widths < length limb_widths)%nat by omega.
- rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; apply Z.mul_pos_pos;
- subst b; rewrite ?nth_default_base; zero_bounds; rewrite ?base_from_limb_widths_length;
- auto using sum_firstn_limb_widths_nonneg, limb_widths_nonneg).
+ rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; subst b; rewrite ?nth_default_base; zero_bounds;
+ assumption).
rewrite (Zminus_0_l_reverse (b i * b j)) at 1.
f_equal.
subst b.
- repeat rewrite nth_default_base by (rewrite ?base_from_limb_widths_length; auto).
+ repeat rewrite nth_default_base by auto.
do 2 rewrite <- Z.pow_add_r by auto using sum_firstn_limb_widths_nonneg.
symmetry.
apply Z.mod_same_pow.
split.
+ apply Z.add_nonneg_nonneg; auto using sum_firstn_limb_widths_nonneg.
- + rewrite base_from_limb_widths_length; auto using limb_widths_nonneg, limb_widths_match_modulus.
+ + auto using limb_widths_match_modulus.
Qed.
Lemma base_good : forall i j : nat,
@@ -212,7 +208,9 @@ Section Pow2BaseProofs.
b i * b j = r * b (i + j)%nat.
Proof.
intros; subst b r.
- repeat rewrite nth_default_base by (omega || auto).
+ clear limb_widths_match_modulus.
+ rewrite base_from_limb_widths_length in *.
+ repeat rewrite nth_default_base by omega.
rewrite (Z.mul_comm _ (2 ^ (sum_firstn limb_widths (i+j)))).
rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; zero_bounds;
auto using sum_firstn_limb_widths_nonneg).
@@ -220,10 +218,11 @@ Section Pow2BaseProofs.
rewrite Z.mod_same_pow; try ring.
split; [ auto using sum_firstn_limb_widths_nonneg | ].
apply limb_widths_good.
- rewrite <-base_from_limb_widths_length; auto using limb_widths_nonneg.
+ assumption.
Qed.
End make_base_vector.
End Pow2BaseProofs.
+Hint Rewrite @base_from_limb_widths_length : distr_length.
Section BitwiseDecodeEncode.
Context {limb_widths} (bv : BaseSystem.BaseVector (base_from_limb_widths limb_widths))
@@ -233,7 +232,7 @@ Section BitwiseDecodeEncode.
Local Notation base := (base_from_limb_widths limb_widths).
Local Notation upper_bound := (upper_bound limb_widths).
- Lemma encode'_spec : forall x i, (i <= length base)%nat ->
+ Lemma encode'_spec : forall x i, (i <= length limb_widths)%nat ->
encode' limb_widths x i = BaseSystem.encode' base x upper_bound i.
Proof.
induction i; intros.
@@ -241,13 +240,12 @@ Section BitwiseDecodeEncode.
+ rewrite encode'_succ, <-IHi by omega.
simpl; do 2 f_equal.
rewrite Z.land_ones, Z.shiftr_div_pow2 by auto using sum_firstn_limb_widths_nonneg.
- match goal with H : (S _ <= length base)%nat |- _ =>
+ match goal with H : (S _ <= length limb_widths)%nat |- _ =>
apply le_lt_or_eq in H; destruct H end.
- repeat f_equal; rewrite nth_default_base by (omega || auto); reflexivity.
- repeat f_equal; try solve [rewrite nth_default_base by (omega || auto); reflexivity].
- rewrite nth_default_out_of_bounds by omega.
+ rewrite nth_default_out_of_bounds by (distr_length; omega).
unfold Pow2Base.upper_bound.
- rewrite <-base_from_limb_widths_length by auto.
congruence.
Qed.
@@ -259,20 +257,17 @@ Section BitwiseDecodeEncode.
Lemma base_upper_bound_compatible : @base_max_succ_divide base upper_bound.
Proof.
unfold base_max_succ_divide; intros i lt_Si_length.
+ rewrite base_from_limb_widths_length in lt_Si_length.
rewrite Nat.lt_eq_cases in lt_Si_length; destruct lt_Si_length;
rewrite !nth_default_base by (omega || auto).
- + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0);
- rewrite <-base_from_limb_widths_length by auto; omega).
+ + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); omega).
rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg.
apply Z.divide_factor_r.
- + rewrite nth_default_out_of_bounds by omega.
+ + rewrite nth_default_out_of_bounds by (distr_length; omega).
unfold Pow2Base.upper_bound.
- replace (length limb_widths) with (S (pred (length limb_widths))) by
- (rewrite base_from_limb_widths_length in H by auto; omega).
- replace i with (pred (length limb_widths)) by
- (rewrite base_from_limb_widths_length in H by auto; omega).
- erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0);
- rewrite <-base_from_limb_widths_length by auto; omega).
+ replace (length limb_widths) with (S (pred (length limb_widths))) by omega.
+ replace i with (pred (length limb_widths)) by omega.
+ erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); omega).
rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg.
apply Z.divide_factor_r.
Qed.
@@ -282,7 +277,7 @@ Section BitwiseDecodeEncode.
BaseSystem.decode base (encodeZ limb_widths x) = x mod upper_bound.
Proof.
intros.
- assert (length base = length limb_widths) by auto using base_from_limb_widths_length.
+ assert (length base = length limb_widths) by distr_length.
unfold encodeZ; rewrite encode'_spec by omega.
rewrite BaseSystemProofs.encode'_spec; unfold Pow2Base.upper_bound; try zero_bounds;
auto using sum_firstn_limb_widths_nonneg.
@@ -522,7 +517,7 @@ Section carrying_helper.
Local Notation base := (base_from_limb_widths limb_widths).
Local Notation log_cap i := (nth_default 0 limb_widths i).
- Lemma update_nth_sum : forall n f us, (n < length us \/ n >= length base)%nat ->
+ Lemma update_nth_sum : forall n f us, (n < length us \/ n >= length limb_widths)%nat ->
BaseSystem.decode base (update_nth n f us) =
(let v := nth_default 0 us n in f v - v) * nth_default 0 base n + BaseSystem.decode base us.
Proof.
@@ -541,17 +536,17 @@ Section carrying_helper.
erewrite (nth_error_value_eq_nth_default _ _ us) by eassumption.
rewrite firstn_length in Heqn0.
rewrite Min.min_l in Heqn0 by omega; subst n0.
- destruct (le_lt_dec (length base) n). {
- rewrite (@nth_default_out_of_bounds _ _ base) by auto.
- rewrite skipn_all by omega.
+ destruct (le_lt_dec (length limb_widths) n). {
+ rewrite (@nth_default_out_of_bounds _ _ base) by (distr_length; auto).
+ rewrite skipn_all by (rewrite base_from_limb_widths_length; omega).
do 2 rewrite decode_base_nil.
ring_simplify; auto.
} {
- rewrite (skipn_nth_default n base 0) by omega.
+ rewrite (skipn_nth_default n base 0) by (distr_length; omega).
do 2 rewrite decode'_cons.
ring_simplify; ring.
} }
- { rewrite (nth_default_out_of_bounds _ base) by omega; ring_simplify.
+ { rewrite (nth_default_out_of_bounds _ base) by (distr_length; omega); ring_simplify.
etransitivity; rewrite BaseSystem.decode'_truncate; [ reflexivity | ].
apply f_equal.
autorewrite with push_firstn simpl_update_nth.
@@ -640,12 +635,12 @@ Section carrying_helper.
Hint Rewrite @length_add_to_nth : distr_length.
- Lemma set_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat ->
+ Lemma set_nth_sum : forall n x us, (n < length us \/ n >= length limb_widths)%nat ->
BaseSystem.decode base (set_nth n x us) =
(x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us.
Proof. intros; unfold set_nth; rewrite update_nth_sum by assumption; reflexivity. Qed.
- Lemma add_to_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat ->
+ Lemma add_to_nth_sum : forall n x us, (n < length us \/ n >= length limb_widths)%nat ->
BaseSystem.decode base (add_to_nth n x us) =
x * nth_default 0 base n + BaseSystem.decode base us.
Proof. intros; rewrite add_to_nth_set_nth, set_nth_sum; try ring_simplify; auto. Qed.
@@ -697,7 +692,7 @@ Section carrying.
Proof. intros; unfold carry_simple; distr_length; reflexivity. Qed.
Hint Rewrite @length_carry_simple : distr_length.
- Lemma nth_default_base_succ : forall i, (S i < length base)%nat ->
+ Lemma nth_default_base_succ : forall i, (S i < length limb_widths)%nat ->
nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i.
Proof.
intros.
@@ -706,13 +701,13 @@ Section carrying.
Qed.
Lemma carry_gen_decode_eq : forall fc fi i' us
- (i := fi (length base) i')
- (Si := fi (length base) (S i)),
- (length us = length base) ->
+ (i := fi i')
+ (Si := fi (S i)),
+ (length us = length limb_widths) ->
BaseSystem.decode base (carry_gen limb_widths fc fi i' us)
= (fc (nth_default 0 us i / 2 ^ log_cap i) *
(if eq_nat_dec Si (S i)
- then if lt_dec (S i) (length base)
+ then if lt_dec (S i) (length limb_widths)
then 2 ^ log_cap i * nth_default 0 base i
else 0
else nth_default 0 base Si)
@@ -720,29 +715,29 @@ Section carrying.
+ BaseSystem.decode base us.
Proof.
intros fc fi i' us i Si H; intros.
- destruct (eq_nat_dec 0 (length base));
+ destruct (eq_nat_dec 0 (length limb_widths));
[ destruct limb_widths, us, i; simpl in *; try congruence;
break_match;
unfold carry_gen, carry_single, add_to_nth;
autorewrite with zsimplify simpl_nth_default simpl_set_nth simpl_update_nth distr_length;
reflexivity
| ].
- (*assert (0 <= i < length base)%nat by (subst i; auto with arith).*)
+ (*assert (0 <= i < length limb_widths)%nat by (subst i; auto with arith).*)
assert (0 <= log_cap i) by auto using log_cap_nonneg.
assert (2 ^ log_cap i <> 0) by (apply Z.pow_nonzero; lia).
unfold carry_gen, carry_single.
- rewrite H; change (i' mod length base)%nat with i.
+ change (i' mod length limb_widths)%nat with i.
rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
rewrite set_nth_sum by omega.
unfold Z.pow2_mod.
rewrite Z.land_ones by auto using log_cap_nonneg.
rewrite Z.shiftr_div_pow2 by auto using log_cap_nonneg.
- change (fi (length base) i') with i.
+ change (fi i') with i.
subst Si.
repeat first [ ring
| match goal with H : _ = _ |- _ => rewrite !H in * end
| rewrite nth_default_base_succ by omega
- | rewrite !(nth_default_out_of_bounds _ base) by omega
+ | rewrite !(nth_default_out_of_bounds _ base) by (distr_length; omega)
| rewrite !(nth_default_out_of_bounds _ us) by omega
| rewrite Z.mod_eq by assumption
| progress distr_length
@@ -751,8 +746,8 @@ Section carrying.
Qed.
Lemma carry_simple_decode_eq : forall i us,
- (length us = length base) ->
- (i < (pred (length base)))%nat ->
+ (length us = length limb_widths) ->
+ (i < (pred (length limb_widths)))%nat ->
BaseSystem.decode base (carry_simple limb_widths i us) = BaseSystem.decode base us.
Proof.
unfold carry_simple; intros; rewrite carry_gen_decode_eq by assumption.
@@ -791,11 +786,11 @@ Section carrying.
Lemma nth_default_carry_gen_full fc fi d i n us
: nth_default d (carry_gen limb_widths fc fi i us) n
= if lt_dec n (length us)
- then (if eq_nat_dec n (fi (length us) i)
+ then (if eq_nat_dec n (fi i)
then Z.pow2_mod (nth_default 0 us n) (log_cap n)
else nth_default 0 us n) +
- if eq_nat_dec n (fi (length us) (S (fi (length us) i)))
- then fc (nth_default 0 us (fi (length us) i) >> log_cap (fi (length us) i))
+ if eq_nat_dec n (fi (S (fi i)))
+ then fc (nth_default 0 us (fi i) >> log_cap (fi i))
else 0
else d.
Proof.
@@ -827,11 +822,11 @@ Section carrying.
: forall fc fi i us,
(0 <= i < length us)%nat
-> nth_default 0 (carry_gen limb_widths fc fi i us) i
- = (if eq_nat_dec i (fi (length us) i)
+ = (if eq_nat_dec i (fi i)
then Z.pow2_mod (nth_default 0 us i) (log_cap i)
else nth_default 0 us i) +
- if eq_nat_dec i (fi (length us) (S (fi (length us) i)))
- then fc (nth_default 0 us (fi (length us) i) >> log_cap (fi (length us) i))
+ if eq_nat_dec i (fi (S (fi i)))
+ then fc (nth_default 0 us (fi i) >> log_cap (fi i))
else 0.
Proof.
intros; autorewrite with push_nth_default natsimplify; break_match; omega.
@@ -849,7 +844,7 @@ Section carrying.
Hint Rewrite @nth_default_carry_simple using (omega || distr_length; omega) : push_nth_default.
End carrying.
-Hint Rewrite @length_carry_gen @base_from_limb_widths_length : distr_length.
+Hint Rewrite @length_carry_gen : distr_length.
Hint Rewrite @length_carry_simple @length_carry_simple_sequence @length_make_chain @length_full_carry_chain @length_carry_simple_full : distr_length.
Hint Rewrite @nth_default_carry_simple_full @nth_default_carry_gen_full : push_nth_default.
Hint Rewrite @nth_default_carry_simple @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default.