diff options
Diffstat (limited to 'src/ModularArithmetic/ExtendedBaseVector.v')
-rw-r--r-- | src/ModularArithmetic/ExtendedBaseVector.v | 205 |
1 files changed, 100 insertions, 105 deletions
diff --git a/src/ModularArithmetic/ExtendedBaseVector.v b/src/ModularArithmetic/ExtendedBaseVector.v index 0afd6b484..fcd871aae 100644 --- a/src/ModularArithmetic/ExtendedBaseVector.v +++ b/src/ModularArithmetic/ExtendedBaseVector.v @@ -1,18 +1,18 @@ -Require Import Zpower ZArith. -Require Import List. +Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith. +Require Import Coq.Lists.List. Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import VerdiTactics. +Require Import Crypto.Tactics.VerdiTactics. Require Import Crypto.ModularArithmetic.Pow2Base. Require Import Crypto.ModularArithmetic.Pow2BaseProofs. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. Require Import Crypto.BaseSystemProofs. Require Crypto.BaseSystem. Local Open Scope Z_scope. Section ExtendedBaseVector. - Context `{prm : PseudoMersenneBaseParams}. + Context (limb_widths : list Z) + (limb_widths_nonnegative : forall x, In x limb_widths -> 0 <= x). + Local Notation k := (sum_firstn limb_widths (length limb_widths)). Local Notation base := (base_from_limb_widths limb_widths). (* This section defines a new BaseVector that has double the length of the BaseVector @@ -43,50 +43,21 @@ Section ExtendedBaseVector. Lemma ext_base_alt : ext_base = base ++ (map (Z.mul (2^k)) base). Proof. unfold ext_base, ext_limb_widths. - rewrite base_from_limb_widths_app by auto using limb_widths_pos, Z.lt_le_incl. + rewrite base_from_limb_widths_app by auto. rewrite two_p_equiv. reflexivity. Qed. Lemma ext_base_positive : forall b, In b ext_base -> b > 0. Proof. - rewrite ext_base_alt. intros b In_b_base. - rewrite in_app_iff in In_b_base. - destruct In_b_base as [In_b_base | In_b_extbase]. - + eapply BaseSystem.base_positive. - eapply In_b_base. - + eapply in_map_iff in In_b_extbase. - destruct In_b_extbase as [b' [b'_2k_b In_b'_base]]. - subst. - specialize (BaseSystem.base_positive b' In_b'_base); intro base_pos. - replace 0 with (2 ^ k * 0) by ring. - apply (Zmult_gt_compat_l b' 0 (2 ^ k)); [| apply base_pos; intuition]. - rewrite Z.gt_lt_iff. - apply Z.pow_pos_nonneg; intuition. - pose proof k_nonneg; omega. + apply base_positive; unfold ext_limb_widths. + intros ? H. apply in_app_or in H; destruct H; auto. Qed. - Lemma base_length_nonzero : (0 < length base)%nat. + Lemma b0_1 : forall x, nth_default x base 0 = 1 -> nth_default x ext_base 0 = 1. Proof. - assert (nth_default 0 base 0 = 1) by (apply BaseSystem.b0_1). - unfold nth_default in H. - case_eq (nth_error base 0); intros; - try (rewrite H0 in H; omega). - apply (nth_error_value_length _ 0 base z); auto. - Qed. - - Lemma b0_1 : forall x, nth_default x ext_base 0 = 1. - Proof. - intros. rewrite ext_base_alt. - rewrite nth_default_app. - assert (0 < length base)%nat by (apply base_length_nonzero). - destruct (lt_dec 0 (length base)); try apply BaseSystem.b0_1; try omega. - Qed. - - Lemma two_k_nonzero : 2^k <> 0. - Proof. - pose proof (Z.pow_eq_0 2 k k_nonneg). - intuition. + intros. rewrite ext_base_alt, nth_default_app. + destruct base; assumption. Qed. Lemma map_nth_default_base_high : forall n, (n < (length base))%nat -> @@ -97,76 +68,85 @@ Section ExtendedBaseVector. erewrite map_nth_default; auto. Qed. - Lemma base_good_over_boundary : forall - (i : nat) - (l : (i < length base)%nat) - (j' : nat) - (Hj': (i + j' < length base)%nat) - , - 2 ^ k * (nth_default 0 base i * nth_default 0 base j') = - 2 ^ k * (nth_default 0 base i * nth_default 0 base j') / - (2 ^ k * nth_default 0 base (i + j')) * - (2 ^ k * nth_default 0 base (i + j')) - . - Proof. - intros. - remember (nth_default 0 base) as b. - rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero). - replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat)) - with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring. - rewrite Z.mul_cancel_l by (exact two_k_nonzero). - replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat)) - with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring. - subst b. - apply (BaseSystem.base_good i j'); omega. - Qed. + Section base_good. + Context (two_k_nonzero : 2^k <> 0) + (base_good : forall i j, (i+j < length base)%nat -> + let b := nth_default 0 base in + let r := (b i * b j) / b (i+j)%nat in + b i * b j = r * b (i+j)%nat) + (limb_widths_match_modulus : forall i j, + (i < length limb_widths)%nat -> + (j < length limb_widths)%nat -> + (i + j >= length limb_widths)%nat -> + let w_sum := sum_firstn limb_widths in + k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j). - Lemma ext_base_good : - forall i j, (i+j < length ext_base)%nat -> - let b := nth_default 0 ext_base in - let r := (b i * b j) / b (i+j)%nat in - b i * b j = r * b (i+j)%nat. - Proof. - intros. - subst b. subst r. - rewrite ext_base_alt in *. - rewrite app_length in H; rewrite map_length in H. - repeat rewrite nth_default_app. - repeat break_if; try omega. - { (* i < length base, j < length base, i + j < length base *) - auto using BaseSystem.base_good. - } { (* i < length base, j < length base, i + j >= length base *) - rewrite (map_nth_default _ _ _ _ 0) by omega. - apply (base_matches_modulus i j); rewrite <-base_length by auto using limb_widths_nonneg; omega. - } { (* i < length base, j >= length base, i + j >= length base *) - do 2 rewrite map_nth_default_base_high by omega. - remember (j - length base)%nat as j'. - replace (i + j - length base)%nat with (i + j')%nat by omega. - replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j')) - with (2 ^ k * (nth_default 0 base i * nth_default 0 base j')) - by ring. - eapply base_good_over_boundary; eauto; omega. - } { (* i >= length base, j < length base, i + j >= length base *) - do 2 rewrite map_nth_default_base_high by omega. - remember (i - length base)%nat as i'. - replace (i + j - length base)%nat with (j + i')%nat by omega. - replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j) - with (2 ^ k * (nth_default 0 base j * nth_default 0 base i')) - by ring. - eapply base_good_over_boundary; eauto; omega. - } - Qed. + Lemma base_good_over_boundary + : forall (i : nat) + (l : (i < length base)%nat) + (j' : nat) + (Hj': (i + j' < length base)%nat), + 2 ^ k * (nth_default 0 base i * nth_default 0 base j') = + (2 ^ k * (nth_default 0 base i * nth_default 0 base j')) + / (2 ^ k * nth_default 0 base (i + j')) * + (2 ^ k * nth_default 0 base (i + j')). + Proof. + clear limb_widths_match_modulus. + intros. + remember (nth_default 0 base) as b. + rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero). + replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat)) + with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring. + rewrite Z.mul_cancel_l by (exact two_k_nonzero). + replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat)) + with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring. + subst b. + apply (base_good i j'); omega. + Qed. - Instance ExtBaseVector : BaseSystem.BaseVector ext_base := { - base_positive := ext_base_positive; - b0_1 := b0_1; - base_good := ext_base_good - }. + Lemma ext_base_good : + forall i j, (i+j < length ext_base)%nat -> + let b := nth_default 0 ext_base in + let r := (b i * b j) / b (i+j)%nat in + b i * b j = r * b (i+j)%nat. + Proof. + intros. + subst b. subst r. + rewrite ext_base_alt in *. + rewrite app_length in H; rewrite map_length in H. + repeat rewrite nth_default_app. + repeat break_if; try omega. + { (* i < length base, j < length base, i + j < length base *) + auto using BaseSystem.base_good. + } { (* i < length base, j < length base, i + j >= length base *) + rewrite (map_nth_default _ _ _ _ 0) by omega. + apply base_matches_modulus; auto using limb_widths_nonnegative, limb_widths_match_modulus; + distr_length. + } { (* i < length base, j >= length base, i + j >= length base *) + do 2 rewrite map_nth_default_base_high by omega. + remember (j - length base)%nat as j'. + replace (i + j - length base)%nat with (i + j')%nat by omega. + replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j')) + with (2 ^ k * (nth_default 0 base i * nth_default 0 base j')) + by ring. + eapply base_good_over_boundary; eauto; omega. + } { (* i >= length base, j < length base, i + j >= length base *) + do 2 rewrite map_nth_default_base_high by omega. + remember (i - length base)%nat as i'. + replace (i + j - length base)%nat with (j + i')%nat by omega. + replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j) + with (2 ^ k * (nth_default 0 base j * nth_default 0 base i')) + by ring. + eapply base_good_over_boundary; eauto; omega. + } + Qed. + End base_good. Lemma extended_base_length: length ext_base = (length base + length base)%nat. Proof. - rewrite ext_base_alt, app_length, map_length; auto. + clear limb_widths_nonnegative. + unfold ext_base, ext_limb_widths; autorewrite with distr_length; reflexivity. Qed. Lemma firstn_us_base_ext_base : forall (us : BaseSystem.digits), @@ -181,6 +161,21 @@ Section ExtendedBaseVector. (length us <= length base)%nat -> BaseSystem.decode base us = BaseSystem.decode ext_base us. Proof. auto using decode_short_initial, firstn_us_base_ext_base. Qed. + + Section BaseVector. + Context {bv : BaseSystem.BaseVector base} + (limb_widths_match_modulus : forall i j, + (i < length limb_widths)%nat -> + (j < length limb_widths)%nat -> + (i + j >= length limb_widths)%nat -> + let w_sum := sum_firstn limb_widths in + k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j). + + Instance ExtBaseVector : BaseSystem.BaseVector ext_base := + { base_positive := ext_base_positive; + b0_1 x := b0_1 x (BaseSystem.b0_1 _); + base_good := ext_base_good (two_sum_firstn_limb_widths_nonzero limb_widths_nonnegative _) BaseSystem.base_good limb_widths_match_modulus }. + End BaseVector. End ExtendedBaseVector. Hint Rewrite @extended_base_length : distr_length. |