diff options
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystemProofs.v')
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemProofs.v | 1580 |
1 files changed, 1245 insertions, 335 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index 274acff5a..0462b0f37 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -1,7 +1,7 @@ Require Import Zpower ZArith. Require Import Coq.Numbers.Natural.Peano.NPeano. Require Import List. -Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. +Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil Crypto.Util.NatUtil. Require Import VerdiTactics. Require Crypto.BaseSystem. Require Import Crypto.ModularArithmetic.ModularBaseSystem Crypto.ModularArithmetic.PrimeFieldTheorems. @@ -22,14 +22,21 @@ Section PseudoMersenneProofs. autounfold; intuition. Qed. + Lemma rep_length : forall us x, us ~= x -> length us = length base. + Proof. + autounfold; intuition. + Qed. + Lemma encode_rep : forall x : F modulus, encode x ~= x. Proof. intros. unfold encode, rep. split. { unfold encode; simpl. - apply base_length_nonzero. + rewrite length_zeros. + pose proof base_length_nonzero; omega. } { unfold decode. + rewrite decode_highzeros. rewrite encode_rep. apply ZToField_FieldToZ. apply bv. @@ -40,8 +47,7 @@ Section PseudoMersenneProofs. Proof. autounfold; intuition. { unfold add. - rewrite add_length_le_max. - case_max; try rewrite Max.max_r; omega. + auto using add_same_length. } unfold decode in *; unfold decode in *. rewrite add_rep. @@ -49,15 +55,14 @@ Section PseudoMersenneProofs. subst; auto. Qed. - Lemma sub_rep : forall c c_0modq, (length c <= length base)%nat -> - forall u v x y, u ~= x -> v ~= y -> + Lemma sub_rep : forall c c_0modq, (length c = length base)%nat -> + forall u v x y, u ~= x -> v ~= y -> ModularBaseSystem.sub c c_0modq u v ~= (x-y)%F. Proof. autounfold; unfold ModularBaseSystem.sub; intuition. { - rewrite sub_length_le_max. + rewrite sub_length. case_max; try rewrite Max.max_r; try omega. - rewrite add_length_le_max. - case_max; try rewrite Max.max_r; omega. + auto using add_same_length. } unfold decode in *; unfold BaseSystem.decode in *. rewrite BaseSystemProofs.sub_rep, BaseSystemProofs.add_rep. @@ -66,7 +71,7 @@ Section PseudoMersenneProofs. subst; auto. Qed. - Lemma decode_short : forall (us : BaseSystem.digits), + Lemma decode_short : forall (us : BaseSystem.digits), (length us <= length base)%nat -> BaseSystem.decode base us = BaseSystem.decode ext_base us. Proof. @@ -80,11 +85,11 @@ Section PseudoMersenneProofs. Qed. Lemma mul_rep_extended : forall (us vs : BaseSystem.digits), - (length us <= length base)%nat -> + (length us <= length base)%nat -> (length vs <= length base)%nat -> (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode ext_base (BaseSystem.mul ext_base us vs). Proof. - intros. + intros. rewrite mul_rep by (apply ExtBaseVector || unfold ext_base; simpl_list; omega). f_equal; rewrite decode_short; auto. Qed. @@ -93,7 +98,7 @@ Section PseudoMersenneProofs. pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega. Qed. - (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *) + (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *) Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus. Proof. intros. @@ -137,34 +142,16 @@ Section PseudoMersenneProofs. rewrite mul_each_rep; auto. Qed. - Lemma reduce_length : forall us, - (length us <= length ext_base)%nat -> - (length (reduce us) <= length base)%nat. + Lemma reduce_length : forall us, + (length base <= length us <= length ext_base)%nat -> + (length (reduce us) = length base)%nat. Proof. - intros. - unfold reduce. - remember (map (Z.mul c) (skipn (length base) us)) as high. - remember (firstn (length base) us) as low. - assert (length low >= length high)%nat. { - subst. rewrite firstn_length. - rewrite map_length. - rewrite skipn_length. - destruct (le_dec (length base) (length us)). { - rewrite Min.min_l by omega. - rewrite extended_base_length in H. omega. - } { - rewrite Min.min_r; omega. - } - } - assert ((length low <= length base)%nat) - by (rewrite Heqlow; rewrite firstn_length; apply Min.le_min_l). - assert (length high <= length base)%nat - by (rewrite Heqhigh; rewrite map_length; rewrite skipn_length; - rewrite extended_base_length in H; omega). - rewrite add_trailing_zeros; auto. - rewrite (add_same_length _ _ (length low)); auto. - rewrite app_length. - rewrite length_zeros; intuition. + rewrite extended_base_length. + unfold reduce; intros. + rewrite add_length_exact. + rewrite map_length, firstn_length, skipn_length. + rewrite Min.min_l by omega. + apply Max.max_l; omega. Qed. Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> u .* v ~= (x*y)%F. @@ -172,20 +159,22 @@ Section PseudoMersenneProofs. autounfold; unfold ModularBaseSystem.mul; intuition. { apply reduce_length. - rewrite mul_length, extended_base_length. - omega. + rewrite mul_length_exact, extended_base_length; try omega. + destruct u; try congruence. + rewrite @nil_length0 in *. + pose proof base_length_nonzero; omega. } { rewrite ZToField_mod, reduce_rep, <-ZToField_mod. rewrite mul_rep by (apply ExtBaseVector || rewrite extended_base_length; omega). subst. - do 2 rewrite decode_short by auto. + do 2 rewrite decode_short by omega. apply ZToField_mul. } Qed. Lemma set_nth_sum : forall n x us, (n < length us)%nat -> - BaseSystem.decode base (set_nth n x us) = + BaseSystem.decode base (set_nth n x us) = (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us. Proof. intros. @@ -213,12 +202,27 @@ Section PseudoMersenneProofs. Qed. Lemma add_to_nth_sum : forall n x us, (n < length us)%nat -> - BaseSystem.decode base (add_to_nth n x us) = + BaseSystem.decode base (add_to_nth n x us) = x * nth_default 0 base n + BaseSystem.decode base us. Proof. unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto. Qed. + Lemma add_to_nth_nth_default : forall n x l i, (0 <= i < length l)%nat -> + nth_default 0 (add_to_nth n x l) i = + if (eq_nat_dec i n) then x + nth_default 0 l i else nth_default 0 l i. + Proof. + intros. + unfold add_to_nth. + rewrite set_nth_nth_default by assumption. + break_if; subst; reflexivity. + Qed. + + Lemma length_add_to_nth : forall n x l, length (add_to_nth n x l) = length l. + Proof. + unfold add_to_nth; intros; apply length_set_nth. + Qed. + Lemma nth_default_base_positive : forall i, (i < length base)%nat -> nth_default 0 base i > 0. Proof. @@ -240,13 +244,21 @@ Section PseudoMersenneProofs. apply base_succ; auto. Qed. + Lemma Fdecode_decode_mod : forall us x, (length us = length base) -> + decode us = x -> BaseSystem.decode base us mod modulus = x. + Proof. + unfold decode; intros ? ? ? decode_us. + rewrite <-decode_us. + apply FieldToZ_ZToField. + Qed. + End PseudoMersenneProofs. Section CarryProofs. Context `{prm : PseudoMersenneBaseParams}. Local Notation "u '~=' x" := (rep u x) (at level 70). Hint Unfold log_cap. - + Lemma base_length_lt_pred : (pred (length base) < length base)%nat. Proof. pose proof base_length_nonzero; omega. @@ -260,7 +272,7 @@ Section CarryProofs. apply limb_widths_nonneg. eapply nth_error_value_In; eauto. Qed. - + Lemma nth_default_base_succ : forall i, (S i < length base)%nat -> nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i. Proof. @@ -342,8 +354,8 @@ Section CarryProofs. Qed. Lemma carry_length : forall i us, - (length us <= length base)%nat -> - (length (carry i us) <= length base)%nat. + (length us = length base)%nat -> + (length (carry i us) = length base)%nat. Proof. unfold carry, carry_simple, carry_and_reduce, add_to_nth. intros; break_if; subst; repeat (rewrite length_set_nth); auto. @@ -356,36 +368,19 @@ Section CarryProofs. us ~= x -> carry i us ~= x. Proof. pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq. - unfold rep, decode, carry in *; intros. - intuition; break_if; subst; eauto; - apply F_eq; simpl; intuition. + intros; split; auto. + unfold rep, decode, carry in *. + intuition; break_if; subst; eauto; apply F_eq; simpl; intuition. Qed. Hint Resolve carry_rep. Lemma carry_sequence_length: forall is us, - (length us <= length base)%nat -> - (length (carry_sequence is us) <= length base)%nat. - Proof. - induction is; boring. - Qed. - Hint Resolve carry_sequence_length. - - Lemma carry_length_exact : forall i us, - (length us = length base)%nat -> - (length (carry i us) = length base)%nat. - Proof. - unfold carry, carry_simple, carry_and_reduce, add_to_nth. - intros; break_if; subst; repeat (rewrite length_set_nth); auto. - Qed. - - Lemma carry_sequence_length_exact: forall is us, (length us = length base)%nat -> (length (carry_sequence is us) = length base)%nat. Proof. induction is; boring. - apply carry_length_exact; auto. Qed. - Hint Resolve carry_sequence_length_exact. + Hint Resolve carry_sequence_length. Lemma carry_sequence_rep : forall is us x, (forall i, In i is -> (i < length base)%nat) -> @@ -395,46 +390,45 @@ Section CarryProofs. induction is; boring. Qed. -End CarryProofs. -Section CanonicalizationProofs. - Context `{prm : PseudoMersenneBaseParams} (lt_1_length_base : (1 < length base)%nat) (c_pos : 0 < c) {B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B). - - (* TODO : move *) - Lemma set_nth_nth_default : forall {A} (d:A) n x l i, (0 <= i < length l)%nat -> - nth_default d (set_nth n x l) i = - if (eq_nat_dec i n) then x else nth_default d l i. + (* TODO : move? *) + Lemma make_chain_lt : forall x i : nat, In i (make_chain x) -> (i < x)%nat. Proof. - induction n; (destruct l; [intros; simpl in *; omega | ]); simpl; - destruct i; break_if; try omega; intros; try apply nth_default_cons; - rewrite !nth_default_cons_S, ?IHn; try break_if; omega || reflexivity. + induction x; simpl; intuition. Qed. - (* TODO : move *) - Lemma add_to_nth_nth_default : forall n x l i, (0 <= i < length l)%nat -> - nth_default 0 (add_to_nth n x l) i = - if (eq_nat_dec i n) then x + nth_default 0 l i else nth_default 0 l i. + Lemma carry_full_preserves_rep : forall us x, + rep us x -> rep (carry_full us) x. Proof. - intros. - unfold add_to_nth. - rewrite set_nth_nth_default by assumption. - break_if; subst; reflexivity. + unfold carry_full; intros. + apply carry_sequence_rep; auto. + unfold full_carry_chain; rewrite base_length; apply make_chain_lt. + eauto using rep_length. Qed. - (* TODO : move *) - Lemma length_add_to_nth : forall n x l, length (add_to_nth n x l) = length l. - Proof. - unfold add_to_nth; intros; apply length_set_nth. - Qed. + Opaque carry_full. - (* TODO : move *) - Lemma singleton_list : forall {A} (l : list A), length l = 1%nat -> exists x, l = x :: nil. + Lemma carry_mul_rep : forall us vs x y, rep us x -> rep vs y -> + rep (carry_mul us vs) (x * y)%F. Proof. - intros; destruct l; simpl in *; try congruence. - eexists; f_equal. - apply length0_nil; omega. + unfold carry_mul; intros; apply carry_full_preserves_rep. + auto using mul_rep. Qed. +End CarryProofs. + +Section CanonicalizationProofs. + Context `{prm : PseudoMersenneBaseParams} (lt_1_length_base : (1 < length base)%nat) + {B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B) + (c_pos : 0 < c) + (* on the first reduce step, we add at most one bit of width to the first digit *) + (c_reduce1 : c * (Z.ones (B - log_cap (pred (length base)))) < max_bound 0 + 1) + (* on the second reduce step, we add at most one bit of width to the first digit, + and leave room to carry c one more time after the highest bit is carried *) + (c_reduce2 : c <= max_bound 0 - c) + (* this condition is probably implied by c_reduce2, but is more straighforward to compute than to prove *) + (two_pow_k_le_2modulus : 2 ^ k <= 2 * modulus). + (* BEGIN groundwork proofs *) Lemma pow_2_log_cap_pos : forall i, 0 < 2 ^ log_cap i. @@ -451,7 +445,7 @@ Section CanonicalizationProofs. omega. Qed. - Hint Resolve log_cap_nonneg. + Local Hint Resolve log_cap_nonneg. Lemma pow2_mod_log_cap_range : forall a i, 0 <= pow2_mod a (log_cap i) <= max_bound i. Proof. intros. @@ -488,6 +482,16 @@ Section CanonicalizationProofs. omega. Qed. + Lemma max_bound_pos : forall i, (i < length base)%nat -> 0 < max_bound i. + Proof. + unfold max_bound, log_cap; intros; apply Z_ones_pos_pos. + apply limb_widths_pos. + rewrite nth_default_eq. + apply nth_In. + rewrite <-base_length; assumption. + Qed. + Local Hint Resolve max_bound_pos. + Lemma max_bound_nonneg : forall i, 0 <= max_bound i. Proof. unfold max_bound; intros; auto using Z_ones_nonneg. @@ -501,15 +505,6 @@ Section CanonicalizationProofs. rewrite Z.land_ones; auto. Qed. - Lemma pow2_mod_upper_bound : forall a b, (0 <= a) -> (0 <= b) -> pow2_mod a b <= a. - Proof. - intros. - unfold pow2_mod. - rewrite Z.land_ones; auto. - apply Z.mod_le; auto. - apply Z.pow_pos_nonneg; omega. - Qed. - Lemma shiftr_eq_0_max_bound : forall i a, Z.shiftr a (log_cap i) = 0 -> a <= max_bound i. Proof. @@ -550,26 +545,26 @@ Section CanonicalizationProofs. omega. Qed. + Lemma log_cap_eq : forall i, log_cap i = nth_default 0 limb_widths i. + Proof. + reflexivity. + Qed. + (* END groundwork proofs *) Opaque pow2_mod log_cap max_bound. (* automation *) Ltac carry_length_conditions' := unfold carry_full, add_to_nth; - rewrite ?length_set_nth, ?carry_length_exact, ?carry_sequence_length_exact, ?carry_sequence_length_exact; + rewrite ?length_set_nth, ?carry_length, ?carry_sequence_length; try omega; try solve [pose proof base_length; pose proof base_length_nonzero; omega || auto ]. Ltac carry_length_conditions := try split; try omega; repeat carry_length_conditions'. - Ltac add_set_nth := rewrite ?add_to_nth_nth_default; try solve [carry_length_conditions]; - try break_if; try omega; rewrite ?set_nth_nth_default; try solve [carry_length_conditions]; - try break_if; try omega. + Ltac add_set_nth := + rewrite ?add_to_nth_nth_default by carry_length_conditions; break_if; try omega; + rewrite ?set_nth_nth_default by carry_length_conditions; break_if; try omega. (* BEGIN defs *) - Definition c_carry_constraint : Prop := - (c * (Z.ones (B - log_cap (pred (length base)))) < max_bound 0 + 1) - /\ (max_bound 0 + c < 2 ^ (log_cap 0 + 1)) - /\ (c <= max_bound 0 - c). - Definition pre_carry_bounds us := forall i, 0 <= nth_default 0 us i < if (eq_nat_dec i 0) then 2 ^ B else 2 ^ B - 2 ^ (B - log_cap (pred i)). @@ -581,26 +576,10 @@ Section CanonicalizationProofs. specialize (PCB i). omega. Qed. - Hint Resolve pre_carry_bounds_nonzero. + Local Hint Resolve pre_carry_bounds_nonzero. - Definition carry_done us := forall i, (i < length base)%nat -> Z.shiftr (nth_default 0 us i) (log_cap i) = 0. - - Lemma carry_carry_done_done : forall i us, - (length us = length base)%nat -> - (i < length base)%nat -> - (forall i, 0 <= nth_default 0 us i) -> - carry_done us -> carry_done (carry i us). - Proof. - unfold carry_done; intros until 3. intros Hcarry_done ? ?. - unfold carry, carry_simple, carry_and_reduce; break_if; subst. - + rewrite Hcarry_done by omega. - rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_bound). - destruct i0; add_set_nth; rewrite ?Z.mul_0_r, ?Z.add_0_l; auto. - match goal with H : S _ = pred (length base) |- _ => rewrite H; auto end. - + rewrite Hcarry_done by omega. - rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_bound). - destruct i0; add_set_nth; subst; rewrite ?Z.add_0_l; auto. - Qed. + Definition carry_done us := forall i, (i < length base)%nat -> + 0 <= nth_default 0 us i /\ Z.shiftr (nth_default 0 us i) (log_cap i) = 0. (* END defs *) @@ -620,6 +599,7 @@ Section CanonicalizationProofs. apply pow2_mod_log_cap_bounds_upper. - rewrite nth_default_out_of_bounds by carry_length_conditions; auto. Qed. + Local Hint Resolve nth_default_carry_bound_upper. Lemma nth_default_carry_bound_lower : forall i us, (length us = length base) -> 0 <= nth_default 0 (carry i us) i. @@ -635,6 +615,7 @@ Section CanonicalizationProofs. apply pow2_mod_log_cap_bounds_lower. - rewrite nth_default_out_of_bounds by carry_length_conditions; omega. Qed. + Local Hint Resolve nth_default_carry_bound_lower. Lemma nth_default_carry_bound_succ_lower : forall i us, (forall i, 0 <= nth_default 0 us i) -> (length us = length base) -> @@ -645,18 +626,15 @@ Section CanonicalizationProofs. + subst. replace (S (pred (length base))) with (length base) by omega. rewrite nth_default_out_of_bounds; carry_length_conditions. unfold carry_and_reduce. - add_set_nth. + carry_length_conditions. + unfold carry_simple. destruct (lt_dec (S i) (length us)). - - add_set_nth. - apply Z.add_nonneg_nonneg; [ apply Z.shiftr_nonneg | ]; unfold pre_carry_bounds in PCB. - * specialize (PCB i). omega. - * specialize (PCB (S i)). omega. + - add_set_nth; zero_bounds. - rewrite nth_default_out_of_bounds by carry_length_conditions; omega. Qed. Lemma carry_unaffected_low : forall i j us, ((0 < i < j)%nat \/ (i = 0 /\ j <> 0 /\ j <> pred (length base))%nat)-> - (length us = length base) -> + (length us = length base) -> nth_default 0 (carry j us) i = nth_default 0 us i. Proof. intros. @@ -671,7 +649,7 @@ Section CanonicalizationProofs. (omega || rewrite length_add_to_nth; rewrite length_set_nth; pose proof base_length_nonzero; omega). reflexivity. Qed. - + Lemma carry_unaffected_high : forall i j us, (S j < i)%nat -> (length us = length base) -> nth_default 0 (carry j us) i = nth_default 0 us i. Proof. @@ -679,7 +657,7 @@ Section CanonicalizationProofs. destruct (lt_dec i (length us)); [ | rewrite !nth_default_out_of_bounds by carry_length_conditions; reflexivity]. unfold carry, carry_simple. - break_if; add_set_nth. + break_if; [omega | add_set_nth]. Qed. Lemma carry_nothing : forall i j us, (i < length base)%nat -> @@ -688,23 +666,65 @@ Section CanonicalizationProofs. nth_default 0 (carry j us) i = nth_default 0 us i. Proof. unfold carry, carry_simple, carry_and_reduce; intros. - break_if; (add_set_nth; + break_if; (add_set_nth; [ rewrite max_bound_shiftr_eq_0 by omega; ring | subst; apply pow2_mod_log_cap_small; assumption ]). Qed. + Lemma carry_done_bounds : forall us, (length us = length base) -> + (carry_done us <-> forall i, 0 <= nth_default 0 us i < 2 ^ log_cap i). + Proof. + intros ? ?; unfold carry_done; split; [ intros Hcarry_done i | intros Hbounds i i_lt ]. + + destruct (lt_dec i (length base)) as [i_lt | i_nlt]. + - specialize (Hcarry_done i i_lt). + split; [ intuition | ]. + rewrite <- max_bound_log_cap. + apply Z.lt_succ_r. + apply shiftr_eq_0_max_bound; intuition. + - rewrite nth_default_out_of_bounds; try split; try omega; auto. + + specialize (Hbounds i). + split; intuition. + apply max_bound_shiftr_eq_0; auto. + rewrite <-max_bound_log_cap in *; omega. + Qed. + + Lemma carry_carry_done_done : forall i us, + (length us = length base)%nat -> + (i < length base)%nat -> + carry_done us -> carry_done (carry i us). + Proof. + unfold carry_done; intros i ? ? i_bound Hcarry_done x x_bound. + destruct (Hcarry_done x x_bound) as [lower_bound_x shiftr_0_x]. + destruct (Hcarry_done i i_bound) as [lower_bound_i shiftr_0_i]. + split. + + rewrite carry_nothing; auto. + split; [ apply Hcarry_done; auto | ]. + apply shiftr_eq_0_max_bound. + apply Hcarry_done; auto. + + unfold carry, carry_simple, carry_and_reduce; break_if; subst. + - add_set_nth; subst. + * rewrite shiftr_0_i, Z.mul_0_r, Z.add_0_l. + assumption. + * rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_bound). + assumption. + - rewrite shiftr_0_i by omega. + rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_bound). + add_set_nth; subst; rewrite ?Z.add_0_l; auto. + Qed. + + Lemma carry_sequence_chain_step : forall i us, + carry_sequence (make_chain (S i)) us = carry i (carry_sequence (make_chain i) us). + Proof. + reflexivity. + Qed. + Lemma carry_bounds_0_upper : forall us j, (length us = length base) -> (0 < j < length base)%nat -> nth_default 0 (carry_sequence (make_chain j) us) 0 <= max_bound 0. Proof. - unfold carry_sequence; induction j; [simpl; intros; omega | ]. - intros. - simpl in *. - destruct (eq_nat_dec 0 j). - + subst. - apply nth_default_carry_bound_upper; fold (carry_sequence (make_chain 0) us); carry_length_conditions. - + rewrite carry_unaffected_low; try omega. - fold (carry_sequence (make_chain j) us); carry_length_conditions. + induction j as [ | [ | j ] IHj ]; [simpl; intros; omega | | ]; intros. + + subst; simpl; auto. + + rewrite carry_sequence_chain_step, carry_unaffected_low; carry_length_conditions. Qed. Lemma carry_bounds_upper : forall i us j, (0 < i < j)%nat -> (length us = length base) -> @@ -721,7 +741,7 @@ Section CanonicalizationProofs. fold (carry_sequence (make_chain j) us); carry_length_conditions. Qed. - Lemma carry_sequence_unaffected : forall i us j, (j < i)%nat -> (length us = length base)%nat -> + Lemma carry_sequence_unaffected : forall i us j, (j < i)%nat -> (length us = length base)%nat -> nth_default 0 (carry_sequence (make_chain j) us) i = nth_default 0 us i. Proof. induction j; [simpl; intros; omega | ]. @@ -731,33 +751,41 @@ Section CanonicalizationProofs. apply IHj; omega. Qed. + (* makes omega run faster *) + Ltac clear_obvious := + match goal with + | [H : ?a <= ?a |- _] => clear H + | [H : ?a <= S ?a |- _] => clear H + | [H : ?a < S ?a |- _] => clear H + | [H : ?a = ?a |- _] => clear H + end. + Lemma carry_sequence_bounds_lower : forall j i us, (length us = length base) -> (forall i, 0 <= nth_default 0 us i) -> (j <= length base)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain j) us) i. Proof. - induction j; intros. - + simpl. auto. - + simpl. - destruct (lt_dec (S j) i). - - rewrite carry_unaffected_high by carry_length_conditions. - apply IHj; auto; omega. - - assert ((i = S j) \/ (i = j) \/ (i < j))%nat as cases by omega. - destruct cases as [? | [? | ?]]. - * subst. apply nth_default_carry_bound_succ_lower; carry_length_conditions. - intros. - eapply IHj; auto; omega. - * subst. apply nth_default_carry_bound_lower; carry_length_conditions. - * destruct (eq_nat_dec j (pred (length base))); - [ | rewrite carry_unaffected_low by carry_length_conditions; apply IHj; auto; omega ]. - subst. - unfold carry, carry_and_reduce; break_if; try omega. - add_set_nth; [ | apply IHj; auto; omega ]. - apply Z.add_nonneg_nonneg; [ | apply IHj; auto; omega ]. - apply Z.mul_nonneg_nonneg; try omega. - apply Z.shiftr_nonneg. - apply IHj; auto; omega. + induction j; intros; simpl; auto. + destruct (lt_dec (S j) i). + + rewrite carry_unaffected_high by carry_length_conditions. + apply IHj; auto; omega. + + assert ((i = S j) \/ (i = j) \/ (i < j))%nat as cases by omega. + destruct cases as [? | [? | ?]]. + - subst. apply nth_default_carry_bound_succ_lower; carry_length_conditions. + intros; eapply IHj; auto; omega. + - subst. apply nth_default_carry_bound_lower; carry_length_conditions. + - destruct (eq_nat_dec j (pred (length base))); + [ | rewrite carry_unaffected_low by carry_length_conditions; apply IHj; auto; omega ]. + subst. + do 2 match goal with H : appcontext[S (pred (length base))] |- _ => + erewrite <-(S_pred (length base)) in H by eauto end. + unfold carry; break_if; [ unfold carry_and_reduce | omega ]. + clear_obvious. + add_set_nth; [ zero_bounds | ]; apply IHj; auto; omega. Qed. + Ltac carry_seq_lower_bound := + repeat (intros; eapply carry_sequence_bounds_lower; eauto; carry_length_conditions). + Lemma carry_bounds_lower : forall i us j, (0 < i <= j)%nat -> (length us = length base) -> (forall i, 0 <= nth_default 0 us i) -> (j <= length base)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain j) us) i. @@ -769,13 +797,12 @@ Section CanonicalizationProofs. destruct (eq_nat_dec i (S j)). + subst. apply nth_default_carry_bound_succ_lower; auto; fold (carry_sequence (make_chain j) us); carry_length_conditions. - intros. - apply carry_sequence_bounds_lower; auto; omega. + carry_seq_lower_bound. + assert (i = j \/ i < j)%nat as cases by omega. destruct cases as [eq_j_i | lt_i_j]; subst; [apply nth_default_carry_bound_lower| rewrite carry_unaffected_low]; try omega; fold (carry_sequence (make_chain j) us); carry_length_conditions. - apply carry_sequence_bounds_lower; auto; omega. + carry_seq_lower_bound. Qed. Lemma carry_full_bounds : forall us i, (i <> 0)%nat -> (forall i, 0 <= nth_default 0 us i) -> @@ -799,18 +826,15 @@ Section CanonicalizationProofs. unfold carry, carry_simple; break_if; try omega. add_set_nth. replace (2 ^ B) with (2 ^ (B - log_cap i) + (2 ^ B - 2 ^ (B - log_cap i))) by omega. - split. - + apply Z.add_nonneg_nonneg; try omega. - apply Z.shiftr_nonneg; try omega. - + apply Z.add_lt_mono; try omega. - rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. - apply Z.div_lt_upper_bound; try apply pow_2_log_cap_pos. - rewrite <-Z.pow_add_r by (apply log_cap_nonneg || apply B_compat_log_cap). - replace (log_cap i + (B - log_cap i)) with B by ring. - omega. + split; [ zero_bounds | ]. + apply Z.add_lt_mono; try omega. + rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. + apply Z.div_lt_upper_bound; try apply pow_2_log_cap_pos. + rewrite <-Z.pow_add_r by (apply log_cap_nonneg || apply B_compat_log_cap). + replace (log_cap i + (B - log_cap i)) with B by ring. + omega. Qed. - Lemma carry_sequence_no_overflow : forall i us, pre_carry_bounds us -> (length us = length base) -> nth_default 0 (carry_sequence (make_chain i) us) i < 2 ^ B. @@ -822,16 +846,13 @@ Section CanonicalizationProofs. intuition. + simpl. destruct (lt_eq_lt_dec i (pred (length base))) as [[? | ? ] | ? ]. - - apply carry_simple_no_overflow; carry_length_conditions. - apply carry_sequence_bounds_lower; carry_length_conditions. - apply carry_sequence_bounds_lower; carry_length_conditions. - rewrite carry_sequence_unaffected; try omega. + - apply carry_simple_no_overflow; carry_length_conditions; carry_seq_lower_bound. + rewrite carry_sequence_unaffected; try omega. specialize (PCB (S i)); rewrite Nat.pred_succ in PCB. break_if; intuition. - unfold carry; break_if; try omega. rewrite nth_default_out_of_bounds; [ apply Z.pow_pos_nonneg; omega | ]. - subst. - unfold carry_and_reduce. + subst; unfold carry_and_reduce. carry_length_conditions. - rewrite nth_default_out_of_bounds; [ apply Z.pow_pos_nonneg; omega | ]. carry_length_conditions. @@ -843,25 +864,20 @@ Section CanonicalizationProofs. Proof. unfold carry_full, full_carry_chain; intros. rewrite <- base_length. - replace (length base) with (S (pred (length base))) at 1 2 by omega. + replace (length base) with (S (pred (length base))) by omega. simpl. unfold carry, carry_and_reduce; break_if; try omega. - add_set_nth. - split. - + apply Z.add_nonneg_nonneg. - - apply Z.mul_nonneg_nonneg; try omega. - apply Z.shiftr_nonneg. - apply carry_sequence_bounds_lower; auto; omega. - - apply carry_sequence_bounds_lower; auto; omega. - + rewrite Z.add_comm. - apply Z.add_le_mono. - - apply carry_bounds_0_upper; auto; omega. - - apply Z.mul_le_mono_pos_l; auto. - apply Z_shiftr_ones; auto; - [ | pose proof (B_compat_log_cap (pred (length base))); omega ]. - split. - * apply carry_bounds_lower; auto; try omega. - * apply carry_sequence_no_overflow; auto. + clear_obvious; add_set_nth. + split; [zero_bounds; carry_seq_lower_bound | ]. + rewrite Z.add_comm. + apply Z.add_le_mono. + + apply carry_bounds_0_upper; auto; omega. + + apply Z.mul_le_mono_pos_l; auto. + apply Z_shiftr_ones; auto; + [ | pose proof (B_compat_log_cap (pred (length base))); omega ]. + split. + - apply carry_bounds_lower; auto; omega. + - apply carry_sequence_no_overflow; auto. Qed. Lemma carry_full_bounds_lower : forall i us, pre_carry_bounds us -> @@ -874,12 +890,12 @@ Section CanonicalizationProofs. - apply carry_bounds_lower; carry_length_conditions. - rewrite nth_default_out_of_bounds; carry_length_conditions. Qed. - + (* END proofs about first carry loop *) - + (* BEGIN proofs about second carry loop *) - Lemma carry_sequence_carry_full_bounds_same : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_sequence_carry_full_bounds_same : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full us)) i <= 2 ^ log_cap i. Proof. @@ -888,12 +904,9 @@ Section CanonicalizationProofs. unfold carry, carry_simple; break_if; try omega. add_set_nth. split. - + apply Z.add_nonneg_nonneg. - - apply Z.shiftr_nonneg. - destruct (eq_nat_dec i 0); subst. - * simpl. - apply carry_full_bounds_0; auto. - * apply IHi; auto; omega. + + zero_bounds; [destruct (eq_nat_dec i 0); subst | ]. + - simpl; apply carry_full_bounds_0; auto. + - apply IHi; auto; omega. - rewrite carry_sequence_unaffected by carry_length_conditions. apply carry_full_bounds; auto; omega. + rewrite <-max_bound_log_cap, <-Z.add_1_l. @@ -905,16 +918,14 @@ Section CanonicalizationProofs. eapply Z.le_lt_trans; [ apply carry_full_bounds_0; auto | ]. replace (2 ^ log_cap 0 * 2) with (2 ^ log_cap 0 + 2 ^ log_cap 0) by ring. rewrite <-max_bound_log_cap, <-Z.add_1_l. - apply Z.add_lt_le_mono; try omega. - unfold c_carry_constraint in *. - intuition. + apply Z.add_lt_le_mono; omega. * eapply Z.le_lt_trans; [ apply IHi; auto; omega | ]. - apply Z.lt_mul_diag_r; auto; omega. + apply Z.lt_mul_diag_r; auto; omega. - rewrite carry_sequence_unaffected by carry_length_conditions. apply carry_full_bounds; auto; omega. Qed. - Lemma carry_full_2_bounds_0 : forall us, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds_0 : forall us, pre_carry_bounds us -> (length us = length base)%nat -> (1 < length base)%nat -> 0 <= nth_default 0 (carry_full (carry_full us)) 0 <= max_bound 0 + c. Proof. @@ -924,19 +935,14 @@ Section CanonicalizationProofs. replace (length base) with (S (pred (length base))) by (pose proof base_length_nonzero; omega). simpl. unfold carry, carry_and_reduce; break_if; try omega. - add_set_nth. + clear_obvious; add_set_nth. split. - + apply Z.add_nonneg_nonneg. - apply Z.mul_nonneg_nonneg; try omega. - apply Z.shiftr_nonneg. + + zero_bounds; [ | carry_seq_lower_bound]. apply carry_sequence_carry_full_bounds_same; auto; omega. - eapply carry_sequence_bounds_lower; eauto; carry_length_conditions. - intros. - eapply carry_sequence_bounds_lower; eauto; carry_length_conditions. + rewrite Z.add_comm. apply Z.add_le_mono. - apply carry_bounds_0_upper; carry_length_conditions. - - replace c with (c * 1) at 2 by ring. + - etransitivity; [ | replace c with (c * 1) by ring; reflexivity ]. apply Z.mul_le_mono_pos_l; try omega. rewrite Z.shiftr_div_pow2 by auto. apply Z.div_le_upper_bound; auto. @@ -945,7 +951,7 @@ Section CanonicalizationProofs. omega. Qed. - Lemma carry_full_2_bounds_succ : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds_succ : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < pred (length base))%nat -> ((0 < i < length base)%nat -> 0 <= nth_default 0 @@ -954,20 +960,14 @@ Section CanonicalizationProofs. 0 <= nth_default 0 (carry_simple i (carry_sequence (make_chain i) (carry_full (carry_full us)))) (S i) <= 2 ^ log_cap (S i). Proof. - unfold carry_simple; intros ? ? PCB CCC length_eq ? IH. + unfold carry_simple; intros ? ? PCB length_eq ? IH. add_set_nth. split. - + apply Z.add_nonneg_nonneg. - apply Z.shiftr_nonneg. - destruct i; - [ simpl; pose proof (carry_full_2_bounds_0 us PCB CCC length_eq); omega | ]. - - assert (0 < S i < length base)%nat as IHpre by omega. - specialize (IH IHpre). - omega. - - rewrite carry_sequence_unaffected by carry_length_conditions. - apply carry_full_bounds; carry_length_conditions. - intros. - apply carry_sequence_bounds_lower; eauto; carry_length_conditions. + + zero_bounds. destruct i; + [ simpl; pose proof (carry_full_2_bounds_0 us PCB length_eq); omega | ]. + rewrite carry_sequence_unaffected by carry_length_conditions. + apply carry_full_bounds; carry_length_conditions. + carry_seq_lower_bound. + rewrite <-max_bound_log_cap, <-Z.add_1_l. rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. apply Z.add_le_mono. @@ -975,10 +975,10 @@ Section CanonicalizationProofs. ring_simplify. apply IH. omega. - rewrite carry_sequence_unaffected by carry_length_conditions. apply carry_full_bounds; carry_length_conditions. - intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions. + carry_seq_lower_bound. Qed. - Lemma carry_full_2_bounds_same : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds_same : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) i <= 2 ^ log_cap i. Proof. @@ -988,36 +988,33 @@ Section CanonicalizationProofs. split; (destruct (eq_nat_dec i 0); subst; [ cbv [make_chain carry_sequence fold_right carry_simple]; add_set_nth | eapply carry_full_2_bounds_succ; eauto; omega]). - + apply Z.add_nonneg_nonneg. - apply Z.shiftr_nonneg. - eapply carry_full_2_bounds_0; eauto. - eapply carry_full_bounds; eauto; carry_length_conditions. - intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions. + + zero_bounds. + - eapply carry_full_2_bounds_0; eauto. + - eapply carry_full_bounds; eauto; carry_length_conditions. + carry_seq_lower_bound. + rewrite <-max_bound_log_cap, <-Z.add_1_l. rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. apply Z.add_le_mono. - apply Z_div_floor; auto. eapply Z.le_lt_trans; [ eapply carry_full_2_bounds_0; eauto | ]. replace (Z.succ 1) with (2 ^ 1) by ring. - rewrite <-Z.pow_add_r by (omega || auto). - unfold c_carry_constraint in *. - intuition. - - apply carry_full_bounds; carry_length_conditions. - intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions. + rewrite <-max_bound_log_cap. + ring_simplify. omega. + - apply carry_full_bounds; carry_length_conditions; carry_seq_lower_bound. Qed. - Lemma carry_full_2_bounds' : forall us i j, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds' : forall us i j, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> (i + j < length base)%nat -> (j <> 0)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain (i + j)) (carry_full (carry_full us))) i <= max_bound i. Proof. induction j; intros; try omega. - split; (destruct j; [ rewrite Nat.add_1_r; simpl + split; (destruct j; [ rewrite Nat.add_1_r; simpl | rewrite <-plus_n_Sm; simpl; rewrite carry_unaffected_low by carry_length_conditions; eapply IHj; eauto; omega ]). + apply nth_default_carry_bound_lower; carry_length_conditions. + apply nth_default_carry_bound_upper; carry_length_conditions. Qed. - Lemma carry_full_2_bounds : forall us i j, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds : forall us i j, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> (i < j < length base)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain j) (carry_full (carry_full us))) i <= max_bound i. Proof. @@ -1026,12 +1023,12 @@ Section CanonicalizationProofs. eapply carry_full_2_bounds'; eauto; omega. Qed. - Lemma carry_carry_full_2_bounds_0_lower : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_carry_full_2_bounds_0_lower : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> (0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) 0). Proof. induction i; try omega. - intros ? ? length_eq ?; simpl. + intros ? length_eq ?; simpl. destruct i. + unfold carry. break_if; @@ -1041,91 +1038,82 @@ Section CanonicalizationProofs. add_set_nth. apply pow2_mod_log_cap_bounds_lower. + rewrite carry_unaffected_low by carry_length_conditions. - assert (0 < S i < length base)%nat by omega. + assert (0 < S i < length base)%nat by omega. intuition. Qed. - Lemma carry_full_2_bounds_lower :forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds_lower :forall us i, pre_carry_bounds us -> (length us = length base)%nat -> 0 <= nth_default 0 (carry_full (carry_full us)) i. Proof. - intros. - destruct i. + intros; destruct i. + apply carry_full_2_bounds_0; auto. + apply carry_full_bounds; try solve [carry_length_conditions]. - intro j. - destruct j. + intro j; destruct j. - apply carry_full_bounds_0; auto. - apply carry_full_bounds; carry_length_conditions. Qed. - Lemma carry_carry_full_2_bounds_0_upper : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_length : forall us, (length us = length base)%nat -> + length (carry_full us) = length us. + Proof. + intros; carry_length_conditions. + Qed. + Local Hint Resolve carry_full_length. + + Lemma carry_carry_full_2_bounds_0_upper : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> (nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) 0 <= max_bound 0 - c) \/ carry_done (carry_sequence (make_chain i) (carry_full (carry_full us))). Proof. induction i; try omega. - intros ? ? length_eq ?; simpl. + intros ? length_eq ?; simpl. destruct i. + destruct (Z_le_dec (nth_default 0 (carry_full (carry_full us)) 0) (max_bound 0)). - right. - unfold carry_done. + apply carry_carry_done_done; try solve [carry_length_conditions]. + apply carry_done_bounds; try solve [carry_length_conditions]. intros. - apply max_bound_shiftr_eq_0; simpl; rewrite carry_nothing; try solve [carry_length_conditions]. - * apply carry_full_2_bounds_lower; auto. - * split; try apply carry_full_2_bounds_lower; auto. - * destruct i; auto. - apply carry_full_bounds; try solve [carry_length_conditions]. - auto using carry_full_bounds_lower. - * split; auto. - apply carry_full_2_bounds_lower; auto. - - unfold carry. + simpl. + split; [ auto using carry_full_2_bounds_lower | ]. + * destruct i; rewrite <-max_bound_log_cap, Z.lt_succ_r; auto. + apply carry_full_bounds; auto using carry_full_bounds_lower. + rewrite carry_full_length; auto. + - left; unfold carry, carry_simple. break_if; [ pose proof base_length_nonzero; replace (length base) with 1%nat in *; omega | ]. - simpl. - unfold carry_simple. - add_set_nth. left. + add_set_nth. simpl. remember ((nth_default 0 (carry_full (carry_full us)) 0)) as x. - apply Z.le_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)). - * replace x with ((x - 2 ^ log_cap 0) + (1 * 2 ^ log_cap 0)) by ring. - rewrite pow2_mod_spec by auto. - rewrite Z.mod_add by (pose proof (pow_2_log_cap_pos 0); omega). - rewrite <-max_bound_log_cap, <-Z.add_1_l, Z.mod_small. - apply Z.sub_le_mono_r. - subst; apply carry_full_2_bounds_0; auto. - split; try omega. - pose proof carry_full_2_bounds_0. - apply Z.le_lt_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)); - [ apply Z.sub_le_mono_r; subst x; apply carry_full_2_bounds_0; auto; - ring_simplify; unfold c_carry_constraint in *; omega | ]. - ring_simplify; unfold c_carry_constraint in *; omega. - * ring_simplify; unfold c_carry_constraint in *; omega. + apply Z.le_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)); try omega. + replace x with ((x - 2 ^ log_cap 0) + (1 * 2 ^ log_cap 0)) by ring. + rewrite pow2_mod_spec by auto. + cbv [make_chain carry_sequence fold_right]. + rewrite Z.mod_add by (pose proof (pow_2_log_cap_pos 0); omega). + rewrite <-max_bound_log_cap, <-Z.add_1_l, Z.mod_small; + [ apply Z.sub_le_mono_r; subst; apply carry_full_2_bounds_0; auto | ]. + split; try omega. + pose proof carry_full_2_bounds_0. + apply Z.le_lt_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)); + [ apply Z.sub_le_mono_r; subst x; apply carry_full_2_bounds_0; auto; + ring_simplify | ]; omega. + rewrite carry_unaffected_low by carry_length_conditions. - assert (0 < S i < length base)%nat by omega. - intuition. - right. + assert (0 < S i < length base)%nat by omega. + intuition; right. apply carry_carry_done_done; try solve [carry_length_conditions]. - intro j. - destruct j. - - apply carry_carry_full_2_bounds_0_lower; auto. - - destruct (lt_eq_lt_dec j i) as [[? | ?] | ?]. - * apply carry_full_2_bounds; auto; omega. - * subst. apply carry_full_2_bounds_same; auto; omega. - * rewrite carry_sequence_unaffected; try solve [carry_length_conditions]. - apply carry_full_2_bounds_lower; auto; omega. - Qed. - + assumption. + Qed. + (* END proofs about second carry loop *) - + (* BEGIN proofs about third carry loop *) - Lemma carry_full_3_bounds : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_3_bounds : forall us i, pre_carry_bounds us -> (length us = length base)%nat ->(i < length base)%nat -> 0 <= nth_default 0 (carry_full (carry_full (carry_full us))) i <= max_bound i. Proof. intros. destruct i; [ | apply carry_full_bounds; carry_length_conditions; - do 2 (intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions) ]. + carry_seq_lower_bound ]. unfold carry_full at 1 4, full_carry_chain. case_eq limb_widths; [intros; pose proof limb_widths_nonnil; congruence | ]. simpl. @@ -1135,45 +1123,967 @@ Section CanonicalizationProofs. unfold carry, carry_and_reduce; break_if; try omega; intros. add_set_nth. split. - + apply Z.add_nonneg_nonneg. - - apply Z.mul_nonneg_nonneg; auto; try omega. - apply Z.shiftr_nonneg. - eapply carry_full_2_bounds_same; eauto; omega. + + zero_bounds. + - eapply carry_full_2_bounds_same; eauto; omega. - eapply carry_carry_full_2_bounds_0_lower; eauto; omega. + pose proof (carry_carry_full_2_bounds_0_upper us (pred (length base))). assert (0 < pred (length base) < length base)%nat by omega. intuition. - replace (max_bound 0) with (c + (max_bound 0 - c)) by ring. apply Z.add_le_mono; try assumption. - replace c with (c * 1) at 2 by ring. + etransitivity; [ | replace c with (c * 1) by ring; reflexivity ]. apply Z.mul_le_mono_pos_l; try omega. rewrite Z.shiftr_div_pow2 by auto. apply Z.div_le_upper_bound; auto. ring_simplify. apply carry_full_2_bounds_same; auto. - - match goal with H : carry_done _ |- _ => unfold carry_done in H; rewrite H by omega end. + - match goal with H0 : (pred (length base) < length base)%nat, + H : carry_done _ |- _ => + destruct (H (pred (length base)) H0) as [Hcd1 Hcd2]; rewrite Hcd2 by omega end. ring_simplify. - apply shiftr_eq_0_max_bound; auto; omega. + apply shiftr_eq_0_max_bound; auto. + assert (0 < length base)%nat as zero_lt_length by omega. + match goal with H : carry_done _ |- _ => + destruct (H 0%nat zero_lt_length) end. + assumption. Qed. - Lemma nth_error_combine : forall {A B} i (x : A) (x' : B) l l', nth_error l i = Some x -> - nth_error l' i = Some x' -> nth_error (combine l l') i = Some (x, x'). - Admitted. - - Lemma nth_error_range : forall {A} i (l : list A), (i < length l)%nat -> - nth_error (range (length l)) i = Some i. - Admitted. + Lemma carry_full_3_done : forall us, pre_carry_bounds us -> + (length us = length base)%nat -> + carry_done (carry_full (carry_full (carry_full us))). + Proof. + intros. + apply carry_done_bounds; [ carry_length_conditions | intros ]. + destruct (lt_dec i (length base)). + + rewrite <-max_bound_log_cap, Z.lt_succ_r. + auto using carry_full_3_bounds. + + rewrite nth_default_out_of_bounds; carry_length_conditions. + Qed. (* END proofs about third carry loop *) - Opaque carry_full. - Lemma freeze_in_bounds : forall us i, (us <> nil)%nat -> - 0 <= nth_default 0 (freeze us) i < 2 ^ log_cap i. + Lemma isFull'_false : forall us n, isFull' us false n = false. Proof. - Admitted. + unfold isFull'; induction n; intros; rewrite Bool.andb_false_r; auto. + Qed. - Lemma freeze_canonical : forall us vs x, rep us x -> rep vs x -> + Lemma isFull'_last : forall us b j, (j <> 0)%nat -> isFull' us b j = true -> + max_bound j = nth_default 0 us j. + Proof. + induction j; simpl; intros; try omega. + match goal with + | [H : isFull' _ ((?comp ?a ?b) && _) _ = true |- _ ] => + case_eq (comp a b); rewrite ?Z.eqb_eq; intro comp_eq; try assumption; + rewrite comp_eq, Bool.andb_false_l, isFull'_false in H; congruence + end. + Qed. + + Lemma isFull'_lower_bound_0 : forall j us b, isFull' us b j = true -> + max_bound 0 - c < nth_default 0 us 0. + Proof. + induction j; intros. + + match goal with H : isFull' _ _ 0 = _ |- _ => cbv [isFull'] in H; + apply Bool.andb_true_iff in H; destruct H end. + apply Z.ltb_lt; assumption. + + eauto. + Qed. + + Lemma isFull'_true_full : forall us i j b, (i <> 0)%nat -> (i <= j)%nat -> isFull' us b j = true -> + max_bound i = nth_default 0 us i. + Proof. + induction j; intros; try omega. + assert (i = S j \/ i <= j)%nat as cases by omega. + destruct cases. + + subst. eapply isFull'_last; eauto. + + eapply IHj; eauto. + Qed. + + Lemma max_ones_nonneg : 0 <= max_ones. + Proof. + unfold max_ones. + apply Z_ones_nonneg. + pose proof limb_widths_nonneg. + induction limb_widths. + cbv; congruence. + simpl. + apply Z.max_le_iff. + right. + apply IHl; auto using in_cons. + Qed. + + Lemma land_max_ones_noop : forall x i, 0 <= x < 2 ^ log_cap i -> Z.land max_ones x = x. + Proof. + unfold max_ones. + intros ? ? x_range. + rewrite Z.land_comm. + rewrite Z.land_ones by apply Z_le_fold_right_max_initial. + apply Z.mod_small. + split; try omega. + eapply Z.lt_le_trans; try eapply x_range. + apply Z.pow_le_mono_r; try omega. + rewrite log_cap_eq. + destruct (lt_dec i (length limb_widths)). + + apply Z_le_fold_right_max. + - apply limb_widths_nonneg. + - rewrite nth_default_eq. + auto using nth_In. + + rewrite nth_default_out_of_bounds by omega. + apply Z_le_fold_right_max_initial. + Qed. + + Lemma full_isFull'_true : forall j us, (length us = length base) -> + ( max_bound 0 - c < nth_default 0 us 0 + /\ (forall i, (0 < i <= j)%nat -> nth_default 0 us i = max_bound i)) -> + isFull' us true j = true. + Proof. + induction j; intros. + + cbv [isFull']; apply Bool.andb_true_iff. + rewrite Z.ltb_lt; intuition. + + intuition. + simpl. + match goal with H : forall j, _ -> ?b j = ?a j |- appcontext[?a ?i =? ?b ?i] => + replace (a i =? b i) with true by (symmetry; apply Z.eqb_eq; symmetry; apply H; omega) end. + apply IHj; auto; intuition. + Qed. + + Lemma isFull'_true_iff : forall j us, (length us = length base) -> (isFull' us true j = true <-> + max_bound 0 - c < nth_default 0 us 0 + /\ (forall i, (0 < i <= j)%nat -> nth_default 0 us i = max_bound i)). + Proof. + intros; split; intros; auto using full_isFull'_true. + split; eauto using isFull'_lower_bound_0. + intros. + symmetry; eapply isFull'_true_full; [ omega | | eauto]. + omega. + Qed. + + Lemma isFull'_true_step : forall us j, isFull' us true (S j) = true -> + isFull' us true j = true. + Proof. + simpl; intros ? ? succ_true. + destruct (max_bound (S j) =? nth_default 0 us (S j)); auto. + rewrite isFull'_false in succ_true. + congruence. + Qed. + + Opaque isFull' max_ones. + + Lemma carry_full_3_length : forall us, (length us = length base) -> + length (carry_full (carry_full (carry_full us))) = length us. + Proof. + intros. + repeat rewrite carry_full_length by (repeat rewrite carry_full_length; auto); auto. + Qed. + Local Hint Resolve carry_full_3_length. + + Lemma nth_default_map2 : forall {A B C} (f : A -> B -> C) ls1 ls2 i d d1 d2, + nth_default d (map2 f ls1 ls2) i = + if lt_dec i (min (length ls1) (length ls2)) + then f (nth_default d1 ls1 i) (nth_default d2 ls2 i) + else d. + Proof. + induction ls1, ls2. + + cbv [map2 length min]. + intros. + break_if; try omega. + apply nth_default_nil. + + cbv [map2 length min]. + intros. + break_if; try omega. + apply nth_default_nil. + + cbv [map2 length min]. + intros. + break_if; try omega. + apply nth_default_nil. + + simpl. + destruct i. + - intros. rewrite !nth_default_cons. + break_if; auto; omega. + - intros. rewrite !nth_default_cons_S. + rewrite IHls1 with (d1 := d1) (d2 := d2). + repeat break_if; auto; omega. + Qed. + + Lemma map2_cons : forall A B C (f : A -> B -> C) ls1 ls2 a b, + map2 f (a :: ls1) (b :: ls2) = f a b :: map2 f ls1 ls2. + Proof. + reflexivity. + Qed. + + Lemma map2_nil_l : forall A B C (f : A -> B -> C) ls2, + map2 f nil ls2 = nil. + Proof. + reflexivity. + Qed. + + Lemma map2_nil_r : forall A B C (f : A -> B -> C) ls1, + map2 f ls1 nil = nil. + Proof. + destruct ls1; reflexivity. + Qed. + Local Hint Resolve map2_nil_r map2_nil_l. + + Opaque map2. + + Lemma map2_length : forall A B C (f : A -> B -> C) ls1 ls2, + length (map2 f ls1 ls2) = min (length ls1) (length ls2). + Proof. + induction ls1, ls2; intros; try solve [cbv; auto]. + rewrite map2_cons, !length_cons, IHls1. + auto. + Qed. + + Lemma modulus_digits'_length : forall i, length (modulus_digits' i) = S i. + Proof. + induction i; intros; [ cbv; congruence | ]. + unfold modulus_digits'; fold modulus_digits'. + rewrite app_length, IHi. + cbv [length]; omega. + Qed. + + Lemma modulus_digits_length : length modulus_digits = length base. + Proof. + unfold modulus_digits. + rewrite modulus_digits'_length; omega. + Qed. + + (* Helps with solving goals of the form [x = y -> min x y = x] or [x = y -> min x y = y] *) + Local Hint Resolve Nat.eq_le_incl eq_le_incl_rev. + + Hint Rewrite app_length cons_length map2_length modulus_digits_length length_zeros + map_length combine_length firstn_length map_app : lengths. + Ltac simpl_lengths := autorewrite with lengths; + repeat rewrite carry_full_length by (repeat rewrite carry_full_length; auto); + auto using Min.min_l; auto using Min.min_r. + + Lemma freeze_length : forall us, (length us = length base) -> + length (freeze us) = length us. + Proof. + unfold freeze; intros; simpl_lengths. + Qed. + + Lemma decode_firstn_succ : forall n us, (length us = length base) -> + (n < length base)%nat -> + BaseSystem.decode' (firstn (S n) base) (firstn (S n) us) = + BaseSystem.decode' (firstn n base) (firstn n us) + + nth_default 0 base n * nth_default 0 us n. + Proof. + intros. + rewrite !firstn_succ with (d := 0) by omega. + rewrite base_app, firstn_app. + autorewrite with lengths; rewrite !Min.min_l by omega. + rewrite Nat.sub_diag, firstn_firstn, firstn0, app_nil_r by omega. + rewrite skipn_app_sharp by (autorewrite with lengths; apply Min.min_l; omega). + rewrite decode'_cons, decode_nil, Z.add_0_r. + reflexivity. + Qed. + + Local Hint Resolve sum_firstn_limb_widths_nonneg. + Local Hint Resolve limb_widths_nonneg. + Local Hint Resolve nth_error_value_In. + + (* TODO : move *) + Lemma sum_firstn_all_succ : forall n l, (length l <= n)%nat -> + sum_firstn l (S n) = sum_firstn l n. + Proof. + unfold sum_firstn; intros. + rewrite !firstn_all_strong by omega. + congruence. + Qed. + + Lemma decode_carry_done_upper_bound' : forall n us, carry_done us -> + (length us = length base) -> + BaseSystem.decode (firstn n base) (firstn n us) < 2 ^ (sum_firstn limb_widths n). + Proof. + induction n; intros; [ cbv; congruence | ]. + destruct (lt_dec n (length base)) as [ n_lt_length | ? ]. + + rewrite decode_firstn_succ; auto. + rewrite base_length in n_lt_length. + destruct (nth_error_length_exists_value _ _ n_lt_length). + erewrite sum_firstn_succ; eauto. + rewrite Z.pow_add_r; eauto. + rewrite nth_default_base by (rewrite base_length; assumption). + rewrite Z.lt_add_lt_sub_r. + eapply Z.lt_le_trans; eauto. + rewrite Z.mul_comm at 1. + rewrite <-Z.mul_sub_distr_l. + rewrite <-Z.mul_1_r at 1. + apply Z.mul_le_mono_nonneg_l; [ apply Z.pow_nonneg; omega | ]. + replace 1 with (Z.succ 0) by reflexivity. + rewrite Z.le_succ_l, Z.lt_0_sub. + match goal with H : carry_done us |- _ => rewrite carry_done_bounds in H by auto; specialize (H n) end. + replace x with (log_cap n); try intuition. + rewrite log_cap_eq. + apply nth_error_value_eq_nth_default; auto. + + repeat erewrite firstn_all_strong by omega. + rewrite sum_firstn_all_succ by (rewrite <-base_length; omega). + eapply Z.le_lt_trans; [ | eauto]. + repeat erewrite firstn_all_strong by omega. + omega. + Qed. + + Lemma decode_carry_done_upper_bound : forall us, carry_done us -> + (length us = length base) -> BaseSystem.decode base us < 2 ^ k. + Proof. + unfold k; intros. + rewrite <-(firstn_all_strong base (length limb_widths)) by (rewrite <-base_length; auto). + rewrite <-(firstn_all_strong us (length limb_widths)) by (rewrite <-base_length; auto). + auto using decode_carry_done_upper_bound'. + Qed. + + Lemma decode_carry_done_lower_bound' : forall n us, carry_done us -> + (length us = length base) -> + 0 <= BaseSystem.decode (firstn n base) (firstn n us). + Proof. + induction n; intros; [ cbv; congruence | ]. + destruct (lt_dec n (length base)) as [ n_lt_length | ? ]. + + rewrite decode_firstn_succ by auto. + zero_bounds. + - rewrite nth_default_base by assumption. + apply Z.pow_nonneg; omega. + - match goal with H : carry_done us |- _ => rewrite carry_done_bounds in H by auto; specialize (H n) end. + intuition. + + eapply Z.le_trans; [ apply IHn; eauto | ]. + repeat rewrite firstn_all_strong by omega. + omega. + Qed. + + Lemma decode_carry_done_lower_bound : forall us, carry_done us -> + (length us = length base) -> 0 <= BaseSystem.decode base us. + Proof. + intros. + rewrite <-(firstn_all_strong base (length limb_widths)) by (rewrite <-base_length; auto). + rewrite <-(firstn_all_strong us (length limb_widths)) by (rewrite <-base_length; auto). + auto using decode_carry_done_lower_bound'. + Qed. + + + Lemma nth_default_modulus_digits' : forall d j i, + nth_default d (modulus_digits' j) i = + if lt_dec i (S j) + then (if (eq_nat_dec i 0) then max_bound i - c + 1 else max_bound i) + else d. + Proof. + induction j; intros; (break_if; [| apply nth_default_out_of_bounds; rewrite modulus_digits'_length; omega]). + + replace i with 0%nat by omega. + apply nth_default_cons. + + simpl. rewrite nth_default_app. + rewrite modulus_digits'_length. + break_if. + - rewrite IHj; break_if; try omega; reflexivity. + - replace i with (S j) by omega. + rewrite Nat.sub_diag, nth_default_cons. + reflexivity. + Qed. + + Lemma nth_default_modulus_digits : forall d i, + nth_default d modulus_digits i = + if lt_dec i (length base) + then (if (eq_nat_dec i 0) then max_bound i - c + 1 else max_bound i) + else d. + Proof. + unfold modulus_digits; intros. + rewrite nth_default_modulus_digits'. + replace (S (length base - 1)) with (length base) by omega. + reflexivity. + Qed. + + Lemma carry_done_modulus_digits : carry_done modulus_digits. + Proof. + apply carry_done_bounds; [apply modulus_digits_length | ]. + intros. + rewrite nth_default_modulus_digits. + break_if; [ | split; auto; omega]. + break_if; subst; split; auto; try rewrite <- max_bound_log_cap; omega. + Qed. + Local Hint Resolve carry_done_modulus_digits. + + (* TODO : move *) + Lemma decode_mod : forall us vs x, (length us = length base) -> (length vs = length base) -> + decode us = x -> + BaseSystem.decode base us mod modulus = BaseSystem.decode base vs mod modulus -> + decode vs = x. + Proof. + unfold decode; intros until 2; intros decode_us_x BSdecode_eq. + rewrite ZToField_mod in decode_us_x |- *. + rewrite <-BSdecode_eq. + assumption. + Qed. + + Ltac simpl_list_lengths := repeat match goal with + | H : appcontext[length (@nil ?A)] |- _ => rewrite (@nil_length0 A) in H + | H : appcontext[length (_ :: _)] |- _ => rewrite length_cons in H + | |- appcontext[length (@nil ?A)] => rewrite (@nil_length0 A) + | |- appcontext[length (_ :: _)] => rewrite length_cons + end. + + Lemma map2_app : forall A B C (f : A -> B -> C) ls1 ls2 ls1' ls2', + (length ls1 = length ls2) -> + map2 f (ls1 ++ ls1') (ls2 ++ ls2') = map2 f ls1 ls2 ++ map2 f ls1' ls2'. + Proof. + induction ls1, ls2; intros; rewrite ?map2_nil_r, ?app_nil_l; try congruence; + simpl_list_lengths; try omega. + rewrite <-!app_comm_cons, !map2_cons. + rewrite IHls1; auto. + Qed. + + Lemma decode_map2_sub : forall us vs, + (length us = length vs) -> + BaseSystem.decode' base (map2 (fun x y => x - y) us vs) + = BaseSystem.decode' base us - BaseSystem.decode' base vs. + Proof. + induction us using rev_ind; induction vs using rev_ind; + intros; autorewrite with lengths in *; simpl_list_lengths; + rewrite ?decode_nil; try omega. + rewrite map2_app by omega. + rewrite map2_cons, map2_nil_l. + rewrite !set_higher. + autorewrite with lengths. + rewrite Min.min_l by omega. + rewrite IHus by omega. + replace (length vs) with (length us) by omega. + ring. + Qed. + + Lemma decode_modulus_digits' : forall i, (i <= length base)%nat -> + BaseSystem.decode' base (modulus_digits' i) = 2 ^ (sum_firstn limb_widths (S i)) - c. + Proof. + induction i; intros; unfold modulus_digits'; fold modulus_digits'. + + case_eq base; + [ intro base_eq; rewrite base_eq, (@nil_length0 Z) in lt_1_length_base; omega | ]. + intros z ? base_eq. + rewrite decode'_cons, decode_nil, Z.add_0_r. + replace z with (nth_default 0 base 0) by (rewrite base_eq; auto). + rewrite nth_default_base by omega. + replace (max_bound 0 - c + 1) with (Z.succ (max_bound 0) - c) by ring. + rewrite max_bound_log_cap. + rewrite sum_firstn_succ with (x := log_cap 0) by (rewrite log_cap_eq; + apply nth_error_Some_nth_default; rewrite <-base_length; omega). + rewrite Z.pow_add_r by auto. + cbv [sum_firstn fold_right firstn]. + ring. + + assert (S i < length base \/ S i = length base)%nat as cases by omega. + destruct cases. + - rewrite sum_firstn_succ with (x := log_cap (S i)) by + (rewrite log_cap_eq; apply nth_error_Some_nth_default; + rewrite <-base_length; omega). + rewrite Z.pow_add_r, <-max_bound_log_cap, set_higher by auto. + rewrite IHi, modulus_digits'_length, nth_default_base by omega. + ring. + - rewrite sum_firstn_all_succ by (rewrite <-base_length; omega). + rewrite decode'_splice, modulus_digits'_length, firstn_all by auto. + rewrite skipn_all, decode_base_nil, Z.add_0_r by omega. + apply IHi. + omega. + Qed. + + Lemma decode_modulus_digits : BaseSystem.decode' base modulus_digits = modulus. + Proof. + unfold modulus_digits; rewrite decode_modulus_digits' by omega. + replace (S (length base - 1)) with (length base) by omega. + rewrite base_length. + fold k. unfold c. + ring. + Qed. + + Lemma map_land_max_ones_modulus_digits' : forall i, + map (Z.land max_ones) (modulus_digits' i) = (modulus_digits' i). + Proof. + induction i; intros. + + cbv [modulus_digits' map]. + f_equal. + apply land_max_ones_noop with (i := 0%nat). + rewrite <-max_bound_log_cap. + omega. + + unfold modulus_digits'; fold modulus_digits'. + rewrite map_app. + f_equal; [ apply IHi; omega | ]. + cbv [map]; f_equal. + apply land_max_ones_noop with (i := S i). + rewrite <-max_bound_log_cap. + split; auto; omega. + Qed. + + Lemma map_land_max_ones_modulus_digits : map (Z.land max_ones) modulus_digits = modulus_digits. + Proof. + apply map_land_max_ones_modulus_digits'. + Qed. + + Opaque modulus_digits. + + Lemma map_land_zero : forall ls, map (Z.land 0) ls = BaseSystem.zeros (length ls). + Proof. + induction ls; boring. + Qed. + + Lemma carry_full_preserves_Fdecode : forall us x, (length us = length base) -> + decode us = x -> decode (carry_full us) = x. + Proof. + intros. + apply carry_full_preserves_rep; auto. + unfold rep; auto. + Qed. + + Lemma freeze_preserves_rep : forall us x, rep us x -> rep (freeze us) x. + Proof. + unfold rep; intros. + intuition; rewrite ?freeze_length; auto. + unfold freeze, and_term. + break_if. + + apply decode_mod with (us := carry_full (carry_full (carry_full us))). + - rewrite carry_full_3_length; auto. + - autorewrite with lengths. + apply Min.min_r. + simpl_lengths; omega. + - repeat apply carry_full_preserves_rep; repeat rewrite carry_full_length; auto. + unfold rep; intuition. + - rewrite decode_map2_sub by (simpl_lengths; omega). + rewrite map_land_max_ones_modulus_digits. + rewrite decode_modulus_digits. + destruct (Z_eq_dec modulus 0); [ subst; rewrite !Zmod_0_r; reflexivity | ]. + rewrite <-Z.add_opp_r. + replace (-modulus) with (-1 * modulus) by ring. + symmetry; auto using Z.mod_add. + + eapply decode_mod; eauto. + simpl_lengths. + rewrite map_land_zero, decode_map2_sub, zeros_rep, Z.sub_0_r by simpl_lengths. + match goal with H : decode ?us = ?x |- _ => erewrite Fdecode_decode_mod; eauto; + do 3 apply carry_full_preserves_Fdecode in H; simpl_lengths + end. + erewrite Fdecode_decode_mod; eauto; simpl_lengths. + Qed. + Hint Resolve freeze_preserves_rep. + + Lemma isFull_true_iff : forall us, (length us = length base) -> (isFull us = true <-> + max_bound 0 - c < nth_default 0 us 0 + /\ (forall i, (0 < i <= length base - 1)%nat -> nth_default 0 us i = max_bound i)). + Proof. + unfold isFull; intros; auto using isFull'_true_iff. + Qed. + + Definition minimal_rep us := BaseSystem.decode base us = (BaseSystem.decode base us) mod modulus. + + Fixpoint compare' us vs i := + match i with + | O => Eq + | S i' => if Z_eq_dec (nth_default 0 us i') (nth_default 0 vs i') + then compare' us vs i' + else Z.compare (nth_default 0 us i') (nth_default 0 vs i') + end. + + (* Lexicographically compare two vectors of equal length, starting from the END of the list + (in our context, this is the most significant end). NOT constant time. *) + Definition compare us vs := compare' us vs (length us). + + Lemma compare'_Eq : forall us vs i, (length us = length vs) -> + compare' us vs i = Eq -> firstn i us = firstn i vs. + Proof. + induction i; intros; [ cbv; congruence | ]. + destruct (lt_dec i (length us)). + + repeat rewrite firstn_succ with (d := 0) by omega. + match goal with H : compare' _ _ (S _) = Eq |- _ => + inversion H end. + break_if; f_equal; auto. + - f_equal; auto. + - rewrite Z.compare_eq_iff in *. congruence. + - rewrite Z.compare_eq_iff in *. congruence. + + rewrite !firstn_all_strong in IHi by omega. + match goal with H : compare' _ _ (S _) = Eq |- _ => + inversion H end. + rewrite (nth_default_out_of_bounds i us) in * by omega. + rewrite (nth_default_out_of_bounds i vs) in * by omega. + break_if; try congruence. + f_equal; auto. + Qed. + + Lemma compare_Eq : forall us vs, (length us = length vs) -> + compare us vs = Eq -> us = vs. + Proof. + intros. + erewrite <-(firstn_all _ us); eauto. + erewrite <-(firstn_all _ vs); eauto. + apply compare'_Eq; auto. + Qed. + + Lemma decode_lt_next_digit : forall us n, (length us = length base) -> + (n < length base)%nat -> (n < length us)%nat -> + carry_done us -> + BaseSystem.decode' (firstn n base) (firstn n us) < + (nth_default 0 base n). + Proof. + induction n; intros ? ? ? bounded. + + cbv [firstn]. + rewrite decode_base_nil. + apply Z.gt_lt; auto using nth_default_base_positive. + + rewrite decode_firstn_succ by (auto || omega). + rewrite nth_default_base_succ by omega. + eapply Z.lt_le_trans. + - apply Z.add_lt_mono_r. + apply IHn; auto; omega. + - rewrite <-(Z.mul_1_r (nth_default 0 base n)) at 1. + rewrite <-Z.mul_add_distr_l, Z.mul_comm. + apply Z.mul_le_mono_pos_r. + * apply Z.gt_lt. apply nth_default_base_positive; omega. + * rewrite Z.add_1_l. + apply Z.le_succ_l. + rewrite carry_done_bounds in bounded by assumption. + apply bounded. + Qed. + + Lemma highest_digit_determines : forall us vs n x, (x < 0) -> + (length us = length base) -> + (length vs = length base) -> + (n < length us)%nat -> carry_done us -> + (n < length vs)%nat -> carry_done vs -> + BaseSystem.decode (firstn n base) (firstn n us) + + nth_default 0 base n * x - + BaseSystem.decode (firstn n base) (firstn n vs) < 0. + Proof. + intros. + eapply Z.le_lt_trans. + + apply Z.le_sub_nonneg. + apply decode_carry_done_lower_bound'; auto. + + eapply Z.le_lt_trans. + - eapply Z.add_le_mono with (q := nth_default 0 base n * -1); [ apply Z.le_refl | ]. + apply Z.mul_le_mono_nonneg_l; try omega. + rewrite nth_default_base by omega; apply Z.pow_nonneg; omega. + - ring_simplify. + apply Z.lt_sub_0. + apply decode_lt_next_digit; auto. + omega. + Qed. + + Lemma Z_compare_decode_step_eq : forall n us vs, + (length us = length base) -> + (length us = length vs) -> + (S n <= length base)%nat -> + (nth_default 0 us n = nth_default 0 vs n) -> + (BaseSystem.decode (firstn (S n) base) us ?= + BaseSystem.decode (firstn (S n) base) vs) = + (BaseSystem.decode (firstn n base) us ?= + BaseSystem.decode (firstn n base) vs). + Proof. + intros until 3; intro nth_default_eq. + destruct (lt_dec n (length us)); try omega. + rewrite firstn_succ with (d := 0), !base_app by omega. + autorewrite with lengths; rewrite Min.min_l by omega. + do 2 (rewrite skipn_nth_default with (d := 0) by omega; + rewrite decode'_cons, decode_base_nil, Z.add_0_r). + rewrite Z.compare_sub, nth_default_eq, Z.add_add_simpl_r_r. + rewrite BaseSystem.decode'_truncate with (us := us). + rewrite BaseSystem.decode'_truncate with (us := vs). + rewrite firstn_length, Min.min_l, <-Z.compare_sub by omega. + reflexivity. + Qed. + + Lemma Z_compare_decode_step_lt : forall n us vs, + (length us = length base) -> + (length us = length vs) -> + (S n <= length base)%nat -> + carry_done us -> carry_done vs -> + (nth_default 0 us n < nth_default 0 vs n) -> + (BaseSystem.decode (firstn (S n) base) us ?= + BaseSystem.decode (firstn (S n) base) vs) = Lt. + Proof. + intros until 5; intro nth_default_lt. + destruct (lt_dec n (length us)). + + rewrite firstn_succ with (d := 0) by omega. + rewrite !base_app. + autorewrite with lengths; rewrite Min.min_l by omega. + do 2 (rewrite skipn_nth_default with (d := 0) by omega; + rewrite decode'_cons, decode_base_nil, Z.add_0_r). + rewrite Z.compare_sub. + apply Z.compare_lt_iff. + ring_simplify. + rewrite <-Z.add_sub_assoc. + rewrite <-Z.mul_sub_distr_l. + apply highest_digit_determines; auto; omega. + + rewrite !nth_default_out_of_bounds in nth_default_lt; omega. + Qed. + + Lemma Z_compare_decode_step_neq : forall n us vs, + (length us = length base) -> (length us = length vs) -> + (S n <= length base)%nat -> + carry_done us -> carry_done vs -> + (nth_default 0 us n <> nth_default 0 vs n) -> + (BaseSystem.decode (firstn (S n) base) us ?= + BaseSystem.decode (firstn (S n) base) vs) = + (nth_default 0 us n ?= nth_default 0 vs n). + Proof. + intros. + destruct (Z_dec (nth_default 0 us n) (nth_default 0 vs n)) as [[?|Hgt]|?]; try congruence. + + etransitivity; try apply Z_compare_decode_step_lt; auto. + + match goal with |- (?a ?= ?b) = (?c ?= ?d) => + rewrite (Z.compare_antisym b a); rewrite (Z.compare_antisym d c) end. + apply CompOpp_inj; rewrite !CompOpp_involutive. + apply gt_lt_symmetry in Hgt. + etransitivity; try apply Z_compare_decode_step_lt; auto; omega. + Qed. + + Lemma decode_compare' : forall n us vs, + (length us = length base) -> + (length us = length vs) -> + (n <= length base)%nat -> + carry_done us -> carry_done vs -> + (BaseSystem.decode (firstn n base) us ?= BaseSystem.decode (firstn n base) vs) + = compare' us vs n. + Proof. + induction n; intros. + + cbv [firstn compare']; rewrite !decode_base_nil; auto. + + unfold compare'; fold compare'. + break_if. + - rewrite Z_compare_decode_step_eq by (auto || omega). + apply IHn; auto; omega. + - rewrite Z_compare_decode_step_neq; (auto || omega). + Qed. + + Lemma decode_compare : forall us vs, + (length us = length base) -> carry_done us -> + (length vs = length base) -> carry_done vs -> + Z.compare (BaseSystem.decode base us) (BaseSystem.decode base vs) = compare us vs. + Proof. + unfold compare; intros. + erewrite <-(firstn_all _ base). + + apply decode_compare'; auto; omega. + + assumption. + Qed. + + Lemma compare'_succ : forall us j vs, compare' us vs (S j) = + if Z.eq_dec (nth_default 0 us j) (nth_default 0 vs j) + then compare' us vs j + else nth_default 0 us j ?= nth_default 0 vs j. + Proof. + reflexivity. + Qed. + + Lemma compare'_firstn_r_small_index : forall us j vs, (j <= length vs)%nat -> + compare' us vs j = compare' us (firstn j vs) j. + Proof. + induction j; intros; auto. + rewrite !compare'_succ by omega. + rewrite firstn_succ with (d := 0) by omega. + rewrite nth_default_app. + simpl_lengths. + rewrite Min.min_l by omega. + destruct (lt_dec j j); try omega. + rewrite Nat.sub_diag. + rewrite nth_default_cons. + break_if; try reflexivity. + rewrite IHj with (vs := firstn j vs ++ nth_default 0 vs j :: nil) by + (autorewrite with lengths; rewrite Min.min_l; omega). + rewrite firstn_app_sharp by (autorewrite with lengths; apply Min.min_l; omega). + apply IHj; omega. + Qed. + + Lemma compare'_firstn_r : forall us j vs, + compare' us vs j = compare' us (firstn j vs) j. + Proof. + intros. + destruct (le_dec j (length vs)). + + auto using compare'_firstn_r_small_index. + + f_equal. symmetry. + apply firstn_all_strong. + omega. + Qed. + + Lemma compare'_not_Lt : forall us vs j, j <> 0%nat -> + (forall i, (0 < i < j)%nat -> 0 <= nth_default 0 us i <= nth_default 0 vs i) -> + compare' us vs j <> Lt -> + nth_default 0 vs 0 <= nth_default 0 us 0 /\ + (forall i : nat, (0 < i < j)%nat -> nth_default 0 us i = nth_default 0 vs i). + Proof. + induction j; try congruence. + rewrite compare'_succ. + intros; destruct (eq_nat_dec j 0). + + break_if; subst; split; intros; try omega. + rewrite Z.compare_ge_iff in *; omega. + + break_if. + - split; intros; [ | destruct (eq_nat_dec i j); subst; auto ]; + apply IHj; auto; intros; try omega; + match goal with H : forall i, _ -> 0 <= ?f i <= ?g i |- 0 <= ?f _ <= ?g _ => + apply H; omega end. + - exfalso. rewrite Z.compare_ge_iff in *. + match goal with H : forall i, ?P -> 0 <= ?f i <= ?g i |- _ => + specialize (H j) end; omega. + Qed. + + Lemma isFull'_compare' : forall us j, j <> 0%nat -> (length us = length base) -> + (j <= length base)%nat -> carry_done us -> + (isFull' us true (j - 1) = true <-> compare' us modulus_digits j <> Lt). + Proof. + unfold compare; induction j; intros; try congruence. + replace (S j - 1)%nat with j by omega. + split; intros. + + simpl. + break_if; [destruct (eq_nat_dec j 0) | ]. + - subst. cbv; congruence. + - apply IHj; auto; try omega. + apply isFull'_true_step. + replace (S (j - 1)) with j by omega; auto. + - rewrite nth_default_modulus_digits in *. + repeat (break_if; try omega). + * subst. + match goal with H : isFull' _ _ _ = true |- _ => + apply isFull'_lower_bound_0 in H end. + apply Z.compare_ge_iff. + omega. + * match goal with H : isFull' _ _ _ = true |- _ => + apply isFull'_true_iff in H; try assumption; destruct H as [? eq_max_bound] end. + specialize (eq_max_bound j). + omega. + + apply isFull'_true_iff; try assumption. + match goal with H : compare' _ _ _ <> Lt |- _ => apply compare'_not_Lt in H; [ destruct H as [Hdigit0 Hnonzero] | | ] end. + - split; [ | intros i i_range; assert (0 < i < S j)%nat as i_range' by omega; + specialize (Hnonzero i i_range')]; + rewrite nth_default_modulus_digits in *; + repeat (break_if; try omega). + - congruence. + - intros. + rewrite nth_default_modulus_digits. + repeat (break_if; try omega). + rewrite <-Z.lt_succ_r with (m := max_bound i). + rewrite max_bound_log_cap; apply carry_done_bounds; assumption. + Qed. + + Lemma isFull_compare : forall us, (length us = length base) -> carry_done us -> + (isFull us = true <-> compare us modulus_digits <> Lt). + Proof. + unfold compare, isFull; intros ? lengths_eq. intros. + rewrite lengths_eq. + apply isFull'_compare'; try omega. + assumption. + Qed. + + Lemma isFull_decode : forall us, (length us = length base) -> carry_done us -> + (isFull us = true <-> + (BaseSystem.decode base us ?= BaseSystem.decode base modulus_digits <> Lt)). + Proof. + intros. + rewrite decode_compare; autorewrite with lengths; auto. + apply isFull_compare; auto. + Qed. + + Lemma isFull_false_upper_bound : forall us, (length us = length base) -> + carry_done us -> isFull us = false -> + BaseSystem.decode base us < modulus. + Proof. + intros. + destruct (Z_lt_dec (BaseSystem.decode base us) modulus) as [? | nlt_modulus]; + [assumption | exfalso]. + apply Z.compare_nlt_iff in nlt_modulus. + rewrite <-decode_modulus_digits in nlt_modulus at 2. + apply isFull_decode in nlt_modulus; try assumption; congruence. + Qed. + + Lemma isFull_true_lower_bound : forall us, (length us = length base) -> + carry_done us -> isFull us = true -> + modulus <= BaseSystem.decode base us. + Proof. + intros. + rewrite <-decode_modulus_digits at 1. + apply Z.compare_ge_iff. + apply isFull_decode; auto. + Qed. + + Lemma freeze_in_bounds : forall us, + pre_carry_bounds us -> (length us = length base) -> + carry_done (freeze us). + Proof. + unfold freeze, and_term; intros ? PCB lengths_eq. + rewrite carry_done_bounds by simpl_lengths; intro i. + rewrite nth_default_map2 with (d1 := 0) (d2 := 0). + simpl_lengths. + break_if; [ | split; (omega || auto)]. + break_if. + + rewrite map_land_max_ones_modulus_digits. + apply isFull_true_iff in Heqb; [ | simpl_lengths]. + destruct Heqb as [first_digit high_digits]. + destruct (eq_nat_dec i 0). + - subst. + clear high_digits. + rewrite nth_default_modulus_digits. + repeat (break_if; try omega). + pose proof (carry_full_3_done us PCB lengths_eq) as cf3_done. + rewrite carry_done_bounds in cf3_done by simpl_lengths. + specialize (cf3_done 0%nat). + omega. + - assert ((0 < i <= length base - 1)%nat) as i_range by + (simpl_lengths; apply lt_min_l in l; omega). + specialize (high_digits i i_range). + clear first_digit i_range. + rewrite high_digits. + rewrite <-max_bound_log_cap. + rewrite nth_default_modulus_digits. + repeat (break_if; try omega). + * rewrite Z.sub_diag. + split; try omega. + apply Z.lt_succ_r; auto. + * rewrite Z.lt_succ_r, Z.sub_0_r. split; (omega || auto). + + rewrite map_land_zero, nth_default_zeros. + rewrite Z.sub_0_r. + apply carry_done_bounds; [ simpl_lengths | ]. + auto using carry_full_3_done. + Qed. + Local Hint Resolve freeze_in_bounds. + + Local Hint Resolve carry_full_3_done. + + Lemma freeze_minimal_rep : forall us, pre_carry_bounds us -> (length us = length base) -> + minimal_rep (freeze us). + Proof. + unfold minimal_rep, freeze, and_term. + intros. + symmetry. apply Z.mod_small. + split; break_if; rewrite decode_map2_sub; simpl_lengths. + + rewrite map_land_max_ones_modulus_digits, decode_modulus_digits. + apply Z.le_0_sub. + apply isFull_true_lower_bound; simpl_lengths. + + rewrite map_land_zero, zeros_rep, Z.sub_0_r. + apply decode_carry_done_lower_bound; simpl_lengths. + + rewrite map_land_max_ones_modulus_digits, decode_modulus_digits. + rewrite Z.lt_sub_lt_add_r. + apply Z.lt_le_trans with (m := 2 * modulus); try omega. + eapply Z.lt_le_trans; [ | apply two_pow_k_le_2modulus ]. + apply decode_carry_done_upper_bound; simpl_lengths. + + rewrite map_land_zero, zeros_rep, Z.sub_0_r. + apply isFull_false_upper_bound; simpl_lengths. + Qed. + Local Hint Resolve freeze_minimal_rep. + + Lemma rep_decode_mod : forall us vs x, rep us x -> rep vs x -> + (BaseSystem.decode base us) mod modulus = (BaseSystem.decode base vs) mod modulus. + Proof. + unfold rep, decode; intros. + intuition. + repeat rewrite <-FieldToZ_ZToField. + congruence. + Qed. + + Lemma minimal_rep_unique : forall us vs x, + rep us x -> minimal_rep us -> carry_done us -> + rep vs x -> minimal_rep vs -> carry_done vs -> + us = vs. + Proof. + intros. + match goal with Hrep1 : rep _ ?x, Hrep2 : rep _ ?x |- _ => + pose proof (rep_decode_mod _ _ _ Hrep1 Hrep2) as eqmod end. + repeat match goal with Hmin : minimal_rep ?us |- _ => unfold minimal_rep in Hmin; + rewrite <- Hmin in eqmod; clear Hmin end. + apply Z.compare_eq_iff in eqmod. + rewrite decode_compare in eqmod; unfold rep in *; auto; intuition; try congruence. + apply compare_Eq; auto. + congruence. + Qed. + + Lemma freeze_canonical : forall us vs x, + pre_carry_bounds us -> rep us x -> + pre_carry_bounds vs -> rep vs x -> freeze us = freeze vs. - Admitted. + Proof. + intros. + assert (length us = length base) by (unfold rep in *; intuition). + assert (length vs = length base) by (unfold rep in *; intuition). + eapply minimal_rep_unique; eauto; rewrite freeze_length; assumption. + Qed. End CanonicalizationProofs.
\ No newline at end of file |