aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@mit.edu>2016-03-20 14:34:51 -0400
committerGravatar Jade Philipoom <jadep@mit.edu>2016-03-20 14:47:57 -0400
commit2f178e16ab2e44b6139ef01dca17f425f02bb319 (patch)
treef792f67fc911997dc8e9be0bb26c5980af88a898
parent724b7b2acb9b857d7c511a320973cead308117c6 (diff)
refactor of Basesystem and ModularBaseSystem; includes general code organization and changes to pseudomersenne base parameters that require bases to be expressed as powers of 2, which reduces the burden of proof on the caller and allows carry functions to use bitwise operations rather than mod and division
-rw-r--r--_CoqProject6
-rw-r--r--src/BaseSystem.v470
-rw-r--r--src/BaseSystemProofs.v490
-rw-r--r--src/ModularArithmetic/ExtendedBaseVector.v163
-rw-r--r--src/ModularArithmetic/ModularBaseSystem.v631
-rw-r--r--src/ModularArithmetic/ModularBaseSystemProofs.v449
-rw-r--r--src/ModularArithmetic/PseudoMersenneBaseParamProofs.v277
-rw-r--r--src/ModularArithmetic/PseudoMersenneBaseParams.v26
-rw-r--r--src/ModularArithmetic/PseudoMersenneBaseRep.v43
9 files changed, 1480 insertions, 1075 deletions
diff --git a/_CoqProject b/_CoqProject
index d7c71f7cb..cae81fdae 100644
--- a/_CoqProject
+++ b/_CoqProject
@@ -1,5 +1,6 @@
-R src Crypto
src/BaseSystem.v
+src/BaseSystemProofs.v
src/BoundedIterOp.v
src/EdDSAProofs.v
src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v
@@ -7,12 +8,17 @@ src/CompleteEdwardsCurve/DoubleAndAdd.v
src/CompleteEdwardsCurve/ExtendedCoordinates.v
src/CompleteEdwardsCurve/Pre.v
src/Encoding/EncodingTheorems.v
+src/ModularArithmetic/ExtendedBaseVector.v
src/ModularArithmetic/FField.v
src/ModularArithmetic/FNsatz.v
src/ModularArithmetic/ModularArithmeticTheorems.v
src/ModularArithmetic/ModularBaseSystem.v
+src/ModularArithmetic/ModularBaseSystemProofs.v
src/ModularArithmetic/Pre.v
src/ModularArithmetic/PrimeFieldTheorems.v
+src/ModularArithmetic/PseudoMersenneBaseParams.v
+src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
+src/ModularArithmetic/PseudoMersenneBaseRep.v
src/ModularArithmetic/Tutorial.v
src/Spec/CompleteEdwardsCurve.v
src/Spec/Ed25519.v
diff --git a/src/BaseSystem.v b/src/BaseSystem.v
index e9e4bc9d4..f3a1a0fdb 100644
--- a/src/BaseSystem.v
+++ b/src/BaseSystem.v
@@ -16,7 +16,7 @@ Class BaseVector (base : list Z):= {
}.
Section BaseSystem.
- Context (base : list Z) (base_vector : BaseVector base).
+ Context (base : list Z).
(** [BaseSystem] implements an constrained positional number system.
A wide variety of bases are supported: the base coefficients are not
required to be powers of 2, and it is NOT necessarily the case that
@@ -32,7 +32,6 @@ Section BaseSystem.
Definition accumulate p acc := fst p * snd p + acc.
Definition decode' bs u := fold_right accumulate 0 (combine u bs).
Definition decode := decode' base.
- Hint Unfold accumulate.
(* Does not carry; z becomes the lowest and only digit. *)
Definition encode (z : Z) := z :: nil.
@@ -51,50 +50,7 @@ Section BaseSystem.
Hint Extern 1 (@eq Z _ _) => ring.
- Lemma add_rep : forall bs us vs, decode' bs (add us vs) = decode' bs us + decode' bs vs.
- Proof.
- unfold decode, decode'; induction bs; destruct us; destruct vs; boring.
- Qed.
-
- Lemma decode_nil : forall bs, decode' bs nil = 0.
- auto.
- Qed.
- Hint Rewrite decode_nil.
-
- Lemma decode_base_nil : forall us, decode' nil us = 0.
- Proof.
- intros; rewrite decode'_truncate; auto.
- Qed.
- Hint Rewrite decode_base_nil.
-
Definition mul_each u := map (Z.mul u).
- Lemma mul_each_rep : forall bs u vs,
- decode' bs (mul_each u vs) = u * decode' bs vs.
- Proof.
- unfold decode'; induction bs; destruct vs; boring.
- Qed.
-
- Lemma base_eq_1cons: base = 1 :: skipn 1 base.
- Proof.
- pose proof (b0_1 0) as H.
- destruct base; compute in H; try discriminate; boring.
- Qed.
-
- Lemma decode'_cons : forall x1 x2 xs1 xs2,
- decode' (x1 :: xs1) (x2 :: xs2) = x1 * x2 + decode' xs1 xs2.
- Proof.
- unfold decode'; boring.
- Qed.
- Hint Rewrite decode'_cons.
-
- Lemma decode_cons : forall x us,
- decode (x :: us) = x + decode (0 :: us).
- Proof.
- unfold decode; intros.
- rewrite base_eq_1cons.
- autorewrite with core; ring_simplify; auto.
- Qed.
-
Fixpoint sub (us vs:digits) : digits :=
match us,vs with
| u::us', v::vs' => u-v :: sub us' vs'
@@ -102,76 +58,13 @@ Section BaseSystem.
| nil, v::vs' => (0-v)::sub nil vs'
end.
- Lemma sub_rep : forall bs us vs, decode' bs (sub us vs) = decode' bs us - decode' bs vs.
- Proof.
- induction bs; destruct us; destruct vs; boring.
- Qed.
-
- Lemma encode_rep : forall z, decode (encode z) = z.
- Proof.
- pose proof base_eq_1cons.
- unfold decode, encode; destruct z; boring.
- Qed.
-
- Lemma mul_each_base : forall us bs c,
- decode' bs (mul_each c us) = decode' (mul_each c bs) us.
- Proof.
- induction us; destruct bs; boring.
- Qed.
-
- Hint Rewrite (@nth_default_nil Z).
- Hint Rewrite (@firstn_nil Z).
- Hint Rewrite (@skipn_nil Z).
-
- Lemma base_app : forall us low high,
- decode' (low ++ high) us = decode' low (firstn (length low) us) + decode' high (skipn (length low) us).
- Proof.
- induction us; destruct low; boring.
- Qed.
-
- Lemma base_mul_app : forall low c us,
- decode' (low ++ mul_each c low) us = decode' low (firstn (length low) us) +
- c * decode' low (skipn (length low) us).
- Proof.
- intros.
- rewrite base_app; f_equal.
- rewrite <- mul_each_rep.
- rewrite mul_each_base.
- reflexivity.
- Qed.
-
Definition crosscoef i j : Z :=
let b := nth_default 0 base in
(b(i) * b(j)) / b(i+j)%nat.
Hint Unfold crosscoef.
Fixpoint zeros n := match n with O => nil | S n' => 0::zeros n' end.
- Lemma zeros_rep : forall bs n, decode' bs (zeros n) = 0.
- induction bs; destruct n; boring.
- Qed.
- Lemma length_zeros : forall n, length (zeros n) = n.
- induction n; boring.
- Qed.
- Hint Rewrite length_zeros.
-
- Lemma app_zeros_zeros : forall n m, zeros n ++ zeros m = zeros (n + m).
- Proof.
- induction n; boring.
- Qed.
- Hint Rewrite app_zeros_zeros.
-
- Lemma zeros_app0 : forall m, zeros m ++ 0 :: nil = zeros (S m).
- Proof.
- induction m; boring.
- Qed.
- Hint Rewrite zeros_app0.
-
- Lemma rev_zeros : forall n, rev (zeros n) = zeros n.
- Proof.
- induction n; boring.
- Qed.
- Hint Rewrite rev_zeros.
-
+
(* mul' is multiplication with the SECOND ARGUMENT REVERSED and OUTPUT REVERSED *)
Fixpoint mul_bi' (i:nat) (vsr:digits) :=
match vsr with
@@ -180,317 +73,7 @@ Section BaseSystem.
end.
Definition mul_bi (i:nat) (vs:digits) : digits :=
zeros i ++ rev (mul_bi' i (rev vs)).
-
- Hint Unfold nth_default.
-
- Lemma decode_single : forall n bs x,
- decode' bs (zeros n ++ x :: nil) = nth_default 0 bs n * x.
- Proof.
- induction n; destruct bs; boring.
- Qed.
- Hint Rewrite decode_single.
-
- Lemma peel_decode : forall xs ys x y, decode' (x::xs) (y::ys) = x*y + decode' xs ys.
- Proof.
- boring.
- Qed.
- Hint Rewrite zeros_rep peel_decode.
-
- Lemma decode_highzeros : forall xs bs n, decode' bs (xs ++ zeros n) = decode' bs xs.
- Proof.
- induction xs; destruct bs; boring.
- Qed.
-
- Lemma mul_bi'_zeros : forall n m, mul_bi' n (zeros m) = zeros m.
- induction m; boring.
- Qed.
- Hint Rewrite mul_bi'_zeros.
-
- Lemma nth_error_base_nonzero : forall n x,
- nth_error base n = Some x -> x <> 0.
- Proof.
- eauto using (@nth_error_value_In Z), Zgt0_neq0, base_positive.
- Qed.
-
- Hint Rewrite plus_0_r.
-
- Lemma mul_bi_single : forall m n x,
- (n + m < length base)%nat ->
- decode (mul_bi n (zeros m ++ x :: nil)) = nth_default 0 base m * x * nth_default 0 base n.
- Proof.
- unfold mul_bi, decode.
- destruct m; simpl; simpl_list; simpl; intros. {
- pose proof nth_error_base_nonzero as nth_nonzero.
- case_eq base; [intros; boring | intros z l base_eq].
- specialize (b0_1 0); intro b0_1'.
- rewrite base_eq in *.
- rewrite nth_default_cons in b0_1'.
- rewrite b0_1' in *.
- boring; nth_tac.
- rewrite Z_div_mul'; eauto.
- destruct x; ring.
- } {
- ssimpl_list.
- autorewrite with core.
- rewrite app_assoc.
- autorewrite with core.
- unfold crosscoef; simpl; ring_simplify.
- rewrite Nat.add_1_r.
- rewrite base_good by auto.
- rewrite Z_div_mult by (apply base_positive; rewrite nth_default_eq; apply nth_In; auto).
- rewrite <- Z.mul_assoc.
- rewrite <- Z.mul_comm.
- rewrite <- Z.mul_assoc.
- rewrite <- Z.mul_assoc.
- destruct (Z.eq_dec x 0); subst; try ring.
- rewrite Z.mul_cancel_l by auto.
- rewrite <- base_good by auto.
- ring.
- }
- Qed.
-
- Lemma set_higher' : forall vs x, vs++x::nil = vs .+ (zeros (length vs) ++ x :: nil).
- induction vs; boring; f_equal; ring.
- Qed.
-
- Lemma set_higher : forall bs vs x,
- decode' bs (vs++x::nil) = decode' bs vs + nth_default 0 bs (length vs) * x.
- Proof.
- intros.
- rewrite set_higher'.
- rewrite add_rep.
- f_equal.
- apply decode_single.
- Qed.
-
- Lemma zeros_plus_zeros : forall n, zeros n = zeros n .+ zeros n.
- induction n; auto.
- simpl; f_equal; auto.
- Qed.
-
- Lemma mul_bi'_n_nil : forall n, mul_bi' n nil = nil.
- Proof.
- unfold mul_bi; auto.
- Qed.
- Hint Rewrite mul_bi'_n_nil.
-
- Lemma add_nil_l : forall us, nil .+ us = us.
- induction us; auto.
- Qed.
- Hint Rewrite add_nil_l.
-
- Lemma add_nil_r : forall us, us .+ nil = us.
- induction us; auto.
- Qed.
- Hint Rewrite add_nil_r.
-
- Lemma add_first_terms : forall us vs a b,
- (a :: us) .+ (b :: vs) = (a + b) :: (us .+ vs).
- auto.
- Qed.
- Hint Rewrite add_first_terms.
-
- Lemma mul_bi'_cons : forall n x us,
- mul_bi' n (x :: us) = x * crosscoef n (length us) :: mul_bi' n us.
- Proof.
- unfold mul_bi'; auto.
- Qed.
-
- Lemma add_same_length : forall us vs l, (length us = l) -> (length vs = l) ->
- length (us .+ vs) = l.
- Proof.
- induction us, vs; boring.
- erewrite (IHus vs (pred l)); boring.
- Qed.
-
- Hint Rewrite app_nil_l.
- Hint Rewrite app_nil_r.
-
- Lemma add_snoc_same_length : forall l us vs a b,
- (length us = l) -> (length vs = l) ->
- (us ++ a :: nil) .+ (vs ++ b :: nil) = (us .+ vs) ++ (a + b) :: nil.
- Proof.
- induction l, us, vs; boring; discriminate.
- Qed.
-
- Lemma mul_bi'_add : forall us n vs l
- (Hlus: length us = l)
- (Hlvs: length vs = l),
- mul_bi' n (rev (us .+ vs)) =
- mul_bi' n (rev us) .+ mul_bi' n (rev vs).
- Proof.
- (* TODO(adamc): please help prettify this *)
- induction us using rev_ind;
- try solve [destruct vs; boring; congruence].
- destruct vs using rev_ind; boring; clear IHvs; simpl_list.
- erewrite (add_snoc_same_length (pred l) us vs _ _); simpl_list.
- repeat rewrite mul_bi'_cons; rewrite add_first_terms; simpl_list.
- rewrite (IHus n vs (pred l)).
- replace (length us) with (pred l).
- replace (length vs) with (pred l).
- rewrite (add_same_length us vs (pred l)).
- f_equal; ring.
-
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- Qed.
-
- Lemma zeros_cons0 : forall n, 0 :: zeros n = zeros (S n).
- auto.
- Qed.
-
- Lemma add_leading_zeros : forall n us vs,
- (zeros n ++ us) .+ (zeros n ++ vs) = zeros n ++ (us .+ vs).
- Proof.
- induction n; boring.
- Qed.
- Lemma rev_add_rev : forall us vs l, (length us = l) -> (length vs = l) ->
- (rev us) .+ (rev vs) = rev (us .+ vs).
- Proof.
- induction us, vs; boring; try solve [subst; discriminate].
- rewrite (add_snoc_same_length (pred l) _ _ _ _) by (subst; simpl_list; omega).
- rewrite (IHus vs (pred l)) by omega; auto.
- Qed.
- Hint Rewrite rev_add_rev.
-
- Lemma mul_bi'_length : forall us n, length (mul_bi' n us) = length us.
- Proof.
- induction us, n; boring.
- Qed.
- Hint Rewrite mul_bi'_length.
-
- Lemma add_comm : forall us vs, us .+ vs = vs .+ us.
- Proof.
- induction us, vs; boring; f_equal; auto.
- Qed.
-
- Hint Rewrite rev_length.
-
- Lemma mul_bi_add_same_length : forall n us vs l,
- (length us = l) -> (length vs = l) ->
- mul_bi n (us .+ vs) = mul_bi n us .+ mul_bi n vs.
- Proof.
- unfold mul_bi; boring.
- rewrite add_leading_zeros.
- erewrite mul_bi'_add; boring.
- erewrite rev_add_rev; boring.
- Qed.
-
- Lemma add_zeros_same_length : forall us, us .+ (zeros (length us)) = us.
- Proof.
- induction us; boring; f_equal; omega.
- Qed.
-
- Hint Rewrite add_zeros_same_length.
- Hint Rewrite minus_diag.
-
- Lemma add_trailing_zeros : forall us vs, (length us >= length vs)%nat ->
- us .+ vs = us .+ (vs ++ (zeros (length us - length vs))).
- Proof.
- induction us, vs; boring; f_equal; boring.
- Qed.
-
- Lemma length_add_ge : forall us vs, (length us >= length vs)%nat ->
- (length (us .+ vs) <= length us)%nat.
- Proof.
- intros.
- rewrite add_trailing_zeros by trivial.
- erewrite add_same_length by (pose proof app_length; boring); omega.
- Qed.
-
- Lemma add_length_le_max : forall us vs,
- (length (us .+ vs) <= max (length us) (length vs))%nat.
- Proof.
- intros; case_max; (rewrite add_comm; apply length_add_ge; omega) ||
- (apply length_add_ge; omega) .
- Qed.
-
- Lemma sub_nil_length: forall us : digits, length (sub nil us) = length us.
- Proof.
- induction us; boring.
- Qed.
-
- Lemma sub_length_le_max : forall us vs,
- (length (sub us vs) <= max (length us) (length vs))%nat.
- Proof.
- induction us, vs; boring.
- rewrite sub_nil_length; auto.
- Qed.
-
- Lemma mul_bi_length : forall us n, length (mul_bi n us) = (length us + n)%nat.
- Proof.
- pose proof mul_bi'_length; unfold mul_bi.
- destruct us; repeat progress (simpl_list; boring).
- Qed.
- Hint Rewrite mul_bi_length.
-
- Lemma mul_bi_trailing_zeros : forall m n us,
- mul_bi n us ++ zeros m = mul_bi n (us ++ zeros m).
- Proof.
- unfold mul_bi.
- induction m; intros; try solve [boring].
- rewrite <- zeros_app0.
- rewrite app_assoc.
- repeat progress (boring; rewrite rev_app_distr).
- Qed.
-
- Lemma mul_bi_add_longer : forall n us vs,
- (length us >= length vs)%nat ->
- mul_bi n (us .+ vs) = mul_bi n us .+ mul_bi n vs.
- Proof.
- boring.
- rewrite add_trailing_zeros by auto.
- rewrite (add_trailing_zeros (mul_bi n us) (mul_bi n vs))
- by (repeat (rewrite mul_bi_length); omega).
- erewrite mul_bi_add_same_length by
- (eauto; simpl_list; rewrite length_zeros; omega).
- rewrite mul_bi_trailing_zeros.
- repeat (f_equal; boring).
- Qed.
-
- Lemma mul_bi_add : forall n us vs,
- mul_bi n (us .+ vs) = (mul_bi n us) .+ (mul_bi n vs).
- Proof.
- intros; pose proof mul_bi_add_longer.
- destruct (le_ge_dec (length us) (length vs)). {
- replace (mul_bi n us .+ mul_bi n vs)
- with (mul_bi n vs .+ mul_bi n us)
- by (apply add_comm).
- replace (us .+ vs)
- with (vs .+ us)
- by (apply add_comm).
- boring.
- } {
- boring.
- }
- Qed.
-
- Lemma mul_bi_rep : forall i vs,
- (i + length vs < length base)%nat ->
- decode (mul_bi i vs) = decode vs * nth_default 0 base i.
- Proof.
- unfold decode.
- induction vs using rev_ind; intros; try solve [unfold mul_bi; boring].
- assert (i + length vs < length base)%nat by
- (rewrite app_length in *; boring).
-
- rewrite set_higher.
- ring_simplify.
- rewrite <- IHvs by auto; clear IHvs.
- rewrite <- mul_bi_single by auto.
- rewrite <- add_rep.
- rewrite <- mul_bi_add.
- rewrite set_higher'.
- auto.
- Qed.
-
(* mul' is multiplication with the FIRST ARGUMENT REVERSED *)
Fixpoint mul' (usr vs:digits) : digits :=
match usr with
@@ -500,51 +83,6 @@ Section BaseSystem.
end.
Definition mul us := mul' (rev us).
- Lemma mul'_rep : forall us vs,
- (length us + length vs <= length base)%nat ->
- decode (mul' (rev us) vs) = decode us * decode vs.
- Proof.
- unfold decode.
- induction us using rev_ind; boring.
-
- assert (length us + length vs < length base)%nat by
- (rewrite app_length in *; boring).
-
- ssimpl_list.
- rewrite add_rep.
- boring.
- rewrite set_higher.
- rewrite mul_each_rep.
- rewrite mul_bi_rep by auto.
- unfold decode; ring.
- Qed.
-
- Lemma mul_rep : forall us vs,
- (length us + length vs <= length base)%nat ->
- decode (mul us vs) = decode us * decode vs.
- Proof.
- exact mul'_rep.
- Qed.
-
- Lemma mul'_length: forall us vs,
- (length (mul' us vs) <= length us + length vs)%nat.
- Proof.
- pose proof add_length_le_max.
- induction us; boring.
- unfold mul_each.
- simpl_list; case_max; boring; omega.
- Qed.
-
- Lemma mul_length: forall us vs,
- (length (mul us vs) <= length us + length vs)%nat.
- Proof.
- intros; unfold mul.
- rewrite mul'_length.
- rewrite rev_length; omega.
- Qed.
-
-(* Print Assumptions mul_rep. *)
-
End BaseSystem.
Section PolynomialBaseCoefs.
@@ -615,7 +153,6 @@ Section PolynomialBaseCoefs.
split; apply Z.pow_nonzero; try apply Zle_0_nat; try solve [intro H; inversion H].
Qed.
- Print BaseVector.
Instance PolyBaseVector : BaseVector poly_base := {
base_positive := poly_base_positive;
b0_1 := poly_b0_1;
@@ -633,8 +170,7 @@ Section BaseSystemExample.
Qed.
Definition base2 := poly_base 2 baseLength.
- About mul.
- Example three_times_two : @mul base2 [1;1;0] [0;1;0] = [0;1;1;0;0].
+ Example three_times_two : mul base2 [1;1;0] [0;1;0] = [0;1;1;0;0].
Proof.
reflexivity.
Qed.
diff --git a/src/BaseSystemProofs.v b/src/BaseSystemProofs.v
new file mode 100644
index 000000000..84374fe8f
--- /dev/null
+++ b/src/BaseSystemProofs.v
@@ -0,0 +1,490 @@
+Require Import List.
+Require Import Util.ListUtil Util.CaseUtil Util.ZUtil.
+Require Import ZArith.ZArith ZArith.Zdiv.
+Require Import Omega NPeano Arith.
+Require Import Crypto.BaseSystem.
+Local Open Scope Z.
+
+Local Infix ".+" := add (at level 50).
+
+Local Hint Extern 1 (@eq Z _ _) => ring.
+
+Section BaseSystemProofs.
+ Context `(base_vector : BaseVector).
+
+ Lemma decode'_truncate : forall bs us, decode' bs us = decode' bs (firstn (length bs) us).
+ Proof.
+ unfold decode'; intros; f_equal; apply combine_truncate_l.
+ Qed.
+
+ Lemma add_rep : forall bs us vs, decode' bs (add us vs) = decode' bs us + decode' bs vs.
+ Proof.
+ unfold decode', accumulate; induction bs; destruct us, vs; boring; ring.
+ Qed.
+
+ Lemma decode_nil : forall bs, decode' bs nil = 0.
+ auto.
+ Qed.
+ Hint Rewrite decode_nil.
+
+ Lemma decode_base_nil : forall us, decode' nil us = 0.
+ Proof.
+ intros; rewrite decode'_truncate; auto.
+ Qed.
+ Hint Rewrite decode_base_nil.
+
+ Lemma mul_each_rep : forall bs u vs,
+ decode' bs (mul_each u vs) = u * decode' bs vs.
+ Proof.
+ unfold decode', accumulate; induction bs; destruct vs; boring; ring.
+ Qed.
+
+ Lemma base_eq_1cons: base = 1 :: skipn 1 base.
+ Proof.
+ pose proof (b0_1 0) as H.
+ destruct base; compute in H; try discriminate; boring.
+ Qed.
+
+ Lemma decode'_cons : forall x1 x2 xs1 xs2,
+ decode' (x1 :: xs1) (x2 :: xs2) = x1 * x2 + decode' xs1 xs2.
+ Proof.
+ unfold decode', accumulate; boring; ring.
+ Qed.
+ Hint Rewrite decode'_cons.
+
+ Lemma decode_cons : forall x us,
+ decode base (x :: us) = x + decode base (0 :: us).
+ Proof.
+ unfold decode; intros.
+ rewrite base_eq_1cons.
+ autorewrite with core; ring_simplify; auto.
+ Qed.
+
+ Lemma sub_rep : forall bs us vs, decode' bs (sub us vs) = decode' bs us - decode' bs vs.
+ Proof.
+ induction bs; destruct us; destruct vs; boring; ring.
+ Qed.
+
+ Lemma encode_rep : forall z, decode base (encode z) = z.
+ Proof.
+ pose proof base_eq_1cons.
+ unfold decode, encode; destruct z; boring.
+ Qed.
+
+ Lemma mul_each_base : forall us bs c,
+ decode' bs (mul_each c us) = decode' (mul_each c bs) us.
+ Proof.
+ induction us; destruct bs; boring; ring.
+ Qed.
+
+ Hint Rewrite (@nth_default_nil Z).
+ Hint Rewrite (@firstn_nil Z).
+ Hint Rewrite (@skipn_nil Z).
+
+ Lemma base_app : forall us low high,
+ decode' (low ++ high) us = decode' low (firstn (length low) us) + decode' high (skipn (length low) us).
+ Proof.
+ induction us; destruct low; boring.
+ Qed.
+
+ Lemma base_mul_app : forall low c us,
+ decode' (low ++ mul_each c low) us = decode' low (firstn (length low) us) +
+ c * decode' low (skipn (length low) us).
+ Proof.
+ intros.
+ rewrite base_app; f_equal.
+ rewrite <- mul_each_rep.
+ rewrite mul_each_base.
+ reflexivity.
+ Qed.
+
+ Lemma zeros_rep : forall bs n, decode' bs (zeros n) = 0.
+ induction bs; destruct n; boring.
+ Qed.
+ Lemma length_zeros : forall n, length (zeros n) = n.
+ induction n; boring.
+ Qed.
+ Hint Rewrite length_zeros.
+
+ Lemma app_zeros_zeros : forall n m, zeros n ++ zeros m = zeros (n + m).
+ Proof.
+ induction n; boring.
+ Qed.
+ Hint Rewrite app_zeros_zeros.
+
+ Lemma zeros_app0 : forall m, zeros m ++ 0 :: nil = zeros (S m).
+ Proof.
+ induction m; boring.
+ Qed.
+ Hint Rewrite zeros_app0.
+
+ Lemma rev_zeros : forall n, rev (zeros n) = zeros n.
+ Proof.
+ induction n; boring.
+ Qed.
+ Hint Rewrite rev_zeros.
+
+ Hint Unfold nth_default.
+
+ Lemma decode_single : forall n bs x,
+ decode' bs (zeros n ++ x :: nil) = nth_default 0 bs n * x.
+ Proof.
+ induction n; destruct bs; boring.
+ Qed.
+ Hint Rewrite decode_single.
+
+ Lemma peel_decode : forall xs ys x y, decode' (x::xs) (y::ys) = x*y + decode' xs ys.
+ Proof.
+ boring.
+ Qed.
+ Hint Rewrite zeros_rep peel_decode.
+
+ Lemma decode_highzeros : forall xs bs n, decode' bs (xs ++ zeros n) = decode' bs xs.
+ Proof.
+ induction xs; destruct bs; boring.
+ Qed.
+
+ Lemma mul_bi'_zeros : forall n m, mul_bi' base n (zeros m) = zeros m.
+ induction m; boring.
+ Qed.
+ Hint Rewrite mul_bi'_zeros.
+
+ Lemma nth_error_base_nonzero : forall n x,
+ nth_error base n = Some x -> x <> 0.
+ Proof.
+ eauto using (@nth_error_value_In Z), Zgt0_neq0, base_positive.
+ Qed.
+
+ Hint Rewrite plus_0_r.
+
+ Lemma mul_bi_single : forall m n x,
+ (n + m < length base)%nat ->
+ decode base (mul_bi base n (zeros m ++ x :: nil)) = nth_default 0 base m * x * nth_default 0 base n.
+ Proof.
+ unfold mul_bi, decode.
+ destruct m; simpl; simpl_list; simpl; intros. {
+ pose proof nth_error_base_nonzero as nth_nonzero.
+ case_eq base; [intros; boring | intros z l base_eq].
+ specialize (b0_1 0); intro b0_1'.
+ rewrite base_eq in *.
+ rewrite nth_default_cons in b0_1'.
+ rewrite b0_1' in *.
+ unfold crosscoef.
+ autounfold; autorewrite with core.
+ unfold nth_default.
+ nth_tac.
+ rewrite Z.mul_1_r.
+ rewrite Z_div_same_full.
+ destruct x; ring.
+ eapply nth_nonzero; eauto.
+ } {
+ ssimpl_list.
+ autorewrite with core.
+ rewrite app_assoc.
+ autorewrite with core.
+ unfold crosscoef; simpl; ring_simplify.
+ rewrite Nat.add_1_r.
+ rewrite base_good by auto.
+ rewrite Z_div_mult by (apply base_positive; rewrite nth_default_eq; apply nth_In; auto).
+ rewrite <- Z.mul_assoc.
+ rewrite <- Z.mul_comm.
+ rewrite <- Z.mul_assoc.
+ rewrite <- Z.mul_assoc.
+ destruct (Z.eq_dec x 0); subst; try ring.
+ rewrite Z.mul_cancel_l by auto.
+ rewrite <- base_good by auto.
+ ring.
+ }
+ Qed.
+
+ Lemma set_higher' : forall vs x, vs++x::nil = vs .+ (zeros (length vs) ++ x :: nil).
+ induction vs; boring; f_equal; ring.
+ Qed.
+
+ Lemma set_higher : forall bs vs x,
+ decode' bs (vs++x::nil) = decode' bs vs + nth_default 0 bs (length vs) * x.
+ Proof.
+ intros.
+ rewrite set_higher'.
+ rewrite add_rep.
+ f_equal.
+ apply decode_single.
+ Qed.
+
+ Lemma zeros_plus_zeros : forall n, zeros n = zeros n .+ zeros n.
+ induction n; auto.
+ simpl; f_equal; auto.
+ Qed.
+
+ Lemma mul_bi'_n_nil : forall n, mul_bi' base n nil = nil.
+ Proof.
+ unfold mul_bi; auto.
+ Qed.
+ Hint Rewrite mul_bi'_n_nil.
+
+ Lemma add_nil_l : forall us, nil .+ us = us.
+ induction us; auto.
+ Qed.
+ Hint Rewrite add_nil_l.
+
+ Lemma add_nil_r : forall us, us .+ nil = us.
+ induction us; auto.
+ Qed.
+ Hint Rewrite add_nil_r.
+
+ Lemma add_first_terms : forall us vs a b,
+ (a :: us) .+ (b :: vs) = (a + b) :: (us .+ vs).
+ auto.
+ Qed.
+ Hint Rewrite add_first_terms.
+
+ Lemma mul_bi'_cons : forall n x us,
+ mul_bi' base n (x :: us) = x * crosscoef base n (length us) :: mul_bi' base n us.
+ Proof.
+ unfold mul_bi'; auto.
+ Qed.
+
+ Lemma add_same_length : forall us vs l, (length us = l) -> (length vs = l) ->
+ length (us .+ vs) = l.
+ Proof.
+ induction us, vs; boring.
+ erewrite (IHus vs (pred l)); boring.
+ Qed.
+
+ Hint Rewrite app_nil_l.
+ Hint Rewrite app_nil_r.
+
+ Lemma add_snoc_same_length : forall l us vs a b,
+ (length us = l) -> (length vs = l) ->
+ (us ++ a :: nil) .+ (vs ++ b :: nil) = (us .+ vs) ++ (a + b) :: nil.
+ Proof.
+ induction l, us, vs; boring; discriminate.
+ Qed.
+
+ Lemma mul_bi'_add : forall us n vs l
+ (Hlus: length us = l)
+ (Hlvs: length vs = l),
+ mul_bi' base n (rev (us .+ vs)) =
+ mul_bi' base n (rev us) .+ mul_bi' base n (rev vs).
+ Proof.
+ (* TODO(adamc): please help prettify this *)
+ induction us using rev_ind;
+ try solve [destruct vs; boring; congruence].
+ destruct vs using rev_ind; boring; clear IHvs; simpl_list.
+ erewrite (add_snoc_same_length (pred l) us vs _ _); simpl_list.
+ repeat rewrite mul_bi'_cons; rewrite add_first_terms; simpl_list.
+ rewrite (IHus n vs (pred l)).
+ replace (length us) with (pred l).
+ replace (length vs) with (pred l).
+ rewrite (add_same_length us vs (pred l)).
+ f_equal; ring.
+
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ Qed.
+
+ Lemma zeros_cons0 : forall n, 0 :: zeros n = zeros (S n).
+ auto.
+ Qed.
+
+ Lemma add_leading_zeros : forall n us vs,
+ (zeros n ++ us) .+ (zeros n ++ vs) = zeros n ++ (us .+ vs).
+ Proof.
+ induction n; boring.
+ Qed.
+
+ Lemma rev_add_rev : forall us vs l, (length us = l) -> (length vs = l) ->
+ (rev us) .+ (rev vs) = rev (us .+ vs).
+ Proof.
+ induction us, vs; boring; try solve [subst; discriminate].
+ rewrite (add_snoc_same_length (pred l) _ _ _ _) by (subst; simpl_list; omega).
+ rewrite (IHus vs (pred l)) by omega; auto.
+ Qed.
+ Hint Rewrite rev_add_rev.
+
+ Lemma mul_bi'_length : forall us n, length (mul_bi' base n us) = length us.
+ Proof.
+ induction us, n; boring.
+ Qed.
+ Hint Rewrite mul_bi'_length.
+
+ Lemma add_comm : forall us vs, us .+ vs = vs .+ us.
+ Proof.
+ induction us, vs; boring; f_equal; auto.
+ Qed.
+
+ Hint Rewrite rev_length.
+
+ Lemma mul_bi_add_same_length : forall n us vs l,
+ (length us = l) -> (length vs = l) ->
+ mul_bi base n (us .+ vs) = mul_bi base n us .+ mul_bi base n vs.
+ Proof.
+ unfold mul_bi; boring.
+ rewrite add_leading_zeros.
+ erewrite mul_bi'_add; boring.
+ erewrite rev_add_rev; boring.
+ Qed.
+
+ Lemma add_zeros_same_length : forall us, us .+ (zeros (length us)) = us.
+ Proof.
+ induction us; boring; f_equal; omega.
+ Qed.
+
+ Hint Rewrite add_zeros_same_length.
+ Hint Rewrite minus_diag.
+
+ Lemma add_trailing_zeros : forall us vs, (length us >= length vs)%nat ->
+ us .+ vs = us .+ (vs ++ (zeros (length us - length vs))).
+ Proof.
+ induction us, vs; boring; f_equal; boring.
+ Qed.
+
+ Lemma length_add_ge : forall us vs, (length us >= length vs)%nat ->
+ (length (us .+ vs) <= length us)%nat.
+ Proof.
+ intros.
+ rewrite add_trailing_zeros by trivial.
+ erewrite add_same_length by (pose proof app_length; boring); omega.
+ Qed.
+
+ Lemma add_length_le_max : forall us vs,
+ (length (us .+ vs) <= max (length us) (length vs))%nat.
+ Proof.
+ intros; case_max; (rewrite add_comm; apply length_add_ge; omega) ||
+ (apply length_add_ge; omega) .
+ Qed.
+
+ Lemma sub_nil_length: forall us : digits, length (sub nil us) = length us.
+ Proof.
+ induction us; boring.
+ Qed.
+
+ Lemma sub_length_le_max : forall us vs,
+ (length (sub us vs) <= max (length us) (length vs))%nat.
+ Proof.
+ induction us, vs; boring.
+ rewrite sub_nil_length; auto.
+ Qed.
+
+ Lemma mul_bi_length : forall us n, length (mul_bi base n us) = (length us + n)%nat.
+ Proof.
+ pose proof mul_bi'_length; unfold mul_bi.
+ destruct us; repeat progress (simpl_list; boring).
+ Qed.
+ Hint Rewrite mul_bi_length.
+
+ Lemma mul_bi_trailing_zeros : forall m n us,
+ mul_bi base n us ++ zeros m = mul_bi base n (us ++ zeros m).
+ Proof.
+ unfold mul_bi.
+ induction m; intros; try solve [boring].
+ rewrite <- zeros_app0.
+ rewrite app_assoc.
+ repeat progress (boring; rewrite rev_app_distr).
+ Qed.
+
+ Lemma mul_bi_add_longer : forall n us vs,
+ (length us >= length vs)%nat ->
+ mul_bi base n (us .+ vs) = mul_bi base n us .+ mul_bi base n vs.
+ Proof.
+ boring.
+ rewrite add_trailing_zeros by auto.
+ rewrite (add_trailing_zeros (mul_bi base n us) (mul_bi base n vs))
+ by (repeat (rewrite mul_bi_length); omega).
+ erewrite mul_bi_add_same_length by
+ (eauto; simpl_list; rewrite length_zeros; omega).
+ rewrite mul_bi_trailing_zeros.
+ repeat (f_equal; boring).
+ Qed.
+
+ Lemma mul_bi_add : forall n us vs,
+ mul_bi base n (us .+ vs) = (mul_bi base n us) .+ (mul_bi base n vs).
+ Proof.
+ intros; pose proof mul_bi_add_longer.
+ destruct (le_ge_dec (length us) (length vs)). {
+ rewrite add_comm.
+ rewrite (add_comm (mul_bi base n us)).
+ boring.
+ } {
+ boring.
+ }
+ Qed.
+
+ Lemma mul_bi_rep : forall i vs,
+ (i + length vs < length base)%nat ->
+ decode base (mul_bi base i vs) = decode base vs * nth_default 0 base i.
+ Proof.
+ unfold decode.
+ induction vs using rev_ind; intros; try solve [unfold mul_bi; boring].
+ assert (i + length vs < length base)%nat by
+ (rewrite app_length in *; boring).
+
+ rewrite set_higher.
+ ring_simplify.
+ rewrite <- IHvs by auto; clear IHvs.
+ rewrite <- mul_bi_single by auto.
+ rewrite <- add_rep.
+ rewrite <- mul_bi_add.
+ rewrite set_higher'.
+ auto.
+ Qed.
+
+ (* mul' is multiplication with the FIRST ARGUMENT REVERSED *)
+ Fixpoint mul' (usr vs:digits) : digits :=
+ match usr with
+ | u::usr' =>
+ mul_each u (mul_bi base (length usr') vs) .+ mul' usr' vs
+ | _ => nil
+ end.
+ Definition mul us := mul' (rev us).
+
+ Lemma mul'_rep : forall us vs,
+ (length us + length vs <= length base)%nat ->
+ decode base (mul' (rev us) vs) = decode base us * decode base vs.
+ Proof.
+ unfold decode.
+ induction us using rev_ind; boring.
+
+ assert (length us + length vs < length base)%nat by
+ (rewrite app_length in *; boring).
+
+ ssimpl_list.
+ rewrite add_rep.
+ boring.
+ rewrite set_higher.
+ rewrite mul_each_rep.
+ rewrite mul_bi_rep by auto.
+ unfold decode; ring.
+ Qed.
+
+ Lemma mul_rep : forall us vs,
+ (length us + length vs <= length base)%nat ->
+ decode base (mul us vs) = decode base us * decode base vs.
+ Proof.
+ exact mul'_rep.
+ Qed.
+
+ Lemma mul'_length: forall us vs,
+ (length (mul' us vs) <= length us + length vs)%nat.
+ Proof.
+ pose proof add_length_le_max.
+ induction us; boring.
+ unfold mul_each.
+ simpl_list; case_max; boring; omega.
+ Qed.
+
+ Lemma mul_length: forall us vs,
+ (length (mul us vs) <= length us + length vs)%nat.
+ Proof.
+ intros; unfold mul.
+ rewrite mul'_length.
+ rewrite rev_length; omega.
+ Qed.
+
+End BaseSystemProofs.
diff --git a/src/ModularArithmetic/ExtendedBaseVector.v b/src/ModularArithmetic/ExtendedBaseVector.v
new file mode 100644
index 000000000..746791d8d
--- /dev/null
+++ b/src/ModularArithmetic/ExtendedBaseVector.v
@@ -0,0 +1,163 @@
+Require Import Zpower ZArith.
+Require Import List.
+Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import VerdiTactics.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
+Require Crypto.BaseSystem.
+Local Open Scope Z_scope.
+
+Section ExtendedBaseVector.
+ Context `{prm : PseudoMersenneBaseParams}.
+ Existing Instance bv.
+
+ (* This section defines a new BaseVector that has double the length of the BaseVector
+ * used to construct [params]. The coefficients of the new vector are as follows:
+ *
+ * ext_base[i] = if (i < length base) then base[i] else 2^k * base[i]
+ *
+ * The purpose of this construction is that it allows us to multiply numbers expressed
+ * using [base], obtaining a number expressed using [ext_base]. (Numbers are "expressed" as
+ * vectors of digits; the value of a digit vector is obtained by doing a dot product with
+ * the base vector.) So if x, y are digit vectors:
+ *
+ * (x \dot base) * (y \dot base) = (z \dot ext_base)
+ *
+ * Then we can separate z into its first and second halves:
+ *
+ * (z \dot ext_base) = (z1 \dot base) + (2 ^ k) * (z2 \dot base)
+ *
+ * Now, if we want to reduce the product modulo 2 ^ k - c:
+ *
+ * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + (2 ^ k) * (z2 \dot base) mod (2^k-c)
+ * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + c * (z2 \dot base) mod (2^k-c)
+ *
+ * This sum may be short enough to express using base; if not, we can reduce again.
+ *)
+ Definition ext_base := base ++ (map (Z.mul (2^k)) base).
+
+ Lemma ext_base_positive : forall b, In b ext_base -> b > 0.
+ Proof.
+ unfold ext_base. intros b In_b_base.
+ rewrite in_app_iff in In_b_base.
+ destruct In_b_base as [In_b_base | In_b_extbase].
+ + eapply BaseSystem.base_positive.
+ eapply In_b_base.
+ + eapply in_map_iff in In_b_extbase.
+ destruct In_b_extbase as [b' [b'_2k_b In_b'_base]].
+ subst.
+ specialize (BaseSystem.base_positive b' In_b'_base); intro base_pos.
+ replace 0 with (2 ^ k * 0) by ring.
+ apply (Zmult_gt_compat_l b' 0 (2 ^ k)); [| apply base_pos; intuition].
+ rewrite Z.gt_lt_iff.
+ apply Z.pow_pos_nonneg; intuition.
+ pose proof k_nonneg; omega.
+ Qed.
+
+ Lemma base_length_nonzero : (0 < length base)%nat.
+ Proof.
+ assert (nth_default 0 base 0 = 1) by (apply BaseSystem.b0_1).
+ unfold nth_default in H.
+ case_eq (nth_error base 0); intros;
+ try (rewrite H0 in H; omega).
+ apply (nth_error_value_length _ 0 base z); auto.
+ Qed.
+
+ Lemma b0_1 : forall x, nth_default x ext_base 0 = 1.
+ Proof.
+ intros. unfold ext_base.
+ rewrite nth_default_app.
+ assert (0 < length base)%nat by (apply base_length_nonzero).
+ destruct (lt_dec 0 (length base)); try apply BaseSystem.b0_1; try omega.
+ Qed.
+
+ Lemma two_k_nonzero : 2^k <> 0.
+ Proof.
+ pose proof (Z.pow_eq_0 2 k k_nonneg).
+ intuition.
+ Qed.
+
+ Lemma map_nth_default_base_high : forall n, (n < (length base))%nat ->
+ nth_default 0 (map (Z.mul (2 ^ k)) base) n =
+ (2 ^ k) * (nth_default 0 base n).
+ Proof.
+ intros.
+ erewrite map_nth_default; auto.
+ Qed.
+
+ Lemma base_good_over_boundary : forall
+ (i : nat)
+ (l : (i < length base)%nat)
+ (j' : nat)
+ (Hj': (i + j' < length base)%nat)
+ ,
+ 2 ^ k * (nth_default 0 base i * nth_default 0 base j') =
+ 2 ^ k * (nth_default 0 base i * nth_default 0 base j') /
+ (2 ^ k * nth_default 0 base (i + j')) *
+ (2 ^ k * nth_default 0 base (i + j'))
+ .
+ Proof.
+ intros.
+ remember (nth_default 0 base) as b.
+ rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero).
+ replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat))
+ with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring.
+ rewrite Z.mul_cancel_l by (exact two_k_nonzero).
+ replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat))
+ with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring.
+ subst b.
+ apply (BaseSystem.base_good i j'); omega.
+ Qed.
+
+ Lemma ext_base_good :
+ forall i j, (i+j < length ext_base)%nat ->
+ let b := nth_default 0 ext_base in
+ let r := (b i * b j) / b (i+j)%nat in
+ b i * b j = r * b (i+j)%nat.
+ Proof.
+ intros.
+ subst b. subst r.
+ unfold ext_base in *.
+ rewrite app_length in H; rewrite map_length in H.
+ repeat rewrite nth_default_app.
+ destruct (lt_dec i (length base));
+ destruct (lt_dec j (length base));
+ destruct (lt_dec (i + j) (length base));
+ try omega.
+ { (* i < length base, j < length base, i + j < length base *)
+ apply BaseSystem.base_good; auto.
+ } { (* i < length base, j < length base, i + j >= length base *)
+ rewrite (map_nth_default _ _ _ _ 0) by omega.
+ apply base_matches_modulus; omega.
+ } { (* i < length base, j >= length base, i + j >= length base *)
+ do 2 rewrite map_nth_default_base_high by omega.
+ remember (j - length base)%nat as j'.
+ replace (i + j - length base)%nat with (i + j')%nat by omega.
+ replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j'))
+ with (2 ^ k * (nth_default 0 base i * nth_default 0 base j'))
+ by ring.
+ eapply base_good_over_boundary; eauto; omega.
+ } { (* i >= length base, j < length base, i + j >= length base *)
+ do 2 rewrite map_nth_default_base_high by omega.
+ remember (i - length base)%nat as i'.
+ replace (i + j - length base)%nat with (j + i')%nat by omega.
+ replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j)
+ with (2 ^ k * (nth_default 0 base j * nth_default 0 base i'))
+ by ring.
+ eapply base_good_over_boundary; eauto; omega.
+ }
+ Qed.
+
+ Instance ExtBaseVector : BaseSystem.BaseVector ext_base := {
+ base_positive := ext_base_positive;
+ b0_1 := b0_1;
+ base_good := ext_base_good
+ }.
+
+ Lemma extended_base_length:
+ length ext_base = (length base + length base)%nat.
+ Proof.
+ unfold ext_base; rewrite app_length; rewrite map_length; auto.
+ Qed.
+End ExtendedBaseVector.
+
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v
index b0e493871..d121e9d5c 100644
--- a/src/ModularArithmetic/ModularBaseSystem.v
+++ b/src/ModularArithmetic/ModularBaseSystem.v
@@ -1,643 +1,58 @@
Require Import Zpower ZArith.
Require Import List.
-Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import Crypto.Util.ListUtil.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
-Require Import VerdiTactics.
-Require Crypto.BaseSystem.
+Require Import Crypto.BaseSystem Crypto.ModularArithmetic.PseudoMersenneBaseParams Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs Crypto.ModularArithmetic.ExtendedBaseVector.
Local Open Scope Z_scope.
-Class PseudoMersenneBaseParams (modulus : Z) (base : list Z) (bv : BaseSystem.BaseVector base) := {
- k : Z;
- c : Z;
- modulus_pseudomersenne : modulus = 2^k - c;
- prime_modulus : Znumtheory.prime modulus;
- base_matches_modulus :
- forall i j,
- (i < length base)%nat ->
- (j < length base)%nat ->
- (i+j >= length base)%nat->
- let b := nth_default 0 base in
- let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in
- b i * b j = r * (2^k * b (i+j-length base)%nat);
- base_succ : forall i, ((S i) < length base)%nat ->
- let b := nth_default 0 base in
- b (S i) mod b i = 0;
- base_tail_matches_modulus:
- 2^k mod nth_default 0 base (pred (length base)) = 0;
- k_nonneg : 0 <= k (* Probably implied by modulus_pseudomersenne. *)
-}.
-(*
-Class RepZMod (modulus : Z) := {
- T : Type;
- encode : F modulus -> T;
- decode : T -> F modulus;
-
- rep : T -> F modulus -> Prop;
- encode_rep : forall x, rep (encode x) x;
- rep_decode : forall u x, rep u x -> decode u = x;
-
- add : T -> T -> T;
- add_rep : forall u v x y, rep u x -> rep v y -> rep (add u v) (x+y)%F;
-
- sub : T -> T -> T;
- sub_rep : forall u v x y, rep u x -> rep v y -> rep (sub u v) (x-y)%F;
-
- mul : T -> T -> T;
- mul_rep : forall u v x y, rep u x -> rep v y -> rep (mul u v) (x*y)%F
-}.
-*)
-Print PseudoMersenneBaseParams.
-Section ExtendedBaseVector.
- Context (base : list Z) {modulus : Z} `(params : PseudoMersenneBaseParams modulus base).
- (* This section defines a new BaseVector that has double the length of the BaseVector
- * used to construct [params]. The coefficients of the new vector are as follows:
- *
- * ext_base[i] = if (i < length base) then base[i] else 2^k * base[i]
- *
- * The purpose of this construction is that it allows us to multiply numbers expressed
- * using [base], obtaining a number expressed using [ext_base]. (Numbers are "expressed" as
- * vectors of digits; the value of a digit vector is obtained by doing a dot product with
- * the base vector.) So if x, y are digit vectors:
- *
- * (x \dot base) * (y \dot base) = (z \dot ext_base)
- *
- * Then we can separate z into its first and second halves:
- *
- * (z \dot ext_base) = (z1 \dot base) + (2 ^ k) * (z2 \dot base)
- *
- * Now, if we want to reduce the product modulo 2 ^ k - c:
- *
- * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + (2 ^ k) * (z2 \dot base) mod (2^k-c)
- * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + c * (z2 \dot base) mod (2^k-c)
- *
- * This sum may be short enough to express using base; if not, we can reduce again.
- *)
- Definition ext_base := base ++ (map (Z.mul (2^k)) base).
-
- Lemma ext_base_positive : forall b, In b ext_base -> b > 0.
- Proof.
- unfold ext_base. intros b In_b_base.
- rewrite in_app_iff in In_b_base.
- destruct In_b_base as [? | In_b_extbase]; auto using BaseSystem.base_positive.
- apply in_map_iff in In_b_extbase.
- destruct In_b_extbase as [b' [b'_2k_b In_b'_base]].
- subst.
- specialize (BaseSystem.base_positive b' In_b'_base); intro base_pos.
- replace 0 with (2 ^ k * 0) by ring.
- apply (Zmult_gt_compat_l b' 0 (2 ^ k)); [| apply base_pos; intuition].
- rewrite Z.gt_lt_iff.
- apply Z.pow_pos_nonneg; intuition.
- pose proof k_nonneg; omega.
- Qed.
-
- Lemma base_length_nonzero : (0 < length base)%nat.
- Proof.
- assert (nth_default 0 base 0 = 1) by (apply BaseSystem.b0_1).
- unfold nth_default in H.
- case_eq (nth_error base 0); intros;
- try (rewrite H0 in H; omega).
- apply (nth_error_value_length _ 0 base z); auto.
- Qed.
-
- Lemma b0_1 : forall x, nth_default x ext_base 0 = 1.
- Proof.
- intros. unfold ext_base.
- rewrite nth_default_app.
- assert (0 < length base)%nat by (apply base_length_nonzero).
- destruct (lt_dec 0 (length base)); try apply BaseSystem.b0_1; try omega.
- Qed.
-
- Lemma two_k_nonzero : 2^k <> 0.
- Proof.
- pose proof (Z.pow_eq_0 2 k k_nonneg).
- intuition.
- Qed.
-
- Lemma map_nth_default_base_high : forall n, (n < (length base))%nat ->
- nth_default 0 (map (Z.mul (2 ^ k)) base) n =
- (2 ^ k) * (nth_default 0 base n).
- Proof.
- intros.
- erewrite map_nth_default; auto.
- Qed.
-
- Lemma ext_base_succ : forall i, ((S i) < length ext_base)%nat ->
- let b := nth_default 0 ext_base in
- b (S i) mod b i = 0.
- Proof.
- intros; subst b; unfold ext_base.
- repeat rewrite nth_default_app.
- do 2 break_if; [apply base_succ; auto | omega | | ]. {
- destruct (lt_eq_lt_dec (S i) (length base)); try omega.
- destruct s; intuition.
- rewrite map_nth_default_base_high by omega.
- replace i with (pred(length base)) by omega.
- rewrite <- Zmult_mod_idemp_l.
- rewrite base_tail_matches_modulus.
- rewrite Zmod_0_l; auto.
- } {
- unfold ext_base in *; rewrite app_length, map_length in *.
- repeat rewrite map_nth_default_base_high by omega.
- rewrite Zmult_mod_distr_l.
- rewrite <- minus_Sn_m by omega.
- rewrite base_succ by omega; ring.
- }
- Qed.
-
- Lemma base_good_over_boundary : forall
- (i : nat)
- (l : (i < length base)%nat)
- (j' : nat)
- (Hj': (i + j' < length base)%nat)
- ,
- 2 ^ k * (nth_default 0 base i * nth_default 0 base j') =
- 2 ^ k * (nth_default 0 base i * nth_default 0 base j') /
- (2 ^ k * nth_default 0 base (i + j')) *
- (2 ^ k * nth_default 0 base (i + j'))
- .
- Proof.
- intros.
- remember (nth_default 0 base) as b.
- rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero).
- replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat))
- with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring.
- rewrite Z.mul_cancel_l by (exact two_k_nonzero).
- replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat))
- with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring.
- subst b.
- apply (BaseSystem.base_good i j'); omega.
- Qed.
-
- Lemma ext_base_good :
- forall i j, (i+j < length ext_base)%nat ->
- let b := nth_default 0 ext_base in
- let r := (b i * b j) / b (i+j)%nat in
- b i * b j = r * b (i+j)%nat.
- Proof.
- intros.
- subst b. subst r.
- unfold ext_base in *.
- rewrite app_length in H; rewrite map_length in H.
- repeat rewrite nth_default_app.
- destruct (lt_dec i (length base));
- destruct (lt_dec j (length base));
- destruct (lt_dec (i + j) (length base));
- try omega.
- { (* i < length base, j < length base, i + j < length base *)
- apply BaseSystem.base_good; auto.
- } { (* i < length base, j < length base, i + j >= length base *)
- rewrite (map_nth_default _ _ _ _ 0) by omega.
- apply base_matches_modulus; omega.
- } { (* i < length base, j >= length base, i + j >= length base *)
- do 2 rewrite map_nth_default_base_high by omega.
- remember (j - length base)%nat as j'.
- replace (i + j - length base)%nat with (i + j')%nat by omega.
- replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j'))
- with (2 ^ k * (nth_default 0 base i * nth_default 0 base j'))
- by ring.
- eapply base_good_over_boundary; eauto; omega.
- } { (* i >= length base, j < length base, i + j >= length base *)
- do 2 rewrite map_nth_default_base_high by omega.
- remember (i - length base)%nat as i'.
- replace (i + j - length base)%nat with (j + i')%nat by omega.
- replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j)
- with (2 ^ k * (nth_default 0 base j * nth_default 0 base i'))
- by ring.
- eapply base_good_over_boundary; eauto; omega.
- }
- Qed.
- Instance ExtBaseVector : BaseSystem.BaseVector ext_base := {
- base_positive := ext_base_positive;
- b0_1 := b0_1;
- base_good := ext_base_good
- }.
-End ExtendedBaseVector.
-
-Print ExtBaseVector.
Section PseudoMersenneBase.
- Context `(prm :PseudoMersenneBaseParams).
-
- Definition T := BaseSystem.digits.
- Definition decode (us : T) : F modulus := ZToField (BaseSystem.decode base us).
- Local Hint Unfold decode.
- Definition rep (us : T) (x : F modulus) := (length us <= length base)%nat /\ decode us = x.
+ Context `{prm :PseudoMersenneBaseParams}.
+ Existing Instance bv.
+
+ Definition decode (us : digits) : F modulus := ZToField (BaseSystem.decode base us).
+
+ Definition rep (us : digits) (x : F modulus) := (length us <= length base)%nat /\ decode us = x.
Local Notation "u '~=' x" := (rep u x) (at level 70).
Local Hint Unfold rep.
- Lemma rep_decode : forall us x, us ~= x -> decode us = x.
- Proof.
- autounfold; intuition.
- Qed.
-
- Definition encode (x : F modulus) := BaseSystem.encode x.
-
- Lemma encode_rep : forall x : F modulus, encode x ~= x.
- Proof.
- intros. unfold encode, rep.
- split. {
- unfold encode; simpl.
- apply base_length_nonzero.
- assumption.
- } {
- unfold decode.
- rewrite BaseSystem.encode_rep.
- apply ZToField_FieldToZ.
- assumption.
- }
- Qed.
-
- Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.add u v ~= (x+y)%F.
- Proof.
- autounfold; intuition. {
- unfold add.
- rewrite BaseSystem.add_length_le_max.
- case_max; try rewrite Max.max_r; omega.
- }
- unfold decode in *; unfold BaseSystem.decode in *.
- rewrite BaseSystem.add_rep.
- rewrite ZToField_add.
- subst; auto.
- Qed.
-
- Lemma sub_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.sub u v ~= (x-y)%F.
- Proof.
- autounfold; intuition. {
- rewrite BaseSystem.sub_length_le_max.
- case_max; try rewrite Max.max_r; omega.
- }
- unfold decode in *; unfold BaseSystem.decode in *.
- rewrite BaseSystem.sub_rep.
- rewrite ZToField_sub.
- subst; auto.
- Qed.
-
- Lemma decode_short : forall (us : T),
- (length us <= length base)%nat ->
- BaseSystem.decode base us = BaseSystem.decode (ext_base base prm) us.
- Proof.
- intros.
- unfold BaseSystem.decode, BaseSystem.decode'.
- rewrite combine_truncate_r.
- rewrite (combine_truncate_r us (ext_base base prm)).
- f_equal; f_equal.
- unfold ext_base.
- rewrite firstn_app_inleft; auto; omega.
- Qed.
-
- Lemma extended_base_length:
- length (ext_base base prm) = (length base + length base)%nat.
- Proof.
- unfold ext_base; rewrite app_length; rewrite map_length; auto.
- Qed.
-
- Lemma mul_rep_extended : forall (us vs : T),
- (length us <= length base)%nat ->
- (length vs <= length base)%nat ->
- (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode (ext_base base prm) (BaseSystem.mul (ext_base base prm) us vs).
- Proof.
- intros.
- rewrite BaseSystem.mul_rep by (apply ExtBaseVector || unfold ext_base; simpl_list; omega).
- f_equal; rewrite decode_short; auto.
- Qed.
+ Definition encode (x : F modulus) := encode x.
(* Converts from length of extended base to length of base by reduction modulo M.*)
- Definition reduce (us : T) : T :=
+ Definition reduce (us : digits) : digits :=
let high := skipn (length base) us in
let low := firstn (length base) us in
let wrap := map (Z.mul c) high in
BaseSystem.add low wrap.
- Lemma modulus_nonzero : modulus <> 0.
- pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega.
- Qed.
-
- (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *)
- Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus.
- Proof.
- intros.
- replace (2^k) with ((2^k - c) + c) by ring.
- rewrite Z.mul_add_distr_r.
- rewrite Zplus_mod.
- rewrite <- modulus_pseudomersenne.
- rewrite Z.mul_comm.
- rewrite mod_mult_plus; auto using modulus_nonzero.
- rewrite <- Zplus_mod; auto.
- Qed.
+ Definition mul (us vs : digits) := reduce (BaseSystem.mul ext_base us vs).
- Lemma extended_shiftadd: forall (us : T),
- BaseSystem.decode (ext_base base prm) us =
- BaseSystem.decode base (firstn (length base) us)
- + (2^k * BaseSystem.decode base (skipn (length base) us)).
- Proof.
- intros.
- unfold BaseSystem.decode; rewrite <- BaseSystem.mul_each_rep.
- unfold ext_base.
- replace (map (Z.mul (2 ^ k)) base) with (BaseSystem.mul_each (2 ^ k) base) by auto.
- rewrite BaseSystem.base_mul_app.
- rewrite <- BaseSystem.mul_each_rep; auto.
- Qed.
-
- Lemma reduce_rep : forall us,
- BaseSystem.decode base (reduce us) mod modulus =
- BaseSystem.decode (ext_base base prm) us mod modulus.
- Proof.
- intros.
- rewrite extended_shiftadd.
- rewrite pseudomersenne_add.
- unfold reduce.
- remember (firstn (length base) us) as low.
- remember (skipn (length base) us) as high.
- unfold BaseSystem.decode.
- rewrite BaseSystem.add_rep.
- replace (map (Z.mul c) high) with (BaseSystem.mul_each c high) by auto.
- rewrite BaseSystem.mul_each_rep; auto.
- Qed.
+End PseudoMersenneBase.
- Lemma reduce_length : forall us,
- (length us <= length (ext_base base prm))%nat ->
- (length (reduce us) <= length (base))%nat.
- Proof.
- intros.
- unfold reduce.
- remember (map (Z.mul c) (skipn (length base) us)) as high.
- remember (firstn (length base) us) as low.
- assert (length low >= length high)%nat. {
- subst. rewrite firstn_length.
- rewrite map_length.
- rewrite skipn_length.
- destruct (le_dec (length base) (length us)). {
- rewrite Min.min_l by omega.
- rewrite extended_base_length in H. omega.
- } {
- rewrite Min.min_r; omega.
- }
- }
- assert ((length low <= length base)%nat)
- by (rewrite Heqlow; rewrite firstn_length; apply Min.le_min_l).
- assert (length high <= length base)%nat
- by (rewrite Heqhigh; rewrite map_length; rewrite skipn_length;
- rewrite extended_base_length in H; omega).
- rewrite BaseSystem.add_trailing_zeros; auto.
- rewrite (BaseSystem.add_same_length _ _ (length low)); auto.
- rewrite app_length.
- rewrite BaseSystem.length_zeros; intuition.
- Qed.
+Section CarryBasePow2.
+ Context `{prm :PseudoMersenneBaseParams}.
- Definition mul (us vs : T) := reduce (BaseSystem.mul (ext_base base prm) us vs).
- Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> mul u v ~= (x*y)%F.
- Proof.
- autounfold; unfold mul; intuition.
- {
- apply reduce_length.
- rewrite BaseSystem.mul_length, extended_base_length.
- omega.
- } {
- rewrite ZToField_mod, reduce_rep, <-ZToField_mod.
- rewrite BaseSystem.mul_rep by
- (apply ExtBaseVector || rewrite extended_base_length; omega).
- subst.
- do 2 rewrite decode_short by auto.
- apply ZToField_mul.
- }
- Qed.
+ Definition log_cap i := nth_default 0 limb_widths i.
Definition add_to_nth n (x:Z) xs :=
set_nth n (x + nth_default 0 xs n) xs.
- Hint Unfold add_to_nth.
-
- (* i must be in the domain of base *)
- Definition cap i :=
- if eq_nat_dec i (pred (length base))
- then (2^k) / nth_default 0 base i
- else nth_default 0 base (S i) / nth_default 0 base i.
+
+ Definition pow2_mod n i := Z.land n (Z.ones i).
Definition carry_simple i := fun us =>
let di := nth_default 0 us i in
- let us' := set_nth i (di mod cap i) us in
- add_to_nth (S i) ( (di / cap i)) us'.
+ let us' := set_nth i (pow2_mod di (log_cap i)) us in
+ add_to_nth (S i) ( (Z.shiftr di (log_cap i))) us'.
Definition carry_and_reduce i := fun us =>
let di := nth_default 0 us i in
- let us' := set_nth i (di mod cap i) us in
- add_to_nth 0 (c * (di / cap i)) us'.
+ let us' := set_nth i (pow2_mod di (log_cap i)) us in
+ add_to_nth 0 (c * (Z.shiftr di (log_cap i))) us'.
- Definition carry i : T -> T :=
+ Definition carry i : digits -> digits :=
if eq_nat_dec i (pred (length base))
then carry_and_reduce i
else carry_simple i.
- (* TODO: move to BaseSystemProofs *)
- Lemma decode'_splice : forall xs ys bs,
- BaseSystem.decode' bs (xs ++ ys) =
- BaseSystem.decode' (firstn (length xs) bs) xs +
- BaseSystem.decode' (skipn (length xs) bs) ys.
- Proof.
- unfold BaseSystem.decode'.
- induction xs; destruct ys, bs; boring.
- + rewrite combine_truncate_r.
- do 2 rewrite Z.add_0_r; auto.
- + unfold BaseSystem.accumulate.
- apply Z.add_assoc.
- Qed.
-
- Lemma set_nth_sum : forall n x us, (n < length us)%nat ->
- BaseSystem.decode base (set_nth n x us) =
- (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us.
- Proof.
- intros.
- unfold BaseSystem.decode.
- nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*)
- unfold splice_nth.
- rewrite <- (firstn_skipn n us) at 4.
- do 2 rewrite decode'_splice.
- remember (length (firstn n us)) as n0.
- ring_simplify.
- remember (BaseSystem.decode' (firstn n0 base) (firstn n us)).
- rewrite (skipn_nth_default n us 0) by omega.
- rewrite firstn_length in Heqn0.
- rewrite Min.min_l in Heqn0 by omega; subst n0.
- destruct (le_lt_dec (length base) n). {
- rewrite nth_default_out_of_bounds by auto.
- rewrite skipn_all by omega.
- do 2 rewrite BaseSystem.decode_base_nil.
- ring_simplify; auto.
- } {
- rewrite (skipn_nth_default n base 0) by omega.
- do 2 rewrite BaseSystem.decode'_cons.
- ring_simplify; ring.
- }
- Qed.
-
- Lemma add_to_nth_sum : forall n x us, (n < length us)%nat ->
- BaseSystem.decode base (add_to_nth n x us) =
- x * nth_default 0 base n + BaseSystem.decode base us.
- Proof.
- unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto.
- Qed.
-
- Lemma nth_default_base_positive : forall i, (i < length base)%nat ->
- nth_default 0 base i > 0.
- Proof.
- intros.
- pose proof (nth_error_length_exists_value _ _ H).
- destruct H0.
- pose proof (nth_error_value_In _ _ _ H0).
- pose proof (BaseSystem.base_positive _ H1).
- unfold nth_default.
- rewrite H0; auto.
- Qed.
-
- Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat ->
- nth_default 0 base (S i) = nth_default 0 base i *
- (nth_default 0 base (S i) / nth_default 0 base i).
- Proof.
- intros.
- apply Z_div_exact_2; try (apply nth_default_base_positive; omega).
- apply base_succ; auto.
- Qed.
-
- Lemma base_length_lt_pred : (pred (length base) < length base)%nat.
- Proof.
- pose proof (base_length_nonzero base); omega.
- Qed.
- Hint Resolve base_length_lt_pred.
-
- Lemma cap_positive: forall i, (i < length base)%nat -> cap i > 0.
- Proof.
- unfold cap; intros; break_if. {
- apply div_positive_gt_0; try (subst; apply base_tail_matches_modulus). {
- rewrite <- two_p_equiv.
- apply two_p_gt_ZERO.
- apply k_nonneg.
- } {
- apply nth_default_base_positive; subst; auto.
- }
- } {
- apply div_positive_gt_0; try (apply base_succ; omega);
- try (apply nth_default_base_positive; omega).
- }
- Qed.
-
- Lemma cap_div_mod : forall us i, (i < (pred (length base)))%nat ->
- let di := nth_default 0 us i in
- (di - (di mod cap i)) * nth_default 0 base i =
- (di / cap i) * nth_default 0 base (S i).
- Proof.
- intros.
- rewrite (Z_div_mod_eq di (cap i)) at 1 by (apply cap_positive; omega);
- ring_simplify.
- unfold cap; break_if; intuition.
- rewrite base_succ_div_mult at 4 by omega; ring.
- Qed.
-
- Lemma carry_simple_decode_eq : forall i us,
- (length us = length base) ->
- (i < (pred (length base)))%nat ->
- BaseSystem.decode base (carry_simple i us) = BaseSystem.decode base us.
- Proof.
- unfold carry_simple. intros.
- rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
- rewrite set_nth_sum by omega.
- rewrite <- cap_div_mod by auto; ring_simplify; auto.
- Qed.
-
- Lemma two_k_div_mul_last :
- 2 ^ k = nth_default 0 base (pred (length base)) *
- (2 ^ k / nth_default 0 base (pred (length base))).
- Proof.
- intros.
- pose proof base_tail_matches_modulus.
- rewrite (Z_div_mod_eq (2 ^ k) (nth_default 0 base (pred (length base)))) at 1 by
- (apply nth_default_base_positive; auto); omega.
- Qed.
-
- Lemma cap_div_mod_reduce : forall us,
- let i := pred (length base) in
- let di := nth_default 0 us i in
- (di - (di mod cap i)) * nth_default 0 base i =
- (di / cap i) * 2 ^ k.
- Proof.
- intros.
- rewrite (Z_div_mod_eq di (cap i)) at 1 by
- (apply cap_positive; auto); ring_simplify.
- unfold cap; break_if; intuition.
- rewrite Z.mul_comm, Z.mul_assoc.
- subst i; rewrite <- two_k_div_mul_last; ring.
- Qed.
-
- Lemma carry_decode_eq_reduce : forall us,
- (length us = length base) ->
- BaseSystem.decode base (carry_and_reduce (pred (length base)) us) mod modulus
- = BaseSystem.decode base us mod modulus.
- Proof.
- unfold carry_and_reduce; intros.
- pose proof (base_length_nonzero base).
- rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
- rewrite set_nth_sum by omega.
- rewrite Zplus_comm, <- Z.mul_assoc.
- rewrite <- pseudomersenne_add.
- rewrite BaseSystem.b0_1.
- rewrite (Z.mul_comm (2 ^ k)).
- rewrite <- Zred_factor0.
- rewrite <- cap_div_mod_reduce by auto.
- do 2 rewrite Zmult_minus_distr_r.
- f_equal.
- ring.
- Qed.
-
- Lemma carry_length : forall i us,
- (length us <= length base)%nat ->
- (length (carry i us) <= length base)%nat.
- Proof.
- unfold carry, carry_simple, carry_and_reduce, add_to_nth.
- intros; break_if; subst; repeat (rewrite length_set_nth); auto.
- Qed.
- Hint Resolve carry_length.
-
- Lemma carry_rep : forall i us x,
- (length us = length base) ->
- (i < length base)%nat ->
- us ~= x -> carry i us ~= x.
- Proof.
- pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq.
- unfold rep, decode, carry in *; intros.
- intuition; break_if; subst; eauto;
- apply F_eq; simpl; intuition.
- Qed.
- Hint Resolve carry_rep.
-
Definition carry_sequence is us := fold_right carry us is.
- Lemma carry_sequence_length: forall is us,
- (length us <= length base)%nat ->
- (length (carry_sequence is us) <= length base)%nat.
- Proof.
- induction is; boring.
- Qed.
- Hint Resolve carry_sequence_length.
-
- Lemma carry_length_exact : forall i us,
- (length us = length base)%nat ->
- (length (carry i us) = length base)%nat.
- Proof.
- unfold carry, carry_simple, carry_and_reduce, add_to_nth.
- intros; break_if; subst; repeat (rewrite length_set_nth); auto.
- Qed.
-
- Lemma carry_sequence_length_exact: forall is us,
- (length us = length base)%nat ->
- (length (carry_sequence is us) = length base)%nat.
- Proof.
- induction is; boring.
- apply carry_length_exact; auto.
- Qed.
- Hint Resolve carry_sequence_length_exact.
-
- Lemma carry_sequence_rep : forall is us x,
- (forall i, In i is -> (i < length base)%nat) ->
- (length us = length base) ->
- us ~= x -> carry_sequence is us ~= x.
- Proof.
- induction is; boring.
- Qed.
-End PseudoMersenneBase.
+End CarryBasePow2.
diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v
new file mode 100644
index 000000000..524c2da27
--- /dev/null
+++ b/src/ModularArithmetic/ModularBaseSystemProofs.v
@@ -0,0 +1,449 @@
+Require Import Zpower ZArith.
+Require Import List.
+Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import VerdiTactics.
+Require Crypto.BaseSystem.
+Require Import Crypto.ModularArithmetic.ModularBaseSystem Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import Crypto.BaseSystemProofs Crypto.ModularArithmetic.PseudoMersenneBaseParams Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs Crypto.ModularArithmetic.ExtendedBaseVector.
+Local Open Scope Z_scope.
+
+Section PseudoMersenneProofs.
+ Context `{prm :PseudoMersenneBaseParams}.
+ Existing Instance bv.
+
+ Local Hint Unfold decode.
+ Local Notation "u '~=' x" := (rep u x) (at level 70).
+ Local Notation "u '.+' x" := (add u x) (at level 70).
+ Local Notation "u '.*' x" := (ModularBaseSystem.mul u x) (at level 70).
+ Local Hint Unfold rep.
+
+ Lemma rep_decode : forall us x, us ~= x -> decode us = x.
+ Proof.
+ autounfold; intuition.
+ Qed.
+
+ Lemma encode_rep : forall x : F modulus, encode x ~= x.
+ Proof.
+ intros. unfold encode, rep.
+ split. {
+ unfold encode; simpl.
+ apply base_length_nonzero.
+ } {
+ unfold decode.
+ rewrite encode_rep.
+ apply ZToField_FieldToZ.
+ apply bv.
+ }
+ Qed.
+
+ Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.add u v ~= (x+y)%F.
+ Proof.
+ autounfold; intuition. {
+ unfold add.
+ rewrite add_length_le_max.
+ case_max; try rewrite Max.max_r; omega.
+ }
+ unfold decode in *; unfold decode in *.
+ rewrite add_rep.
+ rewrite ZToField_add.
+ subst; auto.
+ Qed.
+
+ Lemma sub_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.sub u v ~= (x-y)%F.
+ Proof.
+ autounfold; intuition. {
+ rewrite sub_length_le_max.
+ case_max; try rewrite Max.max_r; omega.
+ }
+ unfold decode in *; unfold BaseSystem.decode in *.
+ rewrite sub_rep.
+ rewrite ZToField_sub.
+ subst; auto.
+ Qed.
+
+ Lemma decode_short : forall (us : BaseSystem.digits),
+ (length us <= length base)%nat ->
+ BaseSystem.decode base us = BaseSystem.decode ext_base us.
+ Proof.
+ intros.
+ unfold BaseSystem.decode, BaseSystem.decode'.
+ rewrite combine_truncate_r.
+ rewrite (combine_truncate_r us ext_base).
+ f_equal; f_equal.
+ unfold ext_base.
+ rewrite firstn_app_inleft; auto; omega.
+ Qed.
+
+ Lemma mul_rep_extended : forall (us vs : BaseSystem.digits),
+ (length us <= length base)%nat ->
+ (length vs <= length base)%nat ->
+ (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode ext_base (BaseSystem.mul ext_base us vs).
+ Proof.
+ intros.
+ rewrite mul_rep by (apply ExtBaseVector || unfold ext_base; simpl_list; omega).
+ f_equal; rewrite decode_short; auto.
+ Qed.
+
+ Lemma modulus_nonzero : modulus <> 0.
+ pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega.
+ Qed.
+
+ (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *)
+ Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus.
+ Proof.
+ intros.
+ replace (2^k) with ((2^k - c) + c) by ring.
+ rewrite Z.mul_add_distr_r.
+ rewrite Zplus_mod.
+ rewrite <- modulus_pseudomersenne.
+ rewrite Z.mul_comm.
+ rewrite mod_mult_plus; auto using modulus_nonzero.
+ rewrite <- Zplus_mod; auto.
+ Qed.
+
+ Lemma extended_shiftadd: forall (us : BaseSystem.digits),
+ BaseSystem.decode ext_base us =
+ BaseSystem.decode base (firstn (length base) us)
+ + (2^k * BaseSystem.decode base (skipn (length base) us)).
+ Proof.
+ intros.
+ unfold BaseSystem.decode; rewrite <- mul_each_rep.
+ unfold ext_base.
+ replace (map (Z.mul (2 ^ k)) base) with (BaseSystem.mul_each (2 ^ k) base) by auto.
+ rewrite base_mul_app.
+ rewrite <- mul_each_rep; auto.
+ Qed.
+
+ Lemma reduce_rep : forall us,
+ BaseSystem.decode base (reduce us) mod modulus =
+ BaseSystem.decode ext_base us mod modulus.
+ Proof.
+ intros.
+ rewrite extended_shiftadd.
+ rewrite pseudomersenne_add.
+ unfold reduce.
+ remember (firstn (length base) us) as low.
+ remember (skipn (length base) us) as high.
+ unfold BaseSystem.decode.
+ rewrite BaseSystemProofs.add_rep.
+ replace (map (Z.mul c) high) with (BaseSystem.mul_each c high) by auto.
+ rewrite mul_each_rep; auto.
+ Qed.
+
+ Lemma reduce_length : forall us,
+ (length us <= length ext_base)%nat ->
+ (length (reduce us) <= length base)%nat.
+ Proof.
+ intros.
+ unfold reduce.
+ remember (map (Z.mul c) (skipn (length base) us)) as high.
+ remember (firstn (length base) us) as low.
+ assert (length low >= length high)%nat. {
+ subst. rewrite firstn_length.
+ rewrite map_length.
+ rewrite skipn_length.
+ destruct (le_dec (length base) (length us)). {
+ rewrite Min.min_l by omega.
+ rewrite extended_base_length in H. omega.
+ } {
+ rewrite Min.min_r; omega.
+ }
+ }
+ assert ((length low <= length base)%nat)
+ by (rewrite Heqlow; rewrite firstn_length; apply Min.le_min_l).
+ assert (length high <= length base)%nat
+ by (rewrite Heqhigh; rewrite map_length; rewrite skipn_length;
+ rewrite extended_base_length in H; omega).
+ rewrite add_trailing_zeros; auto.
+ rewrite (add_same_length _ _ (length low)); auto.
+ rewrite app_length.
+ rewrite length_zeros; intuition.
+ Qed.
+
+ Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> u .* v ~= (x*y)%F.
+ Proof.
+ autounfold; unfold ModularBaseSystem.mul; intuition.
+ {
+ apply reduce_length.
+ rewrite mul_length, extended_base_length.
+ omega.
+ } {
+ rewrite ZToField_mod, reduce_rep, <-ZToField_mod.
+ rewrite mul_rep by
+ (apply ExtBaseVector || rewrite extended_base_length; omega).
+ subst.
+ do 2 rewrite decode_short by auto.
+ apply ZToField_mul.
+ }
+ Qed.
+
+ (* TODO: move to BaseSystemProofs *)
+ Lemma decode'_splice : forall xs ys bs,
+ BaseSystem.decode' bs (xs ++ ys) =
+ BaseSystem.decode' (firstn (length xs) bs) xs +
+ BaseSystem.decode' (skipn (length xs) bs) ys.
+ Proof.
+ unfold BaseSystem.decode'.
+ induction xs; destruct ys, bs; boring.
+ + rewrite combine_truncate_r.
+ do 2 rewrite Z.add_0_r; auto.
+ + unfold BaseSystem.accumulate.
+ apply Z.add_assoc.
+ Qed.
+
+ Lemma set_nth_sum : forall n x us, (n < length us)%nat ->
+ BaseSystem.decode base (set_nth n x us) =
+ (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us.
+ Proof.
+ intros.
+ unfold BaseSystem.decode.
+ nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*)
+ unfold splice_nth.
+ rewrite <- (firstn_skipn n us) at 4.
+ do 2 rewrite decode'_splice.
+ remember (length (firstn n us)) as n0.
+ ring_simplify.
+ remember (BaseSystem.decode' (firstn n0 base) (firstn n us)).
+ rewrite (skipn_nth_default n us 0) by omega.
+ rewrite firstn_length in Heqn0.
+ rewrite Min.min_l in Heqn0 by omega; subst n0.
+ destruct (le_lt_dec (length base) n). {
+ rewrite nth_default_out_of_bounds by auto.
+ rewrite skipn_all by omega.
+ do 2 rewrite decode_base_nil.
+ ring_simplify; auto.
+ } {
+ rewrite (skipn_nth_default n base 0) by omega.
+ do 2 rewrite decode'_cons.
+ ring_simplify; ring.
+ }
+ Qed.
+
+ Lemma add_to_nth_sum : forall n x us, (n < length us)%nat ->
+ BaseSystem.decode base (add_to_nth n x us) =
+ x * nth_default 0 base n + BaseSystem.decode base us.
+ Proof.
+ unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto.
+ Qed.
+
+ Lemma nth_default_base_positive : forall i, (i < length base)%nat ->
+ nth_default 0 base i > 0.
+ Proof.
+ intros.
+ pose proof (nth_error_length_exists_value _ _ H).
+ destruct H0.
+ pose proof (nth_error_value_In _ _ _ H0).
+ pose proof (BaseSystem.base_positive _ H1).
+ unfold nth_default.
+ rewrite H0; auto.
+ Qed.
+
+ Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat ->
+ nth_default 0 base (S i) = nth_default 0 base i *
+ (nth_default 0 base (S i) / nth_default 0 base i).
+ Proof.
+ intros.
+ apply Z_div_exact_2; try (apply nth_default_base_positive; omega).
+ apply base_succ; auto.
+ Qed.
+
+End PseudoMersenneProofs.
+
+Section CarryProofs.
+ Context `{prm : PseudoMersenneBaseParams}.
+ Existing Instance bv.
+ Local Notation "u '~=' x" := (rep u x) (at level 70).
+ Hint Unfold log_cap.
+
+ Lemma base_length_lt_pred : (pred (length base) < length base)%nat.
+ Proof.
+ pose proof base_length_nonzero; omega.
+ Qed.
+ Hint Resolve base_length_lt_pred.
+
+ Lemma log_cap_nonneg : forall i, 0 <= log_cap i.
+ Proof.
+ unfold log_cap, nth_default; intros.
+ case_eq (nth_error limb_widths i); intros; try omega.
+ apply limb_widths_nonneg.
+ eapply nth_error_value_In; eauto.
+ Qed.
+
+ (* TODO : move to ZUtil *)
+ Lemma div_pow2succ : forall n x, (0 <= x) ->
+ n / 2 ^ Z.succ x = Z.div2 (n / 2 ^ x).
+ Proof.
+ intros.
+ rewrite Z.pow_succ_r, Z.mul_comm by auto.
+ rewrite <- Z.div_div by (try apply Z.pow_nonzero; omega).
+ rewrite Zdiv2_div.
+ reflexivity.
+ Qed.
+
+ (* TODO: move to ZUtil *)
+ Lemma shiftr_succ : forall n x,
+ Z.shiftr n (Z.succ x) = Z.shiftr (Z.shiftr n x) 1.
+ Proof.
+ intros.
+ rewrite Z.shiftr_shiftr by omega.
+ reflexivity.
+ Qed.
+
+ (* TODO : move to ZUtil *)
+ Lemma shiftr_div : forall n i, (0 <= i) -> Z.shiftr n i = n / (2 ^ i).
+ Proof.
+ intro.
+ apply natlike_ind; intros; [boring|].
+ rewrite div_pow2succ by auto.
+ rewrite shiftr_succ.
+ unfold Z.shiftr.
+ simpl; f_equal.
+ auto.
+ Qed.
+
+ (* TODO : move to ListUtil *)
+ Lemma nth_error_Some_nth_default : forall {T} i x (l : list T), (i < length l)%nat ->
+ nth_error l i = Some (nth_default x l i).
+ Proof.
+ intros ? ? ? ? i_lt_length.
+ destruct (nth_error_length_exists_value _ _ i_lt_length) as [k nth_err_k].
+ unfold nth_default.
+ rewrite nth_err_k.
+ reflexivity.
+ Qed.
+
+ Lemma nth_default_base_succ : forall i, (S i < length base)%nat ->
+ nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i.
+ Proof.
+ intros.
+ repeat rewrite nth_default_base by omega.
+ rewrite <- Z.pow_add_r by (apply log_cap_nonneg || apply sum_firstn_limb_widths_nonneg).
+ destruct (NPeano.Nat.eq_dec i 0).
+ + subst; f_equal.
+ unfold sum_firstn, log_cap.
+ destruct limb_widths; auto.
+ + erewrite sum_firstn_succ; eauto.
+ unfold log_cap.
+ apply nth_error_Some_nth_default.
+ rewrite <- base_length; omega.
+ Qed.
+
+ Lemma carry_simple_decode_eq : forall i us,
+ (length us = length base) ->
+ (i < (pred (length base)))%nat ->
+ BaseSystem.decode base (carry_simple i us) = BaseSystem.decode base us.
+ Proof.
+ unfold carry_simple. intros.
+ rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
+ rewrite set_nth_sum by omega.
+ unfold pow2_mod.
+ rewrite Z.land_ones by apply log_cap_nonneg.
+ rewrite shiftr_div by apply log_cap_nonneg.
+ rewrite nth_default_base_succ by omega.
+ rewrite Z.mul_assoc.
+ rewrite (Z.mul_comm _ (2 ^ log_cap i)).
+ rewrite mul_div_eq; try ring.
+ apply gt_lt_symmetry.
+ apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg.
+ Qed.
+
+ Lemma carry_decode_eq_reduce : forall us,
+ (length us = length base) ->
+ BaseSystem.decode base (carry_and_reduce (pred (length base)) us) mod modulus
+ = BaseSystem.decode base us mod modulus.
+ Proof.
+ unfold carry_and_reduce; intros ? length_eq.
+ pose proof base_length_nonzero.
+ rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
+ rewrite set_nth_sum by omega.
+ rewrite Zplus_comm, <- Z.mul_assoc, <- pseudomersenne_add, BaseSystem.b0_1.
+ rewrite (Z.mul_comm (2 ^ k)), <- Zred_factor0.
+ f_equal.
+ rewrite <- (Z.add_comm (BaseSystem.decode base us)), <- Z.add_assoc, <- Z.add_0_r.
+ f_equal.
+ destruct (NPeano.Nat.eq_dec (length base) 0) as [length_zero | length_nonzero].
+ + apply length0_nil in length_zero.
+ pose proof (base_length) as limbs_length.
+ rewrite length_zero in length_eq, limbs_length.
+ apply length0_nil in length_eq.
+ symmetry in limbs_length.
+ apply length0_nil in limbs_length.
+ unfold log_cap.
+ subst; rewrite length_zero, limbs_length, nth_default_nil.
+ reflexivity.
+ + rewrite nth_default_base by omega.
+ rewrite <- Z.add_opp_l, <- Z.opp_sub_distr.
+ unfold pow2_mod.
+ rewrite Z.land_ones by apply log_cap_nonneg.
+ rewrite <- mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg).
+ rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg.
+ rewrite Zopp_mult_distr_r.
+ rewrite Z.mul_comm.
+ rewrite Z.mul_assoc.
+ rewrite <- Z.pow_add_r by (apply log_cap_nonneg || apply sum_firstn_limb_widths_nonneg).
+ rewrite <- k_matches_limb_widths.
+ replace (length limb_widths) with (S (pred (length base))) by
+ (subst; rewrite <- base_length; apply NPeano.Nat.succ_pred; omega).
+ rewrite sum_firstn_succ with (x:= log_cap (pred (length base))) by
+ (unfold log_cap; apply nth_error_Some_nth_default; rewrite <- base_length; omega).
+ rewrite <- Zopp_mult_distr_r.
+ rewrite Z.mul_comm.
+ rewrite (Z.add_comm (log_cap (pred (length base)))).
+ ring.
+ Qed.
+
+ Lemma carry_length : forall i us,
+ (length us <= length base)%nat ->
+ (length (carry i us) <= length base)%nat.
+ Proof.
+ unfold carry, carry_simple, carry_and_reduce, add_to_nth.
+ intros; break_if; subst; repeat (rewrite length_set_nth); auto.
+ Qed.
+ Hint Resolve carry_length.
+
+ Lemma carry_rep : forall i us x,
+ (length us = length base) ->
+ (i < length base)%nat ->
+ us ~= x -> carry i us ~= x.
+ Proof.
+ pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq.
+ unfold rep, decode, carry in *; intros.
+ intuition; break_if; subst; eauto;
+ apply F_eq; simpl; intuition.
+ Qed.
+ Hint Resolve carry_rep.
+
+ Lemma carry_sequence_length: forall is us,
+ (length us <= length base)%nat ->
+ (length (carry_sequence is us) <= length base)%nat.
+ Proof.
+ induction is; boring.
+ Qed.
+ Hint Resolve carry_sequence_length.
+
+ Lemma carry_length_exact : forall i us,
+ (length us = length base)%nat ->
+ (length (carry i us) = length base)%nat.
+ Proof.
+ unfold carry, carry_simple, carry_and_reduce, add_to_nth.
+ intros; break_if; subst; repeat (rewrite length_set_nth); auto.
+ Qed.
+
+ Lemma carry_sequence_length_exact: forall is us,
+ (length us = length base)%nat ->
+ (length (carry_sequence is us) = length base)%nat.
+ Proof.
+ induction is; boring.
+ apply carry_length_exact; auto.
+ Qed.
+ Hint Resolve carry_sequence_length_exact.
+
+ Lemma carry_sequence_rep : forall is us x,
+ (forall i, In i is -> (i < length base)%nat) ->
+ (length us = length base) ->
+ us ~= x -> carry_sequence is us ~= x.
+ Proof.
+ induction is; boring.
+ Qed.
+
+End CarryProofs.
diff --git a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
new file mode 100644
index 000000000..847b8e85f
--- /dev/null
+++ b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
@@ -0,0 +1,277 @@
+Require Import Zpower ZArith.
+Require Import List.
+Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import VerdiTactics.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
+Require Crypto.BaseSystem.
+Local Open Scope Z_scope.
+
+Section PseudoMersenneBaseParamProofs.
+ Context `{prm : PseudoMersenneBaseParams}.
+
+ Fixpoint base_from_limb_widths limb_widths :=
+ match limb_widths with
+ | nil => nil
+ | w :: lw => 1 :: map (Z.mul (two_p w)) (base_from_limb_widths lw)
+ end.
+
+ Definition base := base_from_limb_widths limb_widths.
+
+ Lemma base_length : length base = length limb_widths.
+ Proof.
+ unfold base.
+ induction limb_widths; try reflexivity.
+ simpl; rewrite map_length; auto.
+ Qed.
+
+ Lemma nth_error_first : forall {T} (a b : T) l, nth_error (a :: l) 0 = Some b ->
+ a = b.
+ Proof.
+ intros; simpl in *.
+ unfold value in *.
+ congruence.
+ Qed.
+
+ Lemma base_from_limb_widths_step : forall i b w, (S i < length base)%nat ->
+ nth_error base i = Some b ->
+ nth_error limb_widths i = Some w ->
+ nth_error base (S i) = Some (two_p w * b).
+ Proof.
+ unfold base; induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b;
+ unfold base_from_limb_widths in *; fold base_from_limb_widths in *;
+ [rewrite (@nil_length0 Z) in *; omega | ].
+ simpl in *; rewrite map_length in *.
+ case_eq i; intros; subst.
+ + subst; apply nth_error_first in nth_err_w.
+ apply nth_error_first in nth_err_b; subst.
+ apply map_nth_error.
+ case_eq l; intros; subst; [simpl in *; omega | ].
+ unfold base_from_limb_widths; fold base_from_limb_widths.
+ reflexivity.
+ + simpl in nth_err_w.
+ apply nth_error_map in nth_err_w.
+ destruct nth_err_w as [x [A B]].
+ subst.
+ replace (two_p w * (two_p a * x)) with (two_p a * (two_p w * x)) by ring.
+ apply map_nth_error.
+ apply IHl; auto; omega.
+ Qed.
+
+ Lemma nth_error_exists_first : forall {T} l (x : T) (H : nth_error l 0 = Some x),
+ exists l', l = x :: l'.
+ Proof.
+ induction l; try discriminate; eexists.
+ apply nth_error_first in H.
+ subst; eauto.
+ Qed.
+
+ Lemma sum_firstn_succ : forall l i x,
+ nth_error l i = Some x ->
+ sum_firstn l (S i) = x + sum_firstn l i.
+ Proof.
+ unfold sum_firstn; induction l;
+ [intros; rewrite (@nth_error_nil_error Z) in *; congruence | ].
+ intros ? x nth_err_x; destruct (NPeano.Nat.eq_dec i 0).
+ + subst; simpl in *; unfold value in *.
+ congruence.
+ + rewrite <- (NPeano.Nat.succ_pred i) at 2 by auto.
+ rewrite <- (NPeano.Nat.succ_pred i) in nth_err_x by auto.
+ simpl. simpl in nth_err_x.
+ specialize (IHl (pred i) x).
+ rewrite NPeano.Nat.succ_pred in IHl by auto.
+ destruct (NPeano.Nat.eq_dec (pred i) 0).
+ - replace i with 1%nat in * by omega.
+ simpl. replace (pred 1) with 0%nat in * by auto.
+ apply nth_error_exists_first in nth_err_x.
+ destruct nth_err_x as [l' ?].
+ subst; simpl; ring.
+ - rewrite IHl by auto; ring.
+ Qed.
+
+ (* TODO : move to LsitUtil *)
+ Lemma fold_right_invariant : forall {A} P (f: A -> A -> A) l x,
+ P x -> (forall y, In y l -> forall z, P z -> P (f y z)) ->
+ P (fold_right f x l).
+ Proof.
+ induction l; intros ? ? step; auto.
+ simpl.
+ apply step; try apply in_eq.
+ apply IHl; auto.
+ intros y in_y_l.
+ apply (in_cons a) in in_y_l.
+ auto.
+ Qed.
+
+ (* TODO : move to ListUtil *)
+ Lemma In_firstn : forall {T} n l (x : T), In x (firstn n l) -> In x l.
+ Proof.
+ induction n; destruct l; boring.
+ Qed.
+
+ Lemma sum_firstn_limb_widths_nonneg : forall n, 0 <= sum_firstn limb_widths n.
+ Proof.
+ unfold sum_firstn; intros.
+ apply fold_right_invariant; try omega.
+ intros y In_y_lw ? ?.
+ apply Z.add_nonneg_nonneg; try assumption.
+ apply limb_widths_nonneg.
+ eapply In_firstn; eauto.
+ Qed.
+
+ Lemma k_nonneg : 0 <= k.
+ Proof.
+ rewrite <- k_matches_limb_widths.
+ apply sum_firstn_limb_widths_nonneg.
+ Qed.
+
+ Lemma nth_error_base : forall i, (i < length base)%nat ->
+ nth_error base i = Some (two_p (sum_firstn limb_widths i)).
+ Proof.
+ induction i; intros.
+ + unfold base, sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity.
+ intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega.
+ +
+ assert (i < length base)%nat as lt_i_length by omega.
+ specialize (IHi lt_i_length).
+ rewrite base_length in lt_i_length.
+ destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w].
+ erewrite base_from_limb_widths_step; eauto.
+ f_equal.
+ simpl.
+ destruct (NPeano.Nat.eq_dec i 0).
+ - subst; unfold sum_firstn; simpl.
+ apply nth_error_exists_first in nth_err_w.
+ destruct nth_err_w as [l' lw_destruct]; subst.
+ rewrite lw_destruct.
+ ring_simplify.
+ f_equal; simpl; ring.
+ - erewrite sum_firstn_succ; eauto.
+ symmetry.
+ apply two_p_is_exp; auto using sum_firstn_limb_widths_nonneg.
+ apply limb_widths_nonneg.
+ eapply nth_error_value_In; eauto.
+ Qed.
+
+ Lemma nth_default_base : forall d i, (i < length base)%nat ->
+ nth_default d base i = 2 ^ (sum_firstn limb_widths i).
+ Proof.
+ intros ? ? i_lt_length.
+ destruct (nth_error_length_exists_value _ _ i_lt_length) as [x nth_err_x].
+ unfold nth_default.
+ rewrite nth_err_x.
+ rewrite nth_error_base in nth_err_x by assumption.
+ rewrite two_p_correct in nth_err_x.
+ congruence.
+ Qed.
+
+
+ (* TODO : move to ZUtil *)
+ Lemma mod_same_pow : forall a b c, 0 <= c <= b -> a ^ b mod a ^ c = 0.
+ Proof.
+ intros.
+ replace b with (b - c + c) by ring.
+ rewrite Z.pow_add_r by omega.
+ apply Z_mod_mult.
+ Qed.
+
+ Lemma base_matches_modulus: forall i j,
+ (i < length base)%nat ->
+ (j < length base)%nat ->
+ (i+j >= length base)%nat->
+ let b := nth_default 0 base in
+ let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in
+ b i * b j = r * (2^k * b (i+j-length base)%nat).
+ Proof.
+ intros.
+ rewrite (Z.mul_comm r).
+ subst r.
+ assert (i + j - length base < length base)%nat by omega.
+ rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.mul_pos_pos;
+ [ | subst b; rewrite nth_default_base; try assumption ];
+ apply Z.pow_pos_nonneg; omega || apply k_nonneg || apply sum_firstn_limb_widths_nonneg).
+ rewrite (Zminus_0_l_reverse (b i * b j)) at 1.
+ f_equal.
+ subst b.
+ repeat rewrite nth_default_base by assumption.
+ do 2 rewrite <- Z.pow_add_r by (apply sum_firstn_limb_widths_nonneg || apply k_nonneg).
+ symmetry.
+ apply mod_same_pow.
+ split.
+ + apply Z.add_nonneg_nonneg; apply sum_firstn_limb_widths_nonneg || apply k_nonneg.
+ + rewrite base_length in *; apply limb_widths_match_modulus; assumption.
+ Qed.
+
+ Lemma base_succ : forall i, ((S i) < length base)%nat ->
+ nth_default 0 base (S i) mod nth_default 0 base i = 0.
+ Proof.
+ intros.
+ repeat rewrite nth_default_base by omega.
+ apply mod_same_pow.
+ split; [apply sum_firstn_limb_widths_nonneg | ].
+ destruct (NPeano.Nat.eq_dec i 0); subst.
+ + case_eq limb_widths; intro; unfold sum_firstn; simpl; try omega; intros l' lw_eq.
+ apply Z.add_nonneg_nonneg; try omega.
+ apply limb_widths_nonneg.
+ rewrite lw_eq.
+ apply in_eq.
+ + assert (i < length base)%nat as i_lt_length by omega.
+ rewrite base_length in *.
+ apply nth_error_length_exists_value in i_lt_length.
+ destruct i_lt_length as [x nth_err_x].
+ erewrite sum_firstn_succ; eauto.
+ apply nth_error_value_In in nth_err_x.
+ apply limb_widths_nonneg in nth_err_x.
+ omega.
+ Qed.
+
+ Lemma nth_error_subst : forall i b, nth_error base i = Some b ->
+ b = 2 ^ (sum_firstn limb_widths i).
+ Proof.
+ intros i b nth_err_b.
+ pose proof (nth_error_value_length _ _ _ _ nth_err_b).
+ rewrite nth_error_base in nth_err_b by assumption.
+ rewrite two_p_correct in nth_err_b.
+ congruence.
+ Qed.
+
+ Lemma base_positive : forall b : Z, In b base -> b > 0.
+ Proof.
+ intros b In_b_base.
+ apply In_nth_error_value in In_b_base.
+ destruct In_b_base as [i nth_err_b].
+ apply nth_error_subst in nth_err_b.
+ rewrite nth_err_b.
+ apply gt_lt_symmetry.
+ apply Z.pow_pos_nonneg; omega || apply sum_firstn_limb_widths_nonneg.
+ Qed.
+
+ Lemma b0_1 : forall x : Z, nth_default x base 0 = 1.
+ Proof.
+ unfold base; case_eq limb_widths; intros; [pose proof limb_widths_nonnil; congruence | reflexivity].
+ Qed.
+
+ Lemma base_good : forall i j : nat,
+ (i + j < length base)%nat ->
+ let b := nth_default 0 base in
+ let r := b i * b j / b (i + j)%nat in
+ b i * b j = r * b (i + j)%nat.
+ Proof.
+ intros; subst b r.
+ repeat rewrite nth_default_base by omega.
+ rewrite (Z.mul_comm _ (2 ^ (sum_firstn limb_widths (i+j)))).
+ rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply sum_firstn_limb_widths_nonneg).
+ rewrite <- Z.pow_add_r by apply sum_firstn_limb_widths_nonneg.
+ rewrite mod_same_pow; try ring.
+ split; [ apply sum_firstn_limb_widths_nonneg | ].
+ apply limb_widths_good.
+ rewrite <- base_length; assumption.
+ Qed.
+
+ Instance bv : BaseSystem.BaseVector base := {
+ base_positive := base_positive;
+ b0_1 := b0_1;
+ base_good := base_good
+ }.
+
+End PseudoMersenneBaseParamProofs.
diff --git a/src/ModularArithmetic/PseudoMersenneBaseParams.v b/src/ModularArithmetic/PseudoMersenneBaseParams.v
new file mode 100644
index 000000000..122cac0ab
--- /dev/null
+++ b/src/ModularArithmetic/PseudoMersenneBaseParams.v
@@ -0,0 +1,26 @@
+Require Import ZArith.
+Require Import List.
+Require Crypto.BaseSystem.
+Local Open Scope Z_scope.
+
+Definition sum_firstn l n := fold_right Z.add 0 (firstn n l).
+
+Class PseudoMersenneBaseParams (modulus : Z) := {
+ limb_widths : list Z;
+ limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w;
+ limb_widths_nonnil : limb_widths <> nil;
+ limb_widths_good : forall i j, (i + j < length limb_widths)%nat ->
+ sum_firstn limb_widths (i + j) <=
+ sum_firstn limb_widths i + sum_firstn limb_widths j;
+ k : Z;
+ c : Z;
+ modulus_pseudomersenne : modulus = 2^k - c;
+ prime_modulus : Znumtheory.prime modulus;
+ limb_widths_match_modulus : forall i j,
+ (i < length limb_widths)%nat ->
+ (j < length limb_widths)%nat ->
+ (i + j >= length limb_widths)%nat ->
+ let w_sum := sum_firstn limb_widths in
+ k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j;
+ k_matches_limb_widths : sum_firstn limb_widths (length limb_widths) = k
+}.
diff --git a/src/ModularArithmetic/PseudoMersenneBaseRep.v b/src/ModularArithmetic/PseudoMersenneBaseRep.v
new file mode 100644
index 000000000..2cc12b933
--- /dev/null
+++ b/src/ModularArithmetic/PseudoMersenneBaseRep.v
@@ -0,0 +1,43 @@
+Require Import ZArith.
+Require Crypto.BaseSystem.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs Crypto.ModularArithmetic.PseudoMersenneBaseParams.
+Local Open Scope Z_scope.
+
+Class RepZMod (modulus : Z) := {
+ T : Type;
+ encode : F modulus -> T;
+ decode : T -> F modulus;
+
+ rep : T -> F modulus -> Prop;
+ encode_rep : forall x, rep (encode x) x;
+ rep_decode : forall u x, rep u x -> decode u = x;
+
+ add : T -> T -> T;
+ add_rep : forall u v x y, rep u x -> rep v y -> rep (add u v) (x+y)%F;
+
+ sub : T -> T -> T;
+ sub_rep : forall u v x y, rep u x -> rep v y -> rep (sub u v) (x-y)%F;
+
+ mul : T -> T -> T;
+ mul_rep : forall u v x y, rep u x -> rep v y -> rep (mul u v) (x*y)%F
+}.
+
+Instance PseudoMersenneBase m (prm : PseudoMersenneBaseParams m) : RepZMod m := {
+ T := list Z;
+ encode := ModularBaseSystem.encode;
+ decode := ModularBaseSystem.decode;
+
+ rep := ModularBaseSystem.rep;
+ encode_rep := ModularBaseSystemProofs.encode_rep;
+ rep_decode := ModularBaseSystemProofs.rep_decode;
+
+ add := BaseSystem.add;
+ add_rep := ModularBaseSystemProofs.add_rep;
+
+ sub := BaseSystem.sub;
+ sub_rep := ModularBaseSystemProofs.sub_rep;
+
+ mul := ModularBaseSystem.mul;
+ mul_rep := ModularBaseSystemProofs.mul_rep
+}.