aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/Pow2BaseProofs.v
diff options
context:
space:
mode:
authorGravatar Jason Gross <jasongross9@gmail.com>2016-07-18 19:09:46 +0200
committerGravatar GitHub <noreply@github.com>2016-07-18 19:09:46 +0200
commit07ca661557d86b96d1ee0a9b9013d0834158571f (patch)
tree78980ce7dbbf836f1d109159332600370ed224e6 /src/ModularArithmetic/Pow2BaseProofs.v
parent0fd535b57b93bada6cc00c2e372f2f94d2768567 (diff)
Move some definitions to Pow2Base (#24)
* Move some definitions to Pow2Base These definitions don't depend on PseudoMersenneBaseParams, only on limb_widths, and we'll want them for BarrettReduction / P256. * Fix for Coq 8.4
Diffstat (limited to 'src/ModularArithmetic/Pow2BaseProofs.v')
-rw-r--r--src/ModularArithmetic/Pow2BaseProofs.v466
1 files changed, 430 insertions, 36 deletions
diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v
index 7538781c0..ed9b58ccc 100644
--- a/src/ModularArithmetic/Pow2BaseProofs.v
+++ b/src/ModularArithmetic/Pow2BaseProofs.v
@@ -1,16 +1,19 @@
-Require Import Zpower ZArith.
+Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.micromega.Psatz.
Require Import Coq.Numbers.Natural.Peano.NPeano.
Require Import Coq.Lists.List.
-Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil.
+Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil.
+Require Import Crypto.Util.Tactics.
Require Import Crypto.ModularArithmetic.Pow2Base Crypto.BaseSystemProofs.
Require Crypto.BaseSystem.
Local Open Scope Z_scope.
+Create HintDb simpl_add_to_nth discriminated.
+
Section Pow2BaseProofs.
Context {limb_widths} (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w).
- Local Notation "{base}" := (base_from_limb_widths limb_widths).
+ Local Notation base := (base_from_limb_widths limb_widths).
- Lemma base_from_limb_widths_length : length {base} = length limb_widths.
+ Lemma base_from_limb_widths_length : length base = length limb_widths.
Proof.
induction limb_widths; try reflexivity.
simpl; rewrite map_length.
@@ -28,10 +31,10 @@ Section Pow2BaseProofs.
eapply In_firstn; eauto.
Qed. Hint Resolve sum_firstn_limb_widths_nonneg.
- Lemma base_from_limb_widths_step : forall i b w, (S i < length {base})%nat ->
- nth_error {base} i = Some b ->
+ Lemma base_from_limb_widths_step : forall i b w, (S i < length base)%nat ->
+ nth_error base i = Some b ->
nth_error limb_widths i = Some w ->
- nth_error {base} (S i) = Some (two_p w * b).
+ nth_error base (S i) = Some (two_p w * b).
Proof.
induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b;
unfold base_from_limb_widths in *; fold base_from_limb_widths in *;
@@ -54,13 +57,13 @@ Section Pow2BaseProofs.
Qed.
- Lemma nth_error_base : forall i, (i < length {base})%nat ->
- nth_error {base} i = Some (two_p (sum_firstn limb_widths i)).
+ Lemma nth_error_base : forall i, (i < length base)%nat ->
+ nth_error base i = Some (two_p (sum_firstn limb_widths i)).
Proof.
induction i; intros.
+ unfold sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity.
intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega.
- + assert (i < length {base})%nat as lt_i_length by omega.
+ + assert (i < length base)%nat as lt_i_length by omega.
specialize (IHi lt_i_length).
rewrite base_from_limb_widths_length in lt_i_length.
destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w].
@@ -80,8 +83,8 @@ Section Pow2BaseProofs.
eapply nth_error_value_In; eauto.
Qed.
- Lemma nth_default_base : forall d i, (i < length {base})%nat ->
- nth_default d {base} i = 2 ^ (sum_firstn limb_widths i).
+ Lemma nth_default_base : forall d i, (i < length base)%nat ->
+ nth_default d base i = 2 ^ (sum_firstn limb_widths i).
Proof.
intros ? ? i_lt_length.
destruct (nth_error_length_exists_value _ _ i_lt_length) as [x nth_err_x].
@@ -92,8 +95,8 @@ Section Pow2BaseProofs.
congruence.
Qed.
- Lemma base_succ : forall i, ((S i) < length {base})%nat ->
- nth_default 0 {base} (S i) mod nth_default 0 {base} i = 0.
+ Lemma base_succ : forall i, ((S i) < length base)%nat ->
+ nth_default 0 base (S i) mod nth_default 0 base i = 0.
Proof.
intros.
repeat rewrite nth_default_base by omega.
@@ -105,7 +108,7 @@ Section Pow2BaseProofs.
apply limb_widths_nonneg.
rewrite lw_eq.
apply in_eq.
- + assert (i < length {base})%nat as i_lt_length by omega.
+ + assert (i < length base)%nat as i_lt_length by omega.
rewrite base_from_limb_widths_length in *.
apply nth_error_length_exists_value in i_lt_length.
destruct i_lt_length as [x nth_err_x].
@@ -115,7 +118,7 @@ Section Pow2BaseProofs.
omega.
Qed.
- Lemma nth_error_subst : forall i b, nth_error {base} i = Some b ->
+ Lemma nth_error_subst : forall i b, nth_error base i = Some b ->
b = 2 ^ (sum_firstn limb_widths i).
Proof.
intros i b nth_err_b.
@@ -125,7 +128,7 @@ Section Pow2BaseProofs.
congruence.
Qed.
- Lemma base_positive : forall b : Z, In b {base} -> b > 0.
+ Lemma base_positive : forall b : Z, In b base -> b > 0.
Proof.
intros b In_b_base.
apply In_nth_error_value in In_b_base.
@@ -136,7 +139,7 @@ Section Pow2BaseProofs.
apply Z.pow_pos_nonneg; omega || auto using sum_firstn_limb_widths_nonneg.
Qed.
- Lemma b0_1 : forall x : Z, limb_widths <> nil -> nth_default x {base} 0 = 1.
+ Lemma b0_1 : forall x : Z, limb_widths <> nil -> nth_default x base 0 = 1.
Proof.
case_eq limb_widths; intros; [congruence | reflexivity].
Qed.
@@ -154,18 +157,18 @@ Section BitwiseDecodeEncode.
(limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w).
Local Hint Resolve limb_widths_nonneg.
Local Notation "w[ i ]" := (nth_default 0 limb_widths i).
- Local Notation "{base}" := (base_from_limb_widths limb_widths).
- Local Notation "{max}" := (upper_bound limb_widths).
+ Local Notation base := (base_from_limb_widths limb_widths).
+ Local Notation max := (upper_bound limb_widths).
- Lemma encode'_spec : forall x i, (i <= length {base})%nat ->
- encode' limb_widths x i = BaseSystem.encode' {base} x {max} i.
+ Lemma encode'_spec : forall x i, (i <= length base)%nat ->
+ encode' limb_widths x i = BaseSystem.encode' base x max i.
Proof.
induction i; intros.
+ rewrite encode'_zero. reflexivity.
+ rewrite encode'_succ, <-IHi by omega.
simpl; do 2 f_equal.
rewrite Z.land_ones, Z.shiftr_div_pow2 by auto using sum_firstn_limb_widths_nonneg.
- match goal with H : (S _ <= length {base})%nat |- _ =>
+ match goal with H : (S _ <= length base)%nat |- _ =>
apply le_lt_or_eq in H; destruct H end.
- repeat f_equal; rewrite nth_default_base by (omega || auto); reflexivity.
- repeat f_equal; try solve [rewrite nth_default_base by (omega || auto); reflexivity].
@@ -180,12 +183,12 @@ Section BitwiseDecodeEncode.
intros; apply nth_default_preserves_properties; auto; omega.
Qed. Hint Resolve nth_default_limb_widths_nonneg.
- Lemma base_upper_bound_compatible : @base_max_succ_divide {base} {max}.
+ Lemma base_upper_bound_compatible : @base_max_succ_divide base max.
Proof.
unfold base_max_succ_divide; intros i lt_Si_length.
rewrite Nat.lt_eq_cases in lt_Si_length; destruct lt_Si_length;
rewrite !nth_default_base by (omega || auto).
- + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0);
+ + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0);
rewrite <-base_from_limb_widths_length by auto; omega).
rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg.
apply Z.divide_factor_r.
@@ -195,18 +198,18 @@ Section BitwiseDecodeEncode.
(rewrite base_from_limb_widths_length in H by auto; omega).
replace i with (pred (length limb_widths)) by
(rewrite base_from_limb_widths_length in H by auto; omega).
- erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0);
+ erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0);
rewrite <-base_from_limb_widths_length by auto; omega).
rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg.
- apply Z.divide_factor_r.
+ apply Z.divide_factor_r.
Qed.
Hint Resolve base_upper_bound_compatible.
Lemma encodeZ_spec : forall x,
- BaseSystem.decode {base} (encodeZ limb_widths x) = x mod {max}.
+ BaseSystem.decode base (encodeZ limb_widths x) = x mod max.
Proof.
intros.
- assert (length {base} = length limb_widths) by auto using base_from_limb_widths_length.
+ assert (length base = length limb_widths) by auto using base_from_limb_widths_length.
unfold encodeZ; rewrite encode'_spec by omega.
rewrite BaseSystemProofs.encode'_spec; unfold upper_bound; try zero_bounds;
auto using sum_firstn_limb_widths_nonneg.
@@ -236,7 +239,7 @@ Section BitwiseDecodeEncode.
Proof.
intros.
simpl; f_equal.
- match goal with H : bounded _ _ |- _ =>
+ match goal with H : bounded _ _ |- _ =>
rewrite Z.lor_shiftl by (auto; unfold bounded in H; specialize (H i); assumption) end.
rewrite Z.shiftl_mul_pow2 by auto.
ring.
@@ -316,7 +319,7 @@ Section BitwiseDecodeEncode.
Lemma decode_bitwise'_spec : forall us i, (i <= length limb_widths)%nat ->
bounded limb_widths us -> length us = length limb_widths ->
decode_bitwise' limb_widths us i (partial_decode us i (length us - i)) =
- BaseSystem.decode {base} us.
+ BaseSystem.decode base us.
Proof.
induction i; intros.
+ rewrite partial_decode_intermediate by auto.
@@ -328,7 +331,7 @@ Section BitwiseDecodeEncode.
Lemma decode_bitwise_spec : forall us, bounded limb_widths us ->
length us = length limb_widths ->
- decode_bitwise limb_widths us = BaseSystem.decode {base} us.
+ decode_bitwise limb_widths us = BaseSystem.decode base us.
Proof.
unfold decode_bitwise; intros.
replace 0 with (partial_decode us (length us) (length us - length us)) by
@@ -361,7 +364,7 @@ Section UniformBase.
Context {width : Z} (limb_width_pos : 0 < width).
Context (limb_widths : list Z) (limb_widths_nonnil : limb_widths <> nil)
(limb_widths_uniform : forall w, In w limb_widths -> w = width).
- Local Notation "{base}" := (base_from_limb_widths limb_widths).
+ Local Notation base := (base_from_limb_widths limb_widths).
Lemma bounded_uniform : forall us, (length us <= length limb_widths)%nat ->
(bounded limb_widths us <-> (forall u, In u us -> 0 <= u < 2 ^ width)).
@@ -409,7 +412,7 @@ Section UniformBase.
Qed.
Lemma decode_tl_base_shift : forall us, (length us < length limb_widths)%nat ->
- BaseSystem.decode (tl {base}) us = BaseSystem.decode {base} us << width.
+ BaseSystem.decode (tl base) us = BaseSystem.decode base us << width.
Proof.
intros ? Hlength.
edestruct (destruct_repeat limb_widths) as [? | [tl_lw [Heq_lw tl_lw_uniform]]];
@@ -422,7 +425,7 @@ Section UniformBase.
Qed.
Lemma decode_shift : forall us u0, (length (u0 :: us) <= length limb_widths)%nat ->
- BaseSystem.decode {base} (u0 :: us) = u0 + ((BaseSystem.decode {base} us) << width).
+ BaseSystem.decode base (u0 :: us) = u0 + ((BaseSystem.decode base us) << width).
Proof.
intros.
rewrite <-decode_tl_base_shift by (simpl in *; omega).
@@ -439,4 +442,395 @@ Section UniformBase.
replace w with width by (symmetry; auto).
assumption.
Qed.
-End UniformBase. \ No newline at end of file
+End UniformBase.
+
+Section carrying_helper.
+ Context {limb_widths} (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w).
+ Local Notation base := (base_from_limb_widths limb_widths).
+ Local Notation log_cap i := (nth_default 0 limb_widths i).
+
+ Lemma update_nth_sum : forall n f us, (n < length us \/ n >= length base)%nat ->
+ BaseSystem.decode base (update_nth n f us) =
+ (let v := nth_default 0 us n in f v - v) * nth_default 0 base n + BaseSystem.decode base us.
+ Proof.
+ intros.
+ unfold BaseSystem.decode.
+ destruct H as [H|H].
+ { nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*)
+ erewrite nth_error_value_eq_nth_default by eassumption.
+ unfold splice_nth.
+ rewrite <- (firstn_skipn n us) at 3.
+ do 2 rewrite decode'_splice.
+ remember (length (firstn n us)) as n0.
+ ring_simplify.
+ remember (BaseSystem.decode' (firstn n0 base) (firstn n us)).
+ rewrite (skipn_nth_default n us 0) by omega.
+ erewrite (nth_error_value_eq_nth_default _ _ us) by eassumption.
+ rewrite firstn_length in Heqn0.
+ rewrite Min.min_l in Heqn0 by omega; subst n0.
+ destruct (le_lt_dec (length base) n). {
+ rewrite (@nth_default_out_of_bounds _ _ base) by auto.
+ rewrite skipn_all by omega.
+ do 2 rewrite decode_base_nil.
+ ring_simplify; auto.
+ } {
+ rewrite (skipn_nth_default n base 0) by omega.
+ do 2 rewrite decode'_cons.
+ ring_simplify; ring.
+ } }
+ { rewrite (nth_default_out_of_bounds _ base) by omega; ring_simplify.
+ etransitivity; rewrite BaseSystem.decode'_truncate; [ reflexivity | ].
+ apply f_equal.
+ autorewrite with push_firstn simpl_update_nth.
+ rewrite update_nth_out_of_bounds by (distr_length; omega * ).
+ reflexivity. }
+ Qed.
+
+ Lemma unfold_add_to_nth n x
+ : forall xs,
+ add_to_nth n x xs
+ = match n with
+ | O => match xs with
+ | nil => nil
+ | x'::xs' => x + x'::xs'
+ end
+ | S n' => match xs with
+ | nil => nil
+ | x'::xs' => x'::add_to_nth n' x xs'
+ end
+ end.
+ Proof.
+ induction n; destruct xs; reflexivity.
+ Qed.
+
+ Lemma simpl_add_to_nth_0 x
+ : forall xs,
+ add_to_nth 0 x xs
+ = match xs with
+ | nil => nil
+ | x'::xs' => x + x'::xs'
+ end.
+ Proof. intro; rewrite unfold_add_to_nth; reflexivity. Qed.
+
+ Lemma simpl_add_to_nth_S x n
+ : forall xs,
+ add_to_nth (S n) x xs
+ = match xs with
+ | nil => nil
+ | x'::xs' => x'::add_to_nth n x xs'
+ end.
+ Proof. intro; rewrite unfold_add_to_nth; reflexivity. Qed.
+
+ Hint Rewrite @simpl_set_nth_S @simpl_set_nth_0 : simpl_add_to_nth.
+
+ Lemma add_to_nth_cons : forall x u0 us, add_to_nth 0 x (u0 :: us) = x + u0 :: us.
+ Proof. reflexivity. Qed.
+
+ Hint Rewrite @add_to_nth_cons : simpl_add_to_nth.
+
+ Lemma cons_add_to_nth : forall n f y us,
+ y :: add_to_nth n f us = add_to_nth (S n) f (y :: us).
+ Proof.
+ induction n; boring.
+ Qed.
+
+ Hint Rewrite <- @cons_add_to_nth : simpl_add_to_nth.
+
+ Lemma add_to_nth_nil : forall n f, add_to_nth n f nil = nil.
+ Proof.
+ induction n; boring.
+ Qed.
+
+ Hint Rewrite @add_to_nth_nil : simpl_add_to_nth.
+
+ Lemma add_to_nth_set_nth n x xs
+ : add_to_nth n x xs
+ = set_nth n (x + nth_default 0 xs n) xs.
+ Proof.
+ revert xs; induction n; destruct xs;
+ autorewrite with simpl_set_nth simpl_add_to_nth;
+ try rewrite IHn;
+ reflexivity.
+ Qed.
+ Lemma add_to_nth_update_nth n x xs
+ : add_to_nth n x xs
+ = update_nth n (fun y => x + y) xs.
+ Proof.
+ revert xs; induction n; destruct xs;
+ autorewrite with simpl_update_nth simpl_add_to_nth;
+ try rewrite IHn;
+ reflexivity.
+ Qed.
+
+ Lemma length_add_to_nth i x xs : length (add_to_nth i x xs) = length xs.
+ Proof. unfold add_to_nth; distr_length; reflexivity. Qed.
+
+ Hint Rewrite @length_add_to_nth : distr_length.
+
+ Lemma set_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat ->
+ BaseSystem.decode base (set_nth n x us) =
+ (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us.
+ Proof. intros; unfold set_nth; rewrite update_nth_sum by assumption; reflexivity. Qed.
+
+ Lemma add_to_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat ->
+ BaseSystem.decode base (add_to_nth n x us) =
+ x * nth_default 0 base n + BaseSystem.decode base us.
+ Proof. unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto. Qed.
+
+ Lemma add_to_nth_nth_default_full : forall n x l i d,
+ nth_default d (add_to_nth n x l) i =
+ if lt_dec i (length l) then
+ if (eq_nat_dec i n) then x + nth_default d l i
+ else nth_default d l i
+ else d.
+ Proof. intros; rewrite add_to_nth_update_nth; apply update_nth_nth_default_full; assumption. Qed.
+ Hint Rewrite @add_to_nth_nth_default_full : push_nth_default.
+
+ Lemma add_to_nth_nth_default : forall n x l i, (0 <= i < length l)%nat ->
+ nth_default 0 (add_to_nth n x l) i =
+ if (eq_nat_dec i n) then x + nth_default 0 l i else nth_default 0 l i.
+ Proof. intros; rewrite add_to_nth_update_nth; apply update_nth_nth_default; assumption. Qed.
+ Hint Rewrite @add_to_nth_nth_default using omega : push_nth_default.
+
+ Lemma log_cap_nonneg : forall i, 0 <= log_cap i.
+ Proof.
+ unfold nth_default; intros.
+ case_eq (nth_error limb_widths i); intros; try omega.
+ apply limb_widths_nonneg.
+ eapply nth_error_value_In; eauto.
+ Qed. Local Hint Resolve log_cap_nonneg.
+End carrying_helper.
+
+Hint Rewrite @simpl_set_nth_S @simpl_set_nth_0 : simpl_add_to_nth.
+Hint Rewrite @add_to_nth_cons : simpl_add_to_nth.
+Hint Rewrite <- @cons_add_to_nth : simpl_add_to_nth.
+Hint Rewrite @add_to_nth_nil : simpl_add_to_nth.
+Hint Rewrite @length_add_to_nth : distr_length.
+Hint Rewrite @add_to_nth_nth_default_full : push_nth_default.
+Hint Rewrite @add_to_nth_nth_default using (omega || distr_length; omega) : push_nth_default.
+
+Section carrying.
+ Context {limb_widths} (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w).
+ Local Notation base := (base_from_limb_widths limb_widths).
+ Local Notation log_cap i := (nth_default 0 limb_widths i).
+ Local Hint Resolve limb_widths_nonneg sum_firstn_limb_widths_nonneg.
+
+ (*
+ Lemma length_carry_gen : forall f i us, length (carry_gen limb_widths f i us) = length us.
+ Proof. intros; unfold carry_gen, carry_and_reduce_single; distr_length; reflexivity. Qed.
+
+ Hint Rewrite @length_carry_gen : distr_length.
+ *)
+
+ Lemma length_carry_simple : forall i us, length (carry_simple limb_widths i us) = length us.
+ Proof. intros; unfold carry_simple; distr_length; reflexivity. Qed.
+ Hint Rewrite @length_carry_simple : distr_length.
+
+ Lemma nth_default_base_succ : forall i, (S i < length base)%nat ->
+ nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i.
+ Proof.
+ intros.
+ rewrite !nth_default_base, <- Z.pow_add_r by (omega || eauto using log_cap_nonneg).
+ autorewrite with simpl_sum_firstn; reflexivity.
+ Qed.
+
+ (*
+ Lemma carry_gen_decode_eq : forall f i' us (i := (i' mod length base)%nat),
+ (length us = length base) ->
+ BaseSystem.decode base (carry_gen limb_widths f i' us)
+ = ((f (nth_default 0 us i / 2 ^ log_cap i))
+ * (if eq_nat_dec (S i mod length base) 0
+ then nth_default 0 base 0
+ else (2 ^ log_cap i) * (nth_default 0 base i))
+ - (nth_default 0 us i / 2 ^ log_cap i) * 2 ^ log_cap i * nth_default 0 base i
+ )
+ + BaseSystem.decode base us.
+ Proof.
+ intros f i' us i H; intros.
+ destruct (eq_nat_dec 0 (length base));
+ [ destruct limb_widths, us, i; simpl in *; try congruence;
+ unfold carry_gen, carry_and_reduce_single, add_to_nth;
+ autorewrite with zsimplify simpl_nth_default simpl_set_nth simpl_update_nth distr_length;
+ reflexivity
+ | ].
+ assert (0 <= i < length base)%nat by (subst i; auto with arith).
+ assert (0 <= log_cap i) by auto using log_cap_nonneg.
+ assert (2 ^ log_cap i <> 0) by (apply Z.pow_nonzero; lia).
+ unfold carry_gen, carry_and_reduce_single.
+ rewrite H; change (i' mod length base)%nat with i.
+ rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
+ rewrite set_nth_sum by omega.
+ unfold Z.pow2_mod.
+ rewrite Z.land_ones by auto using log_cap_nonneg.
+ rewrite Z.shiftr_div_pow2 by auto using log_cap_nonneg.
+ destruct (eq_nat_dec (S i mod length base) 0);
+ repeat first [ ring
+ | congruence
+ | match goal with H : _ = _ |- _ => rewrite !H in * end
+ | rewrite nth_default_base_succ by omega
+ | rewrite !(nth_default_out_of_bounds _ base) by omega
+ | rewrite !(nth_default_out_of_bounds _ us) by omega
+ | rewrite Z.mod_eq by assumption
+ | progress distr_length
+ | progress autorewrite with natsimplify zsimplify in *
+ | progress break_match ].
+ Qed.
+
+ Lemma carry_simple_decode_eq : forall i us,
+ (length us = length base) ->
+ (i < (pred (length base)))%nat ->
+ BaseSystem.decode base (carry_simple limb_widths i us) = BaseSystem.decode base us.
+ Proof.
+ unfold carry_simple; intros; rewrite carry_gen_decode_eq by assumption.
+ autorewrite with natsimplify.
+ break_match; lia.
+ Qed.
+*)
+
+ Lemma carry_simple_decode_eq : forall i us,
+ (length us = length base) ->
+ (i < (pred (length base)))%nat ->
+ BaseSystem.decode base (carry_simple limb_widths i us) = BaseSystem.decode base us.
+ Proof.
+ unfold carry_simple. intros.
+ rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
+ rewrite set_nth_sum by omega.
+ unfold Z.pow2_mod.
+ rewrite Z.land_ones by eauto using log_cap_nonneg.
+ rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg.
+ rewrite nth_default_base_succ by omega.
+ rewrite Z.mul_assoc.
+ rewrite (Z.mul_comm _ (2 ^ log_cap i)).
+ rewrite Z.mul_div_eq; try ring.
+ apply Z.gt_lt_iff.
+ apply Z.pow_pos_nonneg; omega || eauto using log_cap_nonneg.
+ Qed.
+
+ Lemma length_carry_simple_sequence : forall is us, length (carry_simple_sequence limb_widths is us) = length us.
+ Proof.
+ unfold carry_simple_sequence.
+ induction is; [ reflexivity | simpl; intros ].
+ distr_length.
+ congruence.
+ Qed.
+ Hint Rewrite @length_carry_simple_sequence : distr_length.
+
+ Lemma length_make_chain : forall i, length (make_chain i) = i.
+ Proof. induction i; simpl; congruence. Qed.
+ Hint Rewrite @length_make_chain : distr_length.
+
+ Lemma length_full_carry_chain : length (full_carry_chain limb_widths) = length limb_widths.
+ Proof. unfold full_carry_chain; distr_length; reflexivity. Qed.
+ Hint Rewrite @length_full_carry_chain : distr_length.
+
+ Lemma length_carry_simple_full us : length (carry_simple_full limb_widths us) = length us.
+ Proof. unfold carry_simple_full; distr_length; reflexivity. Qed.
+ Hint Rewrite @length_carry_simple_full : distr_length.
+
+ (* TODO : move? *)
+ Lemma make_chain_lt : forall x i : nat, In i (make_chain x) -> (i < x)%nat.
+ Proof.
+ induction x; simpl; intuition.
+ Qed.
+(*
+ Lemma nth_default_carry_gen_full : forall f d i n us,
+ nth_default d (carry_gen limb_widths f i us) n
+ = if lt_dec n (length us)
+ then if eq_nat_dec n (i mod length us)
+ then (if eq_nat_dec (S n) (length us)
+ then (if eq_nat_dec n 0
+ then f ((nth_default 0 us n) >> log_cap n)
+ else 0)
+ else 0)
+ + Z.pow2_mod (nth_default 0 us n) (log_cap n)
+ else (if eq_nat_dec n (if eq_nat_dec (S (i mod length us)) (length us) then 0%nat else S (i mod length us))
+ then f (nth_default 0 us (i mod length us) >> log_cap (i mod length us))
+ else 0)
+ + nth_default d us n
+ else d.
+ Proof.
+ unfold carry_gen, carry_and_reduce_single.
+ intros; autorewrite with push_nth_default natsimplify distr_length.
+ edestruct lt_dec; [ | reflexivity ].
+ change (S ?x) with (1 + x)%nat.
+ rewrite (Nat.add_mod_idemp_r 1 i (length us)) by omega.
+ autorewrite with natsimplify.
+ change (1 + ?x)%nat with (S x).
+ destruct (eq_nat_dec n (i mod length us));
+ subst; repeat break_match; omega.
+ Qed.
+
+ Hint Rewrite @nth_default_carry_gen_full : push_nth_default.
+
+ Lemma nth_default_carry_simple_full : forall d i n us,
+ nth_default d (carry_simple limb_widths i us) n
+ = if lt_dec n (length us)
+ then if eq_nat_dec n (i mod length us)
+ then (if eq_nat_dec (S n) (length us)
+ then (if eq_nat_dec n 0
+ then (nth_default 0 us n >> log_cap n + Z.pow2_mod (nth_default 0 us n) (log_cap n))
+ (* FIXME: The above is just [nth_default 0 us n], but do we really care about the case of [n = 0], [length us = 1]? *)
+ else Z.pow2_mod (nth_default 0 us n) (log_cap n))
+ else Z.pow2_mod (nth_default 0 us n) (log_cap n))
+ else (if eq_nat_dec n (if eq_nat_dec (S (i mod length us)) (length us) then 0%nat else S (i mod length us))
+ then nth_default 0 us (i mod length us) >> log_cap (i mod length us)
+ else 0)
+ + nth_default d us n
+ else d.
+ Proof.
+ intros; unfold carry_simple; autorewrite with push_nth_default;
+ repeat break_match; reflexivity.
+ Qed.
+
+ Hint Rewrite @nth_default_carry_simple_full : push_nth_default.
+
+ Lemma nth_default_carry_gen
+ : forall f i us,
+ (0 <= i < length us)%nat
+ -> nth_default 0 (carry_gen limb_widths f i us) i
+ = (if PeanoNat.Nat.eq_dec i (if PeanoNat.Nat.eq_dec (S i) (length us) then 0%nat else S i)
+ then f (nth_default 0 us i >> log_cap i) + Z.pow2_mod (nth_default 0 us i) (log_cap i)
+ else Z.pow2_mod (nth_default 0 us i) (log_cap i)).
+ Proof.
+ unfold carry_gen, carry_and_reduce_single.
+ intros; autorewrite with push_nth_default natsimplify; reflexivity.
+ Qed.
+ Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default.
+
+ Lemma nth_default_carry_simple
+ : forall f i us,
+ (0 <= i < length us)%nat
+ -> nth_default 0 (carry_gen limb_widths f i us) i
+ = (if PeanoNat.Nat.eq_dec i (if PeanoNat.Nat.eq_dec (S i) (length us) then 0%nat else S i)
+ then f (nth_default 0 us i >> log_cap i) + Z.pow2_mod (nth_default 0 us i) (log_cap i)
+ else Z.pow2_mod (nth_default 0 us i) (log_cap i)).
+ Proof.
+ unfold carry_gen, carry_and_reduce_single.
+ intros; autorewrite with push_nth_default natsimplify; reflexivity.
+ Qed.
+ Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default.
+
+
+ Lemma nth_default_carry_gen
+ : forall f i us,
+ (0 <= i < length us)%nat
+ -> nth_default 0 (carry_gen limb_widths f i us) i
+ = (if PeanoNat.Nat.eq_dec i (if PeanoNat.Nat.eq_dec (S i) (length us) then 0%nat else S i)
+ then f (nth_default 0 us i >> log_cap i) + Z.pow2_mod (nth_default 0 us i) (log_cap i)
+ else Z.pow2_mod (nth_default 0 us i) (log_cap i)).
+ Proof.
+ unfold carry_gen, carry_and_reduce_single.
+ intros; autorewrite with push_nth_default natsimplify; reflexivity.
+ Qed.
+ Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default.
+*)
+End carrying.
+
+(*
+Hint Rewrite @length_carry_gen : distr_length.
+*)
+Hint Rewrite @length_carry_simple @length_carry_simple_sequence @length_make_chain @length_full_carry_chain @length_carry_simple_full : distr_length.
+(*
+Hint Rewrite @nth_default_carry_gen_full : push_nth_default.
+Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default.
+*)