diff options
Diffstat (limited to 'src/ModularArithmetic/ExtendedBaseVector.v')
-rw-r--r-- | src/ModularArithmetic/ExtendedBaseVector.v | 36 |
1 files changed, 30 insertions, 6 deletions
diff --git a/src/ModularArithmetic/ExtendedBaseVector.v b/src/ModularArithmetic/ExtendedBaseVector.v index d4df6040f..0afd6b484 100644 --- a/src/ModularArithmetic/ExtendedBaseVector.v +++ b/src/ModularArithmetic/ExtendedBaseVector.v @@ -7,12 +7,13 @@ Require Import Crypto.ModularArithmetic.Pow2Base. Require Import Crypto.ModularArithmetic.Pow2BaseProofs. Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.BaseSystemProofs. Require Crypto.BaseSystem. Local Open Scope Z_scope. Section ExtendedBaseVector. Context `{prm : PseudoMersenneBaseParams}. - Local Notation base := (Pow2Base.base_from_limb_widths limb_widths). + Local Notation base := (base_from_limb_widths limb_widths). (* This section defines a new BaseVector that has double the length of the BaseVector * used to construct [params]. The coefficients of the new vector are as follows: @@ -37,11 +38,19 @@ Section ExtendedBaseVector. * * This sum may be short enough to express using base; if not, we can reduce again. *) - Definition ext_base := base ++ (map (Z.mul (2^k)) base). + Definition ext_limb_widths := limb_widths ++ limb_widths. + Definition ext_base := base_from_limb_widths ext_limb_widths. + Lemma ext_base_alt : ext_base = base ++ (map (Z.mul (2^k)) base). + Proof. + unfold ext_base, ext_limb_widths. + rewrite base_from_limb_widths_app by auto using limb_widths_pos, Z.lt_le_incl. + rewrite two_p_equiv. + reflexivity. + Qed. Lemma ext_base_positive : forall b, In b ext_base -> b > 0. Proof. - unfold ext_base. intros b In_b_base. + rewrite ext_base_alt. intros b In_b_base. rewrite in_app_iff in In_b_base. destruct In_b_base as [In_b_base | In_b_extbase]. + eapply BaseSystem.base_positive. @@ -68,7 +77,7 @@ Section ExtendedBaseVector. Lemma b0_1 : forall x, nth_default x ext_base 0 = 1. Proof. - intros. unfold ext_base. + intros. rewrite ext_base_alt. rewrite nth_default_app. assert (0 < length base)%nat by (apply base_length_nonzero). destruct (lt_dec 0 (length base)); try apply BaseSystem.b0_1; try omega. @@ -120,7 +129,7 @@ Section ExtendedBaseVector. Proof. intros. subst b. subst r. - unfold ext_base in *. + rewrite ext_base_alt in *. rewrite app_length in H; rewrite map_length in H. repeat rewrite nth_default_app. repeat break_if; try omega. @@ -157,6 +166,21 @@ Section ExtendedBaseVector. Lemma extended_base_length: length ext_base = (length base + length base)%nat. Proof. - unfold ext_base; rewrite app_length; rewrite map_length; auto. + rewrite ext_base_alt, app_length, map_length; auto. Qed. + + Lemma firstn_us_base_ext_base : forall (us : BaseSystem.digits), + (length us <= length base)%nat + -> firstn (length us) base = firstn (length us) ext_base. + Proof. + rewrite ext_base_alt; intros. + rewrite firstn_app_inleft; auto; omega. + Qed. + + Lemma decode_short : forall (us : BaseSystem.digits), + (length us <= length base)%nat -> + BaseSystem.decode base us = BaseSystem.decode ext_base us. + Proof. auto using decode_short_initial, firstn_us_base_ext_base. Qed. End ExtendedBaseVector. + +Hint Rewrite @extended_base_length : distr_length. |