aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ModularBaseSystem.v
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@mit.edu>2016-03-11 16:32:48 -0500
committerGravatar Jade Philipoom <jadep@mit.edu>2016-03-11 16:32:48 -0500
commit724b7b2acb9b857d7c511a320973cead308117c6 (patch)
treeac7c7d1dcd6fea890c138c6ea9a7e1df65097f0b /src/ModularArithmetic/ModularBaseSystem.v
parentb690b5180af6c8dadcf28dbe6661b43deff47331 (diff)
Refactored BaseSystem and ModularBaseSystem.
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystem.v')
-rw-r--r--src/ModularArithmetic/ModularBaseSystem.v696
1 files changed, 357 insertions, 339 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v
index 8c22a8091..b0e493871 100644
--- a/src/ModularArithmetic/ModularBaseSystem.v
+++ b/src/ModularArithmetic/ModularBaseSystem.v
@@ -2,212 +2,225 @@ Require Import Zpower ZArith.
Require Import List.
Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
-Require Import Crypto.BaseSystem.
Require Import VerdiTactics.
+Require Crypto.BaseSystem.
Local Open Scope Z_scope.
-Module Type PseudoMersenneBaseParams (Import B:BaseCoefs) (Import M:Modulus).
- Parameter k : Z.
- Parameter c : Z.
- Axiom modulus_pseudomersenne : modulus = 2^k - c.
-
- Axiom base_matches_modulus :
+Class PseudoMersenneBaseParams (modulus : Z) (base : list Z) (bv : BaseSystem.BaseVector base) := {
+ k : Z;
+ c : Z;
+ modulus_pseudomersenne : modulus = 2^k - c;
+ prime_modulus : Znumtheory.prime modulus;
+ base_matches_modulus :
forall i j,
(i < length base)%nat ->
(j < length base)%nat ->
(i+j >= length base)%nat->
let b := nth_default 0 base in
let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in
- b i * b j = r * (2^k * b (i+j-length base)%nat).
-
- Axiom base_succ : forall i, ((S i) < length base)%nat ->
+ b i * b j = r * (2^k * b (i+j-length base)%nat);
+ base_succ : forall i, ((S i) < length base)%nat ->
let b := nth_default 0 base in
- b (S i) mod b i = 0.
-
- Axiom base_tail_matches_modulus:
- 2^k mod nth_default 0 base (pred (length base)) = 0.
-
- (* Probably implied by modulus_pseudomersenne. *)
- Axiom k_nonneg : 0 <= k.
-
-End PseudoMersenneBaseParams.
-
-Module Type RepZMod (Import M:Modulus).
- Parameter T : Type.
- Parameter encode : F modulus -> T.
- Parameter decode : T -> F modulus.
-
- Parameter rep : T -> F modulus -> Prop.
- Local Notation "u '~=' x" := (rep u x) (at level 70).
- Axiom encode_rep : forall x, encode x ~= x.
- Axiom rep_decode : forall u x, u ~= x -> decode u = x.
+ b (S i) mod b i = 0;
+ base_tail_matches_modulus:
+ 2^k mod nth_default 0 base (pred (length base)) = 0;
+ k_nonneg : 0 <= k (* Probably implied by modulus_pseudomersenne. *)
+}.
+(*
+Class RepZMod (modulus : Z) := {
+ T : Type;
+ encode : F modulus -> T;
+ decode : T -> F modulus;
+
+ rep : T -> F modulus -> Prop;
+ encode_rep : forall x, rep (encode x) x;
+ rep_decode : forall u x, rep u x -> decode u = x;
+
+ add : T -> T -> T;
+ add_rep : forall u v x y, rep u x -> rep v y -> rep (add u v) (x+y)%F;
+
+ sub : T -> T -> T;
+ sub_rep : forall u v x y, rep u x -> rep v y -> rep (sub u v) (x-y)%F;
+
+ mul : T -> T -> T;
+ mul_rep : forall u v x y, rep u x -> rep v y -> rep (mul u v) (x*y)%F
+}.
+*)
+Print PseudoMersenneBaseParams.
+Section ExtendedBaseVector.
+ Context (base : list Z) {modulus : Z} `(params : PseudoMersenneBaseParams modulus base).
+ (* This section defines a new BaseVector that has double the length of the BaseVector
+ * used to construct [params]. The coefficients of the new vector are as follows:
+ *
+ * ext_base[i] = if (i < length base) then base[i] else 2^k * base[i]
+ *
+ * The purpose of this construction is that it allows us to multiply numbers expressed
+ * using [base], obtaining a number expressed using [ext_base]. (Numbers are "expressed" as
+ * vectors of digits; the value of a digit vector is obtained by doing a dot product with
+ * the base vector.) So if x, y are digit vectors:
+ *
+ * (x \dot base) * (y \dot base) = (z \dot ext_base)
+ *
+ * Then we can separate z into its first and second halves:
+ *
+ * (z \dot ext_base) = (z1 \dot base) + (2 ^ k) * (z2 \dot base)
+ *
+ * Now, if we want to reduce the product modulo 2 ^ k - c:
+ *
+ * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + (2 ^ k) * (z2 \dot base) mod (2^k-c)
+ * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + c * (z2 \dot base) mod (2^k-c)
+ *
+ * This sum may be short enough to express using base; if not, we can reduce again.
+ *)
+ Definition ext_base := base ++ (map (Z.mul (2^k)) base).
+
+ Lemma ext_base_positive : forall b, In b ext_base -> b > 0.
+ Proof.
+ unfold ext_base. intros b In_b_base.
+ rewrite in_app_iff in In_b_base.
+ destruct In_b_base as [? | In_b_extbase]; auto using BaseSystem.base_positive.
+ apply in_map_iff in In_b_extbase.
+ destruct In_b_extbase as [b' [b'_2k_b In_b'_base]].
+ subst.
+ specialize (BaseSystem.base_positive b' In_b'_base); intro base_pos.
+ replace 0 with (2 ^ k * 0) by ring.
+ apply (Zmult_gt_compat_l b' 0 (2 ^ k)); [| apply base_pos; intuition].
+ rewrite Z.gt_lt_iff.
+ apply Z.pow_pos_nonneg; intuition.
+ pose proof k_nonneg; omega.
+ Qed.
+
+ Lemma base_length_nonzero : (0 < length base)%nat.
+ Proof.
+ assert (nth_default 0 base 0 = 1) by (apply BaseSystem.b0_1).
+ unfold nth_default in H.
+ case_eq (nth_error base 0); intros;
+ try (rewrite H0 in H; omega).
+ apply (nth_error_value_length _ 0 base z); auto.
+ Qed.
- Parameter add : T -> T -> T.
- Axiom add_rep : forall u v x y, u ~= x -> v ~= y -> add u v ~= (x+y)%F.
+ Lemma b0_1 : forall x, nth_default x ext_base 0 = 1.
+ Proof.
+ intros. unfold ext_base.
+ rewrite nth_default_app.
+ assert (0 < length base)%nat by (apply base_length_nonzero).
+ destruct (lt_dec 0 (length base)); try apply BaseSystem.b0_1; try omega.
+ Qed.
- Parameter sub : T -> T -> T.
- Axiom sub_rep : forall u v x y, u ~= x -> v ~= y -> sub u v ~= (x-y)%F.
+ Lemma two_k_nonzero : 2^k <> 0.
+ Proof.
+ pose proof (Z.pow_eq_0 2 k k_nonneg).
+ intuition.
+ Qed.
- Parameter mul : T -> T -> T.
- Axiom mul_rep : forall u v x y, u ~= x -> v ~= y -> mul u v ~= (x*y)%F.
-End RepZMod.
+ Lemma map_nth_default_base_high : forall n, (n < (length base))%nat ->
+ nth_default 0 (map (Z.mul (2 ^ k)) base) n =
+ (2 ^ k) * (nth_default 0 base n).
+ Proof.
+ intros.
+ erewrite map_nth_default; auto.
+ Qed.
-Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersenneBaseParams BC M) <: RepZMod M.
- Module EC <: BaseCoefs.
- Definition base := BC.base ++ (map (Z.mul (2^(P.k))) BC.base).
-
- Lemma base_positive : forall b, In b base -> b > 0.
- Proof.
- unfold base. intros.
- rewrite in_app_iff in H.
- destruct H. {
- apply BC.base_positive; auto.
- } {
- specialize BC.base_positive.
- induction BC.base; intros. {
- simpl in H; intuition.
- } {
- simpl in H.
- destruct H; subst.
- replace 0 with (2 ^ P.k * 0) by auto.
- apply (Zmult_gt_compat_l a 0 (2 ^ P.k)).
- rewrite Z.gt_lt_iff.
- apply Z.pow_pos_nonneg; intuition.
- pose proof P.k_nonneg; omega.
- apply H0; left; auto.
- apply IHl; auto.
- intros. apply H0; auto. right; auto.
- }
- }
- Qed.
-
- Lemma base_length_nonzero : (0 < length BC.base)%nat.
- Proof.
- assert (nth_default 0 BC.base 0 = 1) by (apply BC.b0_1).
- unfold nth_default in H.
- case_eq (nth_error BC.base 0); intros;
- try (rewrite H0 in H; omega).
- apply (nth_error_value_length _ 0 BC.base z); auto.
- Qed.
-
- Lemma b0_1 : forall x, nth_default x base 0 = 1.
- Proof.
- intros. unfold base.
- rewrite nth_default_app.
- assert (0 < length BC.base)%nat by (apply base_length_nonzero).
- destruct (lt_dec 0 (length BC.base)); try apply BC.b0_1; try omega.
- Qed.
-
- Lemma two_k_nonzero : 2^P.k <> 0.
- pose proof (Z.pow_eq_0 2 P.k P.k_nonneg).
- intuition.
- Qed.
-
- Lemma map_nth_default_base_high : forall n, (n < (length BC.base))%nat ->
- nth_default 0 (map (Z.mul (2 ^ P.k)) BC.base) n =
- (2 ^ P.k) * (nth_default 0 BC.base n).
- Proof.
- intros.
- erewrite map_nth_default; auto.
- Qed.
-
- Lemma base_succ : forall i, ((S i) < length base)%nat ->
- let b := nth_default 0 base in
- b (S i) mod b i = 0.
- Proof.
- intros; subst b; unfold base.
- repeat rewrite nth_default_app.
- do 2 break_if; try apply P.base_succ; try omega. {
- destruct (lt_eq_lt_dec (S i) (length BC.base)). {
- destruct s; intuition.
- rewrite map_nth_default_base_high by omega.
- replace i with (pred(length BC.base)) by omega.
- rewrite <- Zmult_mod_idemp_l.
- rewrite P.base_tail_matches_modulus; simpl.
- rewrite Zmod_0_l; auto.
- } {
- assert (length BC.base <= i)%nat by (apply lt_n_Sm_le; auto); omega.
- }
- } {
- unfold base in H; rewrite app_length, map_length in H.
- repeat rewrite map_nth_default_base_high by omega.
- rewrite Zmult_mod_distr_l.
- rewrite <- minus_Sn_m by omega.
- rewrite P.base_succ by omega; auto.
- }
- Qed.
-
- Lemma base_good_over_boundary : forall
- (i : nat)
- (l : (i < length BC.base)%nat)
- (j' : nat)
- (Hj': (i + j' < length BC.base)%nat)
- ,
- 2 ^ P.k * (nth_default 0 BC.base i * nth_default 0 BC.base j') =
- 2 ^ P.k * (nth_default 0 BC.base i * nth_default 0 BC.base j') /
- (2 ^ P.k * nth_default 0 BC.base (i + j')) *
- (2 ^ P.k * nth_default 0 BC.base (i + j'))
- .
- Proof. intros.
- remember (nth_default 0 BC.base) as b.
- rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero).
- replace (b i * b j' / b (i + j')%nat * (2 ^ P.k * b (i + j')%nat))
- with ((2 ^ P.k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring.
- rewrite Z.mul_cancel_l by (exact two_k_nonzero).
- replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat))
- with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring.
- subst b.
- apply (BC.base_good i j'); omega.
- Qed.
-
- Lemma base_good :
- forall i j, (i+j < length base)%nat ->
- let b := nth_default 0 base in
- let r := (b i * b j) / b (i+j)%nat in
- b i * b j = r * b (i+j)%nat.
- Proof.
- intros.
- subst b. subst r.
- unfold base in *.
- rewrite app_length in H; rewrite map_length in H.
- repeat rewrite nth_default_app.
- destruct (lt_dec i (length BC.base));
- destruct (lt_dec j (length BC.base));
- destruct (lt_dec (i + j) (length BC.base));
- try omega.
- { (* i < length BC.base, j < length BC.base, i + j < length BC.base *)
- apply BC.base_good; auto.
- } { (* i < length BC.base, j < length BC.base, i + j >= length BC.base *)
- rewrite (map_nth_default _ _ _ _ 0) by omega.
- apply P.base_matches_modulus; omega.
- } { (* i < length BC.base, j >= length BC.base, i + j >= length BC.base *)
- do 2 rewrite map_nth_default_base_high by omega.
- remember (j - length BC.base)%nat as j'.
- replace (i + j - length BC.base)%nat with (i + j')%nat by omega.
- replace (nth_default 0 BC.base i * (2 ^ P.k * nth_default 0 BC.base j'))
- with (2 ^ P.k * (nth_default 0 BC.base i * nth_default 0 BC.base j'))
- by ring.
- eapply base_good_over_boundary; eauto; omega.
- } { (* i >= length BC.base, j < length BC.base, i + j >= length BC.base *)
- do 2 rewrite map_nth_default_base_high by omega.
- remember (i - length BC.base)%nat as i'.
- replace (i + j - length BC.base)%nat with (j + i')%nat by omega.
- replace (2 ^ P.k * nth_default 0 BC.base i' * nth_default 0 BC.base j)
- with (2 ^ P.k * (nth_default 0 BC.base j * nth_default 0 BC.base i'))
- by ring.
- eapply base_good_over_boundary; eauto; omega.
- }
- Qed.
- End EC.
+ Lemma ext_base_succ : forall i, ((S i) < length ext_base)%nat ->
+ let b := nth_default 0 ext_base in
+ b (S i) mod b i = 0.
+ Proof.
+ intros; subst b; unfold ext_base.
+ repeat rewrite nth_default_app.
+ do 2 break_if; [apply base_succ; auto | omega | | ]. {
+ destruct (lt_eq_lt_dec (S i) (length base)); try omega.
+ destruct s; intuition.
+ rewrite map_nth_default_base_high by omega.
+ replace i with (pred(length base)) by omega.
+ rewrite <- Zmult_mod_idemp_l.
+ rewrite base_tail_matches_modulus.
+ rewrite Zmod_0_l; auto.
+ } {
+ unfold ext_base in *; rewrite app_length, map_length in *.
+ repeat rewrite map_nth_default_base_high by omega.
+ rewrite Zmult_mod_distr_l.
+ rewrite <- minus_Sn_m by omega.
+ rewrite base_succ by omega; ring.
+ }
+ Qed.
+
+ Lemma base_good_over_boundary : forall
+ (i : nat)
+ (l : (i < length base)%nat)
+ (j' : nat)
+ (Hj': (i + j' < length base)%nat)
+ ,
+ 2 ^ k * (nth_default 0 base i * nth_default 0 base j') =
+ 2 ^ k * (nth_default 0 base i * nth_default 0 base j') /
+ (2 ^ k * nth_default 0 base (i + j')) *
+ (2 ^ k * nth_default 0 base (i + j'))
+ .
+ Proof.
+ intros.
+ remember (nth_default 0 base) as b.
+ rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero).
+ replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat))
+ with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring.
+ rewrite Z.mul_cancel_l by (exact two_k_nonzero).
+ replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat))
+ with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring.
+ subst b.
+ apply (BaseSystem.base_good i j'); omega.
+ Qed.
+
+ Lemma ext_base_good :
+ forall i j, (i+j < length ext_base)%nat ->
+ let b := nth_default 0 ext_base in
+ let r := (b i * b j) / b (i+j)%nat in
+ b i * b j = r * b (i+j)%nat.
+ Proof.
+ intros.
+ subst b. subst r.
+ unfold ext_base in *.
+ rewrite app_length in H; rewrite map_length in H.
+ repeat rewrite nth_default_app.
+ destruct (lt_dec i (length base));
+ destruct (lt_dec j (length base));
+ destruct (lt_dec (i + j) (length base));
+ try omega.
+ { (* i < length base, j < length base, i + j < length base *)
+ apply BaseSystem.base_good; auto.
+ } { (* i < length base, j < length base, i + j >= length base *)
+ rewrite (map_nth_default _ _ _ _ 0) by omega.
+ apply base_matches_modulus; omega.
+ } { (* i < length base, j >= length base, i + j >= length base *)
+ do 2 rewrite map_nth_default_base_high by omega.
+ remember (j - length base)%nat as j'.
+ replace (i + j - length base)%nat with (i + j')%nat by omega.
+ replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j'))
+ with (2 ^ k * (nth_default 0 base i * nth_default 0 base j'))
+ by ring.
+ eapply base_good_over_boundary; eauto; omega.
+ } { (* i >= length base, j < length base, i + j >= length base *)
+ do 2 rewrite map_nth_default_base_high by omega.
+ remember (i - length base)%nat as i'.
+ replace (i + j - length base)%nat with (j + i')%nat by omega.
+ replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j)
+ with (2 ^ k * (nth_default 0 base j * nth_default 0 base i'))
+ by ring.
+ eapply base_good_over_boundary; eauto; omega.
+ }
+ Qed.
+ Instance ExtBaseVector : BaseSystem.BaseVector ext_base := {
+ base_positive := ext_base_positive;
+ b0_1 := b0_1;
+ base_good := ext_base_good
+ }.
+End ExtendedBaseVector.
- Module E := BaseSystem EC.
- Module B := BaseSystem BC.
+Print ExtBaseVector.
+Section PseudoMersenneBase.
+ Context `(prm :PseudoMersenneBaseParams).
- Definition T := B.digits.
- Local Hint Unfold T.
- Definition decode (us : T) : F modulus := ZToField (B.decode us).
+ Definition T := BaseSystem.digits.
+ Definition decode (us : T) : F modulus := ZToField (BaseSystem.decode base us).
Local Hint Unfold decode.
- Definition rep (us : T) (x : F modulus) := (length us <= length BC.base)%nat /\ decode us = x.
+ Definition rep (us : T) (x : F modulus) := (length us <= length base)%nat /\ decode us = x.
Local Notation "u '~=' x" := (rep u x) (at level 70).
Local Hint Unfold rep.
@@ -216,188 +229,187 @@ Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersen
autounfold; intuition.
Qed.
- Definition encode (x : F modulus) := B.encode x.
+ Definition encode (x : F modulus) := BaseSystem.encode x.
Lemma encode_rep : forall x : F modulus, encode x ~= x.
Proof.
intros. unfold encode, rep.
split. {
- unfold B.encode; simpl.
- apply EC.base_length_nonzero.
+ unfold encode; simpl.
+ apply base_length_nonzero.
+ assumption.
} {
unfold decode.
- rewrite B.encode_rep.
- apply ZToField_idempotent. (* TODO: rename this lemma *)
+ rewrite BaseSystem.encode_rep.
+ apply ZToField_FieldToZ.
+ assumption.
}
Qed.
- Definition add (us vs : T) := B.add us vs.
- Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> add u v ~= (x+y)%F.
+ Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.add u v ~= (x+y)%F.
Proof.
autounfold; intuition. {
unfold add.
- rewrite B.add_length_le_max.
+ rewrite BaseSystem.add_length_le_max.
case_max; try rewrite Max.max_r; omega.
}
- unfold decode in *; unfold B.decode in *.
- rewrite B.add_rep.
+ unfold decode in *; unfold BaseSystem.decode in *.
+ rewrite BaseSystem.add_rep.
rewrite ZToField_add.
subst; auto.
Qed.
- Definition sub (us vs : T) := B.sub us vs.
- Lemma sub_rep : forall u v x y, u ~= x -> v ~= y -> sub u v ~= (x-y)%F.
+ Lemma sub_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.sub u v ~= (x-y)%F.
Proof.
autounfold; intuition. {
- rewrite B.sub_length_le_max.
+ rewrite BaseSystem.sub_length_le_max.
case_max; try rewrite Max.max_r; omega.
}
- unfold decode in *; unfold B.decode in *.
- rewrite B.sub_rep.
+ unfold decode in *; unfold BaseSystem.decode in *.
+ rewrite BaseSystem.sub_rep.
rewrite ZToField_sub.
subst; auto.
Qed.
- Lemma decode_short : forall (us : B.digits),
- (length us <= length BC.base)%nat -> B.decode us = E.decode us.
+ Lemma decode_short : forall (us : T),
+ (length us <= length base)%nat ->
+ BaseSystem.decode base us = BaseSystem.decode (ext_base base prm) us.
Proof.
intros.
- unfold B.decode, B.decode', E.decode, E.decode'.
+ unfold BaseSystem.decode, BaseSystem.decode'.
rewrite combine_truncate_r.
- rewrite (combine_truncate_r us EC.base).
+ rewrite (combine_truncate_r us (ext_base base prm)).
f_equal; f_equal.
- unfold EC.base.
+ unfold ext_base.
rewrite firstn_app_inleft; auto; omega.
Qed.
Lemma extended_base_length:
- length EC.base = (length BC.base + length BC.base)%nat.
+ length (ext_base base prm) = (length base + length base)%nat.
Proof.
- unfold EC.base; rewrite app_length; rewrite map_length; auto.
+ unfold ext_base; rewrite app_length; rewrite map_length; auto.
Qed.
- Lemma mul_rep_extended : forall (us vs : B.digits),
- (length us <= length BC.base)%nat ->
- (length vs <= length BC.base)%nat ->
- B.decode us * B.decode vs = E.decode (E.mul us vs).
+ Lemma mul_rep_extended : forall (us vs : T),
+ (length us <= length base)%nat ->
+ (length vs <= length base)%nat ->
+ (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode (ext_base base prm) (BaseSystem.mul (ext_base base prm) us vs).
Proof.
intros.
- rewrite E.mul_rep by (unfold EC.base; simpl_list; omega).
+ rewrite BaseSystem.mul_rep by (apply ExtBaseVector || unfold ext_base; simpl_list; omega).
f_equal; rewrite decode_short; auto.
Qed.
- (* Converts from length of E.base to length of B.base by reduction modulo M.*)
- Definition reduce (us : E.digits) : B.digits :=
- let high := skipn (length BC.base) us in
- let low := firstn (length BC.base) us in
- let wrap := map (Z.mul P.c) high in
- B.add low wrap.
-
- Lemma two_k_nonzero : 2^P.k <> 0.
- pose proof (Z.pow_eq_0 2 P.k P.k_nonneg).
- intuition.
- Qed.
+ (* Converts from length of extended base to length of base by reduction modulo M.*)
+ Definition reduce (us : T) : T :=
+ let high := skipn (length base) us in
+ let low := firstn (length base) us in
+ let wrap := map (Z.mul c) high in
+ BaseSystem.add low wrap.
Lemma modulus_nonzero : modulus <> 0.
pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega.
Qed.
(* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *)
- Lemma pseudomersenne_add: forall x y, (x + ((2^P.k) * y)) mod modulus = (x + (P.c * y)) mod modulus.
+ Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus.
Proof.
intros.
- replace (2^P.k) with (((2^P.k) - P.c) + P.c) by auto.
+ replace (2^k) with ((2^k - c) + c) by ring.
rewrite Z.mul_add_distr_r.
rewrite Zplus_mod.
- rewrite <- P.modulus_pseudomersenne.
+ rewrite <- modulus_pseudomersenne.
rewrite Z.mul_comm.
rewrite mod_mult_plus; auto using modulus_nonzero.
rewrite <- Zplus_mod; auto.
Qed.
- Lemma extended_shiftadd: forall (us : E.digits), E.decode us =
- B.decode (firstn (length BC.base) us) +
- (2^P.k * B.decode (skipn (length BC.base) us)).
+ Lemma extended_shiftadd: forall (us : T),
+ BaseSystem.decode (ext_base base prm) us =
+ BaseSystem.decode base (firstn (length base) us)
+ + (2^k * BaseSystem.decode base (skipn (length base) us)).
Proof.
intros.
- unfold B.decode, E.decode; rewrite <- B.mul_each_rep.
- replace B.decode' with E.decode' by auto.
- unfold EC.base.
- replace (map (Z.mul (2 ^ P.k)) BC.base) with (E.mul_each (2 ^ P.k) BC.base) by auto.
- rewrite E.base_mul_app.
- rewrite <- E.mul_each_rep; auto.
+ unfold BaseSystem.decode; rewrite <- BaseSystem.mul_each_rep.
+ unfold ext_base.
+ replace (map (Z.mul (2 ^ k)) base) with (BaseSystem.mul_each (2 ^ k) base) by auto.
+ rewrite BaseSystem.base_mul_app.
+ rewrite <- BaseSystem.mul_each_rep; auto.
Qed.
- Lemma reduce_rep : forall us, B.decode (reduce us) mod modulus = (E.decode us) mod modulus.
+ Lemma reduce_rep : forall us,
+ BaseSystem.decode base (reduce us) mod modulus =
+ BaseSystem.decode (ext_base base prm) us mod modulus.
Proof.
intros.
rewrite extended_shiftadd.
rewrite pseudomersenne_add.
unfold reduce.
- remember (firstn (length BC.base) us) as low.
- remember (skipn (length BC.base) us) as high.
- unfold B.decode.
- rewrite B.add_rep.
- replace (map (Z.mul P.c) high) with (B.mul_each P.c high) by auto.
- rewrite B.mul_each_rep; auto.
+ remember (firstn (length base) us) as low.
+ remember (skipn (length base) us) as high.
+ unfold BaseSystem.decode.
+ rewrite BaseSystem.add_rep.
+ replace (map (Z.mul c) high) with (BaseSystem.mul_each c high) by auto.
+ rewrite BaseSystem.mul_each_rep; auto.
Qed.
Lemma reduce_length : forall us,
- (length us <= length EC.base)%nat ->
- (length (reduce us) <= length (BC.base))%nat.
+ (length us <= length (ext_base base prm))%nat ->
+ (length (reduce us) <= length (base))%nat.
Proof.
intros.
unfold reduce.
- remember (map (Z.mul P.c) (skipn (length BC.base) us)) as high.
- remember (firstn (length BC.base) us) as low.
+ remember (map (Z.mul c) (skipn (length base) us)) as high.
+ remember (firstn (length base) us) as low.
assert (length low >= length high)%nat. {
subst. rewrite firstn_length.
rewrite map_length.
rewrite skipn_length.
- destruct (le_dec (length BC.base) (length us)). {
+ destruct (le_dec (length base) (length us)). {
rewrite Min.min_l by omega.
rewrite extended_base_length in H. omega.
} {
- rewrite Min.min_r by omega. omega.
+ rewrite Min.min_r; omega.
}
}
- assert ((length low <= length BC.base)%nat)
+ assert ((length low <= length base)%nat)
by (rewrite Heqlow; rewrite firstn_length; apply Min.le_min_l).
- assert (length high <= length BC.base)%nat
+ assert (length high <= length base)%nat
by (rewrite Heqhigh; rewrite map_length; rewrite skipn_length;
rewrite extended_base_length in H; omega).
- rewrite B.add_trailing_zeros; auto.
- rewrite (B.add_same_length _ _ (length low)); auto.
+ rewrite BaseSystem.add_trailing_zeros; auto.
+ rewrite (BaseSystem.add_same_length _ _ (length low)); auto.
rewrite app_length.
- rewrite B.length_zeros; intuition.
+ rewrite BaseSystem.length_zeros; intuition.
Qed.
- Definition mul (us vs : T) := reduce (E.mul us vs).
+ Definition mul (us vs : T) := reduce (BaseSystem.mul (ext_base base prm) us vs).
Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> mul u v ~= (x*y)%F.
Proof.
- autounfold; unfold mul; intuition. {
- rewrite reduce_length; try omega.
- rewrite E.mul_length.
- rewrite extended_base_length.
+ autounfold; unfold mul; intuition.
+ {
+ apply reduce_length.
+ rewrite BaseSystem.mul_length, extended_base_length.
omega.
+ } {
+ rewrite ZToField_mod, reduce_rep, <-ZToField_mod.
+ rewrite BaseSystem.mul_rep by
+ (apply ExtBaseVector || rewrite extended_base_length; omega).
+ subst.
+ do 2 rewrite decode_short by auto.
+ apply ZToField_mul.
}
- rewrite ZToField_mod, reduce_rep, <-ZToField_mod.
- rewrite E.mul_rep; try (rewrite extended_base_length; omega).
- subst; auto.
- replace (E.decode u) with (B.decode u) by (apply decode_short; omega).
- replace (E.decode v) with (B.decode v) by (apply decode_short; omega).
- apply ZToField_mul.
Qed.
Definition add_to_nth n (x:Z) xs :=
set_nth n (x + nth_default 0 xs n) xs.
Hint Unfold add_to_nth.
- (* i must be in the domain of BC.base *)
+ (* i must be in the domain of base *)
Definition cap i :=
- if eq_nat_dec i (pred (length BC.base))
- then (2^P.k) / nth_default 0 BC.base i
- else nth_default 0 BC.base (S i) / nth_default 0 BC.base i.
+ if eq_nat_dec i (pred (length base))
+ then (2^k) / nth_default 0 base i
+ else nth_default 0 base (S i) / nth_default 0 base i.
Definition carry_simple i := fun us =>
let di := nth_default 0 us i in
@@ -407,106 +419,109 @@ Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersen
Definition carry_and_reduce i := fun us =>
let di := nth_default 0 us i in
let us' := set_nth i (di mod cap i) us in
- add_to_nth 0 (P.c * (di / cap i)) us'.
+ add_to_nth 0 (c * (di / cap i)) us'.
- Definition carry i : B.digits -> B.digits :=
- if eq_nat_dec i (pred (length BC.base))
+ Definition carry i : T -> T :=
+ if eq_nat_dec i (pred (length base))
then carry_and_reduce i
else carry_simple i.
+ (* TODO: move to BaseSystemProofs *)
Lemma decode'_splice : forall xs ys bs,
- B.decode' bs (xs ++ ys) =
- B.decode' (firstn (length xs) bs) xs +
- B.decode' (skipn (length xs) bs) ys.
+ BaseSystem.decode' bs (xs ++ ys) =
+ BaseSystem.decode' (firstn (length xs) bs) xs +
+ BaseSystem.decode' (skipn (length xs) bs) ys.
Proof.
+ unfold BaseSystem.decode'.
induction xs; destruct ys, bs; boring.
- unfold B.decode'.
- rewrite combine_truncate_r.
- ring.
+ + rewrite combine_truncate_r.
+ do 2 rewrite Z.add_0_r; auto.
+ + unfold BaseSystem.accumulate.
+ apply Z.add_assoc.
Qed.
Lemma set_nth_sum : forall n x us, (n < length us)%nat ->
- B.decode (set_nth n x us) =
- (x - nth_default 0 us n) * nth_default 0 BC.base n + B.decode us.
+ BaseSystem.decode base (set_nth n x us) =
+ (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us.
Proof.
intros.
- unfold B.decode.
+ unfold BaseSystem.decode.
nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*)
unfold splice_nth.
rewrite <- (firstn_skipn n us) at 4.
do 2 rewrite decode'_splice.
remember (length (firstn n us)) as n0.
ring_simplify.
- remember (B.decode' (firstn n0 BC.base) (firstn n us)).
+ remember (BaseSystem.decode' (firstn n0 base) (firstn n us)).
rewrite (skipn_nth_default n us 0) by omega.
rewrite firstn_length in Heqn0.
rewrite Min.min_l in Heqn0 by omega; subst n0.
- destruct (le_lt_dec (length BC.base) n). {
+ destruct (le_lt_dec (length base) n). {
rewrite nth_default_out_of_bounds by auto.
rewrite skipn_all by omega.
- do 2 rewrite B.decode_base_nil.
+ do 2 rewrite BaseSystem.decode_base_nil.
ring_simplify; auto.
} {
- rewrite (skipn_nth_default n BC.base 0) by omega.
- do 2 rewrite B.decode'_cons.
+ rewrite (skipn_nth_default n base 0) by omega.
+ do 2 rewrite BaseSystem.decode'_cons.
ring_simplify; ring.
}
Qed.
Lemma add_to_nth_sum : forall n x us, (n < length us)%nat ->
- B.decode (add_to_nth n x us) =
- x * nth_default 0 BC.base n + B.decode us.
+ BaseSystem.decode base (add_to_nth n x us) =
+ x * nth_default 0 base n + BaseSystem.decode base us.
Proof.
unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto.
Qed.
- Lemma nth_default_base_positive : forall i, (i < length BC.base)%nat ->
- nth_default 0 BC.base i > 0.
+ Lemma nth_default_base_positive : forall i, (i < length base)%nat ->
+ nth_default 0 base i > 0.
Proof.
intros.
pose proof (nth_error_length_exists_value _ _ H).
destruct H0.
pose proof (nth_error_value_In _ _ _ H0).
- pose proof (BC.base_positive _ H1).
+ pose proof (BaseSystem.base_positive _ H1).
unfold nth_default.
rewrite H0; auto.
Qed.
- Lemma base_succ_div_mult : forall i, ((S i) < length BC.base)%nat ->
- nth_default 0 BC.base (S i) = nth_default 0 BC.base i *
- (nth_default 0 BC.base (S i) / nth_default 0 BC.base i).
+ Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat ->
+ nth_default 0 base (S i) = nth_default 0 base i *
+ (nth_default 0 base (S i) / nth_default 0 base i).
Proof.
intros.
apply Z_div_exact_2; try (apply nth_default_base_positive; omega).
- apply P.base_succ; auto.
+ apply base_succ; auto.
Qed.
- Lemma base_length_lt_pred : (pred (length BC.base) < length BC.base)%nat.
+ Lemma base_length_lt_pred : (pred (length base) < length base)%nat.
Proof.
- pose proof EC.base_length_nonzero; omega.
+ pose proof (base_length_nonzero base); omega.
Qed.
Hint Resolve base_length_lt_pred.
- Lemma cap_positive: forall i, (i < length BC.base)%nat -> cap i > 0.
+ Lemma cap_positive: forall i, (i < length base)%nat -> cap i > 0.
Proof.
unfold cap; intros; break_if. {
- apply div_positive_gt_0; try (subst; apply P.base_tail_matches_modulus). {
+ apply div_positive_gt_0; try (subst; apply base_tail_matches_modulus). {
rewrite <- two_p_equiv.
apply two_p_gt_ZERO.
- apply P.k_nonneg.
+ apply k_nonneg.
} {
apply nth_default_base_positive; subst; auto.
}
} {
- apply div_positive_gt_0; try (apply P.base_succ; omega);
+ apply div_positive_gt_0; try (apply base_succ; omega);
try (apply nth_default_base_positive; omega).
}
Qed.
- Lemma cap_div_mod : forall us i, (i < (pred (length BC.base)))%nat ->
+ Lemma cap_div_mod : forall us i, (i < (pred (length base)))%nat ->
let di := nth_default 0 us i in
- (di - (di mod cap i)) * nth_default 0 BC.base i =
- (di / cap i) * nth_default 0 BC.base (S i).
+ (di - (di mod cap i)) * nth_default 0 base i =
+ (di / cap i) * nth_default 0 base (S i).
Proof.
intros.
rewrite (Z_div_mod_eq di (cap i)) at 1 by (apply cap_positive; omega);
@@ -516,9 +531,9 @@ Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersen
Qed.
Lemma carry_simple_decode_eq : forall i us,
- (length us = length BC.base) ->
- (i < (pred (length BC.base)))%nat ->
- B.decode (carry_simple i us) = B.decode us.
+ (length us = length base) ->
+ (i < (pred (length base)))%nat ->
+ BaseSystem.decode base (carry_simple i us) = BaseSystem.decode base us.
Proof.
unfold carry_simple. intros.
rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
@@ -527,49 +542,52 @@ Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersen
Qed.
Lemma two_k_div_mul_last :
- 2 ^ P.k = nth_default 0 BC.base (pred (length BC.base)) *
- (2 ^ P.k / nth_default 0 BC.base (pred (length BC.base))).
+ 2 ^ k = nth_default 0 base (pred (length base)) *
+ (2 ^ k / nth_default 0 base (pred (length base))).
Proof.
intros.
- pose proof P.base_tail_matches_modulus.
- rewrite (Z_div_mod_eq (2 ^ P.k) (nth_default 0 BC.base (pred (length BC.base)))) at 1 by
+ pose proof base_tail_matches_modulus.
+ rewrite (Z_div_mod_eq (2 ^ k) (nth_default 0 base (pred (length base)))) at 1 by
(apply nth_default_base_positive; auto); omega.
Qed.
Lemma cap_div_mod_reduce : forall us,
- let i := pred (length BC.base) in
+ let i := pred (length base) in
let di := nth_default 0 us i in
- (di - (di mod cap i)) * nth_default 0 BC.base i =
- (di / cap i) * 2 ^ P.k.
+ (di - (di mod cap i)) * nth_default 0 base i =
+ (di / cap i) * 2 ^ k.
Proof.
intros.
rewrite (Z_div_mod_eq di (cap i)) at 1 by
(apply cap_positive; auto); ring_simplify.
unfold cap; break_if; intuition.
rewrite Z.mul_comm, Z.mul_assoc.
- subst i; rewrite <- two_k_div_mul_last; auto.
+ subst i; rewrite <- two_k_div_mul_last; ring.
Qed.
Lemma carry_decode_eq_reduce : forall us,
- (length us = length BC.base) ->
- B.decode (carry_and_reduce (pred (length BC.base)) us) mod modulus
- = B.decode us mod modulus.
+ (length us = length base) ->
+ BaseSystem.decode base (carry_and_reduce (pred (length base)) us) mod modulus
+ = BaseSystem.decode base us mod modulus.
Proof.
unfold carry_and_reduce; intros.
- pose proof EC.base_length_nonzero.
+ pose proof (base_length_nonzero base).
rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
rewrite set_nth_sum by omega.
- rewrite Zplus_comm; rewrite <- Z.mul_assoc.
+ rewrite Zplus_comm, <- Z.mul_assoc.
rewrite <- pseudomersenne_add.
- rewrite BC.b0_1.
- rewrite (Z.mul_comm (2 ^ P.k)).
+ rewrite BaseSystem.b0_1.
+ rewrite (Z.mul_comm (2 ^ k)).
rewrite <- Zred_factor0.
- rewrite <- cap_div_mod_reduce by auto; auto.
+ rewrite <- cap_div_mod_reduce by auto.
+ do 2 rewrite Zmult_minus_distr_r.
+ f_equal.
+ ring.
Qed.
Lemma carry_length : forall i us,
- (length us <= length BC.base)%nat ->
- (length (carry i us) <= length BC.base)%nat.
+ (length us <= length base)%nat ->
+ (length (carry i us) <= length base)%nat.
Proof.
unfold carry, carry_simple, carry_and_reduce, add_to_nth.
intros; break_if; subst; repeat (rewrite length_set_nth); auto.
@@ -577,8 +595,8 @@ Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersen
Hint Resolve carry_length.
Lemma carry_rep : forall i us x,
- (length us = length BC.base) ->
- (i < length BC.base)%nat ->
+ (length us = length base) ->
+ (i < length base)%nat ->
us ~= x -> carry i us ~= x.
Proof.
pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq.
@@ -591,24 +609,24 @@ Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersen
Definition carry_sequence is us := fold_right carry us is.
Lemma carry_sequence_length: forall is us,
- (length us <= length BC.base)%nat ->
- (length (carry_sequence is us) <= length BC.base)%nat.
+ (length us <= length base)%nat ->
+ (length (carry_sequence is us) <= length base)%nat.
Proof.
induction is; boring.
Qed.
Hint Resolve carry_sequence_length.
Lemma carry_length_exact : forall i us,
- (length us = length BC.base)%nat ->
- (length (carry i us) = length BC.base)%nat.
+ (length us = length base)%nat ->
+ (length (carry i us) = length base)%nat.
Proof.
unfold carry, carry_simple, carry_and_reduce, add_to_nth.
intros; break_if; subst; repeat (rewrite length_set_nth); auto.
Qed.
Lemma carry_sequence_length_exact: forall is us,
- (length us = length BC.base)%nat ->
- (length (carry_sequence is us) = length BC.base)%nat.
+ (length us = length base)%nat ->
+ (length (carry_sequence is us) = length base)%nat.
Proof.
induction is; boring.
apply carry_length_exact; auto.
@@ -616,10 +634,10 @@ Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersen
Hint Resolve carry_sequence_length_exact.
Lemma carry_sequence_rep : forall is us x,
- (forall i, In i is -> (i < length BC.base)%nat) ->
- (length us = length BC.base) ->
+ (forall i, In i is -> (i < length base)%nat) ->
+ (length us = length base) ->
us ~= x -> carry_sequence is us ~= x.
Proof.
induction is; boring.
Qed.
-End PseudoMersenneBase. \ No newline at end of file
+End PseudoMersenneBase.