diff options
Diffstat (limited to 'src/ModularArithmetic/Pow2BaseProofs.v')
-rw-r--r-- | src/ModularArithmetic/Pow2BaseProofs.v | 408 |
1 files changed, 401 insertions, 7 deletions
diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v index e06df9328..a7d7da800 100644 --- a/src/ModularArithmetic/Pow2BaseProofs.v +++ b/src/ModularArithmetic/Pow2BaseProofs.v @@ -1,11 +1,14 @@ -Require Import Zpower ZArith. +Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.micromega.Psatz. Require Import Coq.Numbers.Natural.Peano.NPeano. Require Import Coq.Lists.List. -Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil. +Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil. +Require Import Crypto.Util.Tactics. Require Import Crypto.ModularArithmetic.Pow2Base Crypto.BaseSystemProofs. Require Crypto.BaseSystem. Local Open Scope Z_scope. +Create HintDb simpl_add_to_nth discriminated. + Section Pow2BaseProofs. Context {limb_widths} (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w). Local Notation base := (base_from_limb_widths limb_widths). @@ -185,7 +188,7 @@ Section BitwiseDecodeEncode. unfold base_max_succ_divide; intros i lt_Si_length. rewrite Nat.lt_eq_cases in lt_Si_length; destruct lt_Si_length; rewrite !nth_default_base by (omega || auto). - + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); + + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); rewrite <-base_from_limb_widths_length by auto; omega). rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg. apply Z.divide_factor_r. @@ -195,10 +198,10 @@ Section BitwiseDecodeEncode. (rewrite base_from_limb_widths_length in H by auto; omega). replace i with (pred (length limb_widths)) by (rewrite base_from_limb_widths_length in H by auto; omega). - erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); rewrite <-base_from_limb_widths_length by auto; omega). rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg. - apply Z.divide_factor_r. + apply Z.divide_factor_r. Qed. Hint Resolve base_upper_bound_compatible. @@ -236,7 +239,7 @@ Section BitwiseDecodeEncode. Proof. intros. simpl; f_equal. - match goal with H : bounded _ _ |- _ => + match goal with H : bounded _ _ |- _ => rewrite Z.lor_shiftl by (auto; unfold bounded in H; specialize (H i); assumption) end. rewrite Z.shiftl_mul_pow2 by auto. ring. @@ -439,4 +442,395 @@ Section UniformBase. replace w with width by (symmetry; auto). assumption. Qed. -End UniformBase.
\ No newline at end of file +End UniformBase. + +Section carrying_helper. + Context {limb_widths} (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w). + Local Notation base := (base_from_limb_widths limb_widths). + Local Notation log_cap i := (nth_default 0 limb_widths i). + + Lemma update_nth_sum : forall n f us, (n < length us \/ n >= length base)%nat -> + BaseSystem.decode base (update_nth n f us) = + (let v := nth_default 0 us n in f v - v) * nth_default 0 base n + BaseSystem.decode base us. + Proof. + intros. + unfold BaseSystem.decode. + destruct H as [H|H]. + { nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*) + erewrite nth_error_value_eq_nth_default by eassumption. + unfold splice_nth. + rewrite <- (firstn_skipn n us) at 3. + do 2 rewrite decode'_splice. + remember (length (firstn n us)) as n0. + ring_simplify. + remember (BaseSystem.decode' (firstn n0 base) (firstn n us)). + rewrite (skipn_nth_default n us 0) by omega. + erewrite (nth_error_value_eq_nth_default _ _ us) by eassumption. + rewrite firstn_length in Heqn0. + rewrite Min.min_l in Heqn0 by omega; subst n0. + destruct (le_lt_dec (length base) n). { + rewrite (@nth_default_out_of_bounds _ _ base) by auto. + rewrite skipn_all by omega. + do 2 rewrite decode_base_nil. + ring_simplify; auto. + } { + rewrite (skipn_nth_default n base 0) by omega. + do 2 rewrite decode'_cons. + ring_simplify; ring. + } } + { rewrite (nth_default_out_of_bounds _ base) by omega; ring_simplify. + etransitivity; rewrite BaseSystem.decode'_truncate; [ reflexivity | ]. + apply f_equal. + autorewrite with push_firstn simpl_update_nth. + rewrite update_nth_out_of_bounds by (distr_length; omega * ). + reflexivity. } + Qed. + + Lemma unfold_add_to_nth n x + : forall xs, + add_to_nth n x xs + = match n with + | O => match xs with + | nil => nil + | x'::xs' => x + x'::xs' + end + | S n' => match xs with + | nil => nil + | x'::xs' => x'::add_to_nth n' x xs' + end + end. + Proof. + induction n; destruct xs; reflexivity. + Qed. + + Lemma simpl_add_to_nth_0 x + : forall xs, + add_to_nth 0 x xs + = match xs with + | nil => nil + | x'::xs' => x + x'::xs' + end. + Proof. intro; rewrite unfold_add_to_nth; reflexivity. Qed. + + Lemma simpl_add_to_nth_S x n + : forall xs, + add_to_nth (S n) x xs + = match xs with + | nil => nil + | x'::xs' => x'::add_to_nth n x xs' + end. + Proof. intro; rewrite unfold_add_to_nth; reflexivity. Qed. + + Hint Rewrite @simpl_set_nth_S @simpl_set_nth_0 : simpl_add_to_nth. + + Lemma add_to_nth_cons : forall x u0 us, add_to_nth 0 x (u0 :: us) = x + u0 :: us. + Proof. reflexivity. Qed. + + Hint Rewrite @add_to_nth_cons : simpl_add_to_nth. + + Lemma cons_add_to_nth : forall n f y us, + y :: add_to_nth n f us = add_to_nth (S n) f (y :: us). + Proof. + induction n; boring. + Qed. + + Hint Rewrite <- @cons_add_to_nth : simpl_add_to_nth. + + Lemma add_to_nth_nil : forall n f, add_to_nth n f nil = nil. + Proof. + induction n; boring. + Qed. + + Hint Rewrite @add_to_nth_nil : simpl_add_to_nth. + + Lemma add_to_nth_set_nth n x xs + : add_to_nth n x xs + = set_nth n (x + nth_default 0 xs n) xs. + Proof. + revert xs; induction n; destruct xs; + autorewrite with simpl_set_nth simpl_add_to_nth; + try rewrite IHn; + reflexivity. + Qed. + Lemma add_to_nth_update_nth n x xs + : add_to_nth n x xs + = update_nth n (fun y => x + y) xs. + Proof. + revert xs; induction n; destruct xs; + autorewrite with simpl_update_nth simpl_add_to_nth; + try rewrite IHn; + reflexivity. + Qed. + + Lemma length_add_to_nth i x xs : length (add_to_nth i x xs) = length xs. + Proof. unfold add_to_nth; distr_length; reflexivity. Qed. + + Hint Rewrite @length_add_to_nth : distr_length. + + Lemma set_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat -> + BaseSystem.decode base (set_nth n x us) = + (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us. + Proof. intros; unfold set_nth; rewrite update_nth_sum by assumption; reflexivity. Qed. + + Lemma add_to_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat -> + BaseSystem.decode base (add_to_nth n x us) = + x * nth_default 0 base n + BaseSystem.decode base us. + Proof. unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto. Qed. + + Lemma add_to_nth_nth_default_full : forall n x l i d, + nth_default d (add_to_nth n x l) i = + if lt_dec i (length l) then + if (eq_nat_dec i n) then x + nth_default d l i + else nth_default d l i + else d. + Proof. intros; rewrite add_to_nth_update_nth; apply update_nth_nth_default_full; assumption. Qed. + Hint Rewrite @add_to_nth_nth_default_full : push_nth_default. + + Lemma add_to_nth_nth_default : forall n x l i, (0 <= i < length l)%nat -> + nth_default 0 (add_to_nth n x l) i = + if (eq_nat_dec i n) then x + nth_default 0 l i else nth_default 0 l i. + Proof. intros; rewrite add_to_nth_update_nth; apply update_nth_nth_default; assumption. Qed. + Hint Rewrite @add_to_nth_nth_default using omega : push_nth_default. + + Lemma log_cap_nonneg : forall i, 0 <= log_cap i. + Proof. + unfold nth_default; intros. + case_eq (nth_error limb_widths i); intros; try omega. + apply limb_widths_nonneg. + eapply nth_error_value_In; eauto. + Qed. Local Hint Resolve log_cap_nonneg. +End carrying_helper. + +Hint Rewrite @simpl_set_nth_S @simpl_set_nth_0 : simpl_add_to_nth. +Hint Rewrite @add_to_nth_cons : simpl_add_to_nth. +Hint Rewrite <- @cons_add_to_nth : simpl_add_to_nth. +Hint Rewrite @add_to_nth_nil : simpl_add_to_nth. +Hint Rewrite @length_add_to_nth : distr_length. +Hint Rewrite @add_to_nth_nth_default_full : push_nth_default. +Hint Rewrite @add_to_nth_nth_default using (omega || distr_length; omega) : push_nth_default. + +Section carrying. + Context {limb_widths} (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w). + Local Notation base := (base_from_limb_widths limb_widths). + Local Notation log_cap i := (nth_default 0 limb_widths i). + Local Hint Resolve limb_widths_nonneg sum_firstn_limb_widths_nonneg. + + (* + Lemma length_carry_gen : forall f i us, length (carry_gen limb_widths f i us) = length us. + Proof. intros; unfold carry_gen, carry_and_reduce_single; distr_length; reflexivity. Qed. + + Hint Rewrite @length_carry_gen : distr_length. + *) + + Lemma length_carry_simple : forall i us, length (carry_simple limb_widths i us) = length us. + Proof. intros; unfold carry_simple; distr_length; reflexivity. Qed. + Hint Rewrite @length_carry_simple : distr_length. + + Lemma nth_default_base_succ : forall i, (S i < length base)%nat -> + nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i. + Proof. + intros. + rewrite !nth_default_base, <- Z.pow_add_r by (omega || eauto using log_cap_nonneg). + autorewrite with simpl_sum_firstn; reflexivity. + Qed. + + (* + Lemma carry_gen_decode_eq : forall f i' us (i := (i' mod length base)%nat), + (length us = length base) -> + BaseSystem.decode base (carry_gen limb_widths f i' us) + = ((f (nth_default 0 us i / 2 ^ log_cap i)) + * (if eq_nat_dec (S i mod length base) 0 + then nth_default 0 base 0 + else (2 ^ log_cap i) * (nth_default 0 base i)) + - (nth_default 0 us i / 2 ^ log_cap i) * 2 ^ log_cap i * nth_default 0 base i + ) + + BaseSystem.decode base us. + Proof. + intros f i' us i H; intros. + destruct (eq_nat_dec 0 (length base)); + [ destruct limb_widths, us, i; simpl in *; try congruence; + unfold carry_gen, carry_and_reduce_single, add_to_nth; + autorewrite with zsimplify simpl_nth_default simpl_set_nth simpl_update_nth distr_length; + reflexivity + | ]. + assert (0 <= i < length base)%nat by (subst i; auto with arith). + assert (0 <= log_cap i) by auto using log_cap_nonneg. + assert (2 ^ log_cap i <> 0) by (apply Z.pow_nonzero; lia). + unfold carry_gen, carry_and_reduce_single. + rewrite H; change (i' mod length base)%nat with i. + rewrite add_to_nth_sum by (rewrite length_set_nth; omega). + rewrite set_nth_sum by omega. + unfold Z.pow2_mod. + rewrite Z.land_ones by auto using log_cap_nonneg. + rewrite Z.shiftr_div_pow2 by auto using log_cap_nonneg. + destruct (eq_nat_dec (S i mod length base) 0); + repeat first [ ring + | congruence + | match goal with H : _ = _ |- _ => rewrite !H in * end + | rewrite nth_default_base_succ by omega + | rewrite !(nth_default_out_of_bounds _ base) by omega + | rewrite !(nth_default_out_of_bounds _ us) by omega + | rewrite Z.mod_eq by assumption + | progress distr_length + | progress autorewrite with natsimplify zsimplify in * + | progress break_match ]. + Qed. + + Lemma carry_simple_decode_eq : forall i us, + (length us = length base) -> + (i < (pred (length base)))%nat -> + BaseSystem.decode base (carry_simple limb_widths i us) = BaseSystem.decode base us. + Proof. + unfold carry_simple; intros; rewrite carry_gen_decode_eq by assumption. + autorewrite with natsimplify. + break_match; lia. + Qed. +*) + + Lemma carry_simple_decode_eq : forall i us, + (length us = length base) -> + (i < (pred (length base)))%nat -> + BaseSystem.decode base (carry_simple limb_widths i us) = BaseSystem.decode base us. + Proof. + unfold carry_simple. intros. + rewrite add_to_nth_sum by (rewrite length_set_nth; omega). + rewrite set_nth_sum by omega. + unfold Z.pow2_mod. + rewrite Z.land_ones by eauto using log_cap_nonneg. + rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg. + rewrite nth_default_base_succ by omega. + rewrite Z.mul_assoc. + rewrite (Z.mul_comm _ (2 ^ log_cap i)). + rewrite Z.mul_div_eq; try ring. + apply Z.gt_lt_iff. + apply Z.pow_pos_nonneg; omega || eauto using log_cap_nonneg. + Qed. + + Lemma length_carry_simple_sequence : forall is us, length (carry_simple_sequence limb_widths is us) = length us. + Proof. + unfold carry_simple_sequence. + induction is; [ reflexivity | simpl; intros ]. + distr_length. + congruence. + Qed. + Hint Rewrite @length_carry_simple_sequence : distr_length. + + Lemma length_make_chain : forall i, length (make_chain i) = i. + Proof. induction i; simpl; congruence. Qed. + Hint Rewrite @length_make_chain : distr_length. + + Lemma length_full_carry_chain : length (full_carry_chain limb_widths) = length limb_widths. + Proof. unfold full_carry_chain; distr_length; reflexivity. Qed. + Hint Rewrite @length_full_carry_chain : distr_length. + + Lemma length_carry_simple_full us : length (carry_simple_full limb_widths us) = length us. + Proof. unfold carry_simple_full; distr_length; reflexivity. Qed. + Hint Rewrite @length_carry_simple_full : distr_length. + + (* TODO : move? *) + Lemma make_chain_lt : forall x i : nat, In i (make_chain x) -> (i < x)%nat. + Proof. + induction x; simpl; intuition. + Qed. +(* + Lemma nth_default_carry_gen_full : forall f d i n us, + nth_default d (carry_gen limb_widths f i us) n + = if lt_dec n (length us) + then if eq_nat_dec n (i mod length us) + then (if eq_nat_dec (S n) (length us) + then (if eq_nat_dec n 0 + then f ((nth_default 0 us n) >> log_cap n) + else 0) + else 0) + + Z.pow2_mod (nth_default 0 us n) (log_cap n) + else (if eq_nat_dec n (if eq_nat_dec (S (i mod length us)) (length us) then 0%nat else S (i mod length us)) + then f (nth_default 0 us (i mod length us) >> log_cap (i mod length us)) + else 0) + + nth_default d us n + else d. + Proof. + unfold carry_gen, carry_and_reduce_single. + intros; autorewrite with push_nth_default natsimplify distr_length. + edestruct lt_dec; [ | reflexivity ]. + change (S ?x) with (1 + x)%nat. + rewrite (Nat.add_mod_idemp_r 1 i (length us)) by omega. + autorewrite with natsimplify. + change (1 + ?x)%nat with (S x). + destruct (eq_nat_dec n (i mod length us)); + subst; repeat break_match; omega. + Qed. + + Hint Rewrite @nth_default_carry_gen_full : push_nth_default. + + Lemma nth_default_carry_simple_full : forall d i n us, + nth_default d (carry_simple limb_widths i us) n + = if lt_dec n (length us) + then if eq_nat_dec n (i mod length us) + then (if eq_nat_dec (S n) (length us) + then (if eq_nat_dec n 0 + then (nth_default 0 us n >> log_cap n + Z.pow2_mod (nth_default 0 us n) (log_cap n)) + (* FIXME: The above is just [nth_default 0 us n], but do we really care about the case of [n = 0], [length us = 1]? *) + else Z.pow2_mod (nth_default 0 us n) (log_cap n)) + else Z.pow2_mod (nth_default 0 us n) (log_cap n)) + else (if eq_nat_dec n (if eq_nat_dec (S (i mod length us)) (length us) then 0%nat else S (i mod length us)) + then nth_default 0 us (i mod length us) >> log_cap (i mod length us) + else 0) + + nth_default d us n + else d. + Proof. + intros; unfold carry_simple; autorewrite with push_nth_default; + repeat break_match; reflexivity. + Qed. + + Hint Rewrite @nth_default_carry_simple_full : push_nth_default. + + Lemma nth_default_carry_gen + : forall f i us, + (0 <= i < length us)%nat + -> nth_default 0 (carry_gen limb_widths f i us) i + = (if PeanoNat.Nat.eq_dec i (if PeanoNat.Nat.eq_dec (S i) (length us) then 0%nat else S i) + then f (nth_default 0 us i >> log_cap i) + Z.pow2_mod (nth_default 0 us i) (log_cap i) + else Z.pow2_mod (nth_default 0 us i) (log_cap i)). + Proof. + unfold carry_gen, carry_and_reduce_single. + intros; autorewrite with push_nth_default natsimplify; reflexivity. + Qed. + Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default. + + Lemma nth_default_carry_simple + : forall f i us, + (0 <= i < length us)%nat + -> nth_default 0 (carry_gen limb_widths f i us) i + = (if PeanoNat.Nat.eq_dec i (if PeanoNat.Nat.eq_dec (S i) (length us) then 0%nat else S i) + then f (nth_default 0 us i >> log_cap i) + Z.pow2_mod (nth_default 0 us i) (log_cap i) + else Z.pow2_mod (nth_default 0 us i) (log_cap i)). + Proof. + unfold carry_gen, carry_and_reduce_single. + intros; autorewrite with push_nth_default natsimplify; reflexivity. + Qed. + Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default. + + + Lemma nth_default_carry_gen + : forall f i us, + (0 <= i < length us)%nat + -> nth_default 0 (carry_gen limb_widths f i us) i + = (if PeanoNat.Nat.eq_dec i (if PeanoNat.Nat.eq_dec (S i) (length us) then 0%nat else S i) + then f (nth_default 0 us i >> log_cap i) + Z.pow2_mod (nth_default 0 us i) (log_cap i) + else Z.pow2_mod (nth_default 0 us i) (log_cap i)). + Proof. + unfold carry_gen, carry_and_reduce_single. + intros; autorewrite with push_nth_default natsimplify; reflexivity. + Qed. + Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default. +*) +End carrying. + +(* +Hint Rewrite @length_carry_gen : distr_length. +*) +Hint Rewrite @length_carry_simple @length_carry_simple_sequence @length_make_chain @length_full_carry_chain @length_carry_simple_full : distr_length. +(* +Hint Rewrite @nth_default_carry_gen_full : push_nth_default. +Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default. +*) |