aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--_CoqProject6
-rw-r--r--src/BaseSystem.v470
-rw-r--r--src/BaseSystemProofs.v490
-rw-r--r--src/ModularArithmetic/ExtendedBaseVector.v163
-rw-r--r--src/ModularArithmetic/ModularBaseSystem.v631
-rw-r--r--src/ModularArithmetic/ModularBaseSystemProofs.v449
-rw-r--r--src/ModularArithmetic/PseudoMersenneBaseParamProofs.v277
-rw-r--r--src/ModularArithmetic/PseudoMersenneBaseParams.v26
-rw-r--r--src/ModularArithmetic/PseudoMersenneBaseRep.v43
9 files changed, 1480 insertions, 1075 deletions
diff --git a/_CoqProject b/_CoqProject
index d7c71f7cb..cae81fdae 100644
--- a/_CoqProject
+++ b/_CoqProject
@@ -1,5 +1,6 @@
-R src Crypto
src/BaseSystem.v
+src/BaseSystemProofs.v
src/BoundedIterOp.v
src/EdDSAProofs.v
src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v
@@ -7,12 +8,17 @@ src/CompleteEdwardsCurve/DoubleAndAdd.v
src/CompleteEdwardsCurve/ExtendedCoordinates.v
src/CompleteEdwardsCurve/Pre.v
src/Encoding/EncodingTheorems.v
+src/ModularArithmetic/ExtendedBaseVector.v
src/ModularArithmetic/FField.v
src/ModularArithmetic/FNsatz.v
src/ModularArithmetic/ModularArithmeticTheorems.v
src/ModularArithmetic/ModularBaseSystem.v
+src/ModularArithmetic/ModularBaseSystemProofs.v
src/ModularArithmetic/Pre.v
src/ModularArithmetic/PrimeFieldTheorems.v
+src/ModularArithmetic/PseudoMersenneBaseParams.v
+src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
+src/ModularArithmetic/PseudoMersenneBaseRep.v
src/ModularArithmetic/Tutorial.v
src/Spec/CompleteEdwardsCurve.v
src/Spec/Ed25519.v
diff --git a/src/BaseSystem.v b/src/BaseSystem.v
index e9e4bc9d4..f3a1a0fdb 100644
--- a/src/BaseSystem.v
+++ b/src/BaseSystem.v
@@ -16,7 +16,7 @@ Class BaseVector (base : list Z):= {
}.
Section BaseSystem.
- Context (base : list Z) (base_vector : BaseVector base).
+ Context (base : list Z).
(** [BaseSystem] implements an constrained positional number system.
A wide variety of bases are supported: the base coefficients are not
required to be powers of 2, and it is NOT necessarily the case that
@@ -32,7 +32,6 @@ Section BaseSystem.
Definition accumulate p acc := fst p * snd p + acc.
Definition decode' bs u := fold_right accumulate 0 (combine u bs).
Definition decode := decode' base.
- Hint Unfold accumulate.
(* Does not carry; z becomes the lowest and only digit. *)
Definition encode (z : Z) := z :: nil.
@@ -51,50 +50,7 @@ Section BaseSystem.
Hint Extern 1 (@eq Z _ _) => ring.
- Lemma add_rep : forall bs us vs, decode' bs (add us vs) = decode' bs us + decode' bs vs.
- Proof.
- unfold decode, decode'; induction bs; destruct us; destruct vs; boring.
- Qed.
-
- Lemma decode_nil : forall bs, decode' bs nil = 0.
- auto.
- Qed.
- Hint Rewrite decode_nil.
-
- Lemma decode_base_nil : forall us, decode' nil us = 0.
- Proof.
- intros; rewrite decode'_truncate; auto.
- Qed.
- Hint Rewrite decode_base_nil.
-
Definition mul_each u := map (Z.mul u).
- Lemma mul_each_rep : forall bs u vs,
- decode' bs (mul_each u vs) = u * decode' bs vs.
- Proof.
- unfold decode'; induction bs; destruct vs; boring.
- Qed.
-
- Lemma base_eq_1cons: base = 1 :: skipn 1 base.
- Proof.
- pose proof (b0_1 0) as H.
- destruct base; compute in H; try discriminate; boring.
- Qed.
-
- Lemma decode'_cons : forall x1 x2 xs1 xs2,
- decode' (x1 :: xs1) (x2 :: xs2) = x1 * x2 + decode' xs1 xs2.
- Proof.
- unfold decode'; boring.
- Qed.
- Hint Rewrite decode'_cons.
-
- Lemma decode_cons : forall x us,
- decode (x :: us) = x + decode (0 :: us).
- Proof.
- unfold decode; intros.
- rewrite base_eq_1cons.
- autorewrite with core; ring_simplify; auto.
- Qed.
-
Fixpoint sub (us vs:digits) : digits :=
match us,vs with
| u::us', v::vs' => u-v :: sub us' vs'
@@ -102,76 +58,13 @@ Section BaseSystem.
| nil, v::vs' => (0-v)::sub nil vs'
end.
- Lemma sub_rep : forall bs us vs, decode' bs (sub us vs) = decode' bs us - decode' bs vs.
- Proof.
- induction bs; destruct us; destruct vs; boring.
- Qed.
-
- Lemma encode_rep : forall z, decode (encode z) = z.
- Proof.
- pose proof base_eq_1cons.
- unfold decode, encode; destruct z; boring.
- Qed.
-
- Lemma mul_each_base : forall us bs c,
- decode' bs (mul_each c us) = decode' (mul_each c bs) us.
- Proof.
- induction us; destruct bs; boring.
- Qed.
-
- Hint Rewrite (@nth_default_nil Z).
- Hint Rewrite (@firstn_nil Z).
- Hint Rewrite (@skipn_nil Z).
-
- Lemma base_app : forall us low high,
- decode' (low ++ high) us = decode' low (firstn (length low) us) + decode' high (skipn (length low) us).
- Proof.
- induction us; destruct low; boring.
- Qed.
-
- Lemma base_mul_app : forall low c us,
- decode' (low ++ mul_each c low) us = decode' low (firstn (length low) us) +
- c * decode' low (skipn (length low) us).
- Proof.
- intros.
- rewrite base_app; f_equal.
- rewrite <- mul_each_rep.
- rewrite mul_each_base.
- reflexivity.
- Qed.
-
Definition crosscoef i j : Z :=
let b := nth_default 0 base in
(b(i) * b(j)) / b(i+j)%nat.
Hint Unfold crosscoef.
Fixpoint zeros n := match n with O => nil | S n' => 0::zeros n' end.
- Lemma zeros_rep : forall bs n, decode' bs (zeros n) = 0.
- induction bs; destruct n; boring.
- Qed.
- Lemma length_zeros : forall n, length (zeros n) = n.
- induction n; boring.
- Qed.
- Hint Rewrite length_zeros.
-
- Lemma app_zeros_zeros : forall n m, zeros n ++ zeros m = zeros (n + m).
- Proof.
- induction n; boring.
- Qed.
- Hint Rewrite app_zeros_zeros.
-
- Lemma zeros_app0 : forall m, zeros m ++ 0 :: nil = zeros (S m).
- Proof.
- induction m; boring.
- Qed.
- Hint Rewrite zeros_app0.
-
- Lemma rev_zeros : forall n, rev (zeros n) = zeros n.
- Proof.
- induction n; boring.
- Qed.
- Hint Rewrite rev_zeros.
-
+
(* mul' is multiplication with the SECOND ARGUMENT REVERSED and OUTPUT REVERSED *)
Fixpoint mul_bi' (i:nat) (vsr:digits) :=
match vsr with
@@ -180,317 +73,7 @@ Section BaseSystem.
end.
Definition mul_bi (i:nat) (vs:digits) : digits :=
zeros i ++ rev (mul_bi' i (rev vs)).
-
- Hint Unfold nth_default.
-
- Lemma decode_single : forall n bs x,
- decode' bs (zeros n ++ x :: nil) = nth_default 0 bs n * x.
- Proof.
- induction n; destruct bs; boring.
- Qed.
- Hint Rewrite decode_single.
-
- Lemma peel_decode : forall xs ys x y, decode' (x::xs) (y::ys) = x*y + decode' xs ys.
- Proof.
- boring.
- Qed.
- Hint Rewrite zeros_rep peel_decode.
-
- Lemma decode_highzeros : forall xs bs n, decode' bs (xs ++ zeros n) = decode' bs xs.
- Proof.
- induction xs; destruct bs; boring.
- Qed.
-
- Lemma mul_bi'_zeros : forall n m, mul_bi' n (zeros m) = zeros m.
- induction m; boring.
- Qed.
- Hint Rewrite mul_bi'_zeros.
-
- Lemma nth_error_base_nonzero : forall n x,
- nth_error base n = Some x -> x <> 0.
- Proof.
- eauto using (@nth_error_value_In Z), Zgt0_neq0, base_positive.
- Qed.
-
- Hint Rewrite plus_0_r.
-
- Lemma mul_bi_single : forall m n x,
- (n + m < length base)%nat ->
- decode (mul_bi n (zeros m ++ x :: nil)) = nth_default 0 base m * x * nth_default 0 base n.
- Proof.
- unfold mul_bi, decode.
- destruct m; simpl; simpl_list; simpl; intros. {
- pose proof nth_error_base_nonzero as nth_nonzero.
- case_eq base; [intros; boring | intros z l base_eq].
- specialize (b0_1 0); intro b0_1'.
- rewrite base_eq in *.
- rewrite nth_default_cons in b0_1'.
- rewrite b0_1' in *.
- boring; nth_tac.
- rewrite Z_div_mul'; eauto.
- destruct x; ring.
- } {
- ssimpl_list.
- autorewrite with core.
- rewrite app_assoc.
- autorewrite with core.
- unfold crosscoef; simpl; ring_simplify.
- rewrite Nat.add_1_r.
- rewrite base_good by auto.
- rewrite Z_div_mult by (apply base_positive; rewrite nth_default_eq; apply nth_In; auto).
- rewrite <- Z.mul_assoc.
- rewrite <- Z.mul_comm.
- rewrite <- Z.mul_assoc.
- rewrite <- Z.mul_assoc.
- destruct (Z.eq_dec x 0); subst; try ring.
- rewrite Z.mul_cancel_l by auto.
- rewrite <- base_good by auto.
- ring.
- }
- Qed.
-
- Lemma set_higher' : forall vs x, vs++x::nil = vs .+ (zeros (length vs) ++ x :: nil).
- induction vs; boring; f_equal; ring.
- Qed.
-
- Lemma set_higher : forall bs vs x,
- decode' bs (vs++x::nil) = decode' bs vs + nth_default 0 bs (length vs) * x.
- Proof.
- intros.
- rewrite set_higher'.
- rewrite add_rep.
- f_equal.
- apply decode_single.
- Qed.
-
- Lemma zeros_plus_zeros : forall n, zeros n = zeros n .+ zeros n.
- induction n; auto.
- simpl; f_equal; auto.
- Qed.
-
- Lemma mul_bi'_n_nil : forall n, mul_bi' n nil = nil.
- Proof.
- unfold mul_bi; auto.
- Qed.
- Hint Rewrite mul_bi'_n_nil.
-
- Lemma add_nil_l : forall us, nil .+ us = us.
- induction us; auto.
- Qed.
- Hint Rewrite add_nil_l.
-
- Lemma add_nil_r : forall us, us .+ nil = us.
- induction us; auto.
- Qed.
- Hint Rewrite add_nil_r.
-
- Lemma add_first_terms : forall us vs a b,
- (a :: us) .+ (b :: vs) = (a + b) :: (us .+ vs).
- auto.
- Qed.
- Hint Rewrite add_first_terms.
-
- Lemma mul_bi'_cons : forall n x us,
- mul_bi' n (x :: us) = x * crosscoef n (length us) :: mul_bi' n us.
- Proof.
- unfold mul_bi'; auto.
- Qed.
-
- Lemma add_same_length : forall us vs l, (length us = l) -> (length vs = l) ->
- length (us .+ vs) = l.
- Proof.
- induction us, vs; boring.
- erewrite (IHus vs (pred l)); boring.
- Qed.
-
- Hint Rewrite app_nil_l.
- Hint Rewrite app_nil_r.
-
- Lemma add_snoc_same_length : forall l us vs a b,
- (length us = l) -> (length vs = l) ->
- (us ++ a :: nil) .+ (vs ++ b :: nil) = (us .+ vs) ++ (a + b) :: nil.
- Proof.
- induction l, us, vs; boring; discriminate.
- Qed.
-
- Lemma mul_bi'_add : forall us n vs l
- (Hlus: length us = l)
- (Hlvs: length vs = l),
- mul_bi' n (rev (us .+ vs)) =
- mul_bi' n (rev us) .+ mul_bi' n (rev vs).
- Proof.
- (* TODO(adamc): please help prettify this *)
- induction us using rev_ind;
- try solve [destruct vs; boring; congruence].
- destruct vs using rev_ind; boring; clear IHvs; simpl_list.
- erewrite (add_snoc_same_length (pred l) us vs _ _); simpl_list.
- repeat rewrite mul_bi'_cons; rewrite add_first_terms; simpl_list.
- rewrite (IHus n vs (pred l)).
- replace (length us) with (pred l).
- replace (length vs) with (pred l).
- rewrite (add_same_length us vs (pred l)).
- f_equal; ring.
-
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- Qed.
-
- Lemma zeros_cons0 : forall n, 0 :: zeros n = zeros (S n).
- auto.
- Qed.
-
- Lemma add_leading_zeros : forall n us vs,
- (zeros n ++ us) .+ (zeros n ++ vs) = zeros n ++ (us .+ vs).
- Proof.
- induction n; boring.
- Qed.
- Lemma rev_add_rev : forall us vs l, (length us = l) -> (length vs = l) ->
- (rev us) .+ (rev vs) = rev (us .+ vs).
- Proof.
- induction us, vs; boring; try solve [subst; discriminate].
- rewrite (add_snoc_same_length (pred l) _ _ _ _) by (subst; simpl_list; omega).
- rewrite (IHus vs (pred l)) by omega; auto.
- Qed.
- Hint Rewrite rev_add_rev.
-
- Lemma mul_bi'_length : forall us n, length (mul_bi' n us) = length us.
- Proof.
- induction us, n; boring.
- Qed.
- Hint Rewrite mul_bi'_length.
-
- Lemma add_comm : forall us vs, us .+ vs = vs .+ us.
- Proof.
- induction us, vs; boring; f_equal; auto.
- Qed.
-
- Hint Rewrite rev_length.
-
- Lemma mul_bi_add_same_length : forall n us vs l,
- (length us = l) -> (length vs = l) ->
- mul_bi n (us .+ vs) = mul_bi n us .+ mul_bi n vs.
- Proof.
- unfold mul_bi; boring.
- rewrite add_leading_zeros.
- erewrite mul_bi'_add; boring.
- erewrite rev_add_rev; boring.
- Qed.
-
- Lemma add_zeros_same_length : forall us, us .+ (zeros (length us)) = us.
- Proof.
- induction us; boring; f_equal; omega.
- Qed.
-
- Hint Rewrite add_zeros_same_length.
- Hint Rewrite minus_diag.
-
- Lemma add_trailing_zeros : forall us vs, (length us >= length vs)%nat ->
- us .+ vs = us .+ (vs ++ (zeros (length us - length vs))).
- Proof.
- induction us, vs; boring; f_equal; boring.
- Qed.
-
- Lemma length_add_ge : forall us vs, (length us >= length vs)%nat ->
- (length (us .+ vs) <= length us)%nat.
- Proof.
- intros.
- rewrite add_trailing_zeros by trivial.
- erewrite add_same_length by (pose proof app_length; boring); omega.
- Qed.
-
- Lemma add_length_le_max : forall us vs,
- (length (us .+ vs) <= max (length us) (length vs))%nat.
- Proof.
- intros; case_max; (rewrite add_comm; apply length_add_ge; omega) ||
- (apply length_add_ge; omega) .
- Qed.
-
- Lemma sub_nil_length: forall us : digits, length (sub nil us) = length us.
- Proof.
- induction us; boring.
- Qed.
-
- Lemma sub_length_le_max : forall us vs,
- (length (sub us vs) <= max (length us) (length vs))%nat.
- Proof.
- induction us, vs; boring.
- rewrite sub_nil_length; auto.
- Qed.
-
- Lemma mul_bi_length : forall us n, length (mul_bi n us) = (length us + n)%nat.
- Proof.
- pose proof mul_bi'_length; unfold mul_bi.
- destruct us; repeat progress (simpl_list; boring).
- Qed.
- Hint Rewrite mul_bi_length.
-
- Lemma mul_bi_trailing_zeros : forall m n us,
- mul_bi n us ++ zeros m = mul_bi n (us ++ zeros m).
- Proof.
- unfold mul_bi.
- induction m; intros; try solve [boring].
- rewrite <- zeros_app0.
- rewrite app_assoc.
- repeat progress (boring; rewrite rev_app_distr).
- Qed.
-
- Lemma mul_bi_add_longer : forall n us vs,
- (length us >= length vs)%nat ->
- mul_bi n (us .+ vs) = mul_bi n us .+ mul_bi n vs.
- Proof.
- boring.
- rewrite add_trailing_zeros by auto.
- rewrite (add_trailing_zeros (mul_bi n us) (mul_bi n vs))
- by (repeat (rewrite mul_bi_length); omega).
- erewrite mul_bi_add_same_length by
- (eauto; simpl_list; rewrite length_zeros; omega).
- rewrite mul_bi_trailing_zeros.
- repeat (f_equal; boring).
- Qed.
-
- Lemma mul_bi_add : forall n us vs,
- mul_bi n (us .+ vs) = (mul_bi n us) .+ (mul_bi n vs).
- Proof.
- intros; pose proof mul_bi_add_longer.
- destruct (le_ge_dec (length us) (length vs)). {
- replace (mul_bi n us .+ mul_bi n vs)
- with (mul_bi n vs .+ mul_bi n us)
- by (apply add_comm).
- replace (us .+ vs)
- with (vs .+ us)
- by (apply add_comm).
- boring.
- } {
- boring.
- }
- Qed.
-
- Lemma mul_bi_rep : forall i vs,
- (i + length vs < length base)%nat ->
- decode (mul_bi i vs) = decode vs * nth_default 0 base i.
- Proof.
- unfold decode.
- induction vs using rev_ind; intros; try solve [unfold mul_bi; boring].
- assert (i + length vs < length base)%nat by
- (rewrite app_length in *; boring).
-
- rewrite set_higher.
- ring_simplify.
- rewrite <- IHvs by auto; clear IHvs.
- rewrite <- mul_bi_single by auto.
- rewrite <- add_rep.
- rewrite <- mul_bi_add.
- rewrite set_higher'.
- auto.
- Qed.
-
(* mul' is multiplication with the FIRST ARGUMENT REVERSED *)
Fixpoint mul' (usr vs:digits) : digits :=
match usr with
@@ -500,51 +83,6 @@ Section BaseSystem.
end.
Definition mul us := mul' (rev us).
- Lemma mul'_rep : forall us vs,
- (length us + length vs <= length base)%nat ->
- decode (mul' (rev us) vs) = decode us * decode vs.
- Proof.
- unfold decode.
- induction us using rev_ind; boring.
-
- assert (length us + length vs < length base)%nat by
- (rewrite app_length in *; boring).
-
- ssimpl_list.
- rewrite add_rep.
- boring.
- rewrite set_higher.
- rewrite mul_each_rep.
- rewrite mul_bi_rep by auto.
- unfold decode; ring.
- Qed.
-
- Lemma mul_rep : forall us vs,
- (length us + length vs <= length base)%nat ->
- decode (mul us vs) = decode us * decode vs.
- Proof.
- exact mul'_rep.
- Qed.
-
- Lemma mul'_length: forall us vs,
- (length (mul' us vs) <= length us + length vs)%nat.
- Proof.
- pose proof add_length_le_max.
- induction us; boring.
- unfold mul_each.
- simpl_list; case_max; boring; omega.
- Qed.
-
- Lemma mul_length: forall us vs,
- (length (mul us vs) <= length us + length vs)%nat.
- Proof.
- intros; unfold mul.
- rewrite mul'_length.
- rewrite rev_length; omega.
- Qed.
-
-(* Print Assumptions mul_rep. *)
-
End BaseSystem.
Section PolynomialBaseCoefs.
@@ -615,7 +153,6 @@ Section PolynomialBaseCoefs.
split; apply Z.pow_nonzero; try apply Zle_0_nat; try solve [intro H; inversion H].
Qed.
- Print BaseVector.
Instance PolyBaseVector : BaseVector poly_base := {
base_positive := poly_base_positive;
b0_1 := poly_b0_1;
@@ -633,8 +170,7 @@ Section BaseSystemExample.
Qed.
Definition base2 := poly_base 2 baseLength.
- About mul.
- Example three_times_two : @mul base2 [1;1;0] [0;1;0] = [0;1;1;0;0].
+ Example three_times_two : mul base2 [1;1;0] [0;1;0] = [0;1;1;0;0].
Proof.
reflexivity.
Qed.
diff --git a/src/BaseSystemProofs.v b/src/BaseSystemProofs.v
new file mode 100644
index 000000000..84374fe8f
--- /dev/null
+++ b/src/BaseSystemProofs.v
@@ -0,0 +1,490 @@
+Require Import List.
+Require Import Util.ListUtil Util.CaseUtil Util.ZUtil.
+Require Import ZArith.ZArith ZArith.Zdiv.
+Require Import Omega NPeano Arith.
+Require Import Crypto.BaseSystem.
+Local Open Scope Z.
+
+Local Infix ".+" := add (at level 50).
+
+Local Hint Extern 1 (@eq Z _ _) => ring.
+
+Section BaseSystemProofs.
+ Context `(base_vector : BaseVector).
+
+ Lemma decode'_truncate : forall bs us, decode' bs us = decode' bs (firstn (length bs) us).
+ Proof.
+ unfold decode'; intros; f_equal; apply combine_truncate_l.
+ Qed.
+
+ Lemma add_rep : forall bs us vs, decode' bs (add us vs) = decode' bs us + decode' bs vs.
+ Proof.
+ unfold decode', accumulate; induction bs; destruct us, vs; boring; ring.
+ Qed.
+
+ Lemma decode_nil : forall bs, decode' bs nil = 0.
+ auto.
+ Qed.
+ Hint Rewrite decode_nil.
+
+ Lemma decode_base_nil : forall us, decode' nil us = 0.
+ Proof.
+ intros; rewrite decode'_truncate; auto.
+ Qed.
+ Hint Rewrite decode_base_nil.
+
+ Lemma mul_each_rep : forall bs u vs,
+ decode' bs (mul_each u vs) = u * decode' bs vs.
+ Proof.
+ unfold decode', accumulate; induction bs; destruct vs; boring; ring.
+ Qed.
+
+ Lemma base_eq_1cons: base = 1 :: skipn 1 base.
+ Proof.
+ pose proof (b0_1 0) as H.
+ destruct base; compute in H; try discriminate; boring.
+ Qed.
+
+ Lemma decode'_cons : forall x1 x2 xs1 xs2,
+ decode' (x1 :: xs1) (x2 :: xs2) = x1 * x2 + decode' xs1 xs2.
+ Proof.
+ unfold decode', accumulate; boring; ring.
+ Qed.
+ Hint Rewrite decode'_cons.
+
+ Lemma decode_cons : forall x us,
+ decode base (x :: us) = x + decode base (0 :: us).
+ Proof.
+ unfold decode; intros.
+ rewrite base_eq_1cons.
+ autorewrite with core; ring_simplify; auto.
+ Qed.
+
+ Lemma sub_rep : forall bs us vs, decode' bs (sub us vs) = decode' bs us - decode' bs vs.
+ Proof.
+ induction bs; destruct us; destruct vs; boring; ring.
+ Qed.
+
+ Lemma encode_rep : forall z, decode base (encode z) = z.
+ Proof.
+ pose proof base_eq_1cons.
+ unfold decode, encode; destruct z; boring.
+ Qed.
+
+ Lemma mul_each_base : forall us bs c,
+ decode' bs (mul_each c us) = decode' (mul_each c bs) us.
+ Proof.
+ induction us; destruct bs; boring; ring.
+ Qed.
+
+ Hint Rewrite (@nth_default_nil Z).
+ Hint Rewrite (@firstn_nil Z).
+ Hint Rewrite (@skipn_nil Z).
+
+ Lemma base_app : forall us low high,
+ decode' (low ++ high) us = decode' low (firstn (length low) us) + decode' high (skipn (length low) us).
+ Proof.
+ induction us; destruct low; boring.
+ Qed.
+
+ Lemma base_mul_app : forall low c us,
+ decode' (low ++ mul_each c low) us = decode' low (firstn (length low) us) +
+ c * decode' low (skipn (length low) us).
+ Proof.
+ intros.
+ rewrite base_app; f_equal.
+ rewrite <- mul_each_rep.
+ rewrite mul_each_base.
+ reflexivity.
+ Qed.
+
+ Lemma zeros_rep : forall bs n, decode' bs (zeros n) = 0.
+ induction bs; destruct n; boring.
+ Qed.
+ Lemma length_zeros : forall n, length (zeros n) = n.
+ induction n; boring.
+ Qed.
+ Hint Rewrite length_zeros.
+
+ Lemma app_zeros_zeros : forall n m, zeros n ++ zeros m = zeros (n + m).
+ Proof.
+ induction n; boring.
+ Qed.
+ Hint Rewrite app_zeros_zeros.
+
+ Lemma zeros_app0 : forall m, zeros m ++ 0 :: nil = zeros (S m).
+ Proof.
+ induction m; boring.
+ Qed.
+ Hint Rewrite zeros_app0.
+
+ Lemma rev_zeros : forall n, rev (zeros n) = zeros n.
+ Proof.
+ induction n; boring.
+ Qed.
+ Hint Rewrite rev_zeros.
+
+ Hint Unfold nth_default.
+
+ Lemma decode_single : forall n bs x,
+ decode' bs (zeros n ++ x :: nil) = nth_default 0 bs n * x.
+ Proof.
+ induction n; destruct bs; boring.
+ Qed.
+ Hint Rewrite decode_single.
+
+ Lemma peel_decode : forall xs ys x y, decode' (x::xs) (y::ys) = x*y + decode' xs ys.
+ Proof.
+ boring.
+ Qed.
+ Hint Rewrite zeros_rep peel_decode.
+
+ Lemma decode_highzeros : forall xs bs n, decode' bs (xs ++ zeros n) = decode' bs xs.
+ Proof.
+ induction xs; destruct bs; boring.
+ Qed.
+
+ Lemma mul_bi'_zeros : forall n m, mul_bi' base n (zeros m) = zeros m.
+ induction m; boring.
+ Qed.
+ Hint Rewrite mul_bi'_zeros.
+
+ Lemma nth_error_base_nonzero : forall n x,
+ nth_error base n = Some x -> x <> 0.
+ Proof.
+ eauto using (@nth_error_value_In Z), Zgt0_neq0, base_positive.
+ Qed.
+
+ Hint Rewrite plus_0_r.
+
+ Lemma mul_bi_single : forall m n x,
+ (n + m < length base)%nat ->
+ decode base (mul_bi base n (zeros m ++ x :: nil)) = nth_default 0 base m * x * nth_default 0 base n.
+ Proof.
+ unfold mul_bi, decode.
+ destruct m; simpl; simpl_list; simpl; intros. {
+ pose proof nth_error_base_nonzero as nth_nonzero.
+ case_eq base; [intros; boring | intros z l base_eq].
+ specialize (b0_1 0); intro b0_1'.
+ rewrite base_eq in *.
+ rewrite nth_default_cons in b0_1'.
+ rewrite b0_1' in *.
+ unfold crosscoef.
+ autounfold; autorewrite with core.
+ unfold nth_default.
+ nth_tac.
+ rewrite Z.mul_1_r.
+ rewrite Z_div_same_full.
+ destruct x; ring.
+ eapply nth_nonzero; eauto.
+ } {
+ ssimpl_list.
+ autorewrite with core.
+ rewrite app_assoc.
+ autorewrite with core.
+ unfold crosscoef; simpl; ring_simplify.
+ rewrite Nat.add_1_r.
+ rewrite base_good by auto.
+ rewrite Z_div_mult by (apply base_positive; rewrite nth_default_eq; apply nth_In; auto).
+ rewrite <- Z.mul_assoc.
+ rewrite <- Z.mul_comm.
+ rewrite <- Z.mul_assoc.
+ rewrite <- Z.mul_assoc.
+ destruct (Z.eq_dec x 0); subst; try ring.
+ rewrite Z.mul_cancel_l by auto.
+ rewrite <- base_good by auto.
+ ring.
+ }
+ Qed.
+
+ Lemma set_higher' : forall vs x, vs++x::nil = vs .+ (zeros (length vs) ++ x :: nil).
+ induction vs; boring; f_equal; ring.
+ Qed.
+
+ Lemma set_higher : forall bs vs x,
+ decode' bs (vs++x::nil) = decode' bs vs + nth_default 0 bs (length vs) * x.
+ Proof.
+ intros.
+ rewrite set_higher'.
+ rewrite add_rep.
+ f_equal.
+ apply decode_single.
+ Qed.
+
+ Lemma zeros_plus_zeros : forall n, zeros n = zeros n .+ zeros n.
+ induction n; auto.
+ simpl; f_equal; auto.
+ Qed.
+
+ Lemma mul_bi'_n_nil : forall n, mul_bi' base n nil = nil.
+ Proof.
+ unfold mul_bi; auto.
+ Qed.
+ Hint Rewrite mul_bi'_n_nil.
+
+ Lemma add_nil_l : forall us, nil .+ us = us.
+ induction us; auto.
+ Qed.
+ Hint Rewrite add_nil_l.
+
+ Lemma add_nil_r : forall us, us .+ nil = us.
+ induction us; auto.
+ Qed.
+ Hint Rewrite add_nil_r.
+
+ Lemma add_first_terms : forall us vs a b,
+ (a :: us) .+ (b :: vs) = (a + b) :: (us .+ vs).
+ auto.
+ Qed.
+ Hint Rewrite add_first_terms.
+
+ Lemma mul_bi'_cons : forall n x us,
+ mul_bi' base n (x :: us) = x * crosscoef base n (length us) :: mul_bi' base n us.
+ Proof.
+ unfold mul_bi'; auto.
+ Qed.
+
+ Lemma add_same_length : forall us vs l, (length us = l) -> (length vs = l) ->
+ length (us .+ vs) = l.
+ Proof.
+ induction us, vs; boring.
+ erewrite (IHus vs (pred l)); boring.
+ Qed.
+
+ Hint Rewrite app_nil_l.
+ Hint Rewrite app_nil_r.
+
+ Lemma add_snoc_same_length : forall l us vs a b,
+ (length us = l) -> (length vs = l) ->
+ (us ++ a :: nil) .+ (vs ++ b :: nil) = (us .+ vs) ++ (a + b) :: nil.
+ Proof.
+ induction l, us, vs; boring; discriminate.
+ Qed.
+
+ Lemma mul_bi'_add : forall us n vs l
+ (Hlus: length us = l)
+ (Hlvs: length vs = l),
+ mul_bi' base n (rev (us .+ vs)) =
+ mul_bi' base n (rev us) .+ mul_bi' base n (rev vs).
+ Proof.
+ (* TODO(adamc): please help prettify this *)
+ induction us using rev_ind;
+ try solve [destruct vs; boring; congruence].
+ destruct vs using rev_ind; boring; clear IHvs; simpl_list.
+ erewrite (add_snoc_same_length (pred l) us vs _ _); simpl_list.
+ repeat rewrite mul_bi'_cons; rewrite add_first_terms; simpl_list.
+ rewrite (IHus n vs (pred l)).
+ replace (length us) with (pred l).
+ replace (length vs) with (pred l).
+ rewrite (add_same_length us vs (pred l)).
+ f_equal; ring.
+
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ Qed.
+
+ Lemma zeros_cons0 : forall n, 0 :: zeros n = zeros (S n).
+ auto.
+ Qed.
+
+ Lemma add_leading_zeros : forall n us vs,
+ (zeros n ++ us) .+ (zeros n ++ vs) = zeros n ++ (us .+ vs).
+ Proof.
+ induction n; boring.
+ Qed.
+
+ Lemma rev_add_rev : forall us vs l, (length us = l) -> (length vs = l) ->
+ (rev us) .+ (rev vs) = rev (us .+ vs).
+ Proof.
+ induction us, vs; boring; try solve [subst; discriminate].
+ rewrite (add_snoc_same_length (pred l) _ _ _ _) by (subst; simpl_list; omega).
+ rewrite (IHus vs (pred l)) by omega; auto.
+ Qed.
+ Hint Rewrite rev_add_rev.
+
+ Lemma mul_bi'_length : forall us n, length (mul_bi' base n us) = length us.
+ Proof.
+ induction us, n; boring.
+ Qed.
+ Hint Rewrite mul_bi'_length.
+
+ Lemma add_comm : forall us vs, us .+ vs = vs .+ us.
+ Proof.
+ induction us, vs; boring; f_equal; auto.
+ Qed.
+
+ Hint Rewrite rev_length.
+
+ Lemma mul_bi_add_same_length : forall n us vs l,
+ (length us = l) -> (length vs = l) ->
+ mul_bi base n (us .+ vs) = mul_bi base n us .+ mul_bi base n vs.
+ Proof.
+ unfold mul_bi; boring.
+ rewrite add_leading_zeros.
+ erewrite mul_bi'_add; boring.
+ erewrite rev_add_rev; boring.
+ Qed.
+
+ Lemma add_zeros_same_length : forall us, us .+ (zeros (length us)) = us.
+ Proof.
+ induction us; boring; f_equal; omega.
+ Qed.
+
+ Hint Rewrite add_zeros_same_length.
+ Hint Rewrite minus_diag.
+
+ Lemma add_trailing_zeros : forall us vs, (length us >= length vs)%nat ->
+ us .+ vs = us .+ (vs ++ (zeros (length us - length vs))).
+ Proof.
+ induction us, vs; boring; f_equal; boring.
+ Qed.
+
+ Lemma length_add_ge : forall us vs, (length us >= length vs)%nat ->
+ (length (us .+ vs) <= length us)%nat.
+ Proof.
+ intros.
+ rewrite add_trailing_zeros by trivial.
+ erewrite add_same_length by (pose proof app_length; boring); omega.
+ Qed.
+
+ Lemma add_length_le_max : forall us vs,
+ (length (us .+ vs) <= max (length us) (length vs))%nat.
+ Proof.
+ intros; case_max; (rewrite add_comm; apply length_add_ge; omega) ||
+ (apply length_add_ge; omega) .
+ Qed.
+
+ Lemma sub_nil_length: forall us : digits, length (sub nil us) = length us.
+ Proof.
+ induction us; boring.
+ Qed.
+
+ Lemma sub_length_le_max : forall us vs,
+ (length (sub us vs) <= max (length us) (length vs))%nat.
+ Proof.
+ induction us, vs; boring.
+ rewrite sub_nil_length; auto.
+ Qed.
+
+ Lemma mul_bi_length : forall us n, length (mul_bi base n us) = (length us + n)%nat.
+ Proof.
+ pose proof mul_bi'_length; unfold mul_bi.
+ destruct us; repeat progress (simpl_list; boring).
+ Qed.
+ Hint Rewrite mul_bi_length.
+
+ Lemma mul_bi_trailing_zeros : forall m n us,
+ mul_bi base n us ++ zeros m = mul_bi base n (us ++ zeros m).
+ Proof.
+ unfold mul_bi.
+ induction m; intros; try solve [boring].
+ rewrite <- zeros_app0.
+ rewrite app_assoc.
+ repeat progress (boring; rewrite rev_app_distr).
+ Qed.
+
+ Lemma mul_bi_add_longer : forall n us vs,
+ (length us >= length vs)%nat ->
+ mul_bi base n (us .+ vs) = mul_bi base n us .+ mul_bi base n vs.
+ Proof.
+ boring.
+ rewrite add_trailing_zeros by auto.
+ rewrite (add_trailing_zeros (mul_bi base n us) (mul_bi base n vs))
+ by (repeat (rewrite mul_bi_length); omega).
+ erewrite mul_bi_add_same_length by
+ (eauto; simpl_list; rewrite length_zeros; omega).
+ rewrite mul_bi_trailing_zeros.
+ repeat (f_equal; boring).
+ Qed.
+
+ Lemma mul_bi_add : forall n us vs,
+ mul_bi base n (us .+ vs) = (mul_bi base n us) .+ (mul_bi base n vs).
+ Proof.
+ intros; pose proof mul_bi_add_longer.
+ destruct (le_ge_dec (length us) (length vs)). {
+ rewrite add_comm.
+ rewrite (add_comm (mul_bi base n us)).
+ boring.
+ } {
+ boring.
+ }
+ Qed.
+
+ Lemma mul_bi_rep : forall i vs,
+ (i + length vs < length base)%nat ->
+ decode base (mul_bi base i vs) = decode base vs * nth_default 0 base i.
+ Proof.
+ unfold decode.
+ induction vs using rev_ind; intros; try solve [unfold mul_bi; boring].
+ assert (i + length vs < length base)%nat by
+ (rewrite app_length in *; boring).
+
+ rewrite set_higher.
+ ring_simplify.
+ rewrite <- IHvs by auto; clear IHvs.
+ rewrite <- mul_bi_single by auto.
+ rewrite <- add_rep.
+ rewrite <- mul_bi_add.
+ rewrite set_higher'.
+ auto.
+ Qed.
+
+ (* mul' is multiplication with the FIRST ARGUMENT REVERSED *)
+ Fixpoint mul' (usr vs:digits) : digits :=
+ match usr with
+ | u::usr' =>
+ mul_each u (mul_bi base (length usr') vs) .+ mul' usr' vs
+ | _ => nil
+ end.
+ Definition mul us := mul' (rev us).
+
+ Lemma mul'_rep : forall us vs,
+ (length us + length vs <= length base)%nat ->
+ decode base (mul' (rev us) vs) = decode base us * decode base vs.
+ Proof.
+ unfold decode.
+ induction us using rev_ind; boring.
+
+ assert (length us + length vs < length base)%nat by
+ (rewrite app_length in *; boring).
+
+ ssimpl_list.
+ rewrite add_rep.
+ boring.
+ rewrite set_higher.
+ rewrite mul_each_rep.
+ rewrite mul_bi_rep by auto.
+ unfold decode; ring.
+ Qed.
+
+ Lemma mul_rep : forall us vs,
+ (length us + length vs <= length base)%nat ->
+ decode base (mul us vs) = decode base us * decode base vs.
+ Proof.
+ exact mul'_rep.
+ Qed.
+
+ Lemma mul'_length: forall us vs,
+ (length (mul' us vs) <= length us + length vs)%nat.
+ Proof.
+ pose proof add_length_le_max.
+ induction us; boring.
+ unfold mul_each.
+ simpl_list; case_max; boring; omega.
+ Qed.
+
+ Lemma mul_length: forall us vs,
+ (length (mul us vs) <= length us + length vs)%nat.
+ Proof.
+ intros; unfold mul.
+ rewrite mul'_length.
+ rewrite rev_length; omega.
+ Qed.
+
+End BaseSystemProofs.
diff --git a/src/ModularArithmetic/ExtendedBaseVector.v b/src/ModularArithmetic/ExtendedBaseVector.v
new file mode 100644
index 000000000..746791d8d
--- /dev/null
+++ b/src/ModularArithmetic/ExtendedBaseVector.v
@@ -0,0 +1,163 @@
+Require Import Zpower ZArith.
+Require Import List.
+Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import VerdiTactics.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
+Require Crypto.BaseSystem.
+Local Open Scope Z_scope.
+
+Section ExtendedBaseVector.
+ Context `{prm : PseudoMersenneBaseParams}.
+ Existing Instance bv.
+
+ (* This section defines a new BaseVector that has double the length of the BaseVector
+ * used to construct [params]. The coefficients of the new vector are as follows:
+ *
+ * ext_base[i] = if (i < length base) then base[i] else 2^k * base[i]
+ *
+ * The purpose of this construction is that it allows us to multiply numbers expressed
+ * using [base], obtaining a number expressed using [ext_base]. (Numbers are "expressed" as
+ * vectors of digits; the value of a digit vector is obtained by doing a dot product with
+ * the base vector.) So if x, y are digit vectors:
+ *
+ * (x \dot base) * (y \dot base) = (z \dot ext_base)
+ *
+ * Then we can separate z into its first and second halves:
+ *
+ * (z \dot ext_base) = (z1 \dot base) + (2 ^ k) * (z2 \dot base)
+ *
+ * Now, if we want to reduce the product modulo 2 ^ k - c:
+ *
+ * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + (2 ^ k) * (z2 \dot base) mod (2^k-c)
+ * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + c * (z2 \dot base) mod (2^k-c)
+ *
+ * This sum may be short enough to express using base; if not, we can reduce again.
+ *)
+ Definition ext_base := base ++ (map (Z.mul (2^k)) base).
+
+ Lemma ext_base_positive : forall b, In b ext_base -> b > 0.
+ Proof.
+ unfold ext_base. intros b In_b_base.
+ rewrite in_app_iff in In_b_base.
+ destruct In_b_base as [In_b_base | In_b_extbase].
+ + eapply BaseSystem.base_positive.
+ eapply In_b_base.
+ + eapply in_map_iff in In_b_extbase.
+ destruct In_b_extbase as [b' [b'_2k_b In_b'_base]].
+ subst.
+ specialize (BaseSystem.base_positive b' In_b'_base); intro base_pos.
+ replace 0 with (2 ^ k * 0) by ring.
+ apply (Zmult_gt_compat_l b' 0 (2 ^ k)); [| apply base_pos; intuition].
+ rewrite Z.gt_lt_iff.
+ apply Z.pow_pos_nonneg; intuition.
+ pose proof k_nonneg; omega.
+ Qed.
+
+ Lemma base_length_nonzero : (0 < length base)%nat.
+ Proof.
+ assert (nth_default 0 base 0 = 1) by (apply BaseSystem.b0_1).
+ unfold nth_default in H.
+ case_eq (nth_error base 0); intros;
+ try (rewrite H0 in H; omega).
+ apply (nth_error_value_length _ 0 base z); auto.
+ Qed.
+
+ Lemma b0_1 : forall x, nth_default x ext_base 0 = 1.
+ Proof.
+ intros. unfold ext_base.
+ rewrite nth_default_app.
+ assert (0 < length base)%nat by (apply base_length_nonzero).
+ destruct (lt_dec 0 (length base)); try apply BaseSystem.b0_1; try omega.
+ Qed.
+
+ Lemma two_k_nonzero : 2^k <> 0.
+ Proof.
+ pose proof (Z.pow_eq_0 2 k k_nonneg).
+ intuition.
+ Qed.
+
+ Lemma map_nth_default_base_high : forall n, (n < (length base))%nat ->
+ nth_default 0 (map (Z.mul (2 ^ k)) base) n =
+ (2 ^ k) * (nth_default 0 base n).
+ Proof.
+ intros.
+ erewrite map_nth_default; auto.
+ Qed.
+
+ Lemma base_good_over_boundary : forall
+ (i : nat)
+ (l : (i < length base)%nat)
+ (j' : nat)
+ (Hj': (i + j' < length base)%nat)
+ ,
+ 2 ^ k * (nth_default 0 base i * nth_default 0 base j') =
+ 2 ^ k * (nth_default 0 base i * nth_default 0 base j') /
+ (2 ^ k * nth_default 0 base (i + j')) *
+ (2 ^ k * nth_default 0 base (i + j'))
+ .
+ Proof.
+ intros.
+ remember (nth_default 0 base) as b.
+ rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero).
+ replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat))
+ with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring.
+ rewrite Z.mul_cancel_l by (exact two_k_nonzero).
+ replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat))
+ with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring.
+ subst b.
+ apply (BaseSystem.base_good i j'); omega.
+ Qed.
+
+ Lemma ext_base_good :
+ forall i j, (i+j < length ext_base)%nat ->
+ let b := nth_default 0 ext_base in
+ let r := (b i * b j) / b (i+j)%nat in
+ b i * b j = r * b (i+j)%nat.
+ Proof.
+ intros.
+ subst b. subst r.
+ unfold ext_base in *.
+ rewrite app_length in H; rewrite map_length in H.
+ repeat rewrite nth_default_app.
+ destruct (lt_dec i (length base));
+ destruct (lt_dec j (length base));
+ destruct (lt_dec (i + j) (length base));
+ try omega.
+ { (* i < length base, j < length base, i + j < length base *)
+ apply BaseSystem.base_good; auto.
+ } { (* i < length base, j < length base, i + j >= length base *)
+ rewrite (map_nth_default _ _ _ _ 0) by omega.
+ apply base_matches_modulus; omega.
+ } { (* i < length base, j >= length base, i + j >= length base *)
+ do 2 rewrite map_nth_default_base_high by omega.
+ remember (j - length base)%nat as j'.
+ replace (i + j - length base)%nat with (i + j')%nat by omega.
+ replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j'))
+ with (2 ^ k * (nth_default 0 base i * nth_default 0 base j'))
+ by ring.
+ eapply base_good_over_boundary; eauto; omega.
+ } { (* i >= length base, j < length base, i + j >= length base *)
+ do 2 rewrite map_nth_default_base_high by omega.
+ remember (i - length base)%nat as i'.
+ replace (i + j - length base)%nat with (j + i')%nat by omega.
+ replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j)
+ with (2 ^ k * (nth_default 0 base j * nth_default 0 base i'))
+ by ring.
+ eapply base_good_over_boundary; eauto; omega.
+ }
+ Qed.
+
+ Instance ExtBaseVector : BaseSystem.BaseVector ext_base := {
+ base_positive := ext_base_positive;
+ b0_1 := b0_1;
+ base_good := ext_base_good
+ }.
+
+ Lemma extended_base_length:
+ length ext_base = (length base + length base)%nat.
+ Proof.
+ unfold ext_base; rewrite app_length; rewrite map_length; auto.
+ Qed.
+End ExtendedBaseVector.
+
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v
index b0e493871..d121e9d5c 100644
--- a/src/ModularArithmetic/ModularBaseSystem.v
+++ b/src/ModularArithmetic/ModularBaseSystem.v
@@ -1,643 +1,58 @@
Require Import Zpower ZArith.
Require Import List.
-Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import Crypto.Util.ListUtil.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
-Require Import VerdiTactics.
-Require Crypto.BaseSystem.
+Require Import Crypto.BaseSystem Crypto.ModularArithmetic.PseudoMersenneBaseParams Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs Crypto.ModularArithmetic.ExtendedBaseVector.
Local Open Scope Z_scope.
-Class PseudoMersenneBaseParams (modulus : Z) (base : list Z) (bv : BaseSystem.BaseVector base) := {
- k : Z;
- c : Z;
- modulus_pseudomersenne : modulus = 2^k - c;
- prime_modulus : Znumtheory.prime modulus;
- base_matches_modulus :
- forall i j,
- (i < length base)%nat ->
- (j < length base)%nat ->
- (i+j >= length base)%nat->
- let b := nth_default 0 base in
- let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in
- b i * b j = r * (2^k * b (i+j-length base)%nat);
- base_succ : forall i, ((S i) < length base)%nat ->
- let b := nth_default 0 base in
- b (S i) mod b i = 0;
- base_tail_matches_modulus:
- 2^k mod nth_default 0 base (pred (length base)) = 0;
- k_nonneg : 0 <= k (* Probably implied by modulus_pseudomersenne. *)
-}.
-(*
-Class RepZMod (modulus : Z) := {
- T : Type;
- encode : F modulus -> T;
- decode : T -> F modulus;
-
- rep : T -> F modulus -> Prop;
- encode_rep : forall x, rep (encode x) x;
- rep_decode : forall u x, rep u x -> decode u = x;
-
- add : T -> T -> T;
- add_rep : forall u v x y, rep u x -> rep v y -> rep (add u v) (x+y)%F;
-
- sub : T -> T -> T;
- sub_rep : forall u v x y, rep u x -> rep v y -> rep (sub u v) (x-y)%F;
-
- mul : T -> T -> T;
- mul_rep : forall u v x y, rep u x -> rep v y -> rep (mul u v) (x*y)%F
-}.
-*)
-Print PseudoMersenneBaseParams.
-Section ExtendedBaseVector.
- Context (base : list Z) {modulus : Z} `(params : PseudoMersenneBaseParams modulus base).
- (* This section defines a new BaseVector that has double the length of the BaseVector
- * used to construct [params]. The coefficients of the new vector are as follows:
- *
- * ext_base[i] = if (i < length base) then base[i] else 2^k * base[i]
- *
- * The purpose of this construction is that it allows us to multiply numbers expressed
- * using [base], obtaining a number expressed using [ext_base]. (Numbers are "expressed" as
- * vectors of digits; the value of a digit vector is obtained by doing a dot product with
- * the base vector.) So if x, y are digit vectors:
- *
- * (x \dot base) * (y \dot base) = (z \dot ext_base)
- *
- * Then we can separate z into its first and second halves:
- *
- * (z \dot ext_base) = (z1 \dot base) + (2 ^ k) * (z2 \dot base)
- *
- * Now, if we want to reduce the product modulo 2 ^ k - c:
- *
- * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + (2 ^ k) * (z2 \dot base) mod (2^k-c)
- * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + c * (z2 \dot base) mod (2^k-c)
- *
- * This sum may be short enough to express using base; if not, we can reduce again.
- *)
- Definition ext_base := base ++ (map (Z.mul (2^k)) base).
-
- Lemma ext_base_positive : forall b, In b ext_base -> b > 0.
- Proof.
- unfold ext_base. intros b In_b_base.
- rewrite in_app_iff in In_b_base.
- destruct In_b_base as [? | In_b_extbase]; auto using BaseSystem.base_positive.
- apply in_map_iff in In_b_extbase.
- destruct In_b_extbase as [b' [b'_2k_b In_b'_base]].
- subst.
- specialize (BaseSystem.base_positive b' In_b'_base); intro base_pos.
- replace 0 with (2 ^ k * 0) by ring.
- apply (Zmult_gt_compat_l b' 0 (2 ^ k)); [| apply base_pos; intuition].
- rewrite Z.gt_lt_iff.
- apply Z.pow_pos_nonneg; intuition.
- pose proof k_nonneg; omega.
- Qed.
-
- Lemma base_length_nonzero : (0 < length base)%nat.
- Proof.
- assert (nth_default 0 base 0 = 1) by (apply BaseSystem.b0_1).
- unfold nth_default in H.
- case_eq (nth_error base 0); intros;
- try (rewrite H0 in H; omega).
- apply (nth_error_value_length _ 0 base z); auto.
- Qed.
-
- Lemma b0_1 : forall x, nth_default x ext_base 0 = 1.
- Proof.
- intros. unfold ext_base.
- rewrite nth_default_app.
- assert (0 < length base)%nat by (apply base_length_nonzero).
- destruct (lt_dec 0 (length base)); try apply BaseSystem.b0_1; try omega.
- Qed.
-
- Lemma two_k_nonzero : 2^k <> 0.
- Proof.
- pose proof (Z.pow_eq_0 2 k k_nonneg).
- intuition.
- Qed.
-
- Lemma map_nth_default_base_high : forall n, (n < (length base))%nat ->
- nth_default 0 (map (Z.mul (2 ^ k)) base) n =
- (2 ^ k) * (nth_default 0 base n).
- Proof.
- intros.
- erewrite map_nth_default; auto.
- Qed.
-
- Lemma ext_base_succ : forall i, ((S i) < length ext_base)%nat ->
- let b := nth_default 0 ext_base in
- b (S i) mod b i = 0.
- Proof.
- intros; subst b; unfold ext_base.
- repeat rewrite nth_default_app.
- do 2 break_if; [apply base_succ; auto | omega | | ]. {
- destruct (lt_eq_lt_dec (S i) (length base)); try omega.
- destruct s; intuition.
- rewrite map_nth_default_base_high by omega.
- replace i with (pred(length base)) by omega.
- rewrite <- Zmult_mod_idemp_l.
- rewrite base_tail_matches_modulus.
- rewrite Zmod_0_l; auto.
- } {
- unfold ext_base in *; rewrite app_length, map_length in *.
- repeat rewrite map_nth_default_base_high by omega.
- rewrite Zmult_mod_distr_l.
- rewrite <- minus_Sn_m by omega.
- rewrite base_succ by omega; ring.
- }
- Qed.
-
- Lemma base_good_over_boundary : forall
- (i : nat)
- (l : (i < length base)%nat)
- (j' : nat)
- (Hj': (i + j' < length base)%nat)
- ,
- 2 ^ k * (nth_default 0 base i * nth_default 0 base j') =
- 2 ^ k * (nth_default 0 base i * nth_default 0 base j') /
- (2 ^ k * nth_default 0 base (i + j')) *
- (2 ^ k * nth_default 0 base (i + j'))
- .
- Proof.
- intros.
- remember (nth_default 0 base) as b.
- rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero).
- replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat))
- with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring.
- rewrite Z.mul_cancel_l by (exact two_k_nonzero).
- replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat))
- with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring.
- subst b.
- apply (BaseSystem.base_good i j'); omega.
- Qed.
-
- Lemma ext_base_good :
- forall i j, (i+j < length ext_base)%nat ->
- let b := nth_default 0 ext_base in
- let r := (b i * b j) / b (i+j)%nat in
- b i * b j = r * b (i+j)%nat.
- Proof.
- intros.
- subst b. subst r.
- unfold ext_base in *.
- rewrite app_length in H; rewrite map_length in H.
- repeat rewrite nth_default_app.
- destruct (lt_dec i (length base));
- destruct (lt_dec j (length base));
- destruct (lt_dec (i + j) (length base));
- try omega.
- { (* i < length base, j < length base, i + j < length base *)
- apply BaseSystem.base_good; auto.
- } { (* i < length base, j < length base, i + j >= length base *)
- rewrite (map_nth_default _ _ _ _ 0) by omega.
- apply base_matches_modulus; omega.
- } { (* i < length base, j >= length base, i + j >= length base *)
- do 2 rewrite map_nth_default_base_high by omega.
- remember (j - length base)%nat as j'.
- replace (i + j - length base)%nat with (i + j')%nat by omega.
- replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j'))
- with (2 ^ k * (nth_default 0 base i * nth_default 0 base j'))
- by ring.
- eapply base_good_over_boundary; eauto; omega.
- } { (* i >= length base, j < length base, i + j >= length base *)
- do 2 rewrite map_nth_default_base_high by omega.
- remember (i - length base)%nat as i'.
- replace (i + j - length base)%nat with (j + i')%nat by omega.
- replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j)
- with (2 ^ k * (nth_default 0 base j * nth_default 0 base i'))
- by ring.
- eapply base_good_over_boundary; eauto; omega.
- }
- Qed.
- Instance ExtBaseVector : BaseSystem.BaseVector ext_base := {
- base_positive := ext_base_positive;
- b0_1 := b0_1;
- base_good := ext_base_good
- }.
-End ExtendedBaseVector.
-
-Print ExtBaseVector.
Section PseudoMersenneBase.
- Context `(prm :PseudoMersenneBaseParams).
-
- Definition T := BaseSystem.digits.
- Definition decode (us : T) : F modulus := ZToField (BaseSystem.decode base us).
- Local Hint Unfold decode.
- Definition rep (us : T) (x : F modulus) := (length us <= length base)%nat /\ decode us = x.
+ Context `{prm :PseudoMersenneBaseParams}.
+ Existing Instance bv.
+
+ Definition decode (us : digits) : F modulus := ZToField (BaseSystem.decode base us).
+
+ Definition rep (us : digits) (x : F modulus) := (length us <= length base)%nat /\ decode us = x.
Local Notation "u '~=' x" := (rep u x) (at level 70).
Local Hint Unfold rep.
- Lemma rep_decode : forall us x, us ~= x -> decode us = x.
- Proof.
- autounfold; intuition.
- Qed.
-
- Definition encode (x : F modulus) := BaseSystem.encode x.
-
- Lemma encode_rep : forall x : F modulus, encode x ~= x.
- Proof.
- intros. unfold encode, rep.
- split. {
- unfold encode; simpl.
- apply base_length_nonzero.
- assumption.
- } {
- unfold decode.
- rewrite BaseSystem.encode_rep.
- apply ZToField_FieldToZ.
- assumption.
- }
- Qed.
-
- Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.add u v ~= (x+y)%F.
- Proof.
- autounfold; intuition. {
- unfold add.
- rewrite BaseSystem.add_length_le_max.
- case_max; try rewrite Max.max_r; omega.
- }
- unfold decode in *; unfold BaseSystem.decode in *.
- rewrite BaseSystem.add_rep.
- rewrite ZToField_add.
- subst; auto.
- Qed.
-
- Lemma sub_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.sub u v ~= (x-y)%F.
- Proof.
- autounfold; intuition. {
- rewrite BaseSystem.sub_length_le_max.
- case_max; try rewrite Max.max_r; omega.
- }
- unfold decode in *; unfold BaseSystem.decode in *.
- rewrite BaseSystem.sub_rep.
- rewrite ZToField_sub.
- subst; auto.
- Qed.
-
- Lemma decode_short : forall (us : T),
- (length us <= length base)%nat ->
- BaseSystem.decode base us = BaseSystem.decode (ext_base base prm) us.
- Proof.
- intros.
- unfold BaseSystem.decode, BaseSystem.decode'.
- rewrite combine_truncate_r.
- rewrite (combine_truncate_r us (ext_base base prm)).
- f_equal; f_equal.
- unfold ext_base.
- rewrite firstn_app_inleft; auto; omega.
- Qed.
-
- Lemma extended_base_length:
- length (ext_base base prm) = (length base + length base)%nat.
- Proof.
- unfold ext_base; rewrite app_length; rewrite map_length; auto.
- Qed.
-
- Lemma mul_rep_extended : forall (us vs : T),
- (length us <= length base)%nat ->
- (length vs <= length base)%nat ->
- (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode (ext_base base prm) (BaseSystem.mul (ext_base base prm) us vs).
- Proof.
- intros.
- rewrite BaseSystem.mul_rep by (apply ExtBaseVector || unfold ext_base; simpl_list; omega).
- f_equal; rewrite decode_short; auto.
- Qed.
+ Definition encode (x : F modulus) := encode x.
(* Converts from length of extended base to length of base by reduction modulo M.*)
- Definition reduce (us : T) : T :=
+ Definition reduce (us : digits) : digits :=
let high := skipn (length base) us in
let low := firstn (length base) us in
let wrap := map (Z.mul c) high in
BaseSystem.add low wrap.
- Lemma modulus_nonzero : modulus <> 0.
- pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega.
- Qed.
-
- (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *)
- Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus.
- Proof.
- intros.
- replace (2^k) with ((2^k - c) + c) by ring.
- rewrite Z.mul_add_distr_r.
- rewrite Zplus_mod.
- rewrite <- modulus_pseudomersenne.
- rewrite Z.mul_comm.
- rewrite mod_mult_plus; auto using modulus_nonzero.
- rewrite <- Zplus_mod; auto.
- Qed.
+ Definition mul (us vs : digits) := reduce (BaseSystem.mul ext_base us vs).
- Lemma extended_shiftadd: forall (us : T),
- BaseSystem.decode (ext_base base prm) us =
- BaseSystem.decode base (firstn (length base) us)
- + (2^k * BaseSystem.decode base (skipn (length base) us)).
- Proof.
- intros.
- unfold BaseSystem.decode; rewrite <- BaseSystem.mul_each_rep.
- unfold ext_base.
- replace (map (Z.mul (2 ^ k)) base) with (BaseSystem.mul_each (2 ^ k) base) by auto.
- rewrite BaseSystem.base_mul_app.
- rewrite <- BaseSystem.mul_each_rep; auto.
- Qed.
-
- Lemma reduce_rep : forall us,
- BaseSystem.decode base (reduce us) mod modulus =
- BaseSystem.decode (ext_base base prm) us mod modulus.
- Proof.
- intros.
- rewrite extended_shiftadd.
- rewrite pseudomersenne_add.
- unfold reduce.
- remember (firstn (length base) us) as low.
- remember (skipn (length base) us) as high.
- unfold BaseSystem.decode.
- rewrite BaseSystem.add_rep.
- replace (map (Z.mul c) high) with (BaseSystem.mul_each c high) by auto.
- rewrite BaseSystem.mul_each_rep; auto.
- Qed.
+End PseudoMersenneBase.
- Lemma reduce_length : forall us,
- (length us <= length (ext_base base prm))%nat ->
- (length (reduce us) <= length (base))%nat.
- Proof.
- intros.
- unfold reduce.
- remember (map (Z.mul c) (skipn (length base) us)) as high.
- remember (firstn (length base) us) as low.
- assert (length low >= length high)%nat. {
- subst. rewrite firstn_length.
- rewrite map_length.
- rewrite skipn_length.
- destruct (le_dec (length base) (length us)). {
- rewrite Min.min_l by omega.
- rewrite extended_base_length in H. omega.
- } {
- rewrite Min.min_r; omega.
- }
- }
- assert ((length low <= length base)%nat)
- by (rewrite Heqlow; rewrite firstn_length; apply Min.le_min_l).
- assert (length high <= length base)%nat
- by (rewrite Heqhigh; rewrite map_length; rewrite skipn_length;
- rewrite extended_base_length in H; omega).
- rewrite BaseSystem.add_trailing_zeros; auto.
- rewrite (BaseSystem.add_same_length _ _ (length low)); auto.
- rewrite app_length.
- rewrite BaseSystem.length_zeros; intuition.
- Qed.
+Section CarryBasePow2.
+ Context `{prm :PseudoMersenneBaseParams}.
- Definition mul (us vs : T) := reduce (BaseSystem.mul (ext_base base prm) us vs).
- Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> mul u v ~= (x*y)%F.
- Proof.
- autounfold; unfold mul; intuition.
- {
- apply reduce_length.
- rewrite BaseSystem.mul_length, extended_base_length.
- omega.
- } {
- rewrite ZToField_mod, reduce_rep, <-ZToField_mod.
- rewrite BaseSystem.mul_rep by
- (apply ExtBaseVector || rewrite extended_base_length; omega).
- subst.
- do 2 rewrite decode_short by auto.
- apply ZToField_mul.
- }
- Qed.
+ Definition log_cap i := nth_default 0 limb_widths i.
Definition add_to_nth n (x:Z) xs :=
set_nth n (x + nth_default 0 xs n) xs.
- Hint Unfold add_to_nth.
-
- (* i must be in the domain of base *)
- Definition cap i :=
- if eq_nat_dec i (pred (length base))
- then (2^k) / nth_default 0 base i
- else nth_default 0 base (S i) / nth_default 0 base i.
+
+ Definition pow2_mod n i := Z.land n (Z.ones i).
Definition carry_simple i := fun us =>
let di := nth_default 0 us i in
- let us' := set_nth i (di mod cap i) us in
- add_to_nth (S i) ( (di / cap i)) us'.
+ let us' := set_nth i (pow2_mod di (log_cap i)) us in
+ add_to_nth (S i) ( (Z.shiftr di (log_cap i))) us'.
Definition carry_and_reduce i := fun us =>
let di := nth_default 0 us i in
- let us' := set_nth i (di mod cap i) us in
- add_to_nth 0 (c * (di / cap i)) us'.
+ let us' := set_nth i (pow2_mod di (log_cap i)) us in
+ add_to_nth 0 (c * (Z.shiftr di (log_cap i))) us'.
- Definition carry i : T -> T :=
+ Definition carry i : digits -> digits :=
if eq_nat_dec i (pred (length base))
then carry_and_reduce i
else carry_simple i.
- (* TODO: move to BaseSystemProofs *)
- Lemma decode'_splice : forall xs ys bs,
- BaseSystem.decode' bs (xs ++ ys) =
- BaseSystem.decode' (firstn (length xs) bs) xs +
- BaseSystem.decode' (skipn (length xs) bs) ys.
- Proof.
- unfold BaseSystem.decode'.
- induction xs; destruct ys, bs; boring.
- + rewrite combine_truncate_r.
- do 2 rewrite Z.add_0_r; auto.
- + unfold BaseSystem.accumulate.
- apply Z.add_assoc.
- Qed.
-
- Lemma set_nth_sum : forall n x us, (n < length us)%nat ->
- BaseSystem.decode base (set_nth n x us) =
- (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us.
- Proof.
- intros.
- unfold BaseSystem.decode.
- nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*)
- unfold splice_nth.
- rewrite <- (firstn_skipn n us) at 4.
- do 2 rewrite decode'_splice.
- remember (length (firstn n us)) as n0.
- ring_simplify.
- remember (BaseSystem.decode' (firstn n0 base) (firstn n us)).
- rewrite (skipn_nth_default n us 0) by omega.
- rewrite firstn_length in Heqn0.
- rewrite Min.min_l in Heqn0 by omega; subst n0.
- destruct (le_lt_dec (length base) n). {
- rewrite nth_default_out_of_bounds by auto.
- rewrite skipn_all by omega.
- do 2 rewrite BaseSystem.decode_base_nil.
- ring_simplify; auto.
- } {
- rewrite (skipn_nth_default n base 0) by omega.
- do 2 rewrite BaseSystem.decode'_cons.
- ring_simplify; ring.
- }
- Qed.
-
- Lemma add_to_nth_sum : forall n x us, (n < length us)%nat ->
- BaseSystem.decode base (add_to_nth n x us) =
- x * nth_default 0 base n + BaseSystem.decode base us.
- Proof.
- unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto.
- Qed.
-
- Lemma nth_default_base_positive : forall i, (i < length base)%nat ->
- nth_default 0 base i > 0.
- Proof.
- intros.
- pose proof (nth_error_length_exists_value _ _ H).
- destruct H0.
- pose proof (nth_error_value_In _ _ _ H0).
- pose proof (BaseSystem.base_positive _ H1).
- unfold nth_default.
- rewrite H0; auto.
- Qed.
-
- Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat ->
- nth_default 0 base (S i) = nth_default 0 base i *
- (nth_default 0 base (S i) / nth_default 0 base i).
- Proof.
- intros.
- apply Z_div_exact_2; try (apply nth_default_base_positive; omega).
- apply base_succ; auto.
- Qed.
-
- Lemma base_length_lt_pred : (pred (length base) < length base)%nat.
- Proof.
- pose proof (base_length_nonzero base); omega.
- Qed.
- Hint Resolve base_length_lt_pred.
-
- Lemma cap_positive: forall i, (i < length base)%nat -> cap i > 0.
- Proof.
- unfold cap; intros; break_if. {
- apply div_positive_gt_0; try (subst; apply base_tail_matches_modulus). {
- rewrite <- two_p_equiv.
- apply two_p_gt_ZERO.
- apply k_nonneg.
- } {
- apply nth_default_base_positive; subst; auto.
- }
- } {
- apply div_positive_gt_0; try (apply base_succ; omega);
- try (apply nth_default_base_positive; omega).
- }
- Qed.
-
- Lemma cap_div_mod : forall us i, (i < (pred (length base)))%nat ->
- let di := nth_default 0 us i in
- (di - (di mod cap i)) * nth_default 0 base i =
- (di / cap i) * nth_default 0 base (S i).
- Proof.
- intros.
- rewrite (Z_div_mod_eq di (cap i)) at 1 by (apply cap_positive; omega);
- ring_simplify.
- unfold cap; break_if; intuition.
- rewrite base_succ_div_mult at 4 by omega; ring.
- Qed.
-
- Lemma carry_simple_decode_eq : forall i us,
- (length us = length base) ->
- (i < (pred (length base)))%nat ->
- BaseSystem.decode base (carry_simple i us) = BaseSystem.decode base us.
- Proof.
- unfold carry_simple. intros.
- rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
- rewrite set_nth_sum by omega.
- rewrite <- cap_div_mod by auto; ring_simplify; auto.
- Qed.
-
- Lemma two_k_div_mul_last :
- 2 ^ k = nth_default 0 base (pred (length base)) *
- (2 ^ k / nth_default 0 base (pred (length base))).
- Proof.
- intros.
- pose proof base_tail_matches_modulus.
- rewrite (Z_div_mod_eq (2 ^ k) (nth_default 0 base (pred (length base)))) at 1 by
- (apply nth_default_base_positive; auto); omega.
- Qed.
-
- Lemma cap_div_mod_reduce : forall us,
- let i := pred (length base) in
- let di := nth_default 0 us i in
- (di - (di mod cap i)) * nth_default 0 base i =
- (di / cap i) * 2 ^ k.
- Proof.
- intros.
- rewrite (Z_div_mod_eq di (cap i)) at 1 by
- (apply cap_positive; auto); ring_simplify.
- unfold cap; break_if; intuition.
- rewrite Z.mul_comm, Z.mul_assoc.
- subst i; rewrite <- two_k_div_mul_last; ring.
- Qed.
-
- Lemma carry_decode_eq_reduce : forall us,
- (length us = length base) ->
- BaseSystem.decode base (carry_and_reduce (pred (length base)) us) mod modulus
- = BaseSystem.decode base us mod modulus.
- Proof.
- unfold carry_and_reduce; intros.
- pose proof (base_length_nonzero base).
- rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
- rewrite set_nth_sum by omega.
- rewrite Zplus_comm, <- Z.mul_assoc.
- rewrite <- pseudomersenne_add.
- rewrite BaseSystem.b0_1.
- rewrite (Z.mul_comm (2 ^ k)).
- rewrite <- Zred_factor0.
- rewrite <- cap_div_mod_reduce by auto.
- do 2 rewrite Zmult_minus_distr_r.
- f_equal.
- ring.
- Qed.
-
- Lemma carry_length : forall i us,
- (length us <= length base)%nat ->
- (length (carry i us) <= length base)%nat.
- Proof.
- unfold carry, carry_simple, carry_and_reduce, add_to_nth.
- intros; break_if; subst; repeat (rewrite length_set_nth); auto.
- Qed.
- Hint Resolve carry_length.
-
- Lemma carry_rep : forall i us x,
- (length us = length base) ->
- (i < length base)%nat ->
- us ~= x -> carry i us ~= x.
- Proof.
- pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq.
- unfold rep, decode, carry in *; intros.
- intuition; break_if; subst; eauto;
- apply F_eq; simpl; intuition.
- Qed.
- Hint Resolve carry_rep.
-
Definition carry_sequence is us := fold_right carry us is.
- Lemma carry_sequence_length: forall is us,
- (length us <= length base)%nat ->
- (length (carry_sequence is us) <= length base)%nat.
- Proof.
- induction is; boring.
- Qed.
- Hint Resolve carry_sequence_length.
-
- Lemma carry_length_exact : forall i us,
- (length us = length base)%nat ->
- (length (carry i us) = length base)%nat.
- Proof.
- unfold carry, carry_simple, carry_and_reduce, add_to_nth.
- intros; break_if; subst; repeat (rewrite length_set_nth); auto.
- Qed.
-
- Lemma carry_sequence_length_exact: forall is us,
- (length us = length base)%nat ->
- (length (carry_sequence is us) = length base)%nat.
- Proof.
- induction is; boring.
- apply carry_length_exact; auto.
- Qed.
- Hint Resolve carry_sequence_length_exact.
-
- Lemma carry_sequence_rep : forall is us x,
- (forall i, In i is -> (i < length base)%nat) ->
- (length us = length base) ->
- us ~= x -> carry_sequence is us ~= x.
- Proof.
- induction is; boring.
- Qed.
-End PseudoMersenneBase.
+End CarryBasePow2.
diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v
new file mode 100644
index 000000000..524c2da27
--- /dev/null
+++ b/src/ModularArithmetic/ModularBaseSystemProofs.v
@@ -0,0 +1,449 @@
+Require Import Zpower ZArith.
+Require Import List.
+Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import VerdiTactics.
+Require Crypto.BaseSystem.
+Require Import Crypto.ModularArithmetic.ModularBaseSystem Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import Crypto.BaseSystemProofs Crypto.ModularArithmetic.PseudoMersenneBaseParams Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs Crypto.ModularArithmetic.ExtendedBaseVector.
+Local Open Scope Z_scope.
+
+Section PseudoMersenneProofs.
+ Context `{prm :PseudoMersenneBaseParams}.
+ Existing Instance bv.
+
+ Local Hint Unfold decode.
+ Local Notation "u '~=' x" := (rep u x) (at level 70).
+ Local Notation "u '.+' x" := (add u x) (at level 70).
+ Local Notation "u '.*' x" := (ModularBaseSystem.mul u x) (at level 70).
+ Local Hint Unfold rep.
+
+ Lemma rep_decode : forall us x, us ~= x -> decode us = x.
+ Proof.
+ autounfold; intuition.
+ Qed.
+
+ Lemma encode_rep : forall x : F modulus, encode x ~= x.
+ Proof.
+ intros. unfold encode, rep.
+ split. {
+ unfold encode; simpl.
+ apply base_length_nonzero.
+ } {
+ unfold decode.
+ rewrite encode_rep.
+ apply ZToField_FieldToZ.
+ apply bv.
+ }
+ Qed.
+
+ Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.add u v ~= (x+y)%F.
+ Proof.
+ autounfold; intuition. {
+ unfold add.
+ rewrite add_length_le_max.
+ case_max; try rewrite Max.max_r; omega.
+ }
+ unfold decode in *; unfold decode in *.
+ rewrite add_rep.
+ rewrite ZToField_add.
+ subst; auto.
+ Qed.
+
+ Lemma sub_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.sub u v ~= (x-y)%F.
+ Proof.
+ autounfold; intuition. {
+ rewrite sub_length_le_max.
+ case_max; try rewrite Max.max_r; omega.
+ }
+ unfold decode in *; unfold BaseSystem.decode in *.
+ rewrite sub_rep.
+ rewrite ZToField_sub.
+ subst; auto.
+ Qed.
+
+ Lemma decode_short : forall (us : BaseSystem.digits),
+ (length us <= length base)%nat ->
+ BaseSystem.decode base us = BaseSystem.decode ext_base us.
+ Proof.
+ intros.
+ unfold BaseSystem.decode, BaseSystem.decode'.
+ rewrite combine_truncate_r.
+ rewrite (combine_truncate_r us ext_base).
+ f_equal; f_equal.
+ unfold ext_base.
+ rewrite firstn_app_inleft; auto; omega.
+ Qed.
+
+ Lemma mul_rep_extended : forall (us vs : BaseSystem.digits),
+ (length us <= length base)%nat ->
+ (length vs <= length base)%nat ->
+ (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode ext_base (BaseSystem.mul ext_base us vs).
+ Proof.
+ intros.
+ rewrite mul_rep by (apply ExtBaseVector || unfold ext_base; simpl_list; omega).
+ f_equal; rewrite decode_short; auto.
+ Qed.
+
+ Lemma modulus_nonzero : modulus <> 0.
+ pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega.
+ Qed.
+
+ (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *)
+ Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus.
+ Proof.
+ intros.
+ replace (2^k) with ((2^k - c) + c) by ring.
+ rewrite Z.mul_add_distr_r.
+ rewrite Zplus_mod.
+ rewrite <- modulus_pseudomersenne.
+ rewrite Z.mul_comm.
+ rewrite mod_mult_plus; auto using modulus_nonzero.
+ rewrite <- Zplus_mod; auto.
+ Qed.
+
+ Lemma extended_shiftadd: forall (us : BaseSystem.digits),
+ BaseSystem.decode ext_base us =
+ BaseSystem.decode base (firstn (length base) us)
+ + (2^k * BaseSystem.decode base (skipn (length base) us)).
+ Proof.
+ intros.
+ unfold BaseSystem.decode; rewrite <- mul_each_rep.
+ unfold ext_base.
+ replace (map (Z.mul (2 ^ k)) base) with (BaseSystem.mul_each (2 ^ k) base) by auto.
+ rewrite base_mul_app.
+ rewrite <- mul_each_rep; auto.
+ Qed.
+
+ Lemma reduce_rep : forall us,
+ BaseSystem.decode base (reduce us) mod modulus =
+ BaseSystem.decode ext_base us mod modulus.
+ Proof.
+ intros.
+ rewrite extended_shiftadd.
+ rewrite pseudomersenne_add.
+ unfold reduce.
+ remember (firstn (length base) us) as low.
+ remember (skipn (length base) us) as high.
+ unfold BaseSystem.decode.
+ rewrite BaseSystemProofs.add_rep.
+ replace (map (Z.mul c) high) with (BaseSystem.mul_each c high) by auto.
+ rewrite mul_each_rep; auto.
+ Qed.
+
+ Lemma reduce_length : forall us,
+ (length us <= length ext_base)%nat ->
+ (length (reduce us) <= length base)%nat.
+ Proof.
+ intros.
+ unfold reduce.
+ remember (map (Z.mul c) (skipn (length base) us)) as high.
+ remember (firstn (length base) us) as low.
+ assert (length low >= length high)%nat. {
+ subst. rewrite firstn_length.
+ rewrite map_length.
+ rewrite skipn_length.
+ destruct (le_dec (length base) (length us)). {
+ rewrite Min.min_l by omega.
+ rewrite extended_base_length in H. omega.
+ } {
+ rewrite Min.min_r; omega.
+ }
+ }
+ assert ((length low <= length base)%nat)
+ by (rewrite Heqlow; rewrite firstn_length; apply Min.le_min_l).
+ assert (length high <= length base)%nat
+ by (rewrite Heqhigh; rewrite map_length; rewrite skipn_length;
+ rewrite extended_base_length in H; omega).
+ rewrite add_trailing_zeros; auto.
+ rewrite (add_same_length _ _ (length low)); auto.
+ rewrite app_length.
+ rewrite length_zeros; intuition.
+ Qed.
+
+ Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> u .* v ~= (x*y)%F.
+ Proof.
+ autounfold; unfold ModularBaseSystem.mul; intuition.
+ {
+ apply reduce_length.
+ rewrite mul_length, extended_base_length.
+ omega.
+ } {
+ rewrite ZToField_mod, reduce_rep, <-ZToField_mod.
+ rewrite mul_rep by
+ (apply ExtBaseVector || rewrite extended_base_length; omega).
+ subst.
+ do 2 rewrite decode_short by auto.
+ apply ZToField_mul.
+ }
+ Qed.
+
+ (* TODO: move to BaseSystemProofs *)
+ Lemma decode'_splice : forall xs ys bs,
+ BaseSystem.decode' bs (xs ++ ys) =
+ BaseSystem.decode' (firstn (length xs) bs) xs +
+ BaseSystem.decode' (skipn (length xs) bs) ys.
+ Proof.
+ unfold BaseSystem.decode'.
+ induction xs; destruct ys, bs; boring.
+ + rewrite combine_truncate_r.
+ do 2 rewrite Z.add_0_r; auto.
+ + unfold BaseSystem.accumulate.
+ apply Z.add_assoc.
+ Qed.
+
+ Lemma set_nth_sum : forall n x us, (n < length us)%nat ->
+ BaseSystem.decode base (set_nth n x us) =
+ (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us.
+ Proof.
+ intros.
+ unfold BaseSystem.decode.
+ nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*)
+ unfold splice_nth.
+ rewrite <- (firstn_skipn n us) at 4.
+ do 2 rewrite decode'_splice.
+ remember (length (firstn n us)) as n0.
+ ring_simplify.
+ remember (BaseSystem.decode' (firstn n0 base) (firstn n us)).
+ rewrite (skipn_nth_default n us 0) by omega.
+ rewrite firstn_length in Heqn0.
+ rewrite Min.min_l in Heqn0 by omega; subst n0.
+ destruct (le_lt_dec (length base) n). {
+ rewrite nth_default_out_of_bounds by auto.
+ rewrite skipn_all by omega.
+ do 2 rewrite decode_base_nil.
+ ring_simplify; auto.
+ } {
+ rewrite (skipn_nth_default n base 0) by omega.
+ do 2 rewrite decode'_cons.
+ ring_simplify; ring.
+ }
+ Qed.
+
+ Lemma add_to_nth_sum : forall n x us, (n < length us)%nat ->
+ BaseSystem.decode base (add_to_nth n x us) =
+ x * nth_default 0 base n + BaseSystem.decode base us.
+ Proof.
+ unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto.
+ Qed.
+
+ Lemma nth_default_base_positive : forall i, (i < length base)%nat ->
+ nth_default 0 base i > 0.
+ Proof.
+ intros.
+ pose proof (nth_error_length_exists_value _ _ H).
+ destruct H0.
+ pose proof (nth_error_value_In _ _ _ H0).
+ pose proof (BaseSystem.base_positive _ H1).
+ unfold nth_default.
+ rewrite H0; auto.
+ Qed.
+
+ Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat ->
+ nth_default 0 base (S i) = nth_default 0 base i *
+ (nth_default 0 base (S i) / nth_default 0 base i).
+ Proof.
+ intros.
+ apply Z_div_exact_2; try (apply nth_default_base_positive; omega).
+ apply base_succ; auto.
+ Qed.
+
+End PseudoMersenneProofs.
+
+Section CarryProofs.
+ Context `{prm : PseudoMersenneBaseParams}.
+ Existing Instance bv.
+ Local Notation "u '~=' x" := (rep u x) (at level 70).
+ Hint Unfold log_cap.
+
+ Lemma base_length_lt_pred : (pred (length base) < length base)%nat.
+ Proof.
+ pose proof base_length_nonzero; omega.
+ Qed.
+ Hint Resolve base_length_lt_pred.
+
+ Lemma log_cap_nonneg : forall i, 0 <= log_cap i.
+ Proof.
+ unfold log_cap, nth_default; intros.
+ case_eq (nth_error limb_widths i); intros; try omega.
+ apply limb_widths_nonneg.
+ eapply nth_error_value_In; eauto.
+ Qed.
+
+ (* TODO : move to ZUtil *)
+ Lemma div_pow2succ : forall n x, (0 <= x) ->
+ n / 2 ^ Z.succ x = Z.div2 (n / 2 ^ x).
+ Proof.
+ intros.
+ rewrite Z.pow_succ_r, Z.mul_comm by auto.
+ rewrite <- Z.div_div by (try apply Z.pow_nonzero; omega).
+ rewrite Zdiv2_div.
+ reflexivity.
+ Qed.
+
+ (* TODO: move to ZUtil *)
+ Lemma shiftr_succ : forall n x,
+ Z.shiftr n (Z.succ x) = Z.shiftr (Z.shiftr n x) 1.
+ Proof.
+ intros.
+ rewrite Z.shiftr_shiftr by omega.
+ reflexivity.
+ Qed.
+
+ (* TODO : move to ZUtil *)
+ Lemma shiftr_div : forall n i, (0 <= i) -> Z.shiftr n i = n / (2 ^ i).
+ Proof.
+ intro.
+ apply natlike_ind; intros; [boring|].
+ rewrite div_pow2succ by auto.
+ rewrite shiftr_succ.
+ unfold Z.shiftr.
+ simpl; f_equal.
+ auto.
+ Qed.
+
+ (* TODO : move to ListUtil *)
+ Lemma nth_error_Some_nth_default : forall {T} i x (l : list T), (i < length l)%nat ->
+ nth_error l i = Some (nth_default x l i).
+ Proof.
+ intros ? ? ? ? i_lt_length.
+ destruct (nth_error_length_exists_value _ _ i_lt_length) as [k nth_err_k].
+ unfold nth_default.
+ rewrite nth_err_k.
+ reflexivity.
+ Qed.
+
+ Lemma nth_default_base_succ : forall i, (S i < length base)%nat ->
+ nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i.
+ Proof.
+ intros.
+ repeat rewrite nth_default_base by omega.
+ rewrite <- Z.pow_add_r by (apply log_cap_nonneg || apply sum_firstn_limb_widths_nonneg).
+ destruct (NPeano.Nat.eq_dec i 0).
+ + subst; f_equal.
+ unfold sum_firstn, log_cap.
+ destruct limb_widths; auto.
+ + erewrite sum_firstn_succ; eauto.
+ unfold log_cap.
+ apply nth_error_Some_nth_default.
+ rewrite <- base_length; omega.
+ Qed.
+
+ Lemma carry_simple_decode_eq : forall i us,
+ (length us = length base) ->
+ (i < (pred (length base)))%nat ->
+ BaseSystem.decode base (carry_simple i us) = BaseSystem.decode base us.
+ Proof.
+ unfold carry_simple. intros.
+ rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
+ rewrite set_nth_sum by omega.
+ unfold pow2_mod.
+ rewrite Z.land_ones by apply log_cap_nonneg.
+ rewrite shiftr_div by apply log_cap_nonneg.
+ rewrite nth_default_base_succ by omega.
+ rewrite Z.mul_assoc.
+ rewrite (Z.mul_comm _ (2 ^ log_cap i)).
+ rewrite mul_div_eq; try ring.
+ apply gt_lt_symmetry.
+ apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg.
+ Qed.
+
+ Lemma carry_decode_eq_reduce : forall us,
+ (length us = length base) ->
+ BaseSystem.decode base (carry_and_reduce (pred (length base)) us) mod modulus
+ = BaseSystem.decode base us mod modulus.
+ Proof.
+ unfold carry_and_reduce; intros ? length_eq.
+ pose proof base_length_nonzero.
+ rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
+ rewrite set_nth_sum by omega.
+ rewrite Zplus_comm, <- Z.mul_assoc, <- pseudomersenne_add, BaseSystem.b0_1.
+ rewrite (Z.mul_comm (2 ^ k)), <- Zred_factor0.
+ f_equal.
+ rewrite <- (Z.add_comm (BaseSystem.decode base us)), <- Z.add_assoc, <- Z.add_0_r.
+ f_equal.
+ destruct (NPeano.Nat.eq_dec (length base) 0) as [length_zero | length_nonzero].
+ + apply length0_nil in length_zero.
+ pose proof (base_length) as limbs_length.
+ rewrite length_zero in length_eq, limbs_length.
+ apply length0_nil in length_eq.
+ symmetry in limbs_length.
+ apply length0_nil in limbs_length.
+ unfold log_cap.
+ subst; rewrite length_zero, limbs_length, nth_default_nil.
+ reflexivity.
+ + rewrite nth_default_base by omega.
+ rewrite <- Z.add_opp_l, <- Z.opp_sub_distr.
+ unfold pow2_mod.
+ rewrite Z.land_ones by apply log_cap_nonneg.
+ rewrite <- mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg).
+ rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg.
+ rewrite Zopp_mult_distr_r.
+ rewrite Z.mul_comm.
+ rewrite Z.mul_assoc.
+ rewrite <- Z.pow_add_r by (apply log_cap_nonneg || apply sum_firstn_limb_widths_nonneg).
+ rewrite <- k_matches_limb_widths.
+ replace (length limb_widths) with (S (pred (length base))) by
+ (subst; rewrite <- base_length; apply NPeano.Nat.succ_pred; omega).
+ rewrite sum_firstn_succ with (x:= log_cap (pred (length base))) by
+ (unfold log_cap; apply nth_error_Some_nth_default; rewrite <- base_length; omega).
+ rewrite <- Zopp_mult_distr_r.
+ rewrite Z.mul_comm.
+ rewrite (Z.add_comm (log_cap (pred (length base)))).
+ ring.
+ Qed.
+
+ Lemma carry_length : forall i us,
+ (length us <= length base)%nat ->
+ (length (carry i us) <= length base)%nat.
+ Proof.
+ unfold carry, carry_simple, carry_and_reduce, add_to_nth.
+ intros; break_if; subst; repeat (rewrite length_set_nth); auto.
+ Qed.
+ Hint Resolve carry_length.
+
+ Lemma carry_rep : forall i us x,
+ (length us = length base) ->
+ (i < length base)%nat ->
+ us ~= x -> carry i us ~= x.
+ Proof.
+ pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq.
+ unfold rep, decode, carry in *; intros.
+ intuition; break_if; subst; eauto;
+ apply F_eq; simpl; intuition.
+ Qed.
+ Hint Resolve carry_rep.
+
+ Lemma carry_sequence_length: forall is us,
+ (length us <= length base)%nat ->
+ (length (carry_sequence is us) <= length base)%nat.
+ Proof.
+ induction is; boring.
+ Qed.
+ Hint Resolve carry_sequence_length.
+
+ Lemma carry_length_exact : forall i us,
+ (length us = length base)%nat ->
+ (length (carry i us) = length base)%nat.
+ Proof.
+ unfold carry, carry_simple, carry_and_reduce, add_to_nth.
+ intros; break_if; subst; repeat (rewrite length_set_nth); auto.
+ Qed.
+
+ Lemma carry_sequence_length_exact: forall is us,
+ (length us = length base)%nat ->
+ (length (carry_sequence is us) = length base)%nat.
+ Proof.
+ induction is; boring.
+ apply carry_length_exact; auto.
+ Qed.
+ Hint Resolve carry_sequence_length_exact.
+
+ Lemma carry_sequence_rep : forall is us x,
+ (forall i, In i is -> (i < length base)%nat) ->
+ (length us = length base) ->
+ us ~= x -> carry_sequence is us ~= x.
+ Proof.
+ induction is; boring.
+ Qed.
+
+End CarryProofs.
diff --git a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
new file mode 100644
index 000000000..847b8e85f
--- /dev/null
+++ b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
@@ -0,0 +1,277 @@
+Require Import Zpower ZArith.
+Require Import List.
+Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import VerdiTactics.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
+Require Crypto.BaseSystem.
+Local Open Scope Z_scope.
+
+Section PseudoMersenneBaseParamProofs.
+ Context `{prm : PseudoMersenneBaseParams}.
+
+ Fixpoint base_from_limb_widths limb_widths :=
+ match limb_widths with
+ | nil => nil
+ | w :: lw => 1 :: map (Z.mul (two_p w)) (base_from_limb_widths lw)
+ end.
+
+ Definition base := base_from_limb_widths limb_widths.
+
+ Lemma base_length : length base = length limb_widths.
+ Proof.
+ unfold base.
+ induction limb_widths; try reflexivity.
+ simpl; rewrite map_length; auto.
+ Qed.
+
+ Lemma nth_error_first : forall {T} (a b : T) l, nth_error (a :: l) 0 = Some b ->
+ a = b.
+ Proof.
+ intros; simpl in *.
+ unfold value in *.
+ congruence.
+ Qed.
+
+ Lemma base_from_limb_widths_step : forall i b w, (S i < length base)%nat ->
+ nth_error base i = Some b ->
+ nth_error limb_widths i = Some w ->
+ nth_error base (S i) = Some (two_p w * b).
+ Proof.
+ unfold base; induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b;
+ unfold base_from_limb_widths in *; fold base_from_limb_widths in *;
+ [rewrite (@nil_length0 Z) in *; omega | ].
+ simpl in *; rewrite map_length in *.
+ case_eq i; intros; subst.
+ + subst; apply nth_error_first in nth_err_w.
+ apply nth_error_first in nth_err_b; subst.
+ apply map_nth_error.
+ case_eq l; intros; subst; [simpl in *; omega | ].
+ unfold base_from_limb_widths; fold base_from_limb_widths.
+ reflexivity.
+ + simpl in nth_err_w.
+ apply nth_error_map in nth_err_w.
+ destruct nth_err_w as [x [A B]].
+ subst.
+ replace (two_p w * (two_p a * x)) with (two_p a * (two_p w * x)) by ring.
+ apply map_nth_error.
+ apply IHl; auto; omega.
+ Qed.
+
+ Lemma nth_error_exists_first : forall {T} l (x : T) (H : nth_error l 0 = Some x),
+ exists l', l = x :: l'.
+ Proof.
+ induction l; try discriminate; eexists.
+ apply nth_error_first in H.
+ subst; eauto.
+ Qed.
+
+ Lemma sum_firstn_succ : forall l i x,
+ nth_error l i = Some x ->
+ sum_firstn l (S i) = x + sum_firstn l i.
+ Proof.
+ unfold sum_firstn; induction l;
+ [intros; rewrite (@nth_error_nil_error Z) in *; congruence | ].
+ intros ? x nth_err_x; destruct (NPeano.Nat.eq_dec i 0).
+ + subst; simpl in *; unfold value in *.
+ congruence.
+ + rewrite <- (NPeano.Nat.succ_pred i) at 2 by auto.
+ rewrite <- (NPeano.Nat.succ_pred i) in nth_err_x by auto.
+ simpl. simpl in nth_err_x.
+ specialize (IHl (pred i) x).
+ rewrite NPeano.Nat.succ_pred in IHl by auto.
+ destruct (NPeano.Nat.eq_dec (pred i) 0).
+ - replace i with 1%nat in * by omega.
+ simpl. replace (pred 1) with 0%nat in * by auto.
+ apply nth_error_exists_first in nth_err_x.
+ destruct nth_err_x as [l' ?].
+ subst; simpl; ring.
+ - rewrite IHl by auto; ring.
+ Qed.
+
+ (* TODO : move to LsitUtil *)
+ Lemma fold_right_invariant : forall {A} P (f: A -> A -> A) l x,
+ P x -> (forall y, In y l -> forall z, P z -> P (f y z)) ->
+ P (fold_right f x l).
+ Proof.
+ induction l; intros ? ? step; auto.
+ simpl.
+ apply step; try apply in_eq.
+ apply IHl; auto.
+ intros y in_y_l.
+ apply (in_cons a) in in_y_l.
+ auto.
+ Qed.
+
+ (* TODO : move to ListUtil *)
+ Lemma In_firstn : forall {T} n l (x : T), In x (firstn n l) -> In x l.
+ Proof.
+ induction n; destruct l; boring.
+ Qed.
+
+ Lemma sum_firstn_limb_widths_nonneg : forall n, 0 <= sum_firstn limb_widths n.
+ Proof.
+ unfold sum_firstn; intros.
+ apply fold_right_invariant; try omega.
+ intros y In_y_lw ? ?.
+ apply Z.add_nonneg_nonneg; try assumption.
+ apply limb_widths_nonneg.
+ eapply In_firstn; eauto.
+ Qed.
+
+ Lemma k_nonneg : 0 <= k.
+ Proof.
+ rewrite <- k_matches_limb_widths.
+ apply sum_firstn_limb_widths_nonneg.
+ Qed.
+
+ Lemma nth_error_base : forall i, (i < length base)%nat ->
+ nth_error base i = Some (two_p (sum_firstn limb_widths i)).
+ Proof.
+ induction i; intros.
+ + unfold base, sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity.
+ intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega.
+ +
+ assert (i < length base)%nat as lt_i_length by omega.
+ specialize (IHi lt_i_length).
+ rewrite base_length in lt_i_length.
+ destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w].
+ erewrite base_from_limb_widths_step; eauto.
+ f_equal.
+ simpl.
+ destruct (NPeano.Nat.eq_dec i 0).
+ - subst; unfold sum_firstn; simpl.
+ apply nth_error_exists_first in nth_err_w.
+ destruct nth_err_w as [l' lw_destruct]; subst.
+ rewrite lw_destruct.
+ ring_simplify.
+ f_equal; simpl; ring.
+ - erewrite sum_firstn_succ; eauto.
+ symmetry.
+ apply two_p_is_exp; auto using sum_firstn_limb_widths_nonneg.
+ apply limb_widths_nonneg.
+ eapply nth_error_value_In; eauto.
+ Qed.
+
+ Lemma nth_default_base : forall d i, (i < length base)%nat ->
+ nth_default d base i = 2 ^ (sum_firstn limb_widths i).
+ Proof.
+ intros ? ? i_lt_length.
+ destruct (nth_error_length_exists_value _ _ i_lt_length) as [x nth_err_x].
+ unfold nth_default.
+ rewrite nth_err_x.
+ rewrite nth_error_base in nth_err_x by assumption.
+ rewrite two_p_correct in nth_err_x.
+ congruence.
+ Qed.
+
+
+ (* TODO : move to ZUtil *)
+ Lemma mod_same_pow : forall a b c, 0 <= c <= b -> a ^ b mod a ^ c = 0.
+ Proof.
+ intros.
+ replace b with (b - c + c) by ring.
+ rewrite Z.pow_add_r by omega.
+ apply Z_mod_mult.
+ Qed.
+
+ Lemma base_matches_modulus: forall i j,
+ (i < length base)%nat ->
+ (j < length base)%nat ->
+ (i+j >= length base)%nat->
+ let b := nth_default 0 base in
+ let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in
+ b i * b j = r * (2^k * b (i+j-length base)%nat).
+ Proof.
+ intros.
+ rewrite (Z.mul_comm r).
+ subst r.
+ assert (i + j - length base < length base)%nat by omega.
+ rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.mul_pos_pos;
+ [ | subst b; rewrite nth_default_base; try assumption ];
+ apply Z.pow_pos_nonneg; omega || apply k_nonneg || apply sum_firstn_limb_widths_nonneg).
+ rewrite (Zminus_0_l_reverse (b i * b j)) at 1.
+ f_equal.
+ subst b.
+ repeat rewrite nth_default_base by assumption.
+ do 2 rewrite <- Z.pow_add_r by (apply sum_firstn_limb_widths_nonneg || apply k_nonneg).
+ symmetry.
+ apply mod_same_pow.
+ split.
+ + apply Z.add_nonneg_nonneg; apply sum_firstn_limb_widths_nonneg || apply k_nonneg.
+ + rewrite base_length in *; apply limb_widths_match_modulus; assumption.
+ Qed.
+
+ Lemma base_succ : forall i, ((S i) < length base)%nat ->
+ nth_default 0 base (S i) mod nth_default 0 base i = 0.
+ Proof.
+ intros.
+ repeat rewrite nth_default_base by omega.
+ apply mod_same_pow.
+ split; [apply sum_firstn_limb_widths_nonneg | ].
+ destruct (NPeano.Nat.eq_dec i 0); subst.
+ + case_eq limb_widths; intro; unfold sum_firstn; simpl; try omega; intros l' lw_eq.
+ apply Z.add_nonneg_nonneg; try omega.
+ apply limb_widths_nonneg.
+ rewrite lw_eq.
+ apply in_eq.
+ + assert (i < length base)%nat as i_lt_length by omega.
+ rewrite base_length in *.
+ apply nth_error_length_exists_value in i_lt_length.
+ destruct i_lt_length as [x nth_err_x].
+ erewrite sum_firstn_succ; eauto.
+ apply nth_error_value_In in nth_err_x.
+ apply limb_widths_nonneg in nth_err_x.
+ omega.
+ Qed.
+
+ Lemma nth_error_subst : forall i b, nth_error base i = Some b ->
+ b = 2 ^ (sum_firstn limb_widths i).
+ Proof.
+ intros i b nth_err_b.
+ pose proof (nth_error_value_length _ _ _ _ nth_err_b).
+ rewrite nth_error_base in nth_err_b by assumption.
+ rewrite two_p_correct in nth_err_b.
+ congruence.
+ Qed.
+
+ Lemma base_positive : forall b : Z, In b base -> b > 0.
+ Proof.
+ intros b In_b_base.
+ apply In_nth_error_value in In_b_base.
+ destruct In_b_base as [i nth_err_b].
+ apply nth_error_subst in nth_err_b.
+ rewrite nth_err_b.
+ apply gt_lt_symmetry.
+ apply Z.pow_pos_nonneg; omega || apply sum_firstn_limb_widths_nonneg.
+ Qed.
+
+ Lemma b0_1 : forall x : Z, nth_default x base 0 = 1.
+ Proof.
+ unfold base; case_eq limb_widths; intros; [pose proof limb_widths_nonnil; congruence | reflexivity].
+ Qed.
+
+ Lemma base_good : forall i j : nat,
+ (i + j < length base)%nat ->
+ let b := nth_default 0 base in
+ let r := b i * b j / b (i + j)%nat in
+ b i * b j = r * b (i + j)%nat.
+ Proof.
+ intros; subst b r.
+ repeat rewrite nth_default_base by omega.
+ rewrite (Z.mul_comm _ (2 ^ (sum_firstn limb_widths (i+j)))).
+ rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply sum_firstn_limb_widths_nonneg).
+ rewrite <- Z.pow_add_r by apply sum_firstn_limb_widths_nonneg.
+ rewrite mod_same_pow; try ring.
+ split; [ apply sum_firstn_limb_widths_nonneg | ].
+ apply limb_widths_good.
+ rewrite <- base_length; assumption.
+ Qed.
+
+ Instance bv : BaseSystem.BaseVector base := {
+ base_positive := base_positive;
+ b0_1 := b0_1;
+ base_good := base_good
+ }.
+
+End PseudoMersenneBaseParamProofs.
diff --git a/src/ModularArithmetic/PseudoMersenneBaseParams.v b/src/ModularArithmetic/PseudoMersenneBaseParams.v
new file mode 100644
index 000000000..122cac0ab
--- /dev/null
+++ b/src/ModularArithmetic/PseudoMersenneBaseParams.v
@@ -0,0 +1,26 @@
+Require Import ZArith.
+Require Import List.
+Require Crypto.BaseSystem.
+Local Open Scope Z_scope.
+
+Definition sum_firstn l n := fold_right Z.add 0 (firstn n l).
+
+Class PseudoMersenneBaseParams (modulus : Z) := {
+ limb_widths : list Z;
+ limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w;
+ limb_widths_nonnil : limb_widths <> nil;
+ limb_widths_good : forall i j, (i + j < length limb_widths)%nat ->
+ sum_firstn limb_widths (i + j) <=
+ sum_firstn limb_widths i + sum_firstn limb_widths j;
+ k : Z;
+ c : Z;
+ modulus_pseudomersenne : modulus = 2^k - c;
+ prime_modulus : Znumtheory.prime modulus;
+ limb_widths_match_modulus : forall i j,
+ (i < length limb_widths)%nat ->
+ (j < length limb_widths)%nat ->
+ (i + j >= length limb_widths)%nat ->
+ let w_sum := sum_firstn limb_widths in
+ k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j;
+ k_matches_limb_widths : sum_firstn limb_widths (length limb_widths) = k
+}.
diff --git a/src/ModularArithmetic/PseudoMersenneBaseRep.v b/src/ModularArithmetic/PseudoMersenneBaseRep.v
new file mode 100644
index 000000000..2cc12b933
--- /dev/null
+++ b/src/ModularArithmetic/PseudoMersenneBaseRep.v
@@ -0,0 +1,43 @@
+Require Import ZArith.
+Require Crypto.BaseSystem.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs Crypto.ModularArithmetic.PseudoMersenneBaseParams.
+Local Open Scope Z_scope.
+
+Class RepZMod (modulus : Z) := {
+ T : Type;
+ encode : F modulus -> T;
+ decode : T -> F modulus;
+
+ rep : T -> F modulus -> Prop;
+ encode_rep : forall x, rep (encode x) x;
+ rep_decode : forall u x, rep u x -> decode u = x;
+
+ add : T -> T -> T;
+ add_rep : forall u v x y, rep u x -> rep v y -> rep (add u v) (x+y)%F;
+
+ sub : T -> T -> T;
+ sub_rep : forall u v x y, rep u x -> rep v y -> rep (sub u v) (x-y)%F;
+
+ mul : T -> T -> T;
+ mul_rep : forall u v x y, rep u x -> rep v y -> rep (mul u v) (x*y)%F
+}.
+
+Instance PseudoMersenneBase m (prm : PseudoMersenneBaseParams m) : RepZMod m := {
+ T := list Z;
+ encode := ModularBaseSystem.encode;
+ decode := ModularBaseSystem.decode;
+
+ rep := ModularBaseSystem.rep;
+ encode_rep := ModularBaseSystemProofs.encode_rep;
+ rep_decode := ModularBaseSystemProofs.rep_decode;
+
+ add := BaseSystem.add;
+ add_rep := ModularBaseSystemProofs.add_rep;
+
+ sub := BaseSystem.sub;
+ sub_rep := ModularBaseSystemProofs.sub_rep;
+
+ mul := ModularBaseSystem.mul;
+ mul_rep := ModularBaseSystemProofs.mul_rep
+}.