From 07ca661557d86b96d1ee0a9b9013d0834158571f Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Mon, 18 Jul 2016 19:09:46 +0200 Subject: Move some definitions to Pow2Base (#24) * Move some definitions to Pow2Base These definitions don't depend on PseudoMersenneBaseParams, only on limb_widths, and we'll want them for BarrettReduction / P256. * Fix for Coq 8.4 --- src/ModularArithmetic/ModularBaseSystem.v | 24 +- src/ModularArithmetic/ModularBaseSystemOpt.v | 12 +- src/ModularArithmetic/ModularBaseSystemProofs.v | 188 +++------- src/ModularArithmetic/Pow2Base.v | 52 +++ src/ModularArithmetic/Pow2BaseProofs.v | 466 ++++++++++++++++++++++-- 5 files changed, 541 insertions(+), 201 deletions(-) (limited to 'src/ModularArithmetic') diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index b6138381e..23b0c2ef6 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -42,14 +42,10 @@ Section CarryBasePow2. Local Notation base := (base_from_limb_widths limb_widths). Local Notation log_cap i := (nth_default 0 limb_widths i). - Definition add_to_nth n (x:Z) xs := - set_nth n (x + nth_default 0 xs n) xs. - - Definition carry_simple i := fun us => - let di := nth_default 0 us i in - let us' := set_nth i (Z.pow2_mod di (log_cap i)) us in - add_to_nth (S i) ( (Z.shiftr di (log_cap i))) us'. - + (* + Definition carry_and_reduce := + carry_gen limb_widths (fun ci => c * ci). + *) Definition carry_and_reduce i := fun us => let di := nth_default 0 us i in let us' := set_nth i (Z.pow2_mod di (log_cap i)) us in @@ -58,19 +54,11 @@ Section CarryBasePow2. Definition carry i : digits -> digits := if eq_nat_dec i (pred (length base)) then carry_and_reduce i - else carry_simple i. + else carry_simple limb_widths i. Definition carry_sequence is us := fold_right carry us is. - Fixpoint make_chain i := - match i with - | O => nil - | S i' => i' :: make_chain i' - end. - - Definition full_carry_chain := make_chain (length limb_widths). - - Definition carry_full := carry_sequence full_carry_chain. + Definition carry_full := carry_sequence (full_carry_chain limb_widths). Definition carry_mul us vs := carry_full (mul us vs). diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index 80e2f58ce..7c7004dce 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -33,7 +33,7 @@ Definition nth_default_opt {A} := Eval compute in @nth_default A. Definition set_nth_opt {A} := Eval compute in @set_nth A. Definition update_nth_opt {A} := Eval compute in @update_nth A. Definition map_opt {A B} := Eval compute in @map A B. -Definition full_carry_chain_opt := Eval compute in @full_carry_chain. +Definition full_carry_chain_opt := Eval compute in @Pow2Base.full_carry_chain. Definition length_opt := Eval compute in length. Definition base_from_limb_widths_opt := Eval compute in @Pow2Base.base_from_limb_widths. Definition max_ones_opt := Eval compute in @max_ones. @@ -118,7 +118,7 @@ Section Carries. cbv [carry]. rewrite <- pull_app_if_sumbool. cbv beta delta - [carry carry_and_reduce carry_simple add_to_nth + [carry carry_and_reduce Pow2Base.carry_simple Pow2Base.add_to_nth Z.pow2_mod Z.ones Z.pred PseudoMersenneBaseParams.limb_widths]. change @Pow2Base.base_from_limb_widths with @base_from_limb_widths_opt. @@ -272,17 +272,17 @@ Section Carries. apply carry_sequence_rep; eauto using rep_length. Qed. - Lemma full_carry_chain_bounds : forall i, In i full_carry_chain -> (i < length base)%nat. + Lemma full_carry_chain_bounds : forall i, In i (Pow2Base.full_carry_chain limb_widths) -> (i < length base)%nat. Proof. - unfold full_carry_chain; rewrite <-base_length; intros. - apply make_chain_lt; auto. + unfold Pow2Base.full_carry_chain; rewrite <-base_length; intros. + apply Pow2BaseProofs.make_chain_lt; auto. Qed. Definition carry_full_opt_sig (us : digits) : { b : digits | b = carry_full us }. Proof. eexists. cbv [carry_full]. - change @full_carry_chain with full_carry_chain_opt. + change @Pow2Base.full_carry_chain with full_carry_chain_opt. rewrite <-carry_sequence_opt_cps_correct by (auto; apply full_carry_chain_bounds). reflexivity. Defined. diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index 7e33ab20f..e5ae285de 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -10,6 +10,7 @@ Require Import Crypto.ModularArithmetic.Pow2BaseProofs. Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. Require Import Crypto.ModularArithmetic.ExtendedBaseVector. +Require Import Crypto.Util.Tactics. Require Import Crypto.Util.Notations. Local Open Scope Z_scope. @@ -22,7 +23,8 @@ Section PseudoMersenneProofs. Local Notation "u .+ x" := (add u x). Local Notation "u .* x" := (ModularBaseSystem.mul u x). Local Hint Unfold rep. - Local Hint Resolve limb_widths_nonneg sum_firstn_limb_widths_nonneg. + Local Hint Resolve (@limb_widths_nonneg _ prm) sum_firstn_limb_widths_nonneg. + Local Hint Resolve log_cap_nonneg. Local Notation base := (base_from_limb_widths limb_widths). Local Notation log_cap i := (nth_default 0 limb_widths i). @@ -166,6 +168,11 @@ Section PseudoMersenneProofs. rewrite <- Zplus_mod; auto. Qed. + Lemma pseudomersenne_add': forall x y0 y1 z, (z - x + ((2^k) * y0 * y1)) mod modulus = (c * y0 * y1 - x + z) mod modulus. + Proof. + intros; rewrite <- !Z.add_opp_r, <- !Z.mul_assoc, pseudomersenne_add; apply f_equal2; omega. + Qed. + Lemma extended_shiftadd: forall (us : BaseSystem.digits), BaseSystem.decode ext_base us = BaseSystem.decode base (firstn (length base) us) @@ -207,7 +214,7 @@ Section PseudoMersenneProofs. apply Max.max_l; omega. Qed. - Lemma length_mul : forall u v, + Lemma length_mul : forall u v, length u = length base -> length v = length base -> length (u .* v) = length base. @@ -231,56 +238,6 @@ Section PseudoMersenneProofs. apply ZToField_mul. } Qed. - Lemma set_nth_sum : forall n x us, (n < length us)%nat -> - BaseSystem.decode base (set_nth n x us) = - (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us. - Proof. - intros. - unfold BaseSystem.decode. - nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*) - unfold splice_nth. - rewrite <- (firstn_skipn n us) at 4. - do 2 rewrite decode'_splice. - remember (length (firstn n us)) as n0. - ring_simplify. - remember (BaseSystem.decode' (firstn n0 base) (firstn n us)). - rewrite (skipn_nth_default n us 0) by omega. - rewrite firstn_length in Heqn0. - rewrite Min.min_l in Heqn0 by omega; subst n0. - destruct (le_lt_dec (length base) n). { - rewrite nth_default_out_of_bounds by auto. - rewrite skipn_all by omega. - do 2 rewrite decode_base_nil. - ring_simplify; auto. - } { - rewrite (skipn_nth_default n base 0) by omega. - do 2 rewrite decode'_cons. - ring_simplify; ring. - } - Qed. - - Lemma add_to_nth_sum : forall n x us, (n < length us)%nat -> - BaseSystem.decode base (add_to_nth n x us) = - x * nth_default 0 base n + BaseSystem.decode base us. - Proof. - unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto. - Qed. - - Lemma add_to_nth_nth_default : forall n x l i, (0 <= i < length l)%nat -> - nth_default 0 (add_to_nth n x l) i = - if (eq_nat_dec i n) then x + nth_default 0 l i else nth_default 0 l i. - Proof. - intros. - unfold add_to_nth. - rewrite set_nth_nth_default by assumption. - break_if; subst; reflexivity. - Qed. - - Lemma length_add_to_nth : forall n x l, length (add_to_nth n x l) = length l. - Proof. - unfold add_to_nth; intros; apply length_set_nth. - Qed. - Lemma nth_default_base_positive : forall i, (i < length base)%nat -> nth_default 0 base i > 0. Proof. @@ -310,14 +267,6 @@ Section PseudoMersenneProofs. apply FieldToZ_ZToField. Qed. - Lemma log_cap_nonneg : forall i, 0 <= log_cap i. - Proof. - unfold nth_default; intros. - case_eq (nth_error limb_widths i); intros; try omega. - apply limb_widths_nonneg. - eapply nth_error_value_In; eauto. - Qed. Local Hint Resolve log_cap_nonneg. - Definition carry_done us := forall i, (i < length base)%nat -> 0 <= nth_default 0 us i /\ Z.shiftr (nth_default 0 us i) (log_cap i) = 0. @@ -374,7 +323,7 @@ Section CarryProofs. Local Notation base := (base_from_limb_widths limb_widths). Local Notation log_cap i := (nth_default 0 limb_widths i). Local Notation "u ~= x" := (rep u x). - Local Hint Resolve limb_widths_nonneg sum_firstn_limb_widths_nonneg. + Local Hint Resolve (@limb_widths_nonneg _ prm) sum_firstn_limb_widths_nonneg. Lemma base_length_lt_pred : (pred (length base) < length base)%nat. Proof. @@ -382,40 +331,6 @@ Section CarryProofs. Qed. Hint Resolve base_length_lt_pred. - Lemma nth_default_base_succ : forall i, (S i < length base)%nat -> - nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i. - Proof. - intros. - repeat rewrite nth_default_base by (omega || eauto). - rewrite <- Z.pow_add_r by eauto using log_cap_nonneg. - destruct (NPeano.Nat.eq_dec i 0). - + subst; f_equal. - unfold sum_firstn. - destruct limb_widths; auto. - + erewrite sum_firstn_succ; eauto. - apply nth_error_Some_nth_default. - rewrite <- base_length; omega. - Qed. - - Lemma carry_simple_decode_eq : forall i us, - (length us = length base) -> - (i < (pred (length base)))%nat -> - BaseSystem.decode base (carry_simple i us) = BaseSystem.decode base us. - Proof. - unfold carry_simple. intros. - rewrite add_to_nth_sum by (rewrite length_set_nth; omega). - rewrite set_nth_sum by omega. - unfold Z.pow2_mod. - rewrite Z.land_ones by apply log_cap_nonneg. - rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. - rewrite nth_default_base_succ by omega. - rewrite Z.mul_assoc. - rewrite (Z.mul_comm _ (2 ^ log_cap i)). - rewrite Z.mul_div_eq; try ring. - apply Z.gt_lt_iff. - apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg. - Qed. - Lemma carry_decode_eq_reduce : forall us, (length us = length base) -> BaseSystem.decode base (carry_and_reduce (pred (length base)) us) mod modulus @@ -442,9 +357,9 @@ Section CarryProofs. + rewrite nth_default_base by (omega || eauto). rewrite <- Z.add_opp_l, <- Z.opp_sub_distr. unfold Z.pow2_mod. - rewrite Z.land_ones by apply log_cap_nonneg. - rewrite <- Z.mul_div_eq by (apply Z.gt_lt_iff; apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg). - rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. + rewrite Z.land_ones by eauto using log_cap_nonneg. + rewrite <- Z.mul_div_eq by (apply Z.gt_lt_iff; apply Z.pow_pos_nonneg; omega || eauto using log_cap_nonneg). + rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg. rewrite Zopp_mult_distr_r. rewrite Z.mul_comm. rewrite Z.mul_assoc. @@ -460,21 +375,23 @@ Section CarryProofs. ring. Qed. - Lemma carry_length : forall i us, - (length us = length base)%nat -> - (length (carry i us) = length base)%nat. - Proof. - unfold carry, carry_simple, carry_and_reduce, add_to_nth. - intros; break_if; subst; repeat (rewrite length_set_nth); auto. - Qed. - Hint Resolve carry_length. + Lemma length_carry_and_reduce : forall i us, length (carry_and_reduce i us) = length us. + Proof. intros; unfold carry_and_reduce; autorewrite with distr_length; reflexivity. Qed. + Hint Rewrite @length_carry_and_reduce : distr_length. + + Lemma length_carry : forall i us, length (carry i us) = length us. + Proof. intros; unfold carry; break_if; autorewrite with distr_length; reflexivity. Qed. + Hint Rewrite @length_carry : distr_length. + + Local Hint Extern 1 (length _ = _) => progress autorewrite with distr_length. Lemma carry_rep : forall i us x, (length us = length base) -> (i < length base)%nat -> us ~= x -> carry i us ~= x. Proof. - pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq. + pose proof length_carry. pose proof carry_decode_eq_reduce. pose proof (@carry_simple_decode_eq limb_widths). + specialize_by eauto. intros; split; auto. unfold rep, decode, carry in *. intuition; break_if; subst; eauto; apply F_eq; simpl; intuition. @@ -497,13 +414,6 @@ Section CarryProofs. induction is; boring. Qed. - (* TODO : move? *) - Lemma make_chain_lt : forall x i : nat, In i (make_chain x) -> (i < x)%nat. - Proof. - induction x; simpl; intuition. - Qed. - - Lemma carry_full_length : forall us, (length us = length base)%nat -> length (carry_full us) = length base. Proof. @@ -529,7 +439,7 @@ Section CarryProofs. Qed. Lemma carry_mul_length : forall us vs, - length us = length base -> length vs = length base -> + length us = length base -> length vs = length base -> length (carry_mul us vs) = length base. Proof. intros; cbv [carry_mul]. @@ -538,6 +448,8 @@ Section CarryProofs. End CarryProofs. +Hint Rewrite @length_carry_and_reduce @length_carry : distr_length. + Section CanonicalizationProofs. Context `{prm : PseudoMersenneBaseParams}. Local Notation base := (base_from_limb_widths limb_widths). @@ -553,10 +465,11 @@ Section CanonicalizationProofs. (two_pow_k_le_2modulus : 2 ^ k <= 2 * modulus). (* BEGIN groundwork proofs *) + Local Hint Resolve (@log_cap_nonneg limb_widths) limb_widths_nonneg. Lemma pow_2_log_cap_pos : forall i, 0 < 2 ^ log_cap i. Proof. - intros; apply Z.pow_pos_nonneg; auto using log_cap_nonneg; omega. + intros; apply Z.pow_pos_nonneg; eauto using log_cap_nonneg; omega. Qed. Local Hint Resolve pow_2_log_cap_pos. @@ -568,12 +481,11 @@ Section CanonicalizationProofs. omega. Qed. - Local Hint Resolve log_cap_nonneg. Lemma pow2_mod_log_cap_range : forall a i, 0 <= Z.pow2_mod a (log_cap i) <= max_bound i. Proof. intros. unfold Z.pow2_mod. - rewrite Z.land_ones by apply log_cap_nonneg. + rewrite Z.land_ones by eauto using log_cap_nonneg. unfold max_bound, Z.ones. rewrite Z.shiftl_1_l, <-Z.lt_le_pred. apply Z_mod_lt. @@ -598,7 +510,7 @@ Section CanonicalizationProofs. Proof. intros. unfold Z.pow2_mod. - rewrite Z.land_ones by apply log_cap_nonneg. + rewrite Z.land_ones by eauto using log_cap_nonneg. apply Z.mod_small. split; try omega. rewrite <- max_bound_log_cap. @@ -617,17 +529,10 @@ Section CanonicalizationProofs. Lemma max_bound_nonneg : forall i, 0 <= max_bound i. Proof. - unfold max_bound; intros; auto using Z.ones_nonneg. + unfold max_bound; intros; eauto using Z.ones_nonneg. Qed. Local Hint Resolve max_bound_nonneg. - Lemma pow2_mod_spec : forall a b, (0 <= b) -> Z.pow2_mod a b = a mod (2 ^ b). - Proof. - intros. - unfold Z.pow2_mod. - rewrite Z.land_ones; auto. - Qed. - Lemma shiftr_eq_0_max_bound : forall i a, Z.shiftr a (log_cap i) = 0 -> a <= max_bound i. Proof. @@ -678,7 +583,7 @@ Section CanonicalizationProofs. (* automation *) Ltac carry_length_conditions' := unfold carry_full, add_to_nth; - rewrite ?length_set_nth, ?carry_length, ?carry_sequence_length; + rewrite ?length_set_nth, ?length_carry, ?carry_sequence_length; try omega; try solve [pose proof base_length; pose proof base_length_nonzero; omega || auto ]. Ltac carry_length_conditions := try split; try omega; repeat carry_length_conditions'. @@ -931,9 +836,9 @@ Section CanonicalizationProofs. replace (2 ^ B) with (2 ^ (B - log_cap i) + (2 ^ B - 2 ^ (B - log_cap i))) by omega. split; [ zero_bounds | ]. apply Z.add_lt_mono; try omega. - rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. + rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg. apply Z.div_lt_upper_bound; try apply pow_2_log_cap_pos. - rewrite <-Z.pow_add_r by (apply log_cap_nonneg || apply B_compat_log_cap). + rewrite <-Z.pow_add_r by (eauto using log_cap_nonneg || apply B_compat_log_cap). replace (log_cap i + (B - log_cap i)) with B by ring. omega. Qed. @@ -976,7 +881,7 @@ Section CanonicalizationProofs. apply Z.add_le_mono. + apply carry_bounds_0_upper; auto; omega. + apply Z.mul_le_mono_pos_l; auto using c_pos. - apply Z.shiftr_ones; auto; + apply Z.shiftr_ones; eauto; [ | pose proof (B_compat_log_cap (pred (length base))); omega ]. split. - apply carry_bounds_lower; auto; omega. @@ -1014,7 +919,7 @@ Section CanonicalizationProofs. apply carry_full_bounds; auto; omega. + rewrite <-max_bound_log_cap, <-Z.add_1_l. apply Z.add_le_mono. - - rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. + - rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg. apply Z.div_floor; auto. destruct i. * simpl. @@ -1047,7 +952,7 @@ Section CanonicalizationProofs. - apply carry_bounds_0_upper; carry_length_conditions. - etransitivity; [ | replace c with (c * 1) by ring; reflexivity ]. apply Z.mul_le_mono_pos_l; try (pose proof c_pos; omega). - rewrite Z.shiftr_div_pow2 by auto. + rewrite Z.shiftr_div_pow2 by eauto. apply Z.div_le_upper_bound; auto. ring_simplify. apply carry_sequence_carry_full_bounds_same; auto. @@ -1060,7 +965,7 @@ Section CanonicalizationProofs. 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) i <= 2 ^ log_cap i) -> - 0 <= nth_default 0 (carry_simple i + 0 <= nth_default 0 (carry_simple limb_widths i (carry_sequence (make_chain i) (carry_full (carry_full us)))) (S i) <= 2 ^ log_cap (S i). Proof. unfold carry_simple; intros ? ? PCB length_eq ? IH. @@ -1072,7 +977,7 @@ Section CanonicalizationProofs. apply carry_full_bounds; carry_length_conditions. carry_seq_lower_bound. + rewrite <-max_bound_log_cap, <-Z.add_1_l. - rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. + rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg. apply Z.add_le_mono. - apply Z.div_le_upper_bound; auto. ring_simplify. apply IH. omega. @@ -1096,7 +1001,7 @@ Section CanonicalizationProofs. - eapply carry_full_bounds; eauto; carry_length_conditions. carry_seq_lower_bound. + rewrite <-max_bound_log_cap, <-Z.add_1_l. - rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. + rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg. apply Z.add_le_mono. - apply Z.div_floor; auto. eapply Z.le_lt_trans; [ eapply carry_full_2_bounds_0; eauto | ]. @@ -1183,7 +1088,7 @@ Section CanonicalizationProofs. remember ((nth_default 0 (carry_full (carry_full us)) 0)) as x. apply Z.le_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)); try omega. replace x with ((x - 2 ^ log_cap 0) + (1 * 2 ^ log_cap 0)) by ring. - rewrite pow2_mod_spec by auto. + rewrite Z.pow2_mod_spec by eauto. cbv [make_chain carry_sequence fold_right]. rewrite Z.mod_add by (pose proof (pow_2_log_cap_pos 0); omega). rewrite <-max_bound_log_cap, <-Z.add_1_l, Z.mod_small; @@ -1200,6 +1105,7 @@ Section CanonicalizationProofs. assumption. Qed. + (* END proofs about second carry loop *) (* BEGIN proofs about third carry loop *) @@ -1230,7 +1136,7 @@ Section CanonicalizationProofs. apply Z.add_le_mono; try assumption. etransitivity; [ | replace c with (c * 1) by ring; reflexivity ]. apply Z.mul_le_mono_pos_l; try omega. - rewrite Z.shiftr_div_pow2 by auto. + rewrite Z.shiftr_div_pow2 by eauto. apply Z.div_le_upper_bound; auto. ring_simplify. apply carry_full_2_bounds_same; auto. @@ -1299,8 +1205,8 @@ Section CanonicalizationProofs. Proof. unfold max_ones. apply Z.ones_nonneg. + clear. pose proof limb_widths_nonneg. - clear c_reduce1 lt_1_length_base. induction limb_widths as [|?? IHl]. cbv; congruence. simpl. @@ -1732,7 +1638,7 @@ Section CanonicalizationProofs. rewrite decode_base_nil. apply Z.gt_lt; auto using nth_default_base_positive. + rewrite decode_firstn_succ by (auto || omega). - rewrite nth_default_base_succ by omega. + rewrite nth_default_base_succ by (eauto || omega). eapply Z.lt_le_trans. - apply Z.add_lt_mono_r. apply IHn; auto; omega. @@ -2103,4 +2009,4 @@ Section CanonicalizationProofs. eapply minimal_rep_unique; eauto; rewrite freeze_length; assumption. Qed. -End CanonicalizationProofs. \ No newline at end of file +End CanonicalizationProofs. diff --git a/src/ModularArithmetic/Pow2Base.v b/src/ModularArithmetic/Pow2Base.v index 7d0495ef3..f434a0c9f 100644 --- a/src/ModularArithmetic/Pow2Base.v +++ b/src/ModularArithmetic/Pow2Base.v @@ -1,5 +1,7 @@ Require Import Zpower ZArith. Require Import Crypto.Util.ListUtil. +Require Import Crypto.Util.ZUtil. +Require Crypto.BaseSystem. Require Import Coq.Lists.List. Local Open Scope Z_scope. @@ -16,6 +18,7 @@ Section Pow2Base. Local Notation base := (base_from_limb_widths limb_widths). + Definition bounded us := forall i, 0 <= nth_default 0 us i < 2 ^ w[i]. Definition upper_bound := 2 ^ (sum_firstn limb_widths (length limb_widths)). @@ -39,4 +42,53 @@ Section Pow2Base. (* max must be greater than input; this is used to truncate last digit *) Definition encodeZ x:= encode' x (length limb_widths). + (** ** Carrying *) + Section carrying. + (** Here we implement addition and multiplication with simple + carrying. *) + Notation log_cap i := (nth_default 0 limb_widths i). + + + Definition add_to_nth n (x:Z) xs := + set_nth n (x + nth_default 0 xs n) xs. + (* TODO: Maybe we should use this instead? *) + (* + Definition add_to_nth n (x:Z) xs := + update_nth n (fun y => x + y) xs. + + Definition carry_and_reduce_single i := fun di => + (Z.pow2_mod di (log_cap i), + Z.shiftr di (log_cap i)). + + Definition carry_gen f i := fun us => + let i := (i mod length us)%nat in + let di := nth_default 0 us i in + let '(di', ci) := carry_and_reduce_single i di in + let us' := set_nth i di' us in + add_to_nth ((S i) mod (length us)) (f ci) us'. + + Definition carry_simple := carry_gen (fun ci => ci). + *) + Definition carry_simple i := fun us => + let di := nth_default 0 us i in + let us' := set_nth i (Z.pow2_mod di (log_cap i)) us in + add_to_nth (S i) ( (Z.shiftr di (log_cap i))) us'. + + Definition carry_simple_sequence is us := fold_right carry_simple us is. + + Fixpoint make_chain i := + match i with + | O => nil + | S i' => i' :: make_chain i' + end. + + Definition full_carry_chain := make_chain (length limb_widths). + + Definition carry_simple_full := carry_simple_sequence full_carry_chain. + + Definition carry_simple_add us vs := carry_simple_full (BaseSystem.add us vs). + + Definition carry_simple_mul out_base us vs := carry_simple_full (BaseSystem.mul out_base us vs). + End carrying. + End Pow2Base. diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v index 7538781c0..ed9b58ccc 100644 --- a/src/ModularArithmetic/Pow2BaseProofs.v +++ b/src/ModularArithmetic/Pow2BaseProofs.v @@ -1,16 +1,19 @@ -Require Import Zpower ZArith. +Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.micromega.Psatz. Require Import Coq.Numbers.Natural.Peano.NPeano. Require Import Coq.Lists.List. -Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil. +Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil. +Require Import Crypto.Util.Tactics. Require Import Crypto.ModularArithmetic.Pow2Base Crypto.BaseSystemProofs. Require Crypto.BaseSystem. Local Open Scope Z_scope. +Create HintDb simpl_add_to_nth discriminated. + Section Pow2BaseProofs. Context {limb_widths} (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w). - Local Notation "{base}" := (base_from_limb_widths limb_widths). + Local Notation base := (base_from_limb_widths limb_widths). - Lemma base_from_limb_widths_length : length {base} = length limb_widths. + Lemma base_from_limb_widths_length : length base = length limb_widths. Proof. induction limb_widths; try reflexivity. simpl; rewrite map_length. @@ -28,10 +31,10 @@ Section Pow2BaseProofs. eapply In_firstn; eauto. Qed. Hint Resolve sum_firstn_limb_widths_nonneg. - Lemma base_from_limb_widths_step : forall i b w, (S i < length {base})%nat -> - nth_error {base} i = Some b -> + Lemma base_from_limb_widths_step : forall i b w, (S i < length base)%nat -> + nth_error base i = Some b -> nth_error limb_widths i = Some w -> - nth_error {base} (S i) = Some (two_p w * b). + nth_error base (S i) = Some (two_p w * b). Proof. induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b; unfold base_from_limb_widths in *; fold base_from_limb_widths in *; @@ -54,13 +57,13 @@ Section Pow2BaseProofs. Qed. - Lemma nth_error_base : forall i, (i < length {base})%nat -> - nth_error {base} i = Some (two_p (sum_firstn limb_widths i)). + Lemma nth_error_base : forall i, (i < length base)%nat -> + nth_error base i = Some (two_p (sum_firstn limb_widths i)). Proof. induction i; intros. + unfold sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity. intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega. - + assert (i < length {base})%nat as lt_i_length by omega. + + assert (i < length base)%nat as lt_i_length by omega. specialize (IHi lt_i_length). rewrite base_from_limb_widths_length in lt_i_length. destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w]. @@ -80,8 +83,8 @@ Section Pow2BaseProofs. eapply nth_error_value_In; eauto. Qed. - Lemma nth_default_base : forall d i, (i < length {base})%nat -> - nth_default d {base} i = 2 ^ (sum_firstn limb_widths i). + Lemma nth_default_base : forall d i, (i < length base)%nat -> + nth_default d base i = 2 ^ (sum_firstn limb_widths i). Proof. intros ? ? i_lt_length. destruct (nth_error_length_exists_value _ _ i_lt_length) as [x nth_err_x]. @@ -92,8 +95,8 @@ Section Pow2BaseProofs. congruence. Qed. - Lemma base_succ : forall i, ((S i) < length {base})%nat -> - nth_default 0 {base} (S i) mod nth_default 0 {base} i = 0. + Lemma base_succ : forall i, ((S i) < length base)%nat -> + nth_default 0 base (S i) mod nth_default 0 base i = 0. Proof. intros. repeat rewrite nth_default_base by omega. @@ -105,7 +108,7 @@ Section Pow2BaseProofs. apply limb_widths_nonneg. rewrite lw_eq. apply in_eq. - + assert (i < length {base})%nat as i_lt_length by omega. + + assert (i < length base)%nat as i_lt_length by omega. rewrite base_from_limb_widths_length in *. apply nth_error_length_exists_value in i_lt_length. destruct i_lt_length as [x nth_err_x]. @@ -115,7 +118,7 @@ Section Pow2BaseProofs. omega. Qed. - Lemma nth_error_subst : forall i b, nth_error {base} i = Some b -> + Lemma nth_error_subst : forall i b, nth_error base i = Some b -> b = 2 ^ (sum_firstn limb_widths i). Proof. intros i b nth_err_b. @@ -125,7 +128,7 @@ Section Pow2BaseProofs. congruence. Qed. - Lemma base_positive : forall b : Z, In b {base} -> b > 0. + Lemma base_positive : forall b : Z, In b base -> b > 0. Proof. intros b In_b_base. apply In_nth_error_value in In_b_base. @@ -136,7 +139,7 @@ Section Pow2BaseProofs. apply Z.pow_pos_nonneg; omega || auto using sum_firstn_limb_widths_nonneg. Qed. - Lemma b0_1 : forall x : Z, limb_widths <> nil -> nth_default x {base} 0 = 1. + Lemma b0_1 : forall x : Z, limb_widths <> nil -> nth_default x base 0 = 1. Proof. case_eq limb_widths; intros; [congruence | reflexivity]. Qed. @@ -154,18 +157,18 @@ Section BitwiseDecodeEncode. (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w). Local Hint Resolve limb_widths_nonneg. Local Notation "w[ i ]" := (nth_default 0 limb_widths i). - Local Notation "{base}" := (base_from_limb_widths limb_widths). - Local Notation "{max}" := (upper_bound limb_widths). + Local Notation base := (base_from_limb_widths limb_widths). + Local Notation max := (upper_bound limb_widths). - Lemma encode'_spec : forall x i, (i <= length {base})%nat -> - encode' limb_widths x i = BaseSystem.encode' {base} x {max} i. + Lemma encode'_spec : forall x i, (i <= length base)%nat -> + encode' limb_widths x i = BaseSystem.encode' base x max i. Proof. induction i; intros. + rewrite encode'_zero. reflexivity. + rewrite encode'_succ, <-IHi by omega. simpl; do 2 f_equal. rewrite Z.land_ones, Z.shiftr_div_pow2 by auto using sum_firstn_limb_widths_nonneg. - match goal with H : (S _ <= length {base})%nat |- _ => + match goal with H : (S _ <= length base)%nat |- _ => apply le_lt_or_eq in H; destruct H end. - repeat f_equal; rewrite nth_default_base by (omega || auto); reflexivity. - repeat f_equal; try solve [rewrite nth_default_base by (omega || auto); reflexivity]. @@ -180,12 +183,12 @@ Section BitwiseDecodeEncode. intros; apply nth_default_preserves_properties; auto; omega. Qed. Hint Resolve nth_default_limb_widths_nonneg. - Lemma base_upper_bound_compatible : @base_max_succ_divide {base} {max}. + Lemma base_upper_bound_compatible : @base_max_succ_divide base max. Proof. unfold base_max_succ_divide; intros i lt_Si_length. rewrite Nat.lt_eq_cases in lt_Si_length; destruct lt_Si_length; rewrite !nth_default_base by (omega || auto). - + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); + + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); rewrite <-base_from_limb_widths_length by auto; omega). rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg. apply Z.divide_factor_r. @@ -195,18 +198,18 @@ Section BitwiseDecodeEncode. (rewrite base_from_limb_widths_length in H by auto; omega). replace i with (pred (length limb_widths)) by (rewrite base_from_limb_widths_length in H by auto; omega). - erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); rewrite <-base_from_limb_widths_length by auto; omega). rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg. - apply Z.divide_factor_r. + apply Z.divide_factor_r. Qed. Hint Resolve base_upper_bound_compatible. Lemma encodeZ_spec : forall x, - BaseSystem.decode {base} (encodeZ limb_widths x) = x mod {max}. + BaseSystem.decode base (encodeZ limb_widths x) = x mod max. Proof. intros. - assert (length {base} = length limb_widths) by auto using base_from_limb_widths_length. + assert (length base = length limb_widths) by auto using base_from_limb_widths_length. unfold encodeZ; rewrite encode'_spec by omega. rewrite BaseSystemProofs.encode'_spec; unfold upper_bound; try zero_bounds; auto using sum_firstn_limb_widths_nonneg. @@ -236,7 +239,7 @@ Section BitwiseDecodeEncode. Proof. intros. simpl; f_equal. - match goal with H : bounded _ _ |- _ => + match goal with H : bounded _ _ |- _ => rewrite Z.lor_shiftl by (auto; unfold bounded in H; specialize (H i); assumption) end. rewrite Z.shiftl_mul_pow2 by auto. ring. @@ -316,7 +319,7 @@ Section BitwiseDecodeEncode. Lemma decode_bitwise'_spec : forall us i, (i <= length limb_widths)%nat -> bounded limb_widths us -> length us = length limb_widths -> decode_bitwise' limb_widths us i (partial_decode us i (length us - i)) = - BaseSystem.decode {base} us. + BaseSystem.decode base us. Proof. induction i; intros. + rewrite partial_decode_intermediate by auto. @@ -328,7 +331,7 @@ Section BitwiseDecodeEncode. Lemma decode_bitwise_spec : forall us, bounded limb_widths us -> length us = length limb_widths -> - decode_bitwise limb_widths us = BaseSystem.decode {base} us. + decode_bitwise limb_widths us = BaseSystem.decode base us. Proof. unfold decode_bitwise; intros. replace 0 with (partial_decode us (length us) (length us - length us)) by @@ -361,7 +364,7 @@ Section UniformBase. Context {width : Z} (limb_width_pos : 0 < width). Context (limb_widths : list Z) (limb_widths_nonnil : limb_widths <> nil) (limb_widths_uniform : forall w, In w limb_widths -> w = width). - Local Notation "{base}" := (base_from_limb_widths limb_widths). + Local Notation base := (base_from_limb_widths limb_widths). Lemma bounded_uniform : forall us, (length us <= length limb_widths)%nat -> (bounded limb_widths us <-> (forall u, In u us -> 0 <= u < 2 ^ width)). @@ -409,7 +412,7 @@ Section UniformBase. Qed. Lemma decode_tl_base_shift : forall us, (length us < length limb_widths)%nat -> - BaseSystem.decode (tl {base}) us = BaseSystem.decode {base} us << width. + BaseSystem.decode (tl base) us = BaseSystem.decode base us << width. Proof. intros ? Hlength. edestruct (destruct_repeat limb_widths) as [? | [tl_lw [Heq_lw tl_lw_uniform]]]; @@ -422,7 +425,7 @@ Section UniformBase. Qed. Lemma decode_shift : forall us u0, (length (u0 :: us) <= length limb_widths)%nat -> - BaseSystem.decode {base} (u0 :: us) = u0 + ((BaseSystem.decode {base} us) << width). + BaseSystem.decode base (u0 :: us) = u0 + ((BaseSystem.decode base us) << width). Proof. intros. rewrite <-decode_tl_base_shift by (simpl in *; omega). @@ -439,4 +442,395 @@ Section UniformBase. replace w with width by (symmetry; auto). assumption. Qed. -End UniformBase. \ No newline at end of file +End UniformBase. + +Section carrying_helper. + Context {limb_widths} (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w). + Local Notation base := (base_from_limb_widths limb_widths). + Local Notation log_cap i := (nth_default 0 limb_widths i). + + Lemma update_nth_sum : forall n f us, (n < length us \/ n >= length base)%nat -> + BaseSystem.decode base (update_nth n f us) = + (let v := nth_default 0 us n in f v - v) * nth_default 0 base n + BaseSystem.decode base us. + Proof. + intros. + unfold BaseSystem.decode. + destruct H as [H|H]. + { nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*) + erewrite nth_error_value_eq_nth_default by eassumption. + unfold splice_nth. + rewrite <- (firstn_skipn n us) at 3. + do 2 rewrite decode'_splice. + remember (length (firstn n us)) as n0. + ring_simplify. + remember (BaseSystem.decode' (firstn n0 base) (firstn n us)). + rewrite (skipn_nth_default n us 0) by omega. + erewrite (nth_error_value_eq_nth_default _ _ us) by eassumption. + rewrite firstn_length in Heqn0. + rewrite Min.min_l in Heqn0 by omega; subst n0. + destruct (le_lt_dec (length base) n). { + rewrite (@nth_default_out_of_bounds _ _ base) by auto. + rewrite skipn_all by omega. + do 2 rewrite decode_base_nil. + ring_simplify; auto. + } { + rewrite (skipn_nth_default n base 0) by omega. + do 2 rewrite decode'_cons. + ring_simplify; ring. + } } + { rewrite (nth_default_out_of_bounds _ base) by omega; ring_simplify. + etransitivity; rewrite BaseSystem.decode'_truncate; [ reflexivity | ]. + apply f_equal. + autorewrite with push_firstn simpl_update_nth. + rewrite update_nth_out_of_bounds by (distr_length; omega * ). + reflexivity. } + Qed. + + Lemma unfold_add_to_nth n x + : forall xs, + add_to_nth n x xs + = match n with + | O => match xs with + | nil => nil + | x'::xs' => x + x'::xs' + end + | S n' => match xs with + | nil => nil + | x'::xs' => x'::add_to_nth n' x xs' + end + end. + Proof. + induction n; destruct xs; reflexivity. + Qed. + + Lemma simpl_add_to_nth_0 x + : forall xs, + add_to_nth 0 x xs + = match xs with + | nil => nil + | x'::xs' => x + x'::xs' + end. + Proof. intro; rewrite unfold_add_to_nth; reflexivity. Qed. + + Lemma simpl_add_to_nth_S x n + : forall xs, + add_to_nth (S n) x xs + = match xs with + | nil => nil + | x'::xs' => x'::add_to_nth n x xs' + end. + Proof. intro; rewrite unfold_add_to_nth; reflexivity. Qed. + + Hint Rewrite @simpl_set_nth_S @simpl_set_nth_0 : simpl_add_to_nth. + + Lemma add_to_nth_cons : forall x u0 us, add_to_nth 0 x (u0 :: us) = x + u0 :: us. + Proof. reflexivity. Qed. + + Hint Rewrite @add_to_nth_cons : simpl_add_to_nth. + + Lemma cons_add_to_nth : forall n f y us, + y :: add_to_nth n f us = add_to_nth (S n) f (y :: us). + Proof. + induction n; boring. + Qed. + + Hint Rewrite <- @cons_add_to_nth : simpl_add_to_nth. + + Lemma add_to_nth_nil : forall n f, add_to_nth n f nil = nil. + Proof. + induction n; boring. + Qed. + + Hint Rewrite @add_to_nth_nil : simpl_add_to_nth. + + Lemma add_to_nth_set_nth n x xs + : add_to_nth n x xs + = set_nth n (x + nth_default 0 xs n) xs. + Proof. + revert xs; induction n; destruct xs; + autorewrite with simpl_set_nth simpl_add_to_nth; + try rewrite IHn; + reflexivity. + Qed. + Lemma add_to_nth_update_nth n x xs + : add_to_nth n x xs + = update_nth n (fun y => x + y) xs. + Proof. + revert xs; induction n; destruct xs; + autorewrite with simpl_update_nth simpl_add_to_nth; + try rewrite IHn; + reflexivity. + Qed. + + Lemma length_add_to_nth i x xs : length (add_to_nth i x xs) = length xs. + Proof. unfold add_to_nth; distr_length; reflexivity. Qed. + + Hint Rewrite @length_add_to_nth : distr_length. + + Lemma set_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat -> + BaseSystem.decode base (set_nth n x us) = + (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us. + Proof. intros; unfold set_nth; rewrite update_nth_sum by assumption; reflexivity. Qed. + + Lemma add_to_nth_sum : forall n x us, (n < length us \/ n >= length base)%nat -> + BaseSystem.decode base (add_to_nth n x us) = + x * nth_default 0 base n + BaseSystem.decode base us. + Proof. unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto. Qed. + + Lemma add_to_nth_nth_default_full : forall n x l i d, + nth_default d (add_to_nth n x l) i = + if lt_dec i (length l) then + if (eq_nat_dec i n) then x + nth_default d l i + else nth_default d l i + else d. + Proof. intros; rewrite add_to_nth_update_nth; apply update_nth_nth_default_full; assumption. Qed. + Hint Rewrite @add_to_nth_nth_default_full : push_nth_default. + + Lemma add_to_nth_nth_default : forall n x l i, (0 <= i < length l)%nat -> + nth_default 0 (add_to_nth n x l) i = + if (eq_nat_dec i n) then x + nth_default 0 l i else nth_default 0 l i. + Proof. intros; rewrite add_to_nth_update_nth; apply update_nth_nth_default; assumption. Qed. + Hint Rewrite @add_to_nth_nth_default using omega : push_nth_default. + + Lemma log_cap_nonneg : forall i, 0 <= log_cap i. + Proof. + unfold nth_default; intros. + case_eq (nth_error limb_widths i); intros; try omega. + apply limb_widths_nonneg. + eapply nth_error_value_In; eauto. + Qed. Local Hint Resolve log_cap_nonneg. +End carrying_helper. + +Hint Rewrite @simpl_set_nth_S @simpl_set_nth_0 : simpl_add_to_nth. +Hint Rewrite @add_to_nth_cons : simpl_add_to_nth. +Hint Rewrite <- @cons_add_to_nth : simpl_add_to_nth. +Hint Rewrite @add_to_nth_nil : simpl_add_to_nth. +Hint Rewrite @length_add_to_nth : distr_length. +Hint Rewrite @add_to_nth_nth_default_full : push_nth_default. +Hint Rewrite @add_to_nth_nth_default using (omega || distr_length; omega) : push_nth_default. + +Section carrying. + Context {limb_widths} (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w). + Local Notation base := (base_from_limb_widths limb_widths). + Local Notation log_cap i := (nth_default 0 limb_widths i). + Local Hint Resolve limb_widths_nonneg sum_firstn_limb_widths_nonneg. + + (* + Lemma length_carry_gen : forall f i us, length (carry_gen limb_widths f i us) = length us. + Proof. intros; unfold carry_gen, carry_and_reduce_single; distr_length; reflexivity. Qed. + + Hint Rewrite @length_carry_gen : distr_length. + *) + + Lemma length_carry_simple : forall i us, length (carry_simple limb_widths i us) = length us. + Proof. intros; unfold carry_simple; distr_length; reflexivity. Qed. + Hint Rewrite @length_carry_simple : distr_length. + + Lemma nth_default_base_succ : forall i, (S i < length base)%nat -> + nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i. + Proof. + intros. + rewrite !nth_default_base, <- Z.pow_add_r by (omega || eauto using log_cap_nonneg). + autorewrite with simpl_sum_firstn; reflexivity. + Qed. + + (* + Lemma carry_gen_decode_eq : forall f i' us (i := (i' mod length base)%nat), + (length us = length base) -> + BaseSystem.decode base (carry_gen limb_widths f i' us) + = ((f (nth_default 0 us i / 2 ^ log_cap i)) + * (if eq_nat_dec (S i mod length base) 0 + then nth_default 0 base 0 + else (2 ^ log_cap i) * (nth_default 0 base i)) + - (nth_default 0 us i / 2 ^ log_cap i) * 2 ^ log_cap i * nth_default 0 base i + ) + + BaseSystem.decode base us. + Proof. + intros f i' us i H; intros. + destruct (eq_nat_dec 0 (length base)); + [ destruct limb_widths, us, i; simpl in *; try congruence; + unfold carry_gen, carry_and_reduce_single, add_to_nth; + autorewrite with zsimplify simpl_nth_default simpl_set_nth simpl_update_nth distr_length; + reflexivity + | ]. + assert (0 <= i < length base)%nat by (subst i; auto with arith). + assert (0 <= log_cap i) by auto using log_cap_nonneg. + assert (2 ^ log_cap i <> 0) by (apply Z.pow_nonzero; lia). + unfold carry_gen, carry_and_reduce_single. + rewrite H; change (i' mod length base)%nat with i. + rewrite add_to_nth_sum by (rewrite length_set_nth; omega). + rewrite set_nth_sum by omega. + unfold Z.pow2_mod. + rewrite Z.land_ones by auto using log_cap_nonneg. + rewrite Z.shiftr_div_pow2 by auto using log_cap_nonneg. + destruct (eq_nat_dec (S i mod length base) 0); + repeat first [ ring + | congruence + | match goal with H : _ = _ |- _ => rewrite !H in * end + | rewrite nth_default_base_succ by omega + | rewrite !(nth_default_out_of_bounds _ base) by omega + | rewrite !(nth_default_out_of_bounds _ us) by omega + | rewrite Z.mod_eq by assumption + | progress distr_length + | progress autorewrite with natsimplify zsimplify in * + | progress break_match ]. + Qed. + + Lemma carry_simple_decode_eq : forall i us, + (length us = length base) -> + (i < (pred (length base)))%nat -> + BaseSystem.decode base (carry_simple limb_widths i us) = BaseSystem.decode base us. + Proof. + unfold carry_simple; intros; rewrite carry_gen_decode_eq by assumption. + autorewrite with natsimplify. + break_match; lia. + Qed. +*) + + Lemma carry_simple_decode_eq : forall i us, + (length us = length base) -> + (i < (pred (length base)))%nat -> + BaseSystem.decode base (carry_simple limb_widths i us) = BaseSystem.decode base us. + Proof. + unfold carry_simple. intros. + rewrite add_to_nth_sum by (rewrite length_set_nth; omega). + rewrite set_nth_sum by omega. + unfold Z.pow2_mod. + rewrite Z.land_ones by eauto using log_cap_nonneg. + rewrite Z.shiftr_div_pow2 by eauto using log_cap_nonneg. + rewrite nth_default_base_succ by omega. + rewrite Z.mul_assoc. + rewrite (Z.mul_comm _ (2 ^ log_cap i)). + rewrite Z.mul_div_eq; try ring. + apply Z.gt_lt_iff. + apply Z.pow_pos_nonneg; omega || eauto using log_cap_nonneg. + Qed. + + Lemma length_carry_simple_sequence : forall is us, length (carry_simple_sequence limb_widths is us) = length us. + Proof. + unfold carry_simple_sequence. + induction is; [ reflexivity | simpl; intros ]. + distr_length. + congruence. + Qed. + Hint Rewrite @length_carry_simple_sequence : distr_length. + + Lemma length_make_chain : forall i, length (make_chain i) = i. + Proof. induction i; simpl; congruence. Qed. + Hint Rewrite @length_make_chain : distr_length. + + Lemma length_full_carry_chain : length (full_carry_chain limb_widths) = length limb_widths. + Proof. unfold full_carry_chain; distr_length; reflexivity. Qed. + Hint Rewrite @length_full_carry_chain : distr_length. + + Lemma length_carry_simple_full us : length (carry_simple_full limb_widths us) = length us. + Proof. unfold carry_simple_full; distr_length; reflexivity. Qed. + Hint Rewrite @length_carry_simple_full : distr_length. + + (* TODO : move? *) + Lemma make_chain_lt : forall x i : nat, In i (make_chain x) -> (i < x)%nat. + Proof. + induction x; simpl; intuition. + Qed. +(* + Lemma nth_default_carry_gen_full : forall f d i n us, + nth_default d (carry_gen limb_widths f i us) n + = if lt_dec n (length us) + then if eq_nat_dec n (i mod length us) + then (if eq_nat_dec (S n) (length us) + then (if eq_nat_dec n 0 + then f ((nth_default 0 us n) >> log_cap n) + else 0) + else 0) + + Z.pow2_mod (nth_default 0 us n) (log_cap n) + else (if eq_nat_dec n (if eq_nat_dec (S (i mod length us)) (length us) then 0%nat else S (i mod length us)) + then f (nth_default 0 us (i mod length us) >> log_cap (i mod length us)) + else 0) + + nth_default d us n + else d. + Proof. + unfold carry_gen, carry_and_reduce_single. + intros; autorewrite with push_nth_default natsimplify distr_length. + edestruct lt_dec; [ | reflexivity ]. + change (S ?x) with (1 + x)%nat. + rewrite (Nat.add_mod_idemp_r 1 i (length us)) by omega. + autorewrite with natsimplify. + change (1 + ?x)%nat with (S x). + destruct (eq_nat_dec n (i mod length us)); + subst; repeat break_match; omega. + Qed. + + Hint Rewrite @nth_default_carry_gen_full : push_nth_default. + + Lemma nth_default_carry_simple_full : forall d i n us, + nth_default d (carry_simple limb_widths i us) n + = if lt_dec n (length us) + then if eq_nat_dec n (i mod length us) + then (if eq_nat_dec (S n) (length us) + then (if eq_nat_dec n 0 + then (nth_default 0 us n >> log_cap n + Z.pow2_mod (nth_default 0 us n) (log_cap n)) + (* FIXME: The above is just [nth_default 0 us n], but do we really care about the case of [n = 0], [length us = 1]? *) + else Z.pow2_mod (nth_default 0 us n) (log_cap n)) + else Z.pow2_mod (nth_default 0 us n) (log_cap n)) + else (if eq_nat_dec n (if eq_nat_dec (S (i mod length us)) (length us) then 0%nat else S (i mod length us)) + then nth_default 0 us (i mod length us) >> log_cap (i mod length us) + else 0) + + nth_default d us n + else d. + Proof. + intros; unfold carry_simple; autorewrite with push_nth_default; + repeat break_match; reflexivity. + Qed. + + Hint Rewrite @nth_default_carry_simple_full : push_nth_default. + + Lemma nth_default_carry_gen + : forall f i us, + (0 <= i < length us)%nat + -> nth_default 0 (carry_gen limb_widths f i us) i + = (if PeanoNat.Nat.eq_dec i (if PeanoNat.Nat.eq_dec (S i) (length us) then 0%nat else S i) + then f (nth_default 0 us i >> log_cap i) + Z.pow2_mod (nth_default 0 us i) (log_cap i) + else Z.pow2_mod (nth_default 0 us i) (log_cap i)). + Proof. + unfold carry_gen, carry_and_reduce_single. + intros; autorewrite with push_nth_default natsimplify; reflexivity. + Qed. + Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default. + + Lemma nth_default_carry_simple + : forall f i us, + (0 <= i < length us)%nat + -> nth_default 0 (carry_gen limb_widths f i us) i + = (if PeanoNat.Nat.eq_dec i (if PeanoNat.Nat.eq_dec (S i) (length us) then 0%nat else S i) + then f (nth_default 0 us i >> log_cap i) + Z.pow2_mod (nth_default 0 us i) (log_cap i) + else Z.pow2_mod (nth_default 0 us i) (log_cap i)). + Proof. + unfold carry_gen, carry_and_reduce_single. + intros; autorewrite with push_nth_default natsimplify; reflexivity. + Qed. + Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default. + + + Lemma nth_default_carry_gen + : forall f i us, + (0 <= i < length us)%nat + -> nth_default 0 (carry_gen limb_widths f i us) i + = (if PeanoNat.Nat.eq_dec i (if PeanoNat.Nat.eq_dec (S i) (length us) then 0%nat else S i) + then f (nth_default 0 us i >> log_cap i) + Z.pow2_mod (nth_default 0 us i) (log_cap i) + else Z.pow2_mod (nth_default 0 us i) (log_cap i)). + Proof. + unfold carry_gen, carry_and_reduce_single. + intros; autorewrite with push_nth_default natsimplify; reflexivity. + Qed. + Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default. +*) +End carrying. + +(* +Hint Rewrite @length_carry_gen : distr_length. +*) +Hint Rewrite @length_carry_simple @length_carry_simple_sequence @length_make_chain @length_full_carry_chain @length_carry_simple_full : distr_length. +(* +Hint Rewrite @nth_default_carry_gen_full : push_nth_default. +Hint Rewrite @nth_default_carry_gen using (omega || distr_length; omega) : push_nth_default. +*) -- cgit v1.2.3