diff options
author | Jade Philipoom <jadep@mit.edu> | 2016-03-11 16:32:48 -0500 |
---|---|---|
committer | Jade Philipoom <jadep@mit.edu> | 2016-03-11 16:32:48 -0500 |
commit | 724b7b2acb9b857d7c511a320973cead308117c6 (patch) | |
tree | ac7c7d1dcd6fea890c138c6ea9a7e1df65097f0b /src/ModularArithmetic/ModularBaseSystem.v | |
parent | b690b5180af6c8dadcf28dbe6661b43deff47331 (diff) |
Refactored BaseSystem and ModularBaseSystem.
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystem.v')
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystem.v | 696 |
1 files changed, 357 insertions, 339 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index 8c22a8091..b0e493871 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -2,212 +2,225 @@ Require Import Zpower ZArith. Require Import List. Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.BaseSystem. Require Import VerdiTactics. +Require Crypto.BaseSystem. Local Open Scope Z_scope. -Module Type PseudoMersenneBaseParams (Import B:BaseCoefs) (Import M:Modulus). - Parameter k : Z. - Parameter c : Z. - Axiom modulus_pseudomersenne : modulus = 2^k - c. - - Axiom base_matches_modulus : +Class PseudoMersenneBaseParams (modulus : Z) (base : list Z) (bv : BaseSystem.BaseVector base) := { + k : Z; + c : Z; + modulus_pseudomersenne : modulus = 2^k - c; + prime_modulus : Znumtheory.prime modulus; + base_matches_modulus : forall i j, (i < length base)%nat -> (j < length base)%nat -> (i+j >= length base)%nat-> let b := nth_default 0 base in let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in - b i * b j = r * (2^k * b (i+j-length base)%nat). - - Axiom base_succ : forall i, ((S i) < length base)%nat -> + b i * b j = r * (2^k * b (i+j-length base)%nat); + base_succ : forall i, ((S i) < length base)%nat -> let b := nth_default 0 base in - b (S i) mod b i = 0. - - Axiom base_tail_matches_modulus: - 2^k mod nth_default 0 base (pred (length base)) = 0. - - (* Probably implied by modulus_pseudomersenne. *) - Axiom k_nonneg : 0 <= k. - -End PseudoMersenneBaseParams. - -Module Type RepZMod (Import M:Modulus). - Parameter T : Type. - Parameter encode : F modulus -> T. - Parameter decode : T -> F modulus. - - Parameter rep : T -> F modulus -> Prop. - Local Notation "u '~=' x" := (rep u x) (at level 70). - Axiom encode_rep : forall x, encode x ~= x. - Axiom rep_decode : forall u x, u ~= x -> decode u = x. + b (S i) mod b i = 0; + base_tail_matches_modulus: + 2^k mod nth_default 0 base (pred (length base)) = 0; + k_nonneg : 0 <= k (* Probably implied by modulus_pseudomersenne. *) +}. +(* +Class RepZMod (modulus : Z) := { + T : Type; + encode : F modulus -> T; + decode : T -> F modulus; + + rep : T -> F modulus -> Prop; + encode_rep : forall x, rep (encode x) x; + rep_decode : forall u x, rep u x -> decode u = x; + + add : T -> T -> T; + add_rep : forall u v x y, rep u x -> rep v y -> rep (add u v) (x+y)%F; + + sub : T -> T -> T; + sub_rep : forall u v x y, rep u x -> rep v y -> rep (sub u v) (x-y)%F; + + mul : T -> T -> T; + mul_rep : forall u v x y, rep u x -> rep v y -> rep (mul u v) (x*y)%F +}. +*) +Print PseudoMersenneBaseParams. +Section ExtendedBaseVector. + Context (base : list Z) {modulus : Z} `(params : PseudoMersenneBaseParams modulus base). + (* This section defines a new BaseVector that has double the length of the BaseVector + * used to construct [params]. The coefficients of the new vector are as follows: + * + * ext_base[i] = if (i < length base) then base[i] else 2^k * base[i] + * + * The purpose of this construction is that it allows us to multiply numbers expressed + * using [base], obtaining a number expressed using [ext_base]. (Numbers are "expressed" as + * vectors of digits; the value of a digit vector is obtained by doing a dot product with + * the base vector.) So if x, y are digit vectors: + * + * (x \dot base) * (y \dot base) = (z \dot ext_base) + * + * Then we can separate z into its first and second halves: + * + * (z \dot ext_base) = (z1 \dot base) + (2 ^ k) * (z2 \dot base) + * + * Now, if we want to reduce the product modulo 2 ^ k - c: + * + * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + (2 ^ k) * (z2 \dot base) mod (2^k-c) + * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + c * (z2 \dot base) mod (2^k-c) + * + * This sum may be short enough to express using base; if not, we can reduce again. + *) + Definition ext_base := base ++ (map (Z.mul (2^k)) base). + + Lemma ext_base_positive : forall b, In b ext_base -> b > 0. + Proof. + unfold ext_base. intros b In_b_base. + rewrite in_app_iff in In_b_base. + destruct In_b_base as [? | In_b_extbase]; auto using BaseSystem.base_positive. + apply in_map_iff in In_b_extbase. + destruct In_b_extbase as [b' [b'_2k_b In_b'_base]]. + subst. + specialize (BaseSystem.base_positive b' In_b'_base); intro base_pos. + replace 0 with (2 ^ k * 0) by ring. + apply (Zmult_gt_compat_l b' 0 (2 ^ k)); [| apply base_pos; intuition]. + rewrite Z.gt_lt_iff. + apply Z.pow_pos_nonneg; intuition. + pose proof k_nonneg; omega. + Qed. + + Lemma base_length_nonzero : (0 < length base)%nat. + Proof. + assert (nth_default 0 base 0 = 1) by (apply BaseSystem.b0_1). + unfold nth_default in H. + case_eq (nth_error base 0); intros; + try (rewrite H0 in H; omega). + apply (nth_error_value_length _ 0 base z); auto. + Qed. - Parameter add : T -> T -> T. - Axiom add_rep : forall u v x y, u ~= x -> v ~= y -> add u v ~= (x+y)%F. + Lemma b0_1 : forall x, nth_default x ext_base 0 = 1. + Proof. + intros. unfold ext_base. + rewrite nth_default_app. + assert (0 < length base)%nat by (apply base_length_nonzero). + destruct (lt_dec 0 (length base)); try apply BaseSystem.b0_1; try omega. + Qed. - Parameter sub : T -> T -> T. - Axiom sub_rep : forall u v x y, u ~= x -> v ~= y -> sub u v ~= (x-y)%F. + Lemma two_k_nonzero : 2^k <> 0. + Proof. + pose proof (Z.pow_eq_0 2 k k_nonneg). + intuition. + Qed. - Parameter mul : T -> T -> T. - Axiom mul_rep : forall u v x y, u ~= x -> v ~= y -> mul u v ~= (x*y)%F. -End RepZMod. + Lemma map_nth_default_base_high : forall n, (n < (length base))%nat -> + nth_default 0 (map (Z.mul (2 ^ k)) base) n = + (2 ^ k) * (nth_default 0 base n). + Proof. + intros. + erewrite map_nth_default; auto. + Qed. -Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersenneBaseParams BC M) <: RepZMod M. - Module EC <: BaseCoefs. - Definition base := BC.base ++ (map (Z.mul (2^(P.k))) BC.base). - - Lemma base_positive : forall b, In b base -> b > 0. - Proof. - unfold base. intros. - rewrite in_app_iff in H. - destruct H. { - apply BC.base_positive; auto. - } { - specialize BC.base_positive. - induction BC.base; intros. { - simpl in H; intuition. - } { - simpl in H. - destruct H; subst. - replace 0 with (2 ^ P.k * 0) by auto. - apply (Zmult_gt_compat_l a 0 (2 ^ P.k)). - rewrite Z.gt_lt_iff. - apply Z.pow_pos_nonneg; intuition. - pose proof P.k_nonneg; omega. - apply H0; left; auto. - apply IHl; auto. - intros. apply H0; auto. right; auto. - } - } - Qed. - - Lemma base_length_nonzero : (0 < length BC.base)%nat. - Proof. - assert (nth_default 0 BC.base 0 = 1) by (apply BC.b0_1). - unfold nth_default in H. - case_eq (nth_error BC.base 0); intros; - try (rewrite H0 in H; omega). - apply (nth_error_value_length _ 0 BC.base z); auto. - Qed. - - Lemma b0_1 : forall x, nth_default x base 0 = 1. - Proof. - intros. unfold base. - rewrite nth_default_app. - assert (0 < length BC.base)%nat by (apply base_length_nonzero). - destruct (lt_dec 0 (length BC.base)); try apply BC.b0_1; try omega. - Qed. - - Lemma two_k_nonzero : 2^P.k <> 0. - pose proof (Z.pow_eq_0 2 P.k P.k_nonneg). - intuition. - Qed. - - Lemma map_nth_default_base_high : forall n, (n < (length BC.base))%nat -> - nth_default 0 (map (Z.mul (2 ^ P.k)) BC.base) n = - (2 ^ P.k) * (nth_default 0 BC.base n). - Proof. - intros. - erewrite map_nth_default; auto. - Qed. - - Lemma base_succ : forall i, ((S i) < length base)%nat -> - let b := nth_default 0 base in - b (S i) mod b i = 0. - Proof. - intros; subst b; unfold base. - repeat rewrite nth_default_app. - do 2 break_if; try apply P.base_succ; try omega. { - destruct (lt_eq_lt_dec (S i) (length BC.base)). { - destruct s; intuition. - rewrite map_nth_default_base_high by omega. - replace i with (pred(length BC.base)) by omega. - rewrite <- Zmult_mod_idemp_l. - rewrite P.base_tail_matches_modulus; simpl. - rewrite Zmod_0_l; auto. - } { - assert (length BC.base <= i)%nat by (apply lt_n_Sm_le; auto); omega. - } - } { - unfold base in H; rewrite app_length, map_length in H. - repeat rewrite map_nth_default_base_high by omega. - rewrite Zmult_mod_distr_l. - rewrite <- minus_Sn_m by omega. - rewrite P.base_succ by omega; auto. - } - Qed. - - Lemma base_good_over_boundary : forall - (i : nat) - (l : (i < length BC.base)%nat) - (j' : nat) - (Hj': (i + j' < length BC.base)%nat) - , - 2 ^ P.k * (nth_default 0 BC.base i * nth_default 0 BC.base j') = - 2 ^ P.k * (nth_default 0 BC.base i * nth_default 0 BC.base j') / - (2 ^ P.k * nth_default 0 BC.base (i + j')) * - (2 ^ P.k * nth_default 0 BC.base (i + j')) - . - Proof. intros. - remember (nth_default 0 BC.base) as b. - rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero). - replace (b i * b j' / b (i + j')%nat * (2 ^ P.k * b (i + j')%nat)) - with ((2 ^ P.k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring. - rewrite Z.mul_cancel_l by (exact two_k_nonzero). - replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat)) - with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring. - subst b. - apply (BC.base_good i j'); omega. - Qed. - - Lemma base_good : - forall i j, (i+j < length base)%nat -> - let b := nth_default 0 base in - let r := (b i * b j) / b (i+j)%nat in - b i * b j = r * b (i+j)%nat. - Proof. - intros. - subst b. subst r. - unfold base in *. - rewrite app_length in H; rewrite map_length in H. - repeat rewrite nth_default_app. - destruct (lt_dec i (length BC.base)); - destruct (lt_dec j (length BC.base)); - destruct (lt_dec (i + j) (length BC.base)); - try omega. - { (* i < length BC.base, j < length BC.base, i + j < length BC.base *) - apply BC.base_good; auto. - } { (* i < length BC.base, j < length BC.base, i + j >= length BC.base *) - rewrite (map_nth_default _ _ _ _ 0) by omega. - apply P.base_matches_modulus; omega. - } { (* i < length BC.base, j >= length BC.base, i + j >= length BC.base *) - do 2 rewrite map_nth_default_base_high by omega. - remember (j - length BC.base)%nat as j'. - replace (i + j - length BC.base)%nat with (i + j')%nat by omega. - replace (nth_default 0 BC.base i * (2 ^ P.k * nth_default 0 BC.base j')) - with (2 ^ P.k * (nth_default 0 BC.base i * nth_default 0 BC.base j')) - by ring. - eapply base_good_over_boundary; eauto; omega. - } { (* i >= length BC.base, j < length BC.base, i + j >= length BC.base *) - do 2 rewrite map_nth_default_base_high by omega. - remember (i - length BC.base)%nat as i'. - replace (i + j - length BC.base)%nat with (j + i')%nat by omega. - replace (2 ^ P.k * nth_default 0 BC.base i' * nth_default 0 BC.base j) - with (2 ^ P.k * (nth_default 0 BC.base j * nth_default 0 BC.base i')) - by ring. - eapply base_good_over_boundary; eauto; omega. - } - Qed. - End EC. + Lemma ext_base_succ : forall i, ((S i) < length ext_base)%nat -> + let b := nth_default 0 ext_base in + b (S i) mod b i = 0. + Proof. + intros; subst b; unfold ext_base. + repeat rewrite nth_default_app. + do 2 break_if; [apply base_succ; auto | omega | | ]. { + destruct (lt_eq_lt_dec (S i) (length base)); try omega. + destruct s; intuition. + rewrite map_nth_default_base_high by omega. + replace i with (pred(length base)) by omega. + rewrite <- Zmult_mod_idemp_l. + rewrite base_tail_matches_modulus. + rewrite Zmod_0_l; auto. + } { + unfold ext_base in *; rewrite app_length, map_length in *. + repeat rewrite map_nth_default_base_high by omega. + rewrite Zmult_mod_distr_l. + rewrite <- minus_Sn_m by omega. + rewrite base_succ by omega; ring. + } + Qed. + + Lemma base_good_over_boundary : forall + (i : nat) + (l : (i < length base)%nat) + (j' : nat) + (Hj': (i + j' < length base)%nat) + , + 2 ^ k * (nth_default 0 base i * nth_default 0 base j') = + 2 ^ k * (nth_default 0 base i * nth_default 0 base j') / + (2 ^ k * nth_default 0 base (i + j')) * + (2 ^ k * nth_default 0 base (i + j')) + . + Proof. + intros. + remember (nth_default 0 base) as b. + rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero). + replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat)) + with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring. + rewrite Z.mul_cancel_l by (exact two_k_nonzero). + replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat)) + with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring. + subst b. + apply (BaseSystem.base_good i j'); omega. + Qed. + + Lemma ext_base_good : + forall i j, (i+j < length ext_base)%nat -> + let b := nth_default 0 ext_base in + let r := (b i * b j) / b (i+j)%nat in + b i * b j = r * b (i+j)%nat. + Proof. + intros. + subst b. subst r. + unfold ext_base in *. + rewrite app_length in H; rewrite map_length in H. + repeat rewrite nth_default_app. + destruct (lt_dec i (length base)); + destruct (lt_dec j (length base)); + destruct (lt_dec (i + j) (length base)); + try omega. + { (* i < length base, j < length base, i + j < length base *) + apply BaseSystem.base_good; auto. + } { (* i < length base, j < length base, i + j >= length base *) + rewrite (map_nth_default _ _ _ _ 0) by omega. + apply base_matches_modulus; omega. + } { (* i < length base, j >= length base, i + j >= length base *) + do 2 rewrite map_nth_default_base_high by omega. + remember (j - length base)%nat as j'. + replace (i + j - length base)%nat with (i + j')%nat by omega. + replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j')) + with (2 ^ k * (nth_default 0 base i * nth_default 0 base j')) + by ring. + eapply base_good_over_boundary; eauto; omega. + } { (* i >= length base, j < length base, i + j >= length base *) + do 2 rewrite map_nth_default_base_high by omega. + remember (i - length base)%nat as i'. + replace (i + j - length base)%nat with (j + i')%nat by omega. + replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j) + with (2 ^ k * (nth_default 0 base j * nth_default 0 base i')) + by ring. + eapply base_good_over_boundary; eauto; omega. + } + Qed. + Instance ExtBaseVector : BaseSystem.BaseVector ext_base := { + base_positive := ext_base_positive; + b0_1 := b0_1; + base_good := ext_base_good + }. +End ExtendedBaseVector. - Module E := BaseSystem EC. - Module B := BaseSystem BC. +Print ExtBaseVector. +Section PseudoMersenneBase. + Context `(prm :PseudoMersenneBaseParams). - Definition T := B.digits. - Local Hint Unfold T. - Definition decode (us : T) : F modulus := ZToField (B.decode us). + Definition T := BaseSystem.digits. + Definition decode (us : T) : F modulus := ZToField (BaseSystem.decode base us). Local Hint Unfold decode. - Definition rep (us : T) (x : F modulus) := (length us <= length BC.base)%nat /\ decode us = x. + Definition rep (us : T) (x : F modulus) := (length us <= length base)%nat /\ decode us = x. Local Notation "u '~=' x" := (rep u x) (at level 70). Local Hint Unfold rep. @@ -216,188 +229,187 @@ Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersen autounfold; intuition. Qed. - Definition encode (x : F modulus) := B.encode x. + Definition encode (x : F modulus) := BaseSystem.encode x. Lemma encode_rep : forall x : F modulus, encode x ~= x. Proof. intros. unfold encode, rep. split. { - unfold B.encode; simpl. - apply EC.base_length_nonzero. + unfold encode; simpl. + apply base_length_nonzero. + assumption. } { unfold decode. - rewrite B.encode_rep. - apply ZToField_idempotent. (* TODO: rename this lemma *) + rewrite BaseSystem.encode_rep. + apply ZToField_FieldToZ. + assumption. } Qed. - Definition add (us vs : T) := B.add us vs. - Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> add u v ~= (x+y)%F. + Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.add u v ~= (x+y)%F. Proof. autounfold; intuition. { unfold add. - rewrite B.add_length_le_max. + rewrite BaseSystem.add_length_le_max. case_max; try rewrite Max.max_r; omega. } - unfold decode in *; unfold B.decode in *. - rewrite B.add_rep. + unfold decode in *; unfold BaseSystem.decode in *. + rewrite BaseSystem.add_rep. rewrite ZToField_add. subst; auto. Qed. - Definition sub (us vs : T) := B.sub us vs. - Lemma sub_rep : forall u v x y, u ~= x -> v ~= y -> sub u v ~= (x-y)%F. + Lemma sub_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.sub u v ~= (x-y)%F. Proof. autounfold; intuition. { - rewrite B.sub_length_le_max. + rewrite BaseSystem.sub_length_le_max. case_max; try rewrite Max.max_r; omega. } - unfold decode in *; unfold B.decode in *. - rewrite B.sub_rep. + unfold decode in *; unfold BaseSystem.decode in *. + rewrite BaseSystem.sub_rep. rewrite ZToField_sub. subst; auto. Qed. - Lemma decode_short : forall (us : B.digits), - (length us <= length BC.base)%nat -> B.decode us = E.decode us. + Lemma decode_short : forall (us : T), + (length us <= length base)%nat -> + BaseSystem.decode base us = BaseSystem.decode (ext_base base prm) us. Proof. intros. - unfold B.decode, B.decode', E.decode, E.decode'. + unfold BaseSystem.decode, BaseSystem.decode'. rewrite combine_truncate_r. - rewrite (combine_truncate_r us EC.base). + rewrite (combine_truncate_r us (ext_base base prm)). f_equal; f_equal. - unfold EC.base. + unfold ext_base. rewrite firstn_app_inleft; auto; omega. Qed. Lemma extended_base_length: - length EC.base = (length BC.base + length BC.base)%nat. + length (ext_base base prm) = (length base + length base)%nat. Proof. - unfold EC.base; rewrite app_length; rewrite map_length; auto. + unfold ext_base; rewrite app_length; rewrite map_length; auto. Qed. - Lemma mul_rep_extended : forall (us vs : B.digits), - (length us <= length BC.base)%nat -> - (length vs <= length BC.base)%nat -> - B.decode us * B.decode vs = E.decode (E.mul us vs). + Lemma mul_rep_extended : forall (us vs : T), + (length us <= length base)%nat -> + (length vs <= length base)%nat -> + (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode (ext_base base prm) (BaseSystem.mul (ext_base base prm) us vs). Proof. intros. - rewrite E.mul_rep by (unfold EC.base; simpl_list; omega). + rewrite BaseSystem.mul_rep by (apply ExtBaseVector || unfold ext_base; simpl_list; omega). f_equal; rewrite decode_short; auto. Qed. - (* Converts from length of E.base to length of B.base by reduction modulo M.*) - Definition reduce (us : E.digits) : B.digits := - let high := skipn (length BC.base) us in - let low := firstn (length BC.base) us in - let wrap := map (Z.mul P.c) high in - B.add low wrap. - - Lemma two_k_nonzero : 2^P.k <> 0. - pose proof (Z.pow_eq_0 2 P.k P.k_nonneg). - intuition. - Qed. + (* Converts from length of extended base to length of base by reduction modulo M.*) + Definition reduce (us : T) : T := + let high := skipn (length base) us in + let low := firstn (length base) us in + let wrap := map (Z.mul c) high in + BaseSystem.add low wrap. Lemma modulus_nonzero : modulus <> 0. pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega. Qed. (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *) - Lemma pseudomersenne_add: forall x y, (x + ((2^P.k) * y)) mod modulus = (x + (P.c * y)) mod modulus. + Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus. Proof. intros. - replace (2^P.k) with (((2^P.k) - P.c) + P.c) by auto. + replace (2^k) with ((2^k - c) + c) by ring. rewrite Z.mul_add_distr_r. rewrite Zplus_mod. - rewrite <- P.modulus_pseudomersenne. + rewrite <- modulus_pseudomersenne. rewrite Z.mul_comm. rewrite mod_mult_plus; auto using modulus_nonzero. rewrite <- Zplus_mod; auto. Qed. - Lemma extended_shiftadd: forall (us : E.digits), E.decode us = - B.decode (firstn (length BC.base) us) + - (2^P.k * B.decode (skipn (length BC.base) us)). + Lemma extended_shiftadd: forall (us : T), + BaseSystem.decode (ext_base base prm) us = + BaseSystem.decode base (firstn (length base) us) + + (2^k * BaseSystem.decode base (skipn (length base) us)). Proof. intros. - unfold B.decode, E.decode; rewrite <- B.mul_each_rep. - replace B.decode' with E.decode' by auto. - unfold EC.base. - replace (map (Z.mul (2 ^ P.k)) BC.base) with (E.mul_each (2 ^ P.k) BC.base) by auto. - rewrite E.base_mul_app. - rewrite <- E.mul_each_rep; auto. + unfold BaseSystem.decode; rewrite <- BaseSystem.mul_each_rep. + unfold ext_base. + replace (map (Z.mul (2 ^ k)) base) with (BaseSystem.mul_each (2 ^ k) base) by auto. + rewrite BaseSystem.base_mul_app. + rewrite <- BaseSystem.mul_each_rep; auto. Qed. - Lemma reduce_rep : forall us, B.decode (reduce us) mod modulus = (E.decode us) mod modulus. + Lemma reduce_rep : forall us, + BaseSystem.decode base (reduce us) mod modulus = + BaseSystem.decode (ext_base base prm) us mod modulus. Proof. intros. rewrite extended_shiftadd. rewrite pseudomersenne_add. unfold reduce. - remember (firstn (length BC.base) us) as low. - remember (skipn (length BC.base) us) as high. - unfold B.decode. - rewrite B.add_rep. - replace (map (Z.mul P.c) high) with (B.mul_each P.c high) by auto. - rewrite B.mul_each_rep; auto. + remember (firstn (length base) us) as low. + remember (skipn (length base) us) as high. + unfold BaseSystem.decode. + rewrite BaseSystem.add_rep. + replace (map (Z.mul c) high) with (BaseSystem.mul_each c high) by auto. + rewrite BaseSystem.mul_each_rep; auto. Qed. Lemma reduce_length : forall us, - (length us <= length EC.base)%nat -> - (length (reduce us) <= length (BC.base))%nat. + (length us <= length (ext_base base prm))%nat -> + (length (reduce us) <= length (base))%nat. Proof. intros. unfold reduce. - remember (map (Z.mul P.c) (skipn (length BC.base) us)) as high. - remember (firstn (length BC.base) us) as low. + remember (map (Z.mul c) (skipn (length base) us)) as high. + remember (firstn (length base) us) as low. assert (length low >= length high)%nat. { subst. rewrite firstn_length. rewrite map_length. rewrite skipn_length. - destruct (le_dec (length BC.base) (length us)). { + destruct (le_dec (length base) (length us)). { rewrite Min.min_l by omega. rewrite extended_base_length in H. omega. } { - rewrite Min.min_r by omega. omega. + rewrite Min.min_r; omega. } } - assert ((length low <= length BC.base)%nat) + assert ((length low <= length base)%nat) by (rewrite Heqlow; rewrite firstn_length; apply Min.le_min_l). - assert (length high <= length BC.base)%nat + assert (length high <= length base)%nat by (rewrite Heqhigh; rewrite map_length; rewrite skipn_length; rewrite extended_base_length in H; omega). - rewrite B.add_trailing_zeros; auto. - rewrite (B.add_same_length _ _ (length low)); auto. + rewrite BaseSystem.add_trailing_zeros; auto. + rewrite (BaseSystem.add_same_length _ _ (length low)); auto. rewrite app_length. - rewrite B.length_zeros; intuition. + rewrite BaseSystem.length_zeros; intuition. Qed. - Definition mul (us vs : T) := reduce (E.mul us vs). + Definition mul (us vs : T) := reduce (BaseSystem.mul (ext_base base prm) us vs). Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> mul u v ~= (x*y)%F. Proof. - autounfold; unfold mul; intuition. { - rewrite reduce_length; try omega. - rewrite E.mul_length. - rewrite extended_base_length. + autounfold; unfold mul; intuition. + { + apply reduce_length. + rewrite BaseSystem.mul_length, extended_base_length. omega. + } { + rewrite ZToField_mod, reduce_rep, <-ZToField_mod. + rewrite BaseSystem.mul_rep by + (apply ExtBaseVector || rewrite extended_base_length; omega). + subst. + do 2 rewrite decode_short by auto. + apply ZToField_mul. } - rewrite ZToField_mod, reduce_rep, <-ZToField_mod. - rewrite E.mul_rep; try (rewrite extended_base_length; omega). - subst; auto. - replace (E.decode u) with (B.decode u) by (apply decode_short; omega). - replace (E.decode v) with (B.decode v) by (apply decode_short; omega). - apply ZToField_mul. Qed. Definition add_to_nth n (x:Z) xs := set_nth n (x + nth_default 0 xs n) xs. Hint Unfold add_to_nth. - (* i must be in the domain of BC.base *) + (* i must be in the domain of base *) Definition cap i := - if eq_nat_dec i (pred (length BC.base)) - then (2^P.k) / nth_default 0 BC.base i - else nth_default 0 BC.base (S i) / nth_default 0 BC.base i. + if eq_nat_dec i (pred (length base)) + then (2^k) / nth_default 0 base i + else nth_default 0 base (S i) / nth_default 0 base i. Definition carry_simple i := fun us => let di := nth_default 0 us i in @@ -407,106 +419,109 @@ Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersen Definition carry_and_reduce i := fun us => let di := nth_default 0 us i in let us' := set_nth i (di mod cap i) us in - add_to_nth 0 (P.c * (di / cap i)) us'. + add_to_nth 0 (c * (di / cap i)) us'. - Definition carry i : B.digits -> B.digits := - if eq_nat_dec i (pred (length BC.base)) + Definition carry i : T -> T := + if eq_nat_dec i (pred (length base)) then carry_and_reduce i else carry_simple i. + (* TODO: move to BaseSystemProofs *) Lemma decode'_splice : forall xs ys bs, - B.decode' bs (xs ++ ys) = - B.decode' (firstn (length xs) bs) xs + - B.decode' (skipn (length xs) bs) ys. + BaseSystem.decode' bs (xs ++ ys) = + BaseSystem.decode' (firstn (length xs) bs) xs + + BaseSystem.decode' (skipn (length xs) bs) ys. Proof. + unfold BaseSystem.decode'. induction xs; destruct ys, bs; boring. - unfold B.decode'. - rewrite combine_truncate_r. - ring. + + rewrite combine_truncate_r. + do 2 rewrite Z.add_0_r; auto. + + unfold BaseSystem.accumulate. + apply Z.add_assoc. Qed. Lemma set_nth_sum : forall n x us, (n < length us)%nat -> - B.decode (set_nth n x us) = - (x - nth_default 0 us n) * nth_default 0 BC.base n + B.decode us. + BaseSystem.decode base (set_nth n x us) = + (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us. Proof. intros. - unfold B.decode. + unfold BaseSystem.decode. nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*) unfold splice_nth. rewrite <- (firstn_skipn n us) at 4. do 2 rewrite decode'_splice. remember (length (firstn n us)) as n0. ring_simplify. - remember (B.decode' (firstn n0 BC.base) (firstn n us)). + remember (BaseSystem.decode' (firstn n0 base) (firstn n us)). rewrite (skipn_nth_default n us 0) by omega. rewrite firstn_length in Heqn0. rewrite Min.min_l in Heqn0 by omega; subst n0. - destruct (le_lt_dec (length BC.base) n). { + destruct (le_lt_dec (length base) n). { rewrite nth_default_out_of_bounds by auto. rewrite skipn_all by omega. - do 2 rewrite B.decode_base_nil. + do 2 rewrite BaseSystem.decode_base_nil. ring_simplify; auto. } { - rewrite (skipn_nth_default n BC.base 0) by omega. - do 2 rewrite B.decode'_cons. + rewrite (skipn_nth_default n base 0) by omega. + do 2 rewrite BaseSystem.decode'_cons. ring_simplify; ring. } Qed. Lemma add_to_nth_sum : forall n x us, (n < length us)%nat -> - B.decode (add_to_nth n x us) = - x * nth_default 0 BC.base n + B.decode us. + BaseSystem.decode base (add_to_nth n x us) = + x * nth_default 0 base n + BaseSystem.decode base us. Proof. unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto. Qed. - Lemma nth_default_base_positive : forall i, (i < length BC.base)%nat -> - nth_default 0 BC.base i > 0. + Lemma nth_default_base_positive : forall i, (i < length base)%nat -> + nth_default 0 base i > 0. Proof. intros. pose proof (nth_error_length_exists_value _ _ H). destruct H0. pose proof (nth_error_value_In _ _ _ H0). - pose proof (BC.base_positive _ H1). + pose proof (BaseSystem.base_positive _ H1). unfold nth_default. rewrite H0; auto. Qed. - Lemma base_succ_div_mult : forall i, ((S i) < length BC.base)%nat -> - nth_default 0 BC.base (S i) = nth_default 0 BC.base i * - (nth_default 0 BC.base (S i) / nth_default 0 BC.base i). + Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat -> + nth_default 0 base (S i) = nth_default 0 base i * + (nth_default 0 base (S i) / nth_default 0 base i). Proof. intros. apply Z_div_exact_2; try (apply nth_default_base_positive; omega). - apply P.base_succ; auto. + apply base_succ; auto. Qed. - Lemma base_length_lt_pred : (pred (length BC.base) < length BC.base)%nat. + Lemma base_length_lt_pred : (pred (length base) < length base)%nat. Proof. - pose proof EC.base_length_nonzero; omega. + pose proof (base_length_nonzero base); omega. Qed. Hint Resolve base_length_lt_pred. - Lemma cap_positive: forall i, (i < length BC.base)%nat -> cap i > 0. + Lemma cap_positive: forall i, (i < length base)%nat -> cap i > 0. Proof. unfold cap; intros; break_if. { - apply div_positive_gt_0; try (subst; apply P.base_tail_matches_modulus). { + apply div_positive_gt_0; try (subst; apply base_tail_matches_modulus). { rewrite <- two_p_equiv. apply two_p_gt_ZERO. - apply P.k_nonneg. + apply k_nonneg. } { apply nth_default_base_positive; subst; auto. } } { - apply div_positive_gt_0; try (apply P.base_succ; omega); + apply div_positive_gt_0; try (apply base_succ; omega); try (apply nth_default_base_positive; omega). } Qed. - Lemma cap_div_mod : forall us i, (i < (pred (length BC.base)))%nat -> + Lemma cap_div_mod : forall us i, (i < (pred (length base)))%nat -> let di := nth_default 0 us i in - (di - (di mod cap i)) * nth_default 0 BC.base i = - (di / cap i) * nth_default 0 BC.base (S i). + (di - (di mod cap i)) * nth_default 0 base i = + (di / cap i) * nth_default 0 base (S i). Proof. intros. rewrite (Z_div_mod_eq di (cap i)) at 1 by (apply cap_positive; omega); @@ -516,9 +531,9 @@ Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersen Qed. Lemma carry_simple_decode_eq : forall i us, - (length us = length BC.base) -> - (i < (pred (length BC.base)))%nat -> - B.decode (carry_simple i us) = B.decode us. + (length us = length base) -> + (i < (pred (length base)))%nat -> + BaseSystem.decode base (carry_simple i us) = BaseSystem.decode base us. Proof. unfold carry_simple. intros. rewrite add_to_nth_sum by (rewrite length_set_nth; omega). @@ -527,49 +542,52 @@ Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersen Qed. Lemma two_k_div_mul_last : - 2 ^ P.k = nth_default 0 BC.base (pred (length BC.base)) * - (2 ^ P.k / nth_default 0 BC.base (pred (length BC.base))). + 2 ^ k = nth_default 0 base (pred (length base)) * + (2 ^ k / nth_default 0 base (pred (length base))). Proof. intros. - pose proof P.base_tail_matches_modulus. - rewrite (Z_div_mod_eq (2 ^ P.k) (nth_default 0 BC.base (pred (length BC.base)))) at 1 by + pose proof base_tail_matches_modulus. + rewrite (Z_div_mod_eq (2 ^ k) (nth_default 0 base (pred (length base)))) at 1 by (apply nth_default_base_positive; auto); omega. Qed. Lemma cap_div_mod_reduce : forall us, - let i := pred (length BC.base) in + let i := pred (length base) in let di := nth_default 0 us i in - (di - (di mod cap i)) * nth_default 0 BC.base i = - (di / cap i) * 2 ^ P.k. + (di - (di mod cap i)) * nth_default 0 base i = + (di / cap i) * 2 ^ k. Proof. intros. rewrite (Z_div_mod_eq di (cap i)) at 1 by (apply cap_positive; auto); ring_simplify. unfold cap; break_if; intuition. rewrite Z.mul_comm, Z.mul_assoc. - subst i; rewrite <- two_k_div_mul_last; auto. + subst i; rewrite <- two_k_div_mul_last; ring. Qed. Lemma carry_decode_eq_reduce : forall us, - (length us = length BC.base) -> - B.decode (carry_and_reduce (pred (length BC.base)) us) mod modulus - = B.decode us mod modulus. + (length us = length base) -> + BaseSystem.decode base (carry_and_reduce (pred (length base)) us) mod modulus + = BaseSystem.decode base us mod modulus. Proof. unfold carry_and_reduce; intros. - pose proof EC.base_length_nonzero. + pose proof (base_length_nonzero base). rewrite add_to_nth_sum by (rewrite length_set_nth; omega). rewrite set_nth_sum by omega. - rewrite Zplus_comm; rewrite <- Z.mul_assoc. + rewrite Zplus_comm, <- Z.mul_assoc. rewrite <- pseudomersenne_add. - rewrite BC.b0_1. - rewrite (Z.mul_comm (2 ^ P.k)). + rewrite BaseSystem.b0_1. + rewrite (Z.mul_comm (2 ^ k)). rewrite <- Zred_factor0. - rewrite <- cap_div_mod_reduce by auto; auto. + rewrite <- cap_div_mod_reduce by auto. + do 2 rewrite Zmult_minus_distr_r. + f_equal. + ring. Qed. Lemma carry_length : forall i us, - (length us <= length BC.base)%nat -> - (length (carry i us) <= length BC.base)%nat. + (length us <= length base)%nat -> + (length (carry i us) <= length base)%nat. Proof. unfold carry, carry_simple, carry_and_reduce, add_to_nth. intros; break_if; subst; repeat (rewrite length_set_nth); auto. @@ -577,8 +595,8 @@ Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersen Hint Resolve carry_length. Lemma carry_rep : forall i us x, - (length us = length BC.base) -> - (i < length BC.base)%nat -> + (length us = length base) -> + (i < length base)%nat -> us ~= x -> carry i us ~= x. Proof. pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq. @@ -591,24 +609,24 @@ Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersen Definition carry_sequence is us := fold_right carry us is. Lemma carry_sequence_length: forall is us, - (length us <= length BC.base)%nat -> - (length (carry_sequence is us) <= length BC.base)%nat. + (length us <= length base)%nat -> + (length (carry_sequence is us) <= length base)%nat. Proof. induction is; boring. Qed. Hint Resolve carry_sequence_length. Lemma carry_length_exact : forall i us, - (length us = length BC.base)%nat -> - (length (carry i us) = length BC.base)%nat. + (length us = length base)%nat -> + (length (carry i us) = length base)%nat. Proof. unfold carry, carry_simple, carry_and_reduce, add_to_nth. intros; break_if; subst; repeat (rewrite length_set_nth); auto. Qed. Lemma carry_sequence_length_exact: forall is us, - (length us = length BC.base)%nat -> - (length (carry_sequence is us) = length BC.base)%nat. + (length us = length base)%nat -> + (length (carry_sequence is us) = length base)%nat. Proof. induction is; boring. apply carry_length_exact; auto. @@ -616,10 +634,10 @@ Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersen Hint Resolve carry_sequence_length_exact. Lemma carry_sequence_rep : forall is us x, - (forall i, In i is -> (i < length BC.base)%nat) -> - (length us = length BC.base) -> + (forall i, In i is -> (i < length base)%nat) -> + (length us = length base) -> us ~= x -> carry_sequence is us ~= x. Proof. induction is; boring. Qed. -End PseudoMersenneBase.
\ No newline at end of file +End PseudoMersenneBase. |