aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/Pow2BaseProofs.v
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2016-07-25 21:06:07 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2016-07-25 21:06:07 -0400
commit39a6c95de8a900c859726d875cc40ea96298d31b (patch)
tree750571dc101f477c34340716db87a3697cca41eb /src/ModularArithmetic/Pow2BaseProofs.v
parentea9397e3da37f35d088488be141cb18cc38ea11b (diff)
Put ModularBaseSystem carries in terms of [carry_gen], and pushed this change through the pipeline. Also began the process of redoing canonicalization proofs, attempting to put the messy case analysis in theorem statements rather than separate lemmas.
Diffstat (limited to 'src/ModularArithmetic/Pow2BaseProofs.v')
-rw-r--r--src/ModularArithmetic/Pow2BaseProofs.v135
1 files changed, 65 insertions, 70 deletions
diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v
index 9255f033f..4b616c288 100644
--- a/src/ModularArithmetic/Pow2BaseProofs.v
+++ b/src/ModularArithmetic/Pow2BaseProofs.v
@@ -34,7 +34,7 @@ Section Pow2BaseProofs.
Lemma two_sum_firstn_limb_widths_nonzero n : 2^sum_firstn limb_widths n <> 0.
Proof. pose proof (two_sum_firstn_limb_widths_pos n); omega. Qed.
- Lemma base_from_limb_widths_step : forall i b w, (S i < length base)%nat ->
+ Lemma base_from_limb_widths_step : forall i b w, (S i < length limb_widths)%nat ->
nth_error base i = Some b ->
nth_error limb_widths i = Some w ->
nth_error base (S i) = Some (two_p w * b).
@@ -42,7 +42,7 @@ Section Pow2BaseProofs.
induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b;
unfold base_from_limb_widths in *; fold base_from_limb_widths in *;
[rewrite (@nil_length0 Z) in *; omega | ].
- simpl in *; rewrite map_length in *.
+ simpl in *.
case_eq i; intros; subst.
+ subst; apply nth_error_first in nth_err_w.
apply nth_error_first in nth_err_b; subst.
@@ -60,15 +60,14 @@ Section Pow2BaseProofs.
Qed.
- Lemma nth_error_base : forall i, (i < length base)%nat ->
+ Lemma nth_error_base : forall i, (i < length limb_widths)%nat ->
nth_error base i = Some (two_p (sum_firstn limb_widths i)).
Proof.
induction i; intros.
+ unfold sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity.
intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega.
- + assert (i < length base)%nat as lt_i_length by omega.
+ + assert (i < length limb_widths)%nat as lt_i_length by omega.
specialize (IHi lt_i_length).
- rewrite base_from_limb_widths_length in lt_i_length.
destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w].
erewrite base_from_limb_widths_step; eauto.
f_equal.
@@ -86,19 +85,16 @@ Section Pow2BaseProofs.
eapply nth_error_value_In; eauto.
Qed.
- Lemma nth_default_base : forall d i, (i < length base)%nat ->
+ Lemma nth_default_base : forall d i, (i < length limb_widths)%nat ->
nth_default d base i = 2 ^ (sum_firstn limb_widths i).
Proof.
intros ? ? i_lt_length.
- destruct (nth_error_length_exists_value _ _ i_lt_length) as [x nth_err_x].
- unfold nth_default.
- rewrite nth_err_x.
- rewrite nth_error_base in nth_err_x by assumption.
- rewrite two_p_correct in nth_err_x.
- congruence.
+ apply nth_error_value_eq_nth_default.
+ rewrite nth_error_base, two_p_correct by assumption.
+ reflexivity.
Qed.
- Lemma base_succ : forall i, ((S i) < length base)%nat ->
+ Lemma base_succ : forall i, ((S i) < length limb_widths)%nat ->
nth_default 0 base (S i) mod nth_default 0 base i = 0.
Proof.
intros.
@@ -111,8 +107,7 @@ Section Pow2BaseProofs.
apply limb_widths_nonneg.
rewrite lw_eq.
apply in_eq.
- + assert (i < length base)%nat as i_lt_length by omega.
- rewrite base_from_limb_widths_length in *.
+ + assert (i < length limb_widths)%nat as i_lt_length by omega.
apply nth_error_length_exists_value in i_lt_length.
destruct i_lt_length as [x nth_err_x].
erewrite sum_firstn_succ; eauto.
@@ -126,6 +121,7 @@ Section Pow2BaseProofs.
Proof.
intros i b nth_err_b.
pose proof (nth_error_value_length _ _ _ _ nth_err_b).
+ rewrite base_from_limb_widths_length in *.
rewrite nth_error_base in nth_err_b by assumption.
rewrite two_p_correct in nth_err_b.
congruence.
@@ -168,19 +164,19 @@ Section Pow2BaseProofs.
Section make_base_vector.
Local Notation k := (sum_firstn limb_widths (length limb_widths)).
Context (limb_widths_match_modulus : forall i j,
- (i < length limb_widths)%nat ->
- (j < length limb_widths)%nat ->
- (i + j >= length limb_widths)%nat ->
+ (i < length base)%nat ->
+ (j < length base)%nat ->
+ (i + j >= length base)%nat ->
let w_sum := sum_firstn limb_widths in
- k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j)
+ k + w_sum (i + j - length base)%nat <= w_sum i + w_sum j)
(limb_widths_good : forall i j, (i + j < length limb_widths)%nat ->
sum_firstn limb_widths (i + j) <=
sum_firstn limb_widths i + sum_firstn limb_widths j).
Lemma base_matches_modulus: forall i j,
- (i < length limb_widths)%nat ->
- (j < length limb_widths)%nat ->
- (i+j >= length limb_widths)%nat->
+ (i < length base)%nat ->
+ (j < length base)%nat ->
+ (i+j >= length base)%nat->
let b := nth_default 0 base in
let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in
b i * b j = r * (2^k * b (i+j-length base)%nat).
@@ -188,20 +184,20 @@ Section Pow2BaseProofs.
intros.
rewrite (Z.mul_comm r).
subst r.
+ rewrite base_from_limb_widths_length in *;
assert (i + j - length limb_widths < length limb_widths)%nat by omega.
- rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; apply Z.mul_pos_pos;
- subst b; rewrite ?nth_default_base; zero_bounds; rewrite ?base_from_limb_widths_length;
- auto using sum_firstn_limb_widths_nonneg, limb_widths_nonneg).
+ rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; subst b; rewrite ?nth_default_base; zero_bounds;
+ assumption).
rewrite (Zminus_0_l_reverse (b i * b j)) at 1.
f_equal.
subst b.
- repeat rewrite nth_default_base by (rewrite ?base_from_limb_widths_length; auto).
+ repeat rewrite nth_default_base by auto.
do 2 rewrite <- Z.pow_add_r by auto using sum_firstn_limb_widths_nonneg.
symmetry.
apply Z.mod_same_pow.
split.
+ apply Z.add_nonneg_nonneg; auto using sum_firstn_limb_widths_nonneg.
- + rewrite base_from_limb_widths_length; auto using limb_widths_nonneg, limb_widths_match_modulus.
+ + auto using limb_widths_match_modulus.
Qed.
Lemma base_good : forall i j : nat,
@@ -211,7 +207,9 @@ Section Pow2BaseProofs.
b i * b j = r * b (i + j)%nat.
Proof.
intros; subst b r.
- repeat rewrite nth_default_base by (omega || auto).
+ clear limb_widths_match_modulus.
+ rewrite base_from_limb_widths_length in *.
+ repeat rewrite nth_default_base by omega.
rewrite (Z.mul_comm _ (2 ^ (sum_firstn limb_widths (i+j)))).
rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; zero_bounds;
auto using sum_firstn_limb_widths_nonneg).
@@ -219,10 +217,11 @@ Section Pow2BaseProofs.
rewrite Z.mod_same_pow; try ring.
split; [ auto using sum_firstn_limb_widths_nonneg | ].
apply limb_widths_good.
- rewrite <-base_from_limb_widths_length; auto using limb_widths_nonneg.
+ assumption.
Qed.
End make_base_vector.
End Pow2BaseProofs.
+Hint Rewrite @base_from_limb_widths_length : distr_length.
Section BitwiseDecodeEncode.
Context {limb_widths} (bv : BaseSystem.BaseVector (base_from_limb_widths limb_widths))
@@ -232,7 +231,7 @@ Section BitwiseDecodeEncode.
Local Notation base := (base_from_limb_widths limb_widths).
Local Notation upper_bound := (upper_bound limb_widths).
- Lemma encode'_spec : forall x i, (i <= length base)%nat ->
+ Lemma encode'_spec : forall x i, (i <= length limb_widths)%nat ->
encode' limb_widths x i = BaseSystem.encode' base x upper_bound i.
Proof.
induction i; intros.
@@ -240,13 +239,12 @@ Section BitwiseDecodeEncode.
+ rewrite encode'_succ, <-IHi by omega.
simpl; do 2 f_equal.
rewrite Z.land_ones, Z.shiftr_div_pow2 by auto using sum_firstn_limb_widths_nonneg.
- match goal with H : (S _ <= length base)%nat |- _ =>
+ match goal with H : (S _ <= length limb_widths)%nat |- _ =>
apply le_lt_or_eq in H; destruct H end.
- repeat f_equal; rewrite nth_default_base by (omega || auto); reflexivity.
- repeat f_equal; try solve [rewrite nth_default_base by (omega || auto); reflexivity].
- rewrite nth_default_out_of_bounds by omega.
+ rewrite nth_default_out_of_bounds by (distr_length; omega).
unfold Pow2Base.upper_bound.
- rewrite <-base_from_limb_widths_length by auto.
congruence.
Qed.
@@ -258,20 +256,17 @@ Section BitwiseDecodeEncode.
Lemma base_upper_bound_compatible : @base_max_succ_divide base upper_bound.
Proof.
unfold base_max_succ_divide; intros i lt_Si_length.
+ rewrite base_from_limb_widths_length in lt_Si_length.
rewrite Nat.lt_eq_cases in lt_Si_length; destruct lt_Si_length;
rewrite !nth_default_base by (omega || auto).
- + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0);
- rewrite <-base_from_limb_widths_length by auto; omega).
+ + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); omega).
rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg.
apply Z.divide_factor_r.
- + rewrite nth_default_out_of_bounds by omega.
+ + rewrite nth_default_out_of_bounds by (distr_length; omega).
unfold Pow2Base.upper_bound.
- replace (length limb_widths) with (S (pred (length limb_widths))) by
- (rewrite base_from_limb_widths_length in H by auto; omega).
- replace i with (pred (length limb_widths)) by
- (rewrite base_from_limb_widths_length in H by auto; omega).
- erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0);
- rewrite <-base_from_limb_widths_length by auto; omega).
+ replace (length limb_widths) with (S (pred (length limb_widths))) by omega.
+ replace i with (pred (length limb_widths)) by omega.
+ erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); omega).
rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg.
apply Z.divide_factor_r.
Qed.
@@ -281,7 +276,7 @@ Section BitwiseDecodeEncode.
BaseSystem.decode base (encodeZ limb_widths x) = x mod upper_bound.
Proof.
intros.
- assert (length base = length limb_widths) by auto using base_from_limb_widths_length.
+ assert (length base = length limb_widths) by distr_length.
unfold encodeZ; rewrite encode'_spec by omega.
rewrite BaseSystemProofs.encode'_spec; unfold Pow2Base.upper_bound; try zero_bounds;
auto using sum_firstn_limb_widths_nonneg.
@@ -521,7 +516,7 @@ Section carrying_helper.
Local Notation base := (base_from_limb_widths limb_widths).
Local Notation log_cap i := (nth_default 0 limb_widths i).
- Lemma update_nth_sum : forall n f us, (n < length us \/ n >= length base)%nat ->
+ Lemma update_nth_sum : forall n f us, (n < length us \/ n >= length limb_widths)%nat ->
BaseSystem.decode base (update_nth n f us) =
(let v := nth_default 0 us n in f v - v) * nth_default 0 base n + BaseSystem.decode base us.
Proof.
@@ -540,17 +535,17 @@ Section carrying_helper.
erewrite (nth_error_value_eq_nth_default _ _ us) by eassumption.
rewrite firstn_length in Heqn0.
rewrite Min.min_l in Heqn0 by omega; subst n0.
- destruct (le_lt_dec (length base) n). {
- rewrite (@nth_default_out_of_bounds _ _ base) by auto.
- rewrite skipn_all by omega.
+ destruct (le_lt_dec (length limb_widths) n). {
+ rewrite (@nth_default_out_of_bounds _ _ base) by (distr_length; auto).
+ rewrite skipn_all by (rewrite base_from_limb_widths_length; omega).
do 2 rewrite decode_base_nil.
ring_simplify; auto.
} {
- rewrite (skipn_nth_default n base 0) by omega.
+ rewrite (skipn_nth_default n base 0) by (distr_length; omega).
do 2 rewrite decode'_cons.
ring_simplify; ring.
} }
- { rewrite (nth_default_out_of_bounds _ base) by omega; ring_simplify.
+ { rewrite (nth_default_out_of_bounds _ base) by (distr_length; omega); ring_simplify.
etransitivity; rewrite BaseSystem.decode'_truncate; [ reflexivity | ].
apply f_equal.
autorewrite with push_firstn simpl_update_nth.
@@ -639,12 +634,12 @@ Section carrying_helper.
Hint Rewrite @length_add_to_nth : distr_length.
- Lemma set_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat ->
+ Lemma set_nth_sum : forall n x us, (n < length us \/ n >= length limb_widths)%nat ->
BaseSystem.decode base (set_nth n x us) =
(x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us.
Proof. intros; unfold set_nth; rewrite update_nth_sum by assumption; reflexivity. Qed.
- Lemma add_to_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat ->
+ Lemma add_to_nth_sum : forall n x us, (n < length us \/ n >= length limb_widths)%nat ->
BaseSystem.decode base (add_to_nth n x us) =
x * nth_default 0 base n + BaseSystem.decode base us.
Proof. intros; rewrite add_to_nth_set_nth, set_nth_sum; try ring_simplify; auto. Qed.
@@ -696,7 +691,7 @@ Section carrying.
Proof. intros; unfold carry_simple; distr_length; reflexivity. Qed.
Hint Rewrite @length_carry_simple : distr_length.
- Lemma nth_default_base_succ : forall i, (S i < length base)%nat ->
+ Lemma nth_default_base_succ : forall i, (S i < length limb_widths)%nat ->
nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i.
Proof.
intros.
@@ -705,13 +700,13 @@ Section carrying.
Qed.
Lemma carry_gen_decode_eq : forall fc fi i' us
- (i := fi (length base) i')
- (Si := fi (length base) (S i)),
- (length us = length base) ->
+ (i := fi i')
+ (Si := fi (S i)),
+ (length us = length limb_widths) ->
BaseSystem.decode base (carry_gen limb_widths fc fi i' us)
= (fc (nth_default 0 us i / 2 ^ log_cap i) *
(if eq_nat_dec Si (S i)
- then if lt_dec (S i) (length base)
+ then if lt_dec (S i) (length limb_widths)
then 2 ^ log_cap i * nth_default 0 base i
else 0
else nth_default 0 base Si)
@@ -719,29 +714,29 @@ Section carrying.
+ BaseSystem.decode base us.
Proof.
intros fc fi i' us i Si H; intros.
- destruct (eq_nat_dec 0 (length base));
+ destruct (eq_nat_dec 0 (length limb_widths));
[ destruct limb_widths, us, i; simpl in *; try congruence;
break_match;
unfold carry_gen, carry_single, add_to_nth;
autorewrite with zsimplify simpl_nth_default simpl_set_nth simpl_update_nth distr_length;
reflexivity
| ].
- (*assert (0 <= i < length base)%nat by (subst i; auto with arith).*)
+ (*assert (0 <= i < length limb_widths)%nat by (subst i; auto with arith).*)
assert (0 <= log_cap i) by auto using log_cap_nonneg.
assert (2 ^ log_cap i <> 0) by (apply Z.pow_nonzero; lia).
unfold carry_gen, carry_single.
- rewrite H; change (i' mod length base)%nat with i.
+ change (i' mod length limb_widths)%nat with i.
rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
rewrite set_nth_sum by omega.
unfold Z.pow2_mod.
rewrite Z.land_ones by auto using log_cap_nonneg.
rewrite Z.shiftr_div_pow2 by auto using log_cap_nonneg.
- change (fi (length base) i') with i.
+ change (fi i') with i.
subst Si.
repeat first [ ring
| match goal with H : _ = _ |- _ => rewrite !H in * end
| rewrite nth_default_base_succ by omega
- | rewrite !(nth_default_out_of_bounds _ base) by omega
+ | rewrite !(nth_default_out_of_bounds _ base) by (distr_length; omega)
| rewrite !(nth_default_out_of_bounds _ us) by omega
| rewrite Z.mod_eq by assumption
| progress distr_length
@@ -750,8 +745,8 @@ Section carrying.
Qed.
Lemma carry_simple_decode_eq : forall i us,
- (length us = length base) ->
- (i < (pred (length base)))%nat ->
+ (length us = length limb_widths) ->
+ (i < (pred (length limb_widths)))%nat ->
BaseSystem.decode base (carry_simple limb_widths i us) = BaseSystem.decode base us.
Proof.
unfold carry_simple; intros; rewrite carry_gen_decode_eq by assumption.
@@ -790,11 +785,11 @@ Section carrying.
Lemma nth_default_carry_gen_full fc fi d i n us
: nth_default d (carry_gen limb_widths fc fi i us) n
= if lt_dec n (length us)
- then (if eq_nat_dec n (fi (length us) i)
+ then (if eq_nat_dec n (fi i)
then Z.pow2_mod (nth_default 0 us n) (log_cap n)
else nth_default 0 us n) +
- if eq_nat_dec n (fi (length us) (S (fi (length us) i)))
- then fc (nth_default 0 us (fi (length us) i) >> log_cap (fi (length us) i))
+ if eq_nat_dec n (fi (S (fi i)))
+ then fc (nth_default 0 us (fi i) >> log_cap (fi i))
else 0
else d.
Proof.
@@ -826,11 +821,11 @@ Section carrying.
: forall fc fi i us,
(0 <= i < length us)%nat
-> nth_default 0 (carry_gen limb_widths fc fi i us) i
- = (if eq_nat_dec i (fi (length us) i)
+ = (if eq_nat_dec i (fi i)
then Z.pow2_mod (nth_default 0 us i) (log_cap i)
else nth_default 0 us i) +
- if eq_nat_dec i (fi (length us) (S (fi (length us) i)))
- then fc (nth_default 0 us (fi (length us) i) >> log_cap (fi (length us) i))
+ if eq_nat_dec i (fi (S (fi i)))
+ then fc (nth_default 0 us (fi i) >> log_cap (fi i))
else 0.
Proof.
intros; autorewrite with push_nth_default natsimplify; break_match; omega.
@@ -848,7 +843,7 @@ Section carrying.
Hint Rewrite @nth_default_carry_simple using (omega || distr_length; omega) : push_nth_default.
End carrying.
-Hint Rewrite @length_carry_gen @base_from_limb_widths_length : distr_length.
+Hint Rewrite @length_carry_gen : distr_length.
Hint Rewrite @length_carry_simple @length_carry_simple_sequence @length_make_chain @length_full_carry_chain @length_carry_simple_full : distr_length.
Hint Rewrite @nth_default_carry_simple_full @nth_default_carry_gen_full : push_nth_default.
Hint Rewrite @nth_default_carry_simple @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default.