From 79e6c4ea6cd0ed52fda2168cda78c52e4bc4896a Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Wed, 20 Jul 2016 10:13:48 -0700 Subject: Remove dependency of ext_base on pseudomersenne --- src/ModularArithmetic/ExtendedBaseVector.v | 193 ++++++++++++------------ src/ModularArithmetic/ModularBaseSystem.v | 2 +- src/ModularArithmetic/ModularBaseSystemOpt.v | 4 +- src/ModularArithmetic/ModularBaseSystemProofs.v | 47 +++--- 4 files changed, 120 insertions(+), 126 deletions(-) diff --git a/src/ModularArithmetic/ExtendedBaseVector.v b/src/ModularArithmetic/ExtendedBaseVector.v index 580156bca..ed21d10f9 100644 --- a/src/ModularArithmetic/ExtendedBaseVector.v +++ b/src/ModularArithmetic/ExtendedBaseVector.v @@ -5,14 +5,14 @@ Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. Require Import VerdiTactics. Require Import Crypto.ModularArithmetic.Pow2Base. Require Import Crypto.ModularArithmetic.Pow2BaseProofs. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. Require Import Crypto.BaseSystemProofs. Require Crypto.BaseSystem. Local Open Scope Z_scope. Section ExtendedBaseVector. - Context `{prm : PseudoMersenneBaseParams}. + Context (limb_widths : list Z) + (limb_widths_nonnegative : forall x, In x limb_widths -> 0 <= x). + Local Notation k := (sum_firstn limb_widths (length limb_widths)). Local Notation base := (base_from_limb_widths limb_widths). (* This section defines a new BaseVector that has double the length of the BaseVector @@ -43,44 +43,21 @@ Section ExtendedBaseVector. Lemma ext_base_alt : ext_base = base ++ (map (Z.mul (2^k)) base). Proof. unfold ext_base, ext_limb_widths. - rewrite base_from_limb_widths_app by auto using limb_widths_pos, Z.lt_le_incl. + rewrite base_from_limb_widths_app by auto. rewrite two_p_equiv. reflexivity. Qed. Lemma ext_base_positive : forall b, In b ext_base -> b > 0. Proof. - rewrite ext_base_alt. intros b In_b_base. - rewrite in_app_iff in In_b_base. - destruct In_b_base as [In_b_base | In_b_extbase]. - + eapply BaseSystem.base_positive. - eapply In_b_base. - + eapply in_map_iff in In_b_extbase. - destruct In_b_extbase as [b' [b'_2k_b In_b'_base]]. - subst. - specialize (BaseSystem.base_positive b' In_b'_base); intro base_pos. - replace 0 with (2 ^ k * 0) by ring. - apply (Zmult_gt_compat_l b' 0 (2 ^ k)); [| apply base_pos; intuition]. - rewrite Z.gt_lt_iff. - apply Z.pow_pos_nonneg; intuition. - apply sum_firstn_limb_widths_nonneg; auto using limb_widths_nonneg. + apply base_positive; unfold ext_limb_widths. + intros ? H. apply in_app_or in H; destruct H; auto. Qed. - Lemma base_length_nonzero : (0 < length base)%nat. + Lemma b0_1 : forall x, nth_default x base 0 = 1 -> nth_default x ext_base 0 = 1. Proof. - assert (nth_default 0 base 0 = 1) by (apply BaseSystem.b0_1). - unfold nth_default in H. - case_eq (nth_error base 0); intros; - try (rewrite H0 in H; omega). - apply (nth_error_value_length _ 0 base z); auto. - Qed. - - Lemma b0_1 : forall x, nth_default x ext_base 0 = 1. - Proof. - intros. rewrite ext_base_alt. - rewrite nth_default_app. - assert (0 < length base)%nat by (apply base_length_nonzero). - destruct (lt_dec 0 (length base)); try apply BaseSystem.b0_1; try omega. + intros. rewrite ext_base_alt, nth_default_app. + destruct base; assumption. Qed. Lemma map_nth_default_base_high : forall n, (n < (length base))%nat -> @@ -91,77 +68,84 @@ Section ExtendedBaseVector. erewrite map_nth_default; auto. Qed. - Lemma base_good_over_boundary : forall - (i : nat) - (l : (i < length base)%nat) - (j' : nat) - (Hj': (i + j' < length base)%nat) - , - 2 ^ k * (nth_default 0 base i * nth_default 0 base j') = - 2 ^ k * (nth_default 0 base i * nth_default 0 base j') / - (2 ^ k * nth_default 0 base (i + j')) * - (2 ^ k * nth_default 0 base (i + j')) - . - Proof. - intros. - remember (nth_default 0 base) as b. - rewrite Zdiv_mult_cancel_l by apply two_sum_firstn_limb_widths_nonzero, limb_widths_nonneg. - replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat)) - with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring. - rewrite Z.mul_cancel_l by apply two_sum_firstn_limb_widths_nonzero, limb_widths_nonneg. - replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat)) - with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring. - subst b. - apply (BaseSystem.base_good i j'); omega. - Qed. + Section base_good. + Context (two_k_nonzero : 2^k <> 0) + (base_good : forall i j, (i+j < length base)%nat -> + let b := nth_default 0 base in + let r := (b i * b j) / b (i+j)%nat in + b i * b j = r * b (i+j)%nat) + (limb_widths_match_modulus : forall i j, + (i < length limb_widths)%nat -> + (j < length limb_widths)%nat -> + (i + j >= length limb_widths)%nat -> + let w_sum := sum_firstn limb_widths in + k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j). - Lemma ext_base_good : - forall i j, (i+j < length ext_base)%nat -> - let b := nth_default 0 ext_base in - let r := (b i * b j) / b (i+j)%nat in - b i * b j = r * b (i+j)%nat. - Proof. - intros. - subst b. subst r. - rewrite ext_base_alt in *. - rewrite app_length in H; rewrite map_length in H. - repeat rewrite nth_default_app. - repeat break_if; try omega. - { (* i < length base, j < length base, i + j < length base *) - auto using BaseSystem.base_good. - } { (* i < length base, j < length base, i + j >= length base *) - rewrite (map_nth_default _ _ _ _ 0) by omega. - autorewrite with distr_length in * |- . - apply base_matches_modulus; auto using limb_widths_nonneg, limb_widths_match_modulus. - omega. - } { (* i < length base, j >= length base, i + j >= length base *) - do 2 rewrite map_nth_default_base_high by omega. - remember (j - length base)%nat as j'. - replace (i + j - length base)%nat with (i + j')%nat by omega. - replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j')) - with (2 ^ k * (nth_default 0 base i * nth_default 0 base j')) - by ring. - eapply base_good_over_boundary; eauto; omega. - } { (* i >= length base, j < length base, i + j >= length base *) - do 2 rewrite map_nth_default_base_high by omega. - remember (i - length base)%nat as i'. - replace (i + j - length base)%nat with (j + i')%nat by omega. - replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j) - with (2 ^ k * (nth_default 0 base j * nth_default 0 base i')) - by ring. - eapply base_good_over_boundary; eauto; omega. - } - Qed. + Lemma base_good_over_boundary + : forall (i : nat) + (l : (i < length base)%nat) + (j' : nat) + (Hj': (i + j' < length base)%nat), + 2 ^ k * (nth_default 0 base i * nth_default 0 base j') = + (2 ^ k * (nth_default 0 base i * nth_default 0 base j')) + / (2 ^ k * nth_default 0 base (i + j')) * + (2 ^ k * nth_default 0 base (i + j')). + Proof. + clear limb_widths_match_modulus. + intros. + remember (nth_default 0 base) as b. + rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero). + replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat)) + with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring. + rewrite Z.mul_cancel_l by (exact two_k_nonzero). + replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat)) + with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring. + subst b. + apply (base_good i j'); omega. + Qed. - Instance ExtBaseVector : BaseSystem.BaseVector ext_base := { - base_positive := ext_base_positive; - b0_1 := b0_1; - base_good := ext_base_good - }. + Lemma ext_base_good : + forall i j, (i+j < length ext_base)%nat -> + let b := nth_default 0 ext_base in + let r := (b i * b j) / b (i+j)%nat in + b i * b j = r * b (i+j)%nat. + Proof. + intros. + subst b. subst r. + rewrite ext_base_alt in *. + rewrite app_length in H; rewrite map_length in H. + repeat rewrite nth_default_app. + repeat break_if; try omega. + { (* i < length base, j < length base, i + j < length base *) + auto using BaseSystem.base_good. + } { (* i < length base, j < length base, i + j >= length base *) + rewrite (map_nth_default _ _ _ _ 0) by omega. + apply base_matches_modulus; auto using limb_widths_nonnegative, limb_widths_match_modulus; + distr_length. + } { (* i < length base, j >= length base, i + j >= length base *) + do 2 rewrite map_nth_default_base_high by omega. + remember (j - length base)%nat as j'. + replace (i + j - length base)%nat with (i + j')%nat by omega. + replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j')) + with (2 ^ k * (nth_default 0 base i * nth_default 0 base j')) + by ring. + eapply base_good_over_boundary; eauto; omega. + } { (* i >= length base, j < length base, i + j >= length base *) + do 2 rewrite map_nth_default_base_high by omega. + remember (i - length base)%nat as i'. + replace (i + j - length base)%nat with (j + i')%nat by omega. + replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j) + with (2 ^ k * (nth_default 0 base j * nth_default 0 base i')) + by ring. + eapply base_good_over_boundary; eauto; omega. + } + Qed. + End base_good. Lemma extended_base_length: length ext_base = (length base + length base)%nat. Proof. + clear limb_widths_nonnegative. unfold ext_base, ext_limb_widths; autorewrite with distr_length; reflexivity. Qed. @@ -177,6 +161,21 @@ Section ExtendedBaseVector. (length us <= length base)%nat -> BaseSystem.decode base us = BaseSystem.decode ext_base us. Proof. auto using decode_short_initial, firstn_us_base_ext_base. Qed. + + Section BaseVector. + Context {bv : BaseSystem.BaseVector base} + (limb_widths_match_modulus : forall i j, + (i < length limb_widths)%nat -> + (j < length limb_widths)%nat -> + (i + j >= length limb_widths)%nat -> + let w_sum := sum_firstn limb_widths in + k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j). + + Instance ExtBaseVector : BaseSystem.BaseVector ext_base := + { base_positive := ext_base_positive; + b0_1 x := b0_1 x (BaseSystem.b0_1 _); + base_good := ext_base_good (two_sum_firstn_limb_widths_nonzero limb_widths_nonnegative _) BaseSystem.base_good limb_widths_match_modulus }. + End BaseVector. End ExtendedBaseVector. Hint Rewrite @extended_base_length : distr_length. diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index 23b0c2ef6..8c850c941 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -30,7 +30,7 @@ Section PseudoMersenneBase. let wrap := map (Z.mul c) high in BaseSystem.add low wrap. - Definition mul (us vs : digits) := reduce (BaseSystem.mul ext_base us vs). + Definition mul (us vs : digits) := reduce (BaseSystem.mul (ext_base limb_widths) us vs). (* In order to subtract without underflowing, we add a multiple of the modulus first. *) Definition sub (us vs : digits) := BaseSystem.sub (add modulus_multiple us) vs. diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index 696f10438..1e748892d 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -478,12 +478,12 @@ Section Multiplication. Proof. eexists. cbv [BaseSystem.mul mul mul_each mul_bi mul_bi' zeros reduce]. - rewrite ext_base_alt. + rewrite ext_base_alt by auto using limb_widths_pos with zarith. rewrite <- mul'_opt_correct. change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt. rewrite Z.map_shiftl by apply k_nonneg. rewrite c_subst. - rewrite k_subst. + fold k; rewrite k_subst. change @map with @map_opt. change @Z.shiftl_by with @Z_shiftl_by_opt. reflexivity. diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index c06dcdf98..73146fe75 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -66,6 +66,13 @@ Section PseudoMersenneProofs. Lemma base_length : length base = length limb_widths. Proof. distr_length. Qed. + Lemma base_length_nonzero : length base <> 0%nat. + Proof. + distr_length. + pose proof limb_widths_nonnil. + destruct limb_widths; simpl in *; congruence. + Qed. + Lemma encode'_eq : forall (x : F modulus) i, (i <= length limb_widths)%nat -> encode' limb_widths x i = BaseSystem.encode' base x (2 ^ k) i. Proof. @@ -137,29 +144,15 @@ Section PseudoMersenneProofs. subst; auto. Qed. - Lemma firstn_us_base_ext_base : forall (us : BaseSystem.digits), - (length us <= length base)%nat - -> firstn (length us) base = firstn (length us) ext_base. - Proof. - rewrite ext_base_alt; intros. - rewrite firstn_app_inleft; auto; omega. - Qed. - Local Hint Immediate firstn_us_base_ext_base. - - Lemma decode_short : forall (us : BaseSystem.digits), - (length us <= length base)%nat -> - BaseSystem.decode base us = BaseSystem.decode ext_base us. - Proof. auto using decode_short_initial. Qed. - - Local Hint Immediate ExtBaseVector. + Local Hint Resolve firstn_us_base_ext_base bv ExtBaseVector limb_widths_match_modulus. + Local Hint Extern 1 => apply limb_widths_match_modulus. Lemma mul_rep_extended : forall (us vs : BaseSystem.digits), (length us <= length base)%nat -> (length vs <= length base)%nat -> - (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode ext_base (BaseSystem.mul ext_base us vs). + (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode (ext_base limb_widths) (BaseSystem.mul (ext_base limb_widths) us vs). Proof. - intros; apply mul_rep_two_base; auto; - distr_length. + intros; apply mul_rep_two_base; auto with arith distr_length. Qed. Lemma modulus_nonzero : modulus <> 0. @@ -187,13 +180,14 @@ Section PseudoMersenneProofs. Qed. Lemma extended_shiftadd: forall (us : BaseSystem.digits), - BaseSystem.decode ext_base us = + BaseSystem.decode (ext_base limb_widths) us = BaseSystem.decode base (firstn (length base) us) + (2^k * BaseSystem.decode base (skipn (length base) us)). Proof. intros. unfold BaseSystem.decode; rewrite <- mul_each_rep. - rewrite ext_base_alt. + rewrite ext_base_alt by auto. + fold k. replace (map (Z.mul (2 ^ k)) base) with (BaseSystem.mul_each (2 ^ k) base) by auto. rewrite base_mul_app. rewrite <- mul_each_rep; auto. @@ -201,7 +195,7 @@ Section PseudoMersenneProofs. Lemma reduce_rep : forall us, BaseSystem.decode base (reduce us) mod modulus = - BaseSystem.decode ext_base us mod modulus. + BaseSystem.decode (ext_base limb_widths) us mod modulus. Proof. intros. rewrite extended_shiftadd. @@ -216,7 +210,7 @@ Section PseudoMersenneProofs. Qed. Lemma reduce_length : forall us, - (length base <= length us <= length ext_base)%nat -> + (length base <= length us <= length (ext_base limb_widths))%nat -> (length (reduce us) = length base)%nat. Proof. rewrite extended_base_length. @@ -236,8 +230,9 @@ Section PseudoMersenneProofs. apply reduce_length. rewrite mul_length_exact, extended_base_length; try omega. destruct u; try congruence. - rewrite @nil_length0 in *. - pose proof base_length_nonzero; omega. + pose proof limb_widths_nonnil. + autorewrite with distr_length in *. + destruct limb_widths; simpl in *; congruence. Qed. Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> u .* v ~= (x*y)%F. @@ -246,8 +241,8 @@ Section PseudoMersenneProofs. { apply length_mul; intuition auto. } { intuition idtac; subst. rewrite ZToField_mod, reduce_rep, <-ZToField_mod. - rewrite mul_rep by (apply ExtBaseVector || rewrite extended_base_length; omega). - rewrite 2decode_short by omega. + rewrite mul_rep by (auto using ExtBaseVector || rewrite extended_base_length; omega). + rewrite 2decode_short by auto with omega. apply ZToField_mul. } Qed. -- cgit v1.2.3