aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--_CoqProject17
-rwxr-xr-xetc/freshen-bedrock-files.sh2
-rw-r--r--src/BaseSystem.v550
-rw-r--r--src/BaseSystemProofs.v503
-rw-r--r--src/BoundedIterOp.v94
-rw-r--r--src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v381
-rw-r--r--src/CompleteEdwardsCurve/DoubleAndAdd.v67
-rw-r--r--src/CompleteEdwardsCurve/ExtendedCoordinates.v121
-rw-r--r--src/EdDSAProofs.v57
-rw-r--r--src/Encoding/EncodingTheorems.v2
-rw-r--r--src/Encoding/ModularWordEncodingPre.v45
-rw-r--r--src/Encoding/ModularWordEncodingTheorems.v54
-rw-r--r--src/Encoding/PointEncodingPre.v275
-rw-r--r--src/Encoding/PointEncodingTheorems.v207
-rw-r--r--src/ModularArithmetic/ExtendedBaseVector.v162
-rw-r--r--src/ModularArithmetic/ModularArithmeticTheorems.v44
-rw-r--r--src/ModularArithmetic/ModularBaseSystem.v678
-rw-r--r--src/ModularArithmetic/ModularBaseSystemOpt.v463
-rw-r--r--src/ModularArithmetic/ModularBaseSystemProofs.v1179
-rw-r--r--src/ModularArithmetic/PrimeFieldTheorems.v185
-rw-r--r--src/ModularArithmetic/PseudoMersenneBaseParamProofs.v246
-rw-r--r--src/ModularArithmetic/PseudoMersenneBaseParams.v24
-rw-r--r--src/ModularArithmetic/PseudoMersenneBaseRep.v50
-rw-r--r--src/Rep.v13
-rw-r--r--src/Spec/CompleteEdwardsCurve.v70
-rw-r--r--src/Spec/Ed25519.v19
-rw-r--r--src/Spec/EdDSA.v22
-rw-r--r--src/Spec/Encoding.v61
-rw-r--r--src/Spec/ModularWordEncoding.v40
-rw-r--r--src/Spec/PointEncoding.v200
-rw-r--r--src/Specific/Ed25519.v598
-rw-r--r--src/Specific/GF1305.v74
-rw-r--r--src/Specific/GF25519.v629
-rw-r--r--src/Testbit.v212
-rw-r--r--src/Util/CaseUtil.v6
-rw-r--r--src/Util/IterAssocOp.v76
-rw-r--r--src/Util/ListUtil.v49
-rw-r--r--src/Util/NatUtil.v11
-rw-r--r--src/Util/Tactics.v25
-rw-r--r--src/Util/WordUtil.v34
-rw-r--r--src/Util/ZUtil.v228
-rw-r--r--to_gallina.md7
42 files changed, 5414 insertions, 2366 deletions
diff --git a/_CoqProject b/_CoqProject
index d4a8857a1..bb92c783a 100644
--- a/_CoqProject
+++ b/_CoqProject
@@ -1,26 +1,40 @@
-R src Crypto
src/BaseSystem.v
-src/BoundedIterOp.v
+src/BaseSystemProofs.v
src/EdDSAProofs.v
+src/Rep.v
+src/Testbit.v
src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v
src/CompleteEdwardsCurve/DoubleAndAdd.v
src/CompleteEdwardsCurve/ExtendedCoordinates.v
src/CompleteEdwardsCurve/Pre.v
src/Encoding/EncodingTheorems.v
+src/Encoding/ModularWordEncodingPre.v
+src/Encoding/ModularWordEncodingTheorems.v
+src/Encoding/PointEncodingPre.v
+src/Encoding/PointEncodingTheorems.v
+src/ModularArithmetic/ExtendedBaseVector.v
src/ModularArithmetic/FField.v
src/ModularArithmetic/FNsatz.v
src/ModularArithmetic/ModularArithmeticTheorems.v
src/ModularArithmetic/ModularBaseSystem.v
+src/ModularArithmetic/ModularBaseSystemOpt.v
+src/ModularArithmetic/ModularBaseSystemProofs.v
src/ModularArithmetic/Pre.v
src/ModularArithmetic/PrimeFieldTheorems.v
+src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
+src/ModularArithmetic/PseudoMersenneBaseParams.v
+src/ModularArithmetic/PseudoMersenneBaseRep.v
src/ModularArithmetic/Tutorial.v
src/Spec/CompleteEdwardsCurve.v
src/Spec/Ed25519.v
src/Spec/EdDSA.v
src/Spec/Encoding.v
src/Spec/ModularArithmetic.v
+src/Spec/ModularWordEncoding.v
src/Spec/PointEncoding.v
src/Specific/Ed25519.v
+src/Specific/GF1305.v
src/Specific/GF25519.v
src/Tactics/VerdiTactics.v
src/Util/CaseUtil.v
@@ -28,6 +42,7 @@ src/Util/IterAssocOp.v
src/Util/ListUtil.v
src/Util/NatUtil.v
src/Util/NumTheoryUtil.v
+src/Util/Tactics.v
src/Util/WordUtil.v
src/Util/ZUtil.v
src/Assembly/QhasmCommon.v
diff --git a/etc/freshen-bedrock-files.sh b/etc/freshen-bedrock-files.sh
index 08e0435d7..c2daf7e87 100755
--- a/etc/freshen-bedrock-files.sh
+++ b/etc/freshen-bedrock-files.sh
@@ -42,10 +42,10 @@ for VFILE in $FILES; do
break
fi
done
- )
if [ -z "$FOUND" ]; then
echo "WARNING: Could not find $VOFILE, which $VFILE depends on"
fi
+ )
done
done
diff --git a/src/BaseSystem.v b/src/BaseSystem.v
index 4e07c4564..e6ad55f18 100644
--- a/src/BaseSystem.v
+++ b/src/BaseSystem.v
@@ -5,19 +5,18 @@ Require Import Coq.omega.Omega Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith.
Local Open Scope Z.
-Module Type BaseCoefs.
- (** [BaseCoefs] represent the weights of each digit in a positional number system, with the weight of least significant digit presented first. The following requirements on the base are preconditions for using it with BaseSystem. *)
- Parameter base : list Z.
- Axiom base_positive : forall b, In b base -> b > 0. (* nonzero would probably work too... *)
- Axiom b0_1 : forall x, nth_default x base 0 = 1.
- Axiom base_good :
+Class BaseVector (base : list Z):= {
+ base_positive : forall b, In b base -> b > 0; (* nonzero would probably work too... *)
+ b0_1 : forall x, nth_default x base 0 = 1;
+ base_good :
forall i j, (i+j < length base)%nat ->
let b := nth_default 0 base in
let r := (b i * b j) / b (i+j)%nat in
- b i * b j = r * b (i+j)%nat.
-End BaseCoefs.
+ b i * b j = r * b (i+j)%nat
+}.
-Module BaseSystem (Import B:BaseCoefs).
+Section BaseSystem.
+ Context (base : list Z).
(** [BaseSystem] implements an constrained positional number system.
A wide variety of bases are supported: the base coefficients are not
required to be powers of 2, and it is NOT necessarily the case that
@@ -33,7 +32,6 @@ Module BaseSystem (Import B:BaseCoefs).
Definition accumulate p acc := fst p * snd p + acc.
Definition decode' bs u := fold_right accumulate 0 (combine u bs).
Definition decode := decode' base.
- Hint Unfold accumulate.
(* Does not carry; z becomes the lowest and only digit. *)
Definition encode (z : Z) := z :: nil.
@@ -52,50 +50,7 @@ Module BaseSystem (Import B:BaseCoefs).
Hint Extern 1 (@eq Z _ _) => ring.
- Lemma add_rep : forall bs us vs, decode' bs (add us vs) = decode' bs us + decode' bs vs.
- Proof.
- unfold decode, decode'; induction bs; destruct us; destruct vs; boring.
- Qed.
-
- Lemma decode_nil : forall bs, decode' bs nil = 0.
- auto.
- Qed.
- Hint Rewrite decode_nil.
-
- Lemma decode_base_nil : forall us, decode' nil us = 0.
- Proof.
- intros; rewrite decode'_truncate; auto.
- Qed.
- Hint Rewrite decode_base_nil.
-
Definition mul_each u := map (Z.mul u).
- Lemma mul_each_rep : forall bs u vs,
- decode' bs (mul_each u vs) = u * decode' bs vs.
- Proof.
- unfold decode'; induction bs; destruct vs; boring.
- Qed.
-
- Lemma base_eq_1cons: base = 1 :: skipn 1 base.
- Proof.
- pose proof (b0_1 0) as H.
- destruct base; compute in H; try discriminate; boring.
- Qed.
-
- Lemma decode'_cons : forall x1 x2 xs1 xs2,
- decode' (x1 :: xs1) (x2 :: xs2) = x1 * x2 + decode' xs1 xs2.
- Proof.
- unfold decode'; boring.
- Qed.
- Hint Rewrite decode'_cons.
-
- Lemma decode_cons : forall x us,
- decode (x :: us) = x + decode (0 :: us).
- Proof.
- unfold decode; intros.
- rewrite base_eq_1cons.
- autorewrite with core; ring_simplify; auto.
- Qed.
-
Fixpoint sub (us vs:digits) : digits :=
match us,vs with
| u::us', v::vs' => u-v :: sub us' vs'
@@ -103,76 +58,13 @@ Module BaseSystem (Import B:BaseCoefs).
| nil, v::vs' => (0-v)::sub nil vs'
end.
- Lemma sub_rep : forall bs us vs, decode' bs (sub us vs) = decode' bs us - decode' bs vs.
- Proof.
- induction bs; destruct us; destruct vs; boring.
- Qed.
-
- Lemma encode_rep : forall z, decode (encode z) = z.
- Proof.
- pose proof base_eq_1cons.
- unfold decode, encode; destruct z; boring.
- Qed.
-
- Lemma mul_each_base : forall us bs c,
- decode' bs (mul_each c us) = decode' (mul_each c bs) us.
- Proof.
- induction us; destruct bs; boring.
- Qed.
-
- Hint Rewrite (@nth_default_nil Z).
- Hint Rewrite (@firstn_nil Z).
- Hint Rewrite (@skipn_nil Z).
-
- Lemma base_app : forall us low high,
- decode' (low ++ high) us = decode' low (firstn (length low) us) + decode' high (skipn (length low) us).
- Proof.
- induction us; destruct low; boring.
- Qed.
-
- Lemma base_mul_app : forall low c us,
- decode' (low ++ mul_each c low) us = decode' low (firstn (length low) us) +
- c * decode' low (skipn (length low) us).
- Proof.
- intros.
- rewrite base_app; f_equal.
- rewrite <- mul_each_rep.
- rewrite mul_each_base.
- reflexivity.
- Qed.
-
Definition crosscoef i j : Z :=
let b := nth_default 0 base in
(b(i) * b(j)) / b(i+j)%nat.
Hint Unfold crosscoef.
Fixpoint zeros n := match n with O => nil | S n' => 0::zeros n' end.
- Lemma zeros_rep : forall bs n, decode' bs (zeros n) = 0.
- induction bs; destruct n; boring.
- Qed.
- Lemma length_zeros : forall n, length (zeros n) = n.
- induction n; boring.
- Qed.
- Hint Rewrite length_zeros.
-
- Lemma app_zeros_zeros : forall n m, zeros n ++ zeros m = zeros (n + m).
- Proof.
- induction n; boring.
- Qed.
- Hint Rewrite app_zeros_zeros.
-
- Lemma zeros_app0 : forall m, zeros m ++ 0 :: nil = zeros (S m).
- Proof.
- induction m; boring.
- Qed.
- Hint Rewrite zeros_app0.
-
- Lemma rev_zeros : forall n, rev (zeros n) = zeros n.
- Proof.
- induction n; boring.
- Qed.
- Hint Rewrite rev_zeros.
-
+
(* mul' is multiplication with the SECOND ARGUMENT REVERSED and OUTPUT REVERSED *)
Fixpoint mul_bi' (i:nat) (vsr:digits) :=
match vsr with
@@ -181,311 +73,7 @@ Module BaseSystem (Import B:BaseCoefs).
end.
Definition mul_bi (i:nat) (vs:digits) : digits :=
zeros i ++ rev (mul_bi' i (rev vs)).
-
- Hint Unfold nth_default.
-
- Lemma decode_single : forall n bs x,
- decode' bs (zeros n ++ x :: nil) = nth_default 0 bs n * x.
- Proof.
- induction n; destruct bs; boring.
- Qed.
- Hint Rewrite decode_single.
-
- Lemma peel_decode : forall xs ys x y, decode' (x::xs) (y::ys) = x*y + decode' xs ys.
- Proof.
- boring.
- Qed.
- Hint Rewrite zeros_rep peel_decode.
-
- Lemma decode_highzeros : forall xs bs n, decode' bs (xs ++ zeros n) = decode' bs xs.
- Proof.
- induction xs; destruct bs; boring.
- Qed.
-
- Lemma mul_bi'_zeros : forall n m, mul_bi' n (zeros m) = zeros m.
- induction m; boring.
- Qed.
- Hint Rewrite mul_bi'_zeros.
-
- Lemma nth_error_base_nonzero : forall n x,
- nth_error base n = Some x -> x <> 0.
- Proof.
- eauto using (@nth_error_value_In Z), Zgt0_neq0, base_positive.
- Qed.
-
- Hint Rewrite plus_0_r.
-
- Lemma mul_bi_single : forall m n x,
- (n + m < length base)%nat ->
- decode (mul_bi n (zeros m ++ x :: nil)) = nth_default 0 base m * x * nth_default 0 base n.
- Proof.
- unfold mul_bi, decode.
- destruct m; simpl; simpl_list; simpl; intros. {
- pose proof nth_error_base_nonzero.
- boring; destruct base; nth_tac.
- rewrite Z_div_mul'; eauto.
- } {
- ssimpl_list.
- autorewrite with core.
- rewrite app_assoc.
- autorewrite with core.
- unfold crosscoef; simpl; ring_simplify.
- rewrite Nat.add_1_r.
- rewrite base_good by auto.
- rewrite Z_div_mult by (apply base_positive; rewrite nth_default_eq; apply nth_In; auto).
- rewrite <- Z.mul_assoc.
- rewrite <- Z.mul_comm.
- rewrite <- Z.mul_assoc.
- rewrite <- Z.mul_assoc.
- destruct (Z.eq_dec x 0); subst; try ring.
- rewrite Z.mul_cancel_l by auto.
- rewrite <- base_good by auto.
- ring.
- }
- Qed.
-
- Lemma set_higher' : forall vs x, vs++x::nil = vs .+ (zeros (length vs) ++ x :: nil).
- induction vs; boring; f_equal; ring.
- Qed.
-
- Lemma set_higher : forall bs vs x,
- decode' bs (vs++x::nil) = decode' bs vs + nth_default 0 bs (length vs) * x.
- Proof.
- intros.
- rewrite set_higher'.
- rewrite add_rep.
- f_equal.
- apply decode_single.
- Qed.
-
- Lemma zeros_plus_zeros : forall n, zeros n = zeros n .+ zeros n.
- induction n; auto.
- simpl; f_equal; auto.
- Qed.
-
- Lemma mul_bi'_n_nil : forall n, mul_bi' n nil = nil.
- Proof.
- unfold mul_bi; auto.
- Qed.
- Hint Rewrite mul_bi'_n_nil.
-
- Lemma add_nil_l : forall us, nil .+ us = us.
- induction us; auto.
- Qed.
- Hint Rewrite add_nil_l.
-
- Lemma add_nil_r : forall us, us .+ nil = us.
- induction us; auto.
- Qed.
- Hint Rewrite add_nil_r.
- Lemma add_first_terms : forall us vs a b,
- (a :: us) .+ (b :: vs) = (a + b) :: (us .+ vs).
- auto.
- Qed.
- Hint Rewrite add_first_terms.
-
- Lemma mul_bi'_cons : forall n x us,
- mul_bi' n (x :: us) = x * crosscoef n (length us) :: mul_bi' n us.
- Proof.
- unfold mul_bi'; auto.
- Qed.
-
- Lemma add_same_length : forall us vs l, (length us = l) -> (length vs = l) ->
- length (us .+ vs) = l.
- Proof.
- induction us, vs; boring.
- erewrite (IHus vs (pred l)); boring.
- Qed.
-
- Hint Rewrite app_nil_l.
- Hint Rewrite app_nil_r.
-
- Lemma add_snoc_same_length : forall l us vs a b,
- (length us = l) -> (length vs = l) ->
- (us ++ a :: nil) .+ (vs ++ b :: nil) = (us .+ vs) ++ (a + b) :: nil.
- Proof.
- induction l, us, vs; boring; discriminate.
- Qed.
-
- Lemma mul_bi'_add : forall us n vs l
- (Hlus: length us = l)
- (Hlvs: length vs = l),
- mul_bi' n (rev (us .+ vs)) =
- mul_bi' n (rev us) .+ mul_bi' n (rev vs).
- Proof.
- (* TODO(adamc): please help prettify this *)
- induction us using rev_ind;
- try solve [destruct vs; boring; congruence].
- destruct vs using rev_ind; boring; clear IHvs; simpl_list.
- erewrite (add_snoc_same_length (pred l) us vs _ _); simpl_list.
- repeat rewrite mul_bi'_cons; rewrite add_first_terms; simpl_list.
- rewrite (IHus n vs (pred l)).
- replace (length us) with (pred l).
- replace (length vs) with (pred l).
- rewrite (add_same_length us vs (pred l)).
- f_equal; ring.
-
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- erewrite length_snoc; eauto.
- Qed.
-
- Lemma zeros_cons0 : forall n, 0 :: zeros n = zeros (S n).
- auto.
- Qed.
-
- Lemma add_leading_zeros : forall n us vs,
- (zeros n ++ us) .+ (zeros n ++ vs) = zeros n ++ (us .+ vs).
- Proof.
- induction n; boring.
- Qed.
-
- Lemma rev_add_rev : forall us vs l, (length us = l) -> (length vs = l) ->
- (rev us) .+ (rev vs) = rev (us .+ vs).
- Proof.
- induction us, vs; boring; try solve [subst; discriminate].
- rewrite (add_snoc_same_length (pred l) _ _ _ _) by (subst; simpl_list; omega).
- rewrite (IHus vs (pred l)) by omega; auto.
- Qed.
- Hint Rewrite rev_add_rev.
-
- Lemma mul_bi'_length : forall us n, length (mul_bi' n us) = length us.
- Proof.
- induction us, n; boring.
- Qed.
- Hint Rewrite mul_bi'_length.
-
- Lemma add_comm : forall us vs, us .+ vs = vs .+ us.
- Proof.
- induction us, vs; boring; f_equal; auto.
- Qed.
-
- Hint Rewrite rev_length.
-
- Lemma mul_bi_add_same_length : forall n us vs l,
- (length us = l) -> (length vs = l) ->
- mul_bi n (us .+ vs) = mul_bi n us .+ mul_bi n vs.
- Proof.
- unfold mul_bi; boring.
- rewrite add_leading_zeros.
- erewrite mul_bi'_add; boring.
- erewrite rev_add_rev; boring.
- Qed.
-
- Lemma add_zeros_same_length : forall us, us .+ (zeros (length us)) = us.
- Proof.
- induction us; boring; f_equal; omega.
- Qed.
-
- Hint Rewrite add_zeros_same_length.
- Hint Rewrite minus_diag.
-
- Lemma add_trailing_zeros : forall us vs, (length us >= length vs)%nat ->
- us .+ vs = us .+ (vs ++ (zeros (length us - length vs))).
- Proof.
- induction us, vs; boring; f_equal; boring.
- Qed.
-
- Lemma length_add_ge : forall us vs, (length us >= length vs)%nat ->
- (length (us .+ vs) <= length us)%nat.
- Proof.
- intros.
- rewrite add_trailing_zeros by trivial.
- erewrite add_same_length by (pose proof app_length; boring); omega.
- Qed.
-
- Lemma add_length_le_max : forall us vs,
- (length (us .+ vs) <= max (length us) (length vs))%nat.
- Proof.
- intros; case_max; (rewrite add_comm; apply length_add_ge; omega) ||
- (apply length_add_ge; omega) .
- Qed.
-
- Lemma sub_nil_length: forall us : digits, length (sub nil us) = length us.
- Proof.
- induction us; boring.
- Qed.
-
- Lemma sub_length_le_max : forall us vs,
- (length (sub us vs) <= max (length us) (length vs))%nat.
- Proof.
- induction us, vs; boring.
- rewrite sub_nil_length; auto.
- Qed.
-
- Lemma mul_bi_length : forall us n, length (mul_bi n us) = (length us + n)%nat.
- Proof.
- pose proof mul_bi'_length; unfold mul_bi.
- destruct us; repeat progress (simpl_list; boring).
- Qed.
- Hint Rewrite mul_bi_length.
-
- Lemma mul_bi_trailing_zeros : forall m n us,
- mul_bi n us ++ zeros m = mul_bi n (us ++ zeros m).
- Proof.
- unfold mul_bi.
- induction m; intros; try solve [boring].
- rewrite <- zeros_app0.
- rewrite app_assoc.
- repeat progress (boring; rewrite rev_app_distr).
- Qed.
-
- Lemma mul_bi_add_longer : forall n us vs,
- (length us >= length vs)%nat ->
- mul_bi n (us .+ vs) = mul_bi n us .+ mul_bi n vs.
- Proof.
- boring.
- rewrite add_trailing_zeros by auto.
- rewrite (add_trailing_zeros (mul_bi n us) (mul_bi n vs))
- by (repeat (rewrite mul_bi_length); omega).
- erewrite mul_bi_add_same_length by
- (eauto; simpl_list; rewrite length_zeros; omega).
- rewrite mul_bi_trailing_zeros.
- repeat (f_equal; boring).
- Qed.
-
- Lemma mul_bi_add : forall n us vs,
- mul_bi n (us .+ vs) = (mul_bi n us) .+ (mul_bi n vs).
- Proof.
- intros; pose proof mul_bi_add_longer.
- destruct (le_ge_dec (length us) (length vs)). {
- replace (mul_bi n us .+ mul_bi n vs)
- with (mul_bi n vs .+ mul_bi n us)
- by (apply add_comm).
- replace (us .+ vs)
- with (vs .+ us)
- by (apply add_comm).
- boring.
- } {
- boring.
- }
- Qed.
-
- Lemma mul_bi_rep : forall i vs,
- (i + length vs < length base)%nat ->
- decode (mul_bi i vs) = decode vs * nth_default 0 base i.
- Proof.
- unfold decode.
- induction vs using rev_ind; intros; try solve [unfold mul_bi; boring].
- assert (i + length vs < length base)%nat by
- (rewrite app_length in *; boring).
-
- rewrite set_higher.
- ring_simplify.
- rewrite <- IHvs by auto; clear IHvs.
- rewrite <- mul_bi_single by auto.
- rewrite <- add_rep.
- rewrite <- mul_bi_add.
- rewrite set_higher'.
- auto.
- Qed.
-
(* mul' is multiplication with the FIRST ARGUMENT REVERSED *)
Fixpoint mul' (usr vs:digits) : digits :=
match usr with
@@ -495,66 +83,17 @@ Module BaseSystem (Import B:BaseCoefs).
end.
Definition mul us := mul' (rev us).
- Lemma mul'_rep : forall us vs,
- (length us + length vs <= length base)%nat ->
- decode (mul' (rev us) vs) = decode us * decode vs.
- Proof.
- unfold decode.
- induction us using rev_ind; boring.
-
- assert (length us + length vs < length base)%nat by
- (rewrite app_length in *; boring).
-
- ssimpl_list.
- rewrite add_rep.
- boring.
- rewrite set_higher.
- rewrite mul_each_rep.
- rewrite mul_bi_rep by auto.
- unfold decode; ring.
- Qed.
-
- Lemma mul_rep : forall us vs,
- (length us + length vs <= length base)%nat ->
- decode (mul us vs) = decode us * decode vs.
- Proof.
- exact mul'_rep.
- Qed.
-
- Lemma mul'_length: forall us vs,
- (length (mul' us vs) <= length us + length vs)%nat.
- Proof.
- pose proof add_length_le_max.
- induction us; boring.
- unfold mul_each.
- simpl_list; case_max; boring; omega.
- Qed.
-
- Lemma mul_length: forall us vs,
- (length (mul us vs) <= length us + length vs)%nat.
- Proof.
- intros; unfold mul.
- rewrite mul'_length.
- rewrite rev_length; omega.
- Qed.
-
-(* Print Assumptions mul_rep. *)
-
End BaseSystem.
-Module Type PolynomialBaseParams.
- Parameter b1 : positive. (* the value at which the polynomial is evaluated *)
- Parameter baseLength : nat. (* 1 + degree of the polynomial *)
- Axiom baseLengthNonzero : NPeano.ltb 0 baseLength = true.
-End PolynomialBaseParams.
-
-Module PolynomialBaseCoefs (Import P:PolynomialBaseParams) <: BaseCoefs.
- (** PolynomialBaseCoeffs generates base vectors for [BaseSystem] using the extra assumption that $b_{i+j} = b_j b_j$. *)
+(* Example : polynomial base system *)
+Section PolynomialBaseCoefs.
+ Context (b1 : positive) (baseLength : nat) (baseLengthNonzero : NPeano.ltb 0 baseLength = true).
+ (** PolynomialBaseCoefs generates base vectors for [BaseSystem]. *)
Definition bi i := (Zpos b1)^(Z.of_nat i).
- Definition base := map bi (seq 0 baseLength).
+ Definition poly_base := map bi (seq 0 baseLength).
- Lemma b0_1 : forall x, nth_default x base 0 = 1.
- unfold base, bi, nth_default.
+ Lemma poly_b0_1 : forall x, nth_default x poly_base 0 = 1.
+ unfold poly_base, bi, nth_default.
case_eq baseLength; intros. {
assert ((0 < baseLength)%nat) by
(rewrite <-NPeano.ltb_lt; apply baseLengthNonzero).
@@ -563,9 +102,9 @@ Module PolynomialBaseCoefs (Import P:PolynomialBaseParams) <: BaseCoefs.
auto.
Qed.
- Lemma base_positive : forall b, In b base -> b > 0.
+ Lemma poly_base_positive : forall b, In b poly_base -> b > 0.
Proof.
- unfold base.
+ unfold poly_base.
intros until 0; intro H.
rewrite in_map_iff in *.
destruct H; destruct H.
@@ -573,20 +112,20 @@ Module PolynomialBaseCoefs (Import P:PolynomialBaseParams) <: BaseCoefs.
apply pos_pow_nat_pos.
Qed.
- Lemma base_defn : forall i, (i < length base)%nat ->
- nth_default 0 base i = bi i.
+ Lemma poly_base_defn : forall i, (i < length poly_base)%nat ->
+ nth_default 0 poly_base i = bi i.
Proof.
- unfold base, nth_default; nth_tac.
+ unfold poly_base, nth_default; nth_tac.
Qed.
- Lemma base_succ :
- forall i, ((S i) < length base)%nat ->
- let b := nth_default 0 base in
+ Lemma poly_base_succ :
+ forall i, ((S i) < length poly_base)%nat ->
+ let b := nth_default 0 poly_base in
let r := (b (S i) / b i) in
b (S i) = r * b i.
Proof.
intros; subst b; subst r.
- repeat rewrite base_defn in * by omega.
+ repeat rewrite poly_base_defn in * by omega.
unfold bi.
replace (Z.pos b1 ^ Z.of_nat (S i))
with (Z.pos b1 * (Z.pos b1 ^ Z.of_nat i)) by
@@ -598,13 +137,13 @@ Module PolynomialBaseCoefs (Import P:PolynomialBaseParams) <: BaseCoefs.
pose proof (Zgt_pos_0 b1); omega.
Qed.
- Lemma base_good:
- forall i j, (i + j < length base)%nat ->
- let b := nth_default 0 base in
+ Lemma poly_base_good:
+ forall i j, (i + j < length poly_base)%nat ->
+ let b := nth_default 0 poly_base in
let r := (b i * b j) / b (i+j)%nat in
b i * b j = r * b (i+j)%nat.
Proof.
- unfold base, nth_default; nth_tac.
+ unfold poly_base, nth_default; nth_tac.
clear.
unfold bi.
@@ -614,24 +153,25 @@ Module PolynomialBaseCoefs (Import P:PolynomialBaseParams) <: BaseCoefs.
rewrite <- Z.neq_mul_0.
split; apply Z.pow_nonzero; try apply Zle_0_nat; try solve [intro H; inversion H].
Qed.
+
+ Instance PolyBaseVector : BaseVector poly_base := {
+ base_positive := poly_base_positive;
+ b0_1 := poly_b0_1;
+ base_good := poly_base_good
+ }.
+
End PolynomialBaseCoefs.
-Module BasePoly2Degree32Params <: PolynomialBaseParams.
- Definition b1 := 2%positive.
+Import ListNotations.
+
+Section BaseSystemExample.
Definition baseLength := 32%nat.
Lemma baseLengthNonzero : NPeano.ltb 0 baseLength = true.
compute; reflexivity.
Qed.
-End BasePoly2Degree32Params.
-
-Import ListNotations.
-
-Module BaseSystemExample.
- Module BasePoly2Degree32Coefs := PolynomialBaseCoefs BasePoly2Degree32Params.
- Module BasePoly2Degree32 := BaseSystem BasePoly2Degree32Coefs.
- Import BasePoly2Degree32.
+ Definition base2 := poly_base 2 baseLength.
- Example three_times_two : mul [1;1;0] [0;1;0] = [0;1;1;0;0].
+ Example three_times_two : mul base2 [1;1;0] [0;1;0] = [0;1;1;0;0].
Proof.
reflexivity.
Qed.
@@ -639,20 +179,20 @@ Module BaseSystemExample.
(* python -c "e = lambda x: '['+''.join(reversed(bin(x)[2:])).replace('1','1;').replace('0','0;')[:-1]+']'; print(e(19259)); print(e(41781))" *)
Definition a := [1;1;0;1;1;1;0;0;1;1;0;1;0;0;1].
Definition b := [1;0;1;0;1;1;0;0;1;1;0;0;0;1;0;1].
- Example da : decode a = 19259.
+ Example da : decode base2 a = 19259.
Proof.
reflexivity.
Qed.
- Example db : decode b = 41781.
+ Example db : decode base2 b = 41781.
Proof.
reflexivity.
Qed.
Example encoded_ab :
- mul a b =[1;1;1;2;2;4;2;2;4;5;3;3;3;6;4;2;5;3;4;3;2;1;2;2;2;0;1;1;0;1].
+ mul base2 a b =[1;1;1;2;2;4;2;2;4;5;3;3;3;6;4;2;5;3;4;3;2;1;2;2;2;0;1;1;0;1].
Proof.
reflexivity.
Qed.
- Example dab : decode (mul a b) = 804660279.
+ Example dab : decode base2 (mul base2 a b) = 804660279.
Proof.
reflexivity.
Qed.
diff --git a/src/BaseSystemProofs.v b/src/BaseSystemProofs.v
new file mode 100644
index 000000000..ab56cb711
--- /dev/null
+++ b/src/BaseSystemProofs.v
@@ -0,0 +1,503 @@
+Require Import List.
+Require Import Util.ListUtil Util.CaseUtil Util.ZUtil.
+Require Import ZArith.ZArith ZArith.Zdiv.
+Require Import Omega NPeano Arith.
+Require Import Crypto.BaseSystem.
+Local Open Scope Z.
+
+Local Infix ".+" := add (at level 50).
+
+Local Hint Extern 1 (@eq Z _ _) => ring.
+
+Section BaseSystemProofs.
+ Context `(base_vector : BaseVector).
+
+ Lemma decode'_truncate : forall bs us, decode' bs us = decode' bs (firstn (length bs) us).
+ Proof.
+ unfold decode'; intros; f_equal; apply combine_truncate_l.
+ Qed.
+
+ Lemma decode'_splice : forall xs ys bs,
+ decode' bs (xs ++ ys) =
+ decode' (firstn (length xs) bs) xs + decode' (skipn (length xs) bs) ys.
+ Proof.
+ unfold decode'.
+ induction xs; destruct ys, bs; boring.
+ + rewrite combine_truncate_r.
+ do 2 rewrite Z.add_0_r; auto.
+ + unfold accumulate.
+ apply Z.add_assoc.
+ Qed.
+
+ Lemma add_rep : forall bs us vs, decode' bs (add us vs) = decode' bs us + decode' bs vs.
+ Proof.
+ unfold decode', accumulate; induction bs; destruct us, vs; boring; ring.
+ Qed.
+
+ Lemma decode_nil : forall bs, decode' bs nil = 0.
+ auto.
+ Qed.
+ Hint Rewrite decode_nil.
+
+ Lemma decode_base_nil : forall us, decode' nil us = 0.
+ Proof.
+ intros; rewrite decode'_truncate; auto.
+ Qed.
+ Hint Rewrite decode_base_nil.
+
+ Lemma mul_each_rep : forall bs u vs,
+ decode' bs (mul_each u vs) = u * decode' bs vs.
+ Proof.
+ unfold decode', accumulate; induction bs; destruct vs; boring; ring.
+ Qed.
+
+ Lemma base_eq_1cons: base = 1 :: skipn 1 base.
+ Proof.
+ pose proof (b0_1 0) as H.
+ destruct base; compute in H; try discriminate; boring.
+ Qed.
+
+ Lemma decode'_cons : forall x1 x2 xs1 xs2,
+ decode' (x1 :: xs1) (x2 :: xs2) = x1 * x2 + decode' xs1 xs2.
+ Proof.
+ unfold decode', accumulate; boring; ring.
+ Qed.
+ Hint Rewrite decode'_cons.
+
+ Lemma decode_cons : forall x us,
+ decode base (x :: us) = x + decode base (0 :: us).
+ Proof.
+ unfold decode; intros.
+ rewrite base_eq_1cons.
+ autorewrite with core; ring_simplify; auto.
+ Qed.
+
+ Lemma sub_rep : forall bs us vs, decode' bs (sub us vs) = decode' bs us - decode' bs vs.
+ Proof.
+ induction bs; destruct us; destruct vs; boring; ring.
+ Qed.
+
+ Lemma encode_rep : forall z, decode base (encode z) = z.
+ Proof.
+ pose proof base_eq_1cons.
+ unfold decode, encode; destruct z; boring.
+ Qed.
+
+ Lemma mul_each_base : forall us bs c,
+ decode' bs (mul_each c us) = decode' (mul_each c bs) us.
+ Proof.
+ induction us; destruct bs; boring; ring.
+ Qed.
+
+ Hint Rewrite (@nth_default_nil Z).
+ Hint Rewrite (@firstn_nil Z).
+ Hint Rewrite (@skipn_nil Z).
+
+ Lemma base_app : forall us low high,
+ decode' (low ++ high) us = decode' low (firstn (length low) us) + decode' high (skipn (length low) us).
+ Proof.
+ induction us; destruct low; boring.
+ Qed.
+
+ Lemma base_mul_app : forall low c us,
+ decode' (low ++ mul_each c low) us = decode' low (firstn (length low) us) +
+ c * decode' low (skipn (length low) us).
+ Proof.
+ intros.
+ rewrite base_app; f_equal.
+ rewrite <- mul_each_rep.
+ rewrite mul_each_base.
+ reflexivity.
+ Qed.
+
+ Lemma zeros_rep : forall bs n, decode' bs (zeros n) = 0.
+ induction bs; destruct n; boring.
+ Qed.
+ Lemma length_zeros : forall n, length (zeros n) = n.
+ induction n; boring.
+ Qed.
+ Hint Rewrite length_zeros.
+
+ Lemma app_zeros_zeros : forall n m, zeros n ++ zeros m = zeros (n + m).
+ Proof.
+ induction n; boring.
+ Qed.
+ Hint Rewrite app_zeros_zeros.
+
+ Lemma zeros_app0 : forall m, zeros m ++ 0 :: nil = zeros (S m).
+ Proof.
+ induction m; boring.
+ Qed.
+ Hint Rewrite zeros_app0.
+
+ Lemma rev_zeros : forall n, rev (zeros n) = zeros n.
+ Proof.
+ induction n; boring.
+ Qed.
+ Hint Rewrite rev_zeros.
+
+ Hint Unfold nth_default.
+
+ Lemma decode_single : forall n bs x,
+ decode' bs (zeros n ++ x :: nil) = nth_default 0 bs n * x.
+ Proof.
+ induction n; destruct bs; boring.
+ Qed.
+ Hint Rewrite decode_single.
+
+ Lemma peel_decode : forall xs ys x y, decode' (x::xs) (y::ys) = x*y + decode' xs ys.
+ Proof.
+ boring.
+ Qed.
+ Hint Rewrite zeros_rep peel_decode.
+
+ Lemma decode_highzeros : forall xs bs n, decode' bs (xs ++ zeros n) = decode' bs xs.
+ Proof.
+ induction xs; destruct bs; boring.
+ Qed.
+
+ Lemma mul_bi'_zeros : forall n m, mul_bi' base n (zeros m) = zeros m.
+ induction m; boring.
+ Qed.
+ Hint Rewrite mul_bi'_zeros.
+
+ Lemma nth_error_base_nonzero : forall n x,
+ nth_error base n = Some x -> x <> 0.
+ Proof.
+ eauto using (@nth_error_value_In Z), Zgt0_neq0, base_positive.
+ Qed.
+
+ Hint Rewrite plus_0_r.
+
+ Lemma mul_bi_single : forall m n x,
+ (n + m < length base)%nat ->
+ decode base (mul_bi base n (zeros m ++ x :: nil)) = nth_default 0 base m * x * nth_default 0 base n.
+ Proof.
+ unfold mul_bi, decode.
+ destruct m; simpl; simpl_list; simpl; intros. {
+ pose proof nth_error_base_nonzero as nth_nonzero.
+ case_eq base; [intros; boring | intros z l base_eq].
+ specialize (b0_1 0); intro b0_1'.
+ rewrite base_eq in *.
+ rewrite nth_default_cons in b0_1'.
+ rewrite b0_1' in *.
+ unfold crosscoef.
+ autounfold; autorewrite with core.
+ unfold nth_default.
+ nth_tac.
+ rewrite Z.mul_1_r.
+ rewrite Z_div_same_full.
+ destruct x; ring.
+ eapply nth_nonzero; eauto.
+ } {
+ ssimpl_list.
+ autorewrite with core.
+ rewrite app_assoc.
+ autorewrite with core.
+ unfold crosscoef; simpl; ring_simplify.
+ rewrite Nat.add_1_r.
+ rewrite base_good by auto.
+ rewrite Z_div_mult by (apply base_positive; rewrite nth_default_eq; apply nth_In; auto).
+ rewrite <- Z.mul_assoc.
+ rewrite <- Z.mul_comm.
+ rewrite <- Z.mul_assoc.
+ rewrite <- Z.mul_assoc.
+ destruct (Z.eq_dec x 0); subst; try ring.
+ rewrite Z.mul_cancel_l by auto.
+ rewrite <- base_good by auto.
+ ring.
+ }
+ Qed.
+
+ Lemma set_higher' : forall vs x, vs++x::nil = vs .+ (zeros (length vs) ++ x :: nil).
+ induction vs; boring; f_equal; ring.
+ Qed.
+
+ Lemma set_higher : forall bs vs x,
+ decode' bs (vs++x::nil) = decode' bs vs + nth_default 0 bs (length vs) * x.
+ Proof.
+ intros.
+ rewrite set_higher'.
+ rewrite add_rep.
+ f_equal.
+ apply decode_single.
+ Qed.
+
+ Lemma zeros_plus_zeros : forall n, zeros n = zeros n .+ zeros n.
+ induction n; auto.
+ simpl; f_equal; auto.
+ Qed.
+
+ Lemma mul_bi'_n_nil : forall n, mul_bi' base n nil = nil.
+ Proof.
+ unfold mul_bi; auto.
+ Qed.
+ Hint Rewrite mul_bi'_n_nil.
+
+ Lemma add_nil_l : forall us, nil .+ us = us.
+ induction us; auto.
+ Qed.
+ Hint Rewrite add_nil_l.
+
+ Lemma add_nil_r : forall us, us .+ nil = us.
+ induction us; auto.
+ Qed.
+ Hint Rewrite add_nil_r.
+
+ Lemma add_first_terms : forall us vs a b,
+ (a :: us) .+ (b :: vs) = (a + b) :: (us .+ vs).
+ auto.
+ Qed.
+ Hint Rewrite add_first_terms.
+
+ Lemma mul_bi'_cons : forall n x us,
+ mul_bi' base n (x :: us) = x * crosscoef base n (length us) :: mul_bi' base n us.
+ Proof.
+ unfold mul_bi'; auto.
+ Qed.
+
+ Lemma add_same_length : forall us vs l, (length us = l) -> (length vs = l) ->
+ length (us .+ vs) = l.
+ Proof.
+ induction us, vs; boring.
+ erewrite (IHus vs (pred l)); boring.
+ Qed.
+
+ Hint Rewrite app_nil_l.
+ Hint Rewrite app_nil_r.
+
+ Lemma add_snoc_same_length : forall l us vs a b,
+ (length us = l) -> (length vs = l) ->
+ (us ++ a :: nil) .+ (vs ++ b :: nil) = (us .+ vs) ++ (a + b) :: nil.
+ Proof.
+ induction l, us, vs; boring; discriminate.
+ Qed.
+
+ Lemma mul_bi'_add : forall us n vs l
+ (Hlus: length us = l)
+ (Hlvs: length vs = l),
+ mul_bi' base n (rev (us .+ vs)) =
+ mul_bi' base n (rev us) .+ mul_bi' base n (rev vs).
+ Proof.
+ (* TODO(adamc): please help prettify this *)
+ induction us using rev_ind;
+ try solve [destruct vs; boring; congruence].
+ destruct vs using rev_ind; boring; clear IHvs; simpl_list.
+ erewrite (add_snoc_same_length (pred l) us vs _ _); simpl_list.
+ repeat rewrite mul_bi'_cons; rewrite add_first_terms; simpl_list.
+ rewrite (IHus n vs (pred l)).
+ replace (length us) with (pred l).
+ replace (length vs) with (pred l).
+ rewrite (add_same_length us vs (pred l)).
+ f_equal; ring.
+
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ erewrite length_snoc; eauto.
+ Qed.
+
+ Lemma zeros_cons0 : forall n, 0 :: zeros n = zeros (S n).
+ auto.
+ Qed.
+
+ Lemma add_leading_zeros : forall n us vs,
+ (zeros n ++ us) .+ (zeros n ++ vs) = zeros n ++ (us .+ vs).
+ Proof.
+ induction n; boring.
+ Qed.
+
+ Lemma rev_add_rev : forall us vs l, (length us = l) -> (length vs = l) ->
+ (rev us) .+ (rev vs) = rev (us .+ vs).
+ Proof.
+ induction us, vs; boring; try solve [subst; discriminate].
+ rewrite (add_snoc_same_length (pred l) _ _ _ _) by (subst; simpl_list; omega).
+ rewrite (IHus vs (pred l)) by omega; auto.
+ Qed.
+ Hint Rewrite rev_add_rev.
+
+ Lemma mul_bi'_length : forall us n, length (mul_bi' base n us) = length us.
+ Proof.
+ induction us, n; boring.
+ Qed.
+ Hint Rewrite mul_bi'_length.
+
+ Lemma add_comm : forall us vs, us .+ vs = vs .+ us.
+ Proof.
+ induction us, vs; boring; f_equal; auto.
+ Qed.
+
+ Hint Rewrite rev_length.
+
+ Lemma mul_bi_add_same_length : forall n us vs l,
+ (length us = l) -> (length vs = l) ->
+ mul_bi base n (us .+ vs) = mul_bi base n us .+ mul_bi base n vs.
+ Proof.
+ unfold mul_bi; boring.
+ rewrite add_leading_zeros.
+ erewrite mul_bi'_add; boring.
+ erewrite rev_add_rev; boring.
+ Qed.
+
+ Lemma add_zeros_same_length : forall us, us .+ (zeros (length us)) = us.
+ Proof.
+ induction us; boring; f_equal; omega.
+ Qed.
+
+ Hint Rewrite add_zeros_same_length.
+ Hint Rewrite minus_diag.
+
+ Lemma add_trailing_zeros : forall us vs, (length us >= length vs)%nat ->
+ us .+ vs = us .+ (vs ++ (zeros (length us - length vs))).
+ Proof.
+ induction us, vs; boring; f_equal; boring.
+ Qed.
+
+ Lemma length_add_ge : forall us vs, (length us >= length vs)%nat ->
+ (length (us .+ vs) <= length us)%nat.
+ Proof.
+ intros.
+ rewrite add_trailing_zeros by trivial.
+ erewrite add_same_length by (pose proof app_length; boring); omega.
+ Qed.
+
+ Lemma add_length_le_max : forall us vs,
+ (length (us .+ vs) <= max (length us) (length vs))%nat.
+ Proof.
+ intros; case_max; (rewrite add_comm; apply length_add_ge; omega) ||
+ (apply length_add_ge; omega) .
+ Qed.
+
+ Lemma sub_nil_length: forall us : digits, length (sub nil us) = length us.
+ Proof.
+ induction us; boring.
+ Qed.
+
+ Lemma sub_length_le_max : forall us vs,
+ (length (sub us vs) <= max (length us) (length vs))%nat.
+ Proof.
+ induction us, vs; boring.
+ rewrite sub_nil_length; auto.
+ Qed.
+
+ Lemma mul_bi_length : forall us n, length (mul_bi base n us) = (length us + n)%nat.
+ Proof.
+ pose proof mul_bi'_length; unfold mul_bi.
+ destruct us; repeat progress (simpl_list; boring).
+ Qed.
+ Hint Rewrite mul_bi_length.
+
+ Lemma mul_bi_trailing_zeros : forall m n us,
+ mul_bi base n us ++ zeros m = mul_bi base n (us ++ zeros m).
+ Proof.
+ unfold mul_bi.
+ induction m; intros; try solve [boring].
+ rewrite <- zeros_app0.
+ rewrite app_assoc.
+ repeat progress (boring; rewrite rev_app_distr).
+ Qed.
+
+ Lemma mul_bi_add_longer : forall n us vs,
+ (length us >= length vs)%nat ->
+ mul_bi base n (us .+ vs) = mul_bi base n us .+ mul_bi base n vs.
+ Proof.
+ boring.
+ rewrite add_trailing_zeros by auto.
+ rewrite (add_trailing_zeros (mul_bi base n us) (mul_bi base n vs))
+ by (repeat (rewrite mul_bi_length); omega).
+ erewrite mul_bi_add_same_length by
+ (eauto; simpl_list; rewrite length_zeros; omega).
+ rewrite mul_bi_trailing_zeros.
+ repeat (f_equal; boring).
+ Qed.
+
+ Lemma mul_bi_add : forall n us vs,
+ mul_bi base n (us .+ vs) = (mul_bi base n us) .+ (mul_bi base n vs).
+ Proof.
+ intros; pose proof mul_bi_add_longer.
+ destruct (le_ge_dec (length us) (length vs)). {
+ rewrite add_comm.
+ rewrite (add_comm (mul_bi base n us)).
+ boring.
+ } {
+ boring.
+ }
+ Qed.
+
+ Lemma mul_bi_rep : forall i vs,
+ (i + length vs < length base)%nat ->
+ decode base (mul_bi base i vs) = decode base vs * nth_default 0 base i.
+ Proof.
+ unfold decode.
+ induction vs using rev_ind; intros; try solve [unfold mul_bi; boring].
+ assert (i + length vs < length base)%nat by
+ (rewrite app_length in *; boring).
+
+ rewrite set_higher.
+ ring_simplify.
+ rewrite <- IHvs by auto; clear IHvs.
+ rewrite <- mul_bi_single by auto.
+ rewrite <- add_rep.
+ rewrite <- mul_bi_add.
+ rewrite set_higher'.
+ auto.
+ Qed.
+
+ (* mul' is multiplication with the FIRST ARGUMENT REVERSED *)
+ Fixpoint mul' (usr vs:digits) : digits :=
+ match usr with
+ | u::usr' =>
+ mul_each u (mul_bi base (length usr') vs) .+ mul' usr' vs
+ | _ => nil
+ end.
+ Definition mul us := mul' (rev us).
+
+ Lemma mul'_rep : forall us vs,
+ (length us + length vs <= length base)%nat ->
+ decode base (mul' (rev us) vs) = decode base us * decode base vs.
+ Proof.
+ unfold decode.
+ induction us using rev_ind; boring.
+
+ assert (length us + length vs < length base)%nat by
+ (rewrite app_length in *; boring).
+
+ ssimpl_list.
+ rewrite add_rep.
+ boring.
+ rewrite set_higher.
+ rewrite mul_each_rep.
+ rewrite mul_bi_rep by auto.
+ unfold decode; ring.
+ Qed.
+
+ Lemma mul_rep : forall us vs,
+ (length us + length vs <= length base)%nat ->
+ decode base (mul us vs) = decode base us * decode base vs.
+ Proof.
+ exact mul'_rep.
+ Qed.
+
+ Lemma mul'_length: forall us vs,
+ (length (mul' us vs) <= length us + length vs)%nat.
+ Proof.
+ pose proof add_length_le_max.
+ induction us; boring.
+ unfold mul_each.
+ simpl_list; case_max; boring; omega.
+ Qed.
+
+ Lemma mul_length: forall us vs,
+ (length (mul us vs) <= length us + length vs)%nat.
+ Proof.
+ intros; unfold mul.
+ rewrite mul'_length.
+ rewrite rev_length; omega.
+ Qed.
+
+
+End BaseSystemProofs.
diff --git a/src/BoundedIterOp.v b/src/BoundedIterOp.v
deleted file mode 100644
index bd4f7d66e..000000000
--- a/src/BoundedIterOp.v
+++ /dev/null
@@ -1,94 +0,0 @@
-Require Import Crypto.Tactics.VerdiTactics.
-Require Import Coq.Numbers.BinNums Coq.NArith.NArith Coq.NArith.Nnat Coq.ZArith.ZArith.
-
-Definition testbit_rev p i bound := Pos.testbit_nat p (bound - i).
-
-Lemma testbit_rev_succ : forall p i b, (i < S b) ->
- testbit_rev p i (S b) =
- match p with
- | xI p' => testbit_rev p' i b
- | xO p' => testbit_rev p' i b
- | 1%positive => false
- end.
-Proof.
- unfold testbit_rev; intros; destruct p; rewrite <- minus_Sn_m by omega; auto.
-Qed.
-
-(* implements Pos.iter_op only using testbit, not destructing the positive *)
-Definition iter_op {A} (op : A -> A -> A) (zero : A) (bound : nat) (p : positive) :=
- fix iter (i : nat) (a : A) {struct i} : A :=
- match i with
- | O => zero
- | S i' => let ret := iter i' (op a a) in
- if testbit_rev p i bound
- then op a ret
- else ret
- end.
-
-Lemma iter_op_step : forall {A} op z b p i (a : A), (i < S b) ->
- iter_op op z (S b) p i a =
- match p with
- | xI p' => iter_op op z b p' i a
- | xO p' => iter_op op z b p' i a
- | 1%positive => z
- end.
-Proof.
- destruct p; unfold iter_op; (induction i; [ auto |]); intros; rewrite testbit_rev_succ by omega; rewrite IHi by omega; reflexivity.
-Qed.
-
-Lemma pos_size_gt0 : forall p, 0 < Pos.size_nat p.
-Proof.
- destruct p; intros; auto; try apply Lt.lt_0_Sn.
-Qed.
-Hint Resolve pos_size_gt0.
-
-Lemma iter_op_spec : forall b p {A} op z (a : A) (zero_id : forall x : A, op x z = x), (Pos.size_nat p <= b) ->
- iter_op op z b p b a = Pos.iter_op op p a.
-Proof.
- induction b; intros; [pose proof (pos_size_gt0 p); omega |].
- destruct p; simpl; rewrite iter_op_step by omega; unfold testbit_rev; rewrite Minus.minus_diag; try rewrite IHb; simpl in *; auto; omega.
-Qed.
-
-Lemma xO_neq1 : forall p, (1 < p~0)%positive.
-Proof.
- induction p; auto; apply Pos.lt_succ_diag_r.
-Qed.
-
-Lemma xI_neq1 : forall p, (1 < p~1)%positive.
-Proof.
- induction p; auto; eapply Pos.lt_trans; apply Pos.lt_succ_diag_r.
-Qed.
-
-Lemma xI_is_succ : forall n p, Pos.of_succ_nat n = p~1%positive ->
- (Pos.to_nat (2 * p))%nat = n.
-Proof.
- induction n; intros; try discriminate.
- rewrite <- Pnat.Nat2Pos.id by apply NPeano.Nat.neq_succ_0.
- rewrite Pnat.Pos2Nat.inj_iff.
- rewrite <- Pos.of_nat_succ.
- apply Pos.succ_inj.
- rewrite <- Pos.xI_succ_xO.
- auto.
-Qed.
-
-Lemma xO_is_succ : forall n p, Pos.of_succ_nat n = p~0%positive ->
- Pos.to_nat (Pos.pred_double p) = n.
-Proof.
- induction n; intros; try discriminate.
- rewrite Pos.pred_double_spec.
- rewrite <- Pnat.Pos2Nat.inj_iff in *.
- rewrite Pnat.Pos2Nat.inj_xO in *.
- rewrite Pnat.SuccNat2Pos.id_succ in *.
- rewrite Pnat.Pos2Nat.inj_pred by apply xO_neq1.
- rewrite <- NPeano.Nat.succ_inj_wd.
- rewrite Pnat.Pos2Nat.inj_xO.
- omega.
-Qed.
-
-Lemma size_of_succ : forall n,
- Pos.size_nat (Pos.of_nat n) <= Pos.size_nat (Pos.of_nat (S n)).
-Proof.
- intros; induction n; [simpl; auto|].
- apply Pos.size_nat_monotone.
- apply Pos.lt_succ_diag_r.
-Qed.
diff --git a/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v b/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v
index 3740f5a29..f70479c3a 100644
--- a/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v
+++ b/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v
@@ -7,107 +7,294 @@ Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Coq.Logic.Eqdep_dec.
Require Import Crypto.Tactics.VerdiTactics.
-Section CompleteEdwardsCurveTheorems.
- Context {prm:TwistedEdwardsParams}.
- Local Opaque q a d prime_q two_lt_q nonzero_a square_a nonsquare_d. (* [F_field] calls [compute] *)
- Existing Instance prime_q.
+Module E.
+ Section CompleteEdwardsCurveTheorems.
+ Context {prm:TwistedEdwardsParams}.
+ Local Opaque q a d prime_q two_lt_q nonzero_a square_a nonsquare_d. (* [F_field] calls [compute] *)
+ Existing Instance prime_q.
+
+ Add Field Ffield_p' : (@Ffield_theory q _)
+ (morphism (@Fring_morph q),
+ preprocess [Fpreprocess],
+ postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption],
+ constants [Fconstant],
+ div (@Fmorph_div_theory q),
+ power_tac (@Fpower_theory q) [Fexp_tac]).
+
+ Add Field Ffield_notConstant : (OpaqueFieldTheory q)
+ (constants [notConstant]).
+
+ Ltac clear_prm :=
+ generalize dependent a; intro a; intros;
+ generalize dependent d; intro d; intros;
+ generalize dependent prime_q; intro prime_q; intros;
+ generalize dependent q; intro q; intros;
+ clear prm.
+
+ Lemma point_eq : forall xy1 xy2 pf1 pf2,
+ xy1 = xy2 -> exist E.onCurve xy1 pf1 = exist E.onCurve xy2 pf2.
+ Proof.
+ destruct xy1, xy2; intros; find_injection; intros; subst. apply f_equal.
+ apply UIP_dec, F_eq_dec. (* this is a hack. We actually don't care about the equality of the proofs. However, we *can* prove it, and knowing it lets us use the universal equality instead of a type-specific equivalence, which makes many things nicer. *)
+ Qed. Hint Resolve point_eq.
+
+ Definition point_eqb (p1 p2:E.point) : bool := andb
+ (F_eqb (fst (proj1_sig p1)) (fst (proj1_sig p2)))
+ (F_eqb (snd (proj1_sig p1)) (snd (proj1_sig p2))).
+
+ Local Ltac t :=
+ unfold point_eqb;
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress simpl in *
+ | _ => progress subst
+ | [P:E.point |- _ ] => destruct P
+ | [x: (F q * F q)%type |- _ ] => destruct x
+ | [H: _ /\ _ |- _ ] => destruct H
+ | [H: _ |- _ ] => rewrite Bool.andb_true_iff in H
+ | [H: _ |- _ ] => apply F_eqb_eq in H
+ | _ => rewrite F_eqb_refl
+ end; eauto.
+
+ Lemma point_eqb_sound : forall p1 p2, point_eqb p1 p2 = true -> p1 = p2.
+ Proof.
+ t.
+ Qed.
+
+ Lemma point_eqb_complete : forall p1 p2, p1 = p2 -> point_eqb p1 p2 = true.
+ Proof.
+ t.
+ Qed.
+
+ Lemma point_eqb_neq : forall p1 p2, point_eqb p1 p2 = false -> p1 <> p2.
+ Proof.
+ intros. destruct (point_eqb p1 p2) eqn:Hneq; intuition.
+ apply point_eqb_complete in H0; congruence.
+ Qed.
+
+ Lemma point_eqb_neq_complete : forall p1 p2, p1 <> p2 -> point_eqb p1 p2 = false.
+ Proof.
+ intros. destruct (point_eqb p1 p2) eqn:Hneq; intuition.
+ apply point_eqb_sound in Hneq. congruence.
+ Qed.
+
+ Lemma point_eqb_refl : forall p, point_eqb p p = true.
+ Proof.
+ t.
+ Qed.
+
+ Definition point_eq_dec (p1 p2:E.point) : {p1 = p2} + {p1 <> p2}.
+ destruct (point_eqb p1 p2) eqn:H; match goal with
+ | [ H: _ |- _ ] => apply point_eqb_sound in H
+ | [ H: _ |- _ ] => apply point_eqb_neq in H
+ end; eauto.
+ Qed.
+
+ Lemma point_eqb_correct : forall p1 p2, point_eqb p1 p2 = if point_eq_dec p1 p2 then true else false.
+ Proof.
+ intros. destruct (point_eq_dec p1 p2); eauto using point_eqb_complete, point_eqb_neq_complete.
+ Qed.
+
+ Ltac Edefn := unfold E.add, E.add', E.zero; intros;
+ repeat match goal with
+ | [ p : E.point |- _ ] =>
+ let x := fresh "x" p in
+ let y := fresh "y" p in
+ let pf := fresh "pf" p in
+ destruct p as [[x y] pf]; unfold E.onCurve in pf
+ | _ => eapply point_eq, (f_equal2 pair)
+ | _ => eapply point_eq
+ end.
+ Lemma add_comm : forall A B, (A+B = B+A)%E.
+ Proof.
+ Edefn; apply (f_equal2 div); ring.
+ Qed.
+
+ Ltac unifiedAdd_nonzero := match goal with
+ | [ |- (?op 1 (d * _ * _ * _ * _ *
+ inv (1 - d * ?xA * ?xB * ?yA * ?yB) * inv (1 + d * ?xA * ?xB * ?yA * ?yB)))%F <> 0%F]
+ => let Hadd := fresh "Hadd" in
+ pose proof (@unifiedAdd'_onCurve _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d (xA, yA) (xB, yB)) as Hadd;
+ simpl in Hadd;
+ match goal with
+ | [H : (1 - d * ?xC * xB * ?yC * yB)%F <> 0%F |- (?op 1 ?other)%F <> 0%F] =>
+ replace other with
+ (d * xC * ((xA * yB + yA * xB) / (1 + d * xA * xB * yA * yB))
+ * yC * ((yA * yB - a * xA * xB) / (1 - d * xA * xB * yA * yB)))%F by (subst; unfold div; ring);
+ auto
+ end
+ end.
+
+ Lemma add_assoc : forall A B C, (A+(B+C) = (A+B)+C)%E.
+ Proof.
+ Edefn; F_field_simplify_eq; try abstract (rewrite ?@F_pow_2_r in *; clear_prm; F_nsatz);
+ pose proof (@edwardsAddCompletePlus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d);
+ pose proof (@edwardsAddCompleteMinus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d);
+ cbv beta iota in *;
+ repeat split; field_nonzero idtac; unifiedAdd_nonzero.
+ Qed.
+
+ Lemma add_0_r : forall P, (P + E.zero = P)%E.
+ Proof.
+ Edefn; repeat rewrite ?F_add_0_r, ?F_add_0_l, ?F_sub_0_l, ?F_sub_0_r,
+ ?F_mul_0_r, ?F_mul_0_l, ?F_mul_1_l, ?F_mul_1_r, ?F_div_1_r; exact eq_refl.
+ Qed.
- Add Field Ffield_p' : (@Ffield_theory q _)
- (morphism (@Fring_morph q),
- preprocess [Fpreprocess],
- postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption],
- constants [Fconstant],
- div (@Fmorph_div_theory q),
- power_tac (@Fpower_theory q) [Fexp_tac]).
+ Lemma add_0_l : forall P, (E.zero + P)%E = P.
+ Proof.
+ intros; rewrite add_comm. apply add_0_r.
+ Qed.
- Add Field Ffield_notConstant : (OpaqueFieldTheory q)
- (constants [notConstant]).
+ Lemma mul_0_l : forall P, (0 * P = E.zero)%E.
+ Proof.
+ auto.
+ Qed.
- Ltac clear_prm :=
- generalize dependent a; intro a; intros;
- generalize dependent d; intro d; intros;
- generalize dependent prime_q; intro prime_q; intros;
- generalize dependent q; intro q; intros;
- clear prm.
+ Lemma mul_S_l : forall n P, (S n * P)%E = (P + n * P)%E.
+ Proof.
+ auto.
+ Qed.
- Lemma point_eq : forall p1 p2, p1 = p2 -> forall pf1 pf2,
- mkPoint p1 pf1 = mkPoint p2 pf2.
- Proof.
- destruct p1, p2; intros; find_injection; intros; subst; apply f_equal.
- apply UIP_dec, F_eq_dec. (* this is a hack. We actually don't care about the equality of the proofs. However, we *can* prove it, and knowing it lets us use the universal equality instead of a type-specific equivalence, which makes many things nicer. *)
- Qed.
- Hint Resolve point_eq.
+ Lemma mul_add_l : forall a b P, ((a + b)%nat * P)%E = E.add (a * P)%E (b * P)%E.
+ Proof.
+ induction a; intros; rewrite ?plus_Sn_m, ?plus_O_n, ?mul_S_l, ?mul_0_l, ?add_0_l, ?mul_S_, ?IHa, ?add_assoc; auto.
+ Qed.
- Ltac Edefn := unfold unifiedAdd, unifiedAdd', zero; intros;
- repeat match goal with
- | [ p : point |- _ ] =>
- let x := fresh "x" p in
- let y := fresh "y" p in
- let pf := fresh "pf" p in
- destruct p as [[x y] pf]; unfold onCurve in pf
- | _ => eapply point_eq, (f_equal2 pair)
- | _ => eapply point_eq
- end.
- Lemma twistedAddComm : forall A B, (A+B = B+A)%E.
- Proof.
- Edefn; apply (f_equal2 div); ring.
- Qed.
+ Lemma mul_assoc : forall (n m : nat) P, (n * (m * P) = (n * m)%nat * P)%E.
+ Proof.
+ induction n; intros; auto.
+ rewrite ?mul_S_l, ?Mult.mult_succ_l, ?mul_add_l, ?IHn, add_comm. reflexivity.
+ Qed.
- Lemma twistedAddAssoc : forall A B C, (A+(B+C) = (A+B)+C)%E.
- Proof.
- (* The Ltac takes ~15s, the Qed no longer takes longer than I have had patience for *)
- Edefn; F_field_simplify_eq; try abstract (rewrite ?@F_pow_2_r in *; clear_prm; F_nsatz);
- repeat split; match goal with [ |- _ = 0%F -> False ] => admit end;
- fail "unreachable".
- Qed.
-
- Lemma zeroIsIdentity : forall P, (P + zero = P)%E.
- Proof.
- Edefn; repeat rewrite ?F_add_0_r, ?F_add_0_l, ?F_sub_0_l, ?F_sub_0_r,
- ?F_mul_0_r, ?F_mul_0_l, ?F_mul_1_l, ?F_mul_1_r, ?F_div_1_r; exact eq_refl.
- Qed.
-
- (* solve for x ^ 2 *)
- Definition solve_for_x2 (y : F q) := ((y ^ 2 - 1) / (d * (y ^ 2) - a))%F.
-
- Lemma d_y2_a_nonzero : (forall y, 0 <> d * y ^ 2 - a)%F.
- intros ? eq_zero.
- pose proof prime_q.
- destruct square_a as [sqrt_a sqrt_a_id].
- rewrite <- sqrt_a_id in eq_zero.
- destruct (Fq_square_mul_sub _ _ _ eq_zero) as [ [sqrt_d sqrt_d_id] | a_zero].
- + pose proof (nonsquare_d sqrt_d); auto.
- + subst.
- rewrite Fq_pow_zero in sqrt_a_id by congruence.
- auto using nonzero_a.
- Qed.
-
- Lemma a_d_y2_nonzero : (forall y, a - d * y ^ 2 <> 0)%F.
- Proof.
- intros y eq_zero.
- pose proof prime_q.
- eapply F_minus_swap in eq_zero.
- eauto using (d_y2_a_nonzero y).
- Qed.
-
- Lemma solve_correct : forall x y, onCurve (x, y) <->
- (x ^ 2 = solve_for_x2 y)%F.
- Proof.
- split.
- + intro onCurve_x_y.
+ Lemma mul_zero_r : forall m, (m * E.zero = E.zero)%E.
+ Proof.
+ induction m; rewrite ?mul_S_l, ?add_0_l; auto.
+ Qed.
+
+ (* solve for x ^ 2 *)
+ Definition solve_for_x2 (y : F q) := ((y ^ 2 - 1) / (d * (y ^ 2) - a))%F.
+
+ Lemma d_y2_a_nonzero : (forall y, 0 <> d * y ^ 2 - a)%F.
+ intros ? eq_zero.
pose proof prime_q.
- unfold onCurve in onCurve_x_y.
- eapply F_div_mul; auto using (d_y2_a_nonzero y).
- replace (x ^ 2 * (d * y ^ 2 - a))%F with ((d * x ^ 2 * y ^ 2) - (a * x ^ 2))%F by ring.
- rewrite F_sub_add_swap.
- replace (y ^ 2 + a * x ^ 2)%F with (a * x ^ 2 + y ^ 2)%F by ring.
- rewrite onCurve_x_y.
- ring.
- + intro x2_eq.
- unfold onCurve, solve_for_x2 in *.
- rewrite x2_eq.
- field.
- auto using d_y2_a_nonzero.
- Qed.
-
-End CompleteEdwardsCurveTheorems.
+ destruct square_a as [sqrt_a sqrt_a_id].
+ rewrite <- sqrt_a_id in eq_zero.
+ destruct (Fq_square_mul_sub _ _ _ eq_zero) as [ [sqrt_d sqrt_d_id] | a_zero].
+ + pose proof (nonsquare_d sqrt_d); auto.
+ + subst.
+ rewrite Fq_pow_zero in sqrt_a_id by congruence.
+ auto using nonzero_a.
+ Qed.
+
+ Lemma a_d_y2_nonzero : (forall y, a - d * y ^ 2 <> 0)%F.
+ Proof.
+ intros y eq_zero.
+ pose proof prime_q.
+ eapply F_minus_swap in eq_zero.
+ eauto using (d_y2_a_nonzero y).
+ Qed.
+
+ Lemma solve_correct : forall x y, E.onCurve (x, y) <->
+ (x ^ 2 = solve_for_x2 y)%F.
+ Proof.
+ split.
+ + intro onCurve_x_y.
+ pose proof prime_q.
+ unfold E.onCurve in onCurve_x_y.
+ eapply F_div_mul; auto using (d_y2_a_nonzero y).
+ replace (x ^ 2 * (d * y ^ 2 - a))%F with ((d * x ^ 2 * y ^ 2) - (a * x ^ 2))%F by ring.
+ rewrite F_sub_add_swap.
+ replace (y ^ 2 + a * x ^ 2)%F with (a * x ^ 2 + y ^ 2)%F by ring.
+ rewrite onCurve_x_y.
+ ring.
+ + intro x2_eq.
+ unfold E.onCurve, solve_for_x2 in *.
+ rewrite x2_eq.
+ field.
+ auto using d_y2_a_nonzero.
+ Qed.
+
+
+ Program Definition opp (P:E.point) : E.point := let '(x, y) := proj1_sig P in (opp x, y).
+ Next Obligation. Proof.
+ pose (proj2_sig P) as H; rewrite <-Heq_anonymous in H; simpl in H.
+ rewrite F_square_opp; trivial.
+ Qed.
+
+ Definition sub P Q := (P + opp Q)%E.
+
+ Lemma opp_zero : opp E.zero = E.zero.
+ Proof.
+ pose proof @F_opp_0.
+ unfold opp, E.zero; eapply point_eq; congruence.
+ Qed.
+
+ Lemma add_opp_r : forall P, (P + opp P = E.zero)%E.
+ Proof.
+ unfold opp; Edefn; rewrite ?@F_pow_2_r in *; (F_field_simplify_eq; [clear_prm; F_nsatz|..]);
+ rewrite <-?@F_pow_2_r in *;
+ pose proof (@edwardsAddCompletePlus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d _ _ _ _ pfP pfP);
+ pose proof (@edwardsAddCompleteMinus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d _ _ _ _ pfP pfP);
+ field_nonzero idtac.
+ Qed.
+
+ Lemma add_opp_l : forall P, (opp P + P = E.zero)%E.
+ Proof.
+ intros. rewrite add_comm. eapply add_opp_r.
+ Qed.
+
+ Lemma add_cancel_r : forall A B C, (B+A = C+A -> B = C)%E.
+ Proof.
+ intros.
+ assert ((B + A) + opp A = (C + A) + opp A)%E as Hc by congruence.
+ rewrite <-!add_assoc, !add_opp_r, !add_0_r in Hc; exact Hc.
+ Qed.
+
+ Lemma add_cancel_l : forall A B C, (A+B = A+C -> B = C)%E.
+ Proof.
+ intros.
+ rewrite (add_comm A C) in H.
+ rewrite (add_comm A B) in H.
+ eauto using add_cancel_r.
+ Qed.
+
+ Lemma shuffle_eq_add_opp : forall P Q R, (P + Q = R <-> Q = opp P + R)%E.
+ Proof.
+ split; intros.
+ { assert (opp P + (P + Q) = opp P + R)%E as Hc by congruence.
+ rewrite add_assoc, add_opp_l, add_comm, add_0_r in Hc; exact Hc. }
+ { subst. rewrite add_assoc, add_opp_r, add_comm, add_0_r; reflexivity. }
+ Qed.
+
+ Lemma opp_opp : forall P, opp (opp P) = P.
+ Proof.
+ intros.
+ pose proof (add_opp_r P%E) as H.
+ rewrite add_comm in H.
+ rewrite shuffle_eq_add_opp in H.
+ rewrite add_0_r in H.
+ congruence.
+ Qed.
+
+ Lemma opp_add : forall P Q, opp (P + Q)%E = (opp P + opp Q)%E.
+ Proof.
+ intros.
+ pose proof (add_opp_r (P+Q)%E) as H.
+ rewrite <-!add_assoc in H.
+ rewrite add_comm in H.
+ rewrite <-!add_assoc in H.
+ rewrite shuffle_eq_add_opp in H.
+ rewrite add_comm in H.
+ rewrite shuffle_eq_add_opp in H.
+ rewrite add_0_r in H.
+ assumption.
+ Qed.
+
+ Lemma opp_mul : forall n P, opp (E.mul n P) = E.mul n (opp P).
+ Proof.
+ pose proof opp_add; pose proof opp_zero.
+ induction n; simpl; intros; congruence.
+ Qed.
+ End CompleteEdwardsCurveTheorems.
+End E.
+Infix "-" := E.sub : E_scope. \ No newline at end of file
diff --git a/src/CompleteEdwardsCurve/DoubleAndAdd.v b/src/CompleteEdwardsCurve/DoubleAndAdd.v
index 84c1289f6..50027349d 100644
--- a/src/CompleteEdwardsCurve/DoubleAndAdd.v
+++ b/src/CompleteEdwardsCurve/DoubleAndAdd.v
@@ -1,71 +1,30 @@
Require Import Crypto.Tactics.VerdiTactics.
-Require Import Crypto.Spec.CompleteEdwardsCurve Crypto.BoundedIterOp.
+Require Import Crypto.Spec.CompleteEdwardsCurve Crypto.Util.IterAssocOp.
Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems.
Require Import Coq.Numbers.BinNums Coq.NArith.NArith Coq.NArith.Nnat Coq.ZArith.ZArith.
Section EdwardsDoubleAndAdd.
Context {prm:TwistedEdwardsParams}.
- Definition doubleAndAdd (b n : nat) (P : point) : point :=
- match N.of_nat n with
- | 0%N => zero
- | N.pos p => iter_op unifiedAdd zero b p b P
- end.
+ Definition doubleAndAdd (bound n : nat) (P : E.point) : E.point :=
+ iter_op E.add E.zero N.testbit_nat (N.of_nat n) P bound.
- Lemma scalarMult_double : forall n P, scalarMult (n + n) P = scalarMult n (P + P)%E.
+ Lemma scalarMult_double : forall n P, E.mul (n + n) P = E.mul n (P + P)%E.
Proof.
intros.
replace (n + n)%nat with (n * 2)%nat by omega.
induction n; simpl; auto.
- rewrite twistedAddAssoc.
+ rewrite E.add_assoc.
f_equal; auto.
Qed.
- Lemma iter_op_double : forall p P,
- Pos.iter_op unifiedAdd (p + p) P = Pos.iter_op unifiedAdd p (P + P)%E.
+ Lemma doubleAndAdd_spec : forall bound n P, N.size_nat (N.of_nat n) <= bound ->
+ E.mul n P = doubleAndAdd bound n P.
Proof.
- intros.
- rewrite Pos.add_diag.
- unfold Pos.iter_op; simpl.
- auto.
- Qed.
-
- Lemma doubleAndAdd_spec : forall n b P, (Pos.size_nat (Pos.of_nat n) <= b) ->
- scalarMult n P = doubleAndAdd b n P.
- Proof.
- induction n; auto; intros.
- unfold doubleAndAdd; simpl.
- rewrite Pos.of_nat_succ.
- rewrite iter_op_spec by (auto using zeroIsIdentity).
- case_eq (Pos.of_nat (S n)); intros. {
- simpl; f_equal.
- rewrite (IHn b) by (pose proof (size_of_succ n); omega).
- unfold doubleAndAdd.
- rewrite H0 in H.
- rewrite <- Pos.of_nat_succ in H0.
- rewrite <- (xI_is_succ n p) by apply H0.
- rewrite Znat.positive_nat_N; simpl.
- rewrite iter_op_spec; auto using zeroIsIdentity.
- } {
- simpl; f_equal.
- rewrite (IHn b) by (pose proof (size_of_succ n); omega).
- unfold doubleAndAdd.
- rewrite <- (xO_is_succ n p) by (rewrite Pos.of_nat_succ; auto).
- rewrite Znat.positive_nat_N; simpl.
- rewrite <- Pos.succ_pred_double in H0.
- rewrite H0 in H.
- rewrite iter_op_spec by (auto using zeroIsIdentity;
- pose proof (Pos.lt_succ_diag_r (Pos.pred_double p));
- apply Pos.size_nat_monotone in H1; omega; auto).
- rewrite <- iter_op_double.
- rewrite Pos.add_diag.
- rewrite <- Pos.succ_pred_double.
- rewrite Pos.iter_op_succ by apply twistedAddAssoc; auto.
- } {
- rewrite <- Pnat.Pos2Nat.inj_iff in H0.
- rewrite Pnat.Nat2Pos.id in H0 by auto.
- rewrite Pnat.Pos2Nat.inj_1 in H0.
- assert (n = 0)%nat by omega; subst.
- auto using zeroIsIdentity.
- }
+ induction n; auto; intros; unfold doubleAndAdd;
+ rewrite iter_op_spec with (scToN := fun x => x); (
+ unfold Morphisms.Proper, Morphisms.respectful, Equivalence.equiv;
+ intros; subst; try rewrite Nat2N.id;
+ reflexivity || assumption || apply E.add_assoc
+ || rewrite E.add_comm; apply E.add_0_r).
Qed.
End EdwardsDoubleAndAdd. \ No newline at end of file
diff --git a/src/CompleteEdwardsCurve/ExtendedCoordinates.v b/src/CompleteEdwardsCurve/ExtendedCoordinates.v
index e918ac128..e91bc084b 100644
--- a/src/CompleteEdwardsCurve/ExtendedCoordinates.v
+++ b/src/CompleteEdwardsCurve/ExtendedCoordinates.v
@@ -3,7 +3,7 @@ Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.ModularArithmetic.FField.
Require Import Crypto.Tactics.VerdiTactics.
-Require Import Util.IterAssocOp BinNat NArith.
+Require Import Util.IterAssocOp BinNat NArith.
Require Import Coq.Setoids.Setoid Coq.Classes.Morphisms Coq.Classes.Equivalence.
Local Open Scope equiv_scope.
Local Open Scope F_scope.
@@ -19,10 +19,10 @@ Section ExtendedCoordinates.
postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption],
constants [Fconstant],
div (@Fmorph_div_theory q),
- power_tac (@Fpower_theory q) [Fexp_tac]).
+ power_tac (@Fpower_theory q) [Fexp_tac]).
Add Field Ffield_notConstant : (OpaqueFieldTheory q)
- (constants [notConstant]).
+ (constants [notConstant]).
(** [extended] represents a point on an elliptic curve using extended projective
* Edwards coordinates with twist a=-1 (see <https://eprint.iacr.org/2008/522.pdf>). *)
@@ -44,8 +44,8 @@ Section ExtendedCoordinates.
Local Hint Unfold twistedToExtended extendedToTwisted rep.
Local Notation "P '~=' rP" := (rep P rP) (at level 70).
- Ltac unfoldExtended :=
- repeat progress (autounfold; unfold onCurve, unifiedAdd, unifiedAdd', rep in *; intros);
+ Ltac unfoldExtended :=
+ repeat progress (autounfold; unfold E.onCurve, E.add, E.add', rep in *; intros);
repeat match goal with
| [ p : (F q*F q)%type |- _ ] =>
let x := fresh "x" p in
@@ -61,7 +61,7 @@ Section ExtendedCoordinates.
| [ H: @eq (F q * F q)%type _ _ |- _ ] => invcs H
| [ H: @eq F q ?x _ |- _ ] => isVar x; rewrite H; clear H
end.
-
+
Ltac solveExtended := unfoldExtended;
repeat match goal with
| [ |- _ /\ _ ] => split
@@ -83,14 +83,14 @@ Section ExtendedCoordinates.
solveExtended.
Qed.
- Definition extendedPoint := { P:extended | rep P (extendedToTwisted P) /\ onCurve (extendedToTwisted P) }.
+ Definition extendedPoint := { P:extended | rep P (extendedToTwisted P) /\ E.onCurve (extendedToTwisted P) }.
- Program Definition mkExtendedPoint : point -> extendedPoint := twistedToExtended.
+ Program Definition mkExtendedPoint : E.point -> extendedPoint := twistedToExtended.
Next Obligation.
destruct x; erewrite extendedToTwisted_rep; eauto using twistedToExtended_rep.
Qed.
- Program Definition unExtendedPoint : extendedPoint -> point := extendedToTwisted.
+ Program Definition unExtendedPoint : extendedPoint -> E.point := extendedToTwisted.
Next Obligation.
destruct x; simpl; intuition.
Qed.
@@ -103,7 +103,7 @@ Section ExtendedCoordinates.
Lemma unExtendedPoint_mkExtendedPoint : forall P, unExtendedPoint (mkExtendedPoint P) = P.
Proof.
- destruct P; eapply point_eq; simpl; erewrite extendedToTwisted_rep; eauto using twistedToExtended_rep.
+ destruct P; eapply E.point_eq; simpl; erewrite extendedToTwisted_rep; eauto using twistedToExtended_rep.
Qed.
Global Instance Proper_mkExtendedPoint : Proper (eq==>equiv) mkExtendedPoint.
@@ -116,6 +116,8 @@ Section ExtendedCoordinates.
repeat (econstructor || intro); unfold extendedPoint_eq in *; congruence.
Qed.
+ Definition twice_d := d + d.
+
Section TwistMinus1.
Context (a_eq_minus1 : a = opp 1).
(** Second equation from <http://eprint.iacr.org/2008/522.pdf> section 3.1, also <https://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html#addition-add-2008-hwcd-3> and <https://tools.ietf.org/html/draft-josefsson-eddsa-ed25519-03> *)
@@ -124,8 +126,8 @@ Section ExtendedCoordinates.
let '(X2, Y2, Z2, T2) := P2 in
let A := (Y1-X1)*(Y2-X2) in
let B := (Y1+X1)*(Y2+X2) in
- let C := T1*ZToField 2*d*T2 in
- let D := Z1*ZToField 2 *Z2 in
+ let C := T1*twice_d*T2 in
+ let D := Z1*(Z2+Z2) in
let E := B-A in
let F := D-C in
let G := D+C in
@@ -135,47 +137,30 @@ Section ExtendedCoordinates.
let T3 := E*H in
let Z3 := F*G in
(X3, Y3, Z3, T3).
- Local Hint Unfold unifiedAdd.
+ Local Hint Unfold E.add.
+
+ Local Ltac tnz := repeat apply Fq_mul_nonzero_nonzero; auto using (@char_gt_2 q two_lt_q).
- Lemma unifiedAddM1'_rep: forall P Q rP rQ, onCurve rP -> onCurve rQ ->
- P ~= rP -> Q ~= rQ -> (unifiedAddM1' P Q) ~= (unifiedAdd' rP rQ).
+ Lemma F_mul_2_l : forall x : F q, ZToField 2 * x = x + x.
+ intros. ring.
+ Qed.
+
+ Lemma unifiedAddM1'_rep: forall P Q rP rQ, E.onCurve rP -> E.onCurve rQ ->
+ P ~= rP -> Q ~= rQ -> (unifiedAddM1' P Q) ~= (E.add' rP rQ).
Proof.
intros P Q rP rQ HoP HoQ HrP HrQ.
pose proof (@edwardsAddCompletePlus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d).
pose proof (@edwardsAddCompleteMinus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d).
- unfoldExtended; rewrite a_eq_minus1 in *; simpl in *.
+ unfoldExtended; unfold twice_d; rewrite a_eq_minus1 in *; simpl in *. repeat rewrite <-F_mul_2_l.
repeat split; repeat apply (f_equal2 pair); try F_field; repeat split; auto;
repeat rewrite ?F_add_0_r, ?F_add_0_l, ?F_sub_0_l, ?F_sub_0_r,
- ?F_mul_0_r, ?F_mul_0_l, ?F_mul_1_l, ?F_mul_1_r, ?F_div_1_r.
-
- Ltac tnz := repeat apply Fq_mul_nonzero_nonzero; auto using (@char_gt_2 q two_lt_q).
- (* If we we had reasoning modulo associativity and commutativity,
- * the following tactic would probably solve all remaining goals here:
- repeat match goal with [H1: @eq (F p) _ _, H2: @eq (F p) _ _ |- _ ] =>
- let H := fresh "H" in (
- pose proof (edwardsAddCompletePlus _ _ _ _ H1 H2) as H;
- match type of H with ?xs <> 0 => ac_rewrite (eq_refl xs) end
- ) || (
- pose proof (edwardsAddCompleteMinus _ _ _ _ H1 H2) as H;
- match type of H with ?xs <> 0 => ac_rewrite (eq_refl xs) end
- ); tnz
- end. *)
-
- - replace (ZP * ZQ * ZP * ZQ + d * XP * XQ * YP * YQ) with (ZQ*ZQ*ZP*ZP* (1 + d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz.
- - replace (ZP * ZToField 2 * ZQ * (ZP * ZQ) + XP * YP * ZToField 2 * d * (XQ * YQ)) with (ZToField 2*ZQ*ZQ*ZP*ZP* (1 + d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz.
- - replace (ZP * ZToField 2 * ZQ * (ZP * ZQ) - XP * YP * ZToField 2 * d * (XQ * YQ)) with (ZToField 2*ZQ*ZQ*ZP*ZP* (1 - d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz.
- - replace (ZP * ZQ * ZP * ZQ - d * XP * XQ * YP * YQ) with (ZQ*ZQ*ZP*ZP* (1 - d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz.
- - replace (ZP * ZToField 2 * ZQ * (ZP * ZQ) + XP * YP * ZToField 2 * d * (XQ * YQ)) with (ZToField 2*ZQ*ZQ*ZP*ZP* (1 + d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz.
- - replace (ZP * ZToField 2 * ZQ * (ZP * ZQ) - XP * YP * ZToField 2 * d * (XQ * YQ)) with (ZToField 2*ZQ*ZQ*ZP*ZP* (1 - d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz.
- - replace (ZP * ZToField 2 * ZQ - XP * YP / ZP * ZToField 2 * d * (XQ * YQ / ZQ) ) with (ZToField 2*ZQ*ZP* (1 - d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz.
- replace (ZP * ZToField 2 * ZQ + XP * YP / ZP * ZToField 2 * d * (XQ * YQ / ZQ) ) with (ZToField 2*ZQ*ZP* (1 + d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz.
- - replace (ZP * ZToField 2 * ZQ * (ZP * ZQ) + XP * YP * ZToField 2 * d * (XQ * YQ)) with (ZToField 2*ZQ*ZQ*ZP*ZP* (1 + d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz.
- - replace (ZP * ZToField 2 * ZQ * (ZP * ZQ) - XP * YP * ZToField 2 * d * (XQ * YQ)) with (ZToField 2*ZQ*ZQ*ZP*ZP* (1 - d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz.
+ ?F_mul_0_r, ?F_mul_0_l, ?F_mul_1_l, ?F_mul_1_r, ?F_div_1_r;
+ field_nonzero tnz.
Qed.
- Lemma unifiedAdd'_onCurve : forall P Q, onCurve P -> onCurve Q -> onCurve (unifiedAdd' P Q).
+ Lemma unifiedAdd'_onCurve : forall P Q, E.onCurve P -> E.onCurve Q -> E.onCurve (E.add' P Q).
Proof.
- intros; pose proof (proj2_sig (unifiedAdd (mkPoint _ H) (mkPoint _ H0))); eauto.
+ intros; pose proof (proj2_sig (E.add (exist _ _ H) (exist _ _ H0))); eauto.
Qed.
Program Definition unifiedAddM1 : extendedPoint -> extendedPoint -> extendedPoint := unifiedAddM1'.
@@ -188,9 +173,9 @@ Section ExtendedCoordinates.
eauto using unifiedAdd'_onCurve.
Qed.
- Lemma unifiedAddM1_rep : forall P Q, unifiedAdd (unExtendedPoint P) (unExtendedPoint Q) = unExtendedPoint (unifiedAddM1 P Q).
+ Lemma unifiedAddM1_rep : forall P Q, E.add (unExtendedPoint P) (unExtendedPoint Q) = unExtendedPoint (unifiedAddM1 P Q).
Proof.
- destruct P, Q; unfold unExtendedPoint, unifiedAdd, unifiedAddM1; eapply point_eq; simpl in *; intuition.
+ destruct P, Q; unfold unExtendedPoint, E.add, unifiedAddM1; eapply E.point_eq; simpl in *; intuition.
pose proof (unifiedAddM1'_rep x x0 (extendedToTwisted x) (extendedToTwisted x0));
destruct (unifiedAddM1' x x0);
unfold rep in *; intuition.
@@ -201,34 +186,54 @@ Section ExtendedCoordinates.
repeat (econstructor || intro).
repeat match goal with [H: _ === _ |- _ ] => inversion H; clear H end; unfold equiv, extendedPoint_eq.
rewrite <-!unifiedAddM1_rep.
- destruct x, y, x0, y0; simpl in *; eapply point_eq; congruence.
+ destruct x, y, x0, y0; simpl in *; eapply E.point_eq; congruence.
Qed.
- Lemma unifiedAddM1_0_r : forall P, unifiedAddM1 P (mkExtendedPoint zero) === P.
+ Lemma unifiedAddM1_0_r : forall P, unifiedAddM1 P (mkExtendedPoint E.zero) === P.
unfold equiv, extendedPoint_eq; intros.
- rewrite <-!unifiedAddM1_rep, unExtendedPoint_mkExtendedPoint, zeroIsIdentity; auto.
+ rewrite <-!unifiedAddM1_rep, unExtendedPoint_mkExtendedPoint, E.add_0_r; auto.
Qed.
- Lemma unifiedAddM1_0_l : forall P, unifiedAddM1 (mkExtendedPoint zero) P === P.
+ Lemma unifiedAddM1_0_l : forall P, unifiedAddM1 (mkExtendedPoint E.zero) P === P.
unfold equiv, extendedPoint_eq; intros.
- rewrite <-!unifiedAddM1_rep, twistedAddComm, unExtendedPoint_mkExtendedPoint, zeroIsIdentity; auto.
+ rewrite <-!unifiedAddM1_rep, E.add_comm, unExtendedPoint_mkExtendedPoint, E.add_0_r; auto.
Qed.
Lemma unifiedAddM1_assoc : forall a b c, unifiedAddM1 a (unifiedAddM1 b c) === unifiedAddM1 (unifiedAddM1 a b) c.
Proof.
unfold equiv, extendedPoint_eq; intros.
- rewrite <-!unifiedAddM1_rep, twistedAddAssoc; auto.
+ rewrite <-!unifiedAddM1_rep, E.add_assoc; auto.
+ Qed.
+
+ Lemma testbit_conversion_identity : forall x i, N.testbit_nat x i = N.testbit_nat ((fun a => a) x) i.
+ Proof.
+ trivial.
Qed.
-
- Definition scalarMultM1 := iter_op unifiedAddM1 (mkExtendedPoint zero).
- Definition scalarMultM1_spec := iter_op_spec unifiedAddM1 unifiedAddM1_assoc (mkExtendedPoint zero) unifiedAddM1_0_l.
- Lemma scalarMultM1_rep : forall n P, unExtendedPoint (scalarMultM1 (N.of_nat n) P) = scalarMult n (unExtendedPoint P).
- intros; rewrite scalarMultM1_spec, Nat2N.id.
+
+ Definition scalarMultM1 := iter_op unifiedAddM1 (mkExtendedPoint E.zero) N.testbit_nat.
+ Definition scalarMultM1_spec :=
+ iter_op_spec unifiedAddM1 unifiedAddM1_assoc (mkExtendedPoint E.zero) unifiedAddM1_0_l
+ N.testbit_nat (fun x => x) testbit_conversion_identity.
+ Lemma scalarMultM1_rep : forall n P, unExtendedPoint (scalarMultM1 (N.of_nat n) P (N.size_nat (N.of_nat n))) = E.mul n (unExtendedPoint P).
+ intros; rewrite scalarMultM1_spec, Nat2N.id; auto.
induction n; [simpl; rewrite !unExtendedPoint_mkExtendedPoint; reflexivity|].
- unfold scalarMult; fold scalarMult.
+ unfold E.mul; fold E.mul.
rewrite <-IHn, unifiedAddM1_rep; auto.
Qed.
End TwistMinus1.
-End ExtendedCoordinates. \ No newline at end of file
+ Definition negateExtended' P := let '(X, Y, Z, T) := P in (opp X, Y, Z, opp T).
+ Program Definition negateExtended (P:extendedPoint) : extendedPoint := negateExtended' (proj1_sig P).
+ Next Obligation.
+ Proof.
+ unfold negateExtended', rep; destruct P as [[X Y Z T] H]; simpl. destruct H as [[[] []] ?]; subst.
+ repeat rewrite ?F_div_opp_1, ?F_mul_opp_l, ?F_square_opp; repeat split; trivial.
+ Qed.
+
+ Lemma negateExtended_correct : forall P, E.opp (unExtendedPoint P) = unExtendedPoint (negateExtended P).
+ Proof.
+ unfold E.opp, unExtendedPoint, negateExtended; destruct P as [[]]; simpl; intros.
+ eapply E.point_eq; repeat rewrite ?F_div_opp_1, ?F_mul_opp_l, ?F_square_opp; trivial.
+ Qed.
+End ExtendedCoordinates.
diff --git a/src/EdDSAProofs.v b/src/EdDSAProofs.v
index 83467bf6d..dba71b49c 100644
--- a/src/EdDSAProofs.v
+++ b/src/EdDSAProofs.v
@@ -37,7 +37,7 @@ Section EdDSAProofs.
Lemma decode_sign_split2 : forall sk {n} (M : word n),
split2 b b (sign (public sk) sk M) =
let r : nat := H (prngKey sk ++ M) in (* secret nonce *)
- let R : point := (r * B)%E in (* commitment to nonce *)
+ let R : E.point := (r * B)%E in (* commitment to nonce *)
let s : nat := curveKey sk in (* secret scalar *)
let S : F (Z.of_nat l) := ZToField (Z.of_nat (r + H (enc R ++ public sk ++ M) * s)) in
enc S.
@@ -46,62 +46,21 @@ Section EdDSAProofs.
Qed.
Hint Rewrite decode_sign_split2.
- Lemma zero_times : forall P, (0 * P = zero)%E.
- Proof.
- auto.
- Qed.
-
- Lemma zero_plus : forall P, (zero + P = P)%E.
- Proof.
- intros; rewrite twistedAddComm; apply zeroIsIdentity.
- Qed.
-
- Lemma times_S : forall n m, S n * m = m + n * m.
- Proof.
- auto.
- Qed.
-
- Lemma times_S_nat : forall n m, (S n * m = m + n * m)%nat.
- Proof.
- auto.
- Qed.
-
- Hint Rewrite plus_O_n plus_Sn_m times_S times_S_nat.
- Hint Rewrite zeroIsIdentity zero_times zero_plus twistedAddAssoc.
-
- Lemma scalarMult_distr : forall n0 m, ((n0 + m)%nat * B)%E = unifiedAdd (n0 * B)%E (m * B)%E.
- Proof.
- unfold scalarMult; induction n0; arith.
- Qed.
- Hint Rewrite scalarMult_distr.
-
- Lemma scalarMult_assoc : forall (n0 m : nat), (n0 * (m * B) = (n0 * m)%nat * B)%E.
- Proof.
- induction n0; arith; simpl; arith.
- Qed.
- Hint Rewrite scalarMult_assoc.
-
- Lemma scalarMult_zero : forall m, (m * zero = zero)%E.
- Proof.
- unfold scalarMult; induction m; arith.
- Qed.
- Hint Rewrite scalarMult_zero.
+ Hint Rewrite E.add_0_r E.add_0_l E.add_assoc.
+ Hint Rewrite E.mul_assoc E.mul_add_l E.mul_0_l E.mul_zero_r.
+ Hint Rewrite plus_O_n plus_Sn_m mult_0_l mult_succ_l.
Hint Rewrite l_order_B.
-
- Lemma l_order_B' : forall x, (l * x * B = zero)%E.
+ Lemma l_order_B' : forall x, (l * x * B = E.zero)%E.
Proof.
- intros; rewrite Mult.mult_comm. rewrite <- scalarMult_assoc. arith.
- Qed.
-
- Hint Rewrite l_order_B'.
+ intros; rewrite Mult.mult_comm. rewrite <- E.mul_assoc. arith.
+ Qed. Hint Rewrite l_order_B'.
Lemma scalarMult_mod_l : forall n0, (n0 mod l * B = n0 * B)%E.
Proof.
intros.
rewrite (div_mod n0 l) at 2 by (generalize l_odd; omega).
arith.
- Qed.
- Hint Rewrite scalarMult_mod_l.
+ Qed. Hint Rewrite scalarMult_mod_l.
Hint Rewrite @encoding_valid.
Hint Rewrite @FieldToZ_ZToField.
diff --git a/src/Encoding/EncodingTheorems.v b/src/Encoding/EncodingTheorems.v
index f53ad0319..52ac91ada 100644
--- a/src/Encoding/EncodingTheorems.v
+++ b/src/Encoding/EncodingTheorems.v
@@ -1,7 +1,7 @@
Require Import Crypto.Spec.Encoding.
Section EncodingTheorems.
- Context {A B : Type} {E : encoding of A as B}.
+ Context {A B : Type} {E : canonical encoding of A as B}.
Lemma encoding_inj : forall x y, enc x = enc y -> x = y.
Proof.
diff --git a/src/Encoding/ModularWordEncodingPre.v b/src/Encoding/ModularWordEncodingPre.v
new file mode 100644
index 000000000..417344b43
--- /dev/null
+++ b/src/Encoding/ModularWordEncodingPre.v
@@ -0,0 +1,45 @@
+Require Import Coq.ZArith.ZArith.
+Require Import Coq.Numbers.Natural.Peano.NPeano.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import Bedrock.Word.
+Require Import Crypto.Tactics.VerdiTactics.
+Require Import Crypto.Util.ZUtil Crypto.Util.WordUtil.
+Require Import Crypto.Spec.Encoding.
+
+Local Open Scope nat_scope.
+
+Section ModularWordEncodingPre.
+ Context {m : Z} {sz : nat} {m_pos : (0 < m)%Z} {bound_check : Z.to_nat m < 2 ^ sz}.
+
+ Let Fm_enc (x : F m) : word sz := NToWord sz (Z.to_N (FieldToZ x)).
+
+ Let Fm_dec (x_ : word sz) : option (F m) :=
+ let z := Z.of_N (wordToN (x_)) in
+ if Z_lt_dec z m
+ then Some (ZToField z)
+ else None
+ .
+
+ Lemma Fm_encoding_valid : forall x, Fm_dec (Fm_enc x) = Some x.
+ Proof.
+ unfold Fm_dec, Fm_enc; intros.
+ pose proof (FieldToZ_range x m_pos).
+ rewrite wordToN_NToWord_idempotent by (apply bound_check_nat_N;
+ assert (Z.to_nat x < Z.to_nat m) by (apply Z2Nat.inj_lt; omega); omega).
+ rewrite Z2N.id by omega.
+ rewrite ZToField_idempotent.
+ break_if; auto; omega.
+ Qed.
+
+ Lemma Fm_encoding_canonical : forall w x, Fm_dec w = Some x -> Fm_enc x = w.
+ Proof.
+ unfold Fm_dec, Fm_enc; intros ? ? dec_Some.
+ break_if; [ | congruence ].
+ inversion dec_Some.
+ rewrite FieldToZ_ZToField.
+ rewrite Z.mod_small by (pose proof (N2Z.is_nonneg (wordToN w)); try omega).
+ rewrite N2Z.id.
+ apply NToWord_wordToN.
+ Qed.
+
+End ModularWordEncodingPre.
diff --git a/src/Encoding/ModularWordEncodingTheorems.v b/src/Encoding/ModularWordEncodingTheorems.v
new file mode 100644
index 000000000..7251ac1e6
--- /dev/null
+++ b/src/Encoding/ModularWordEncodingTheorems.v
@@ -0,0 +1,54 @@
+Require Import Coq.ZArith.ZArith Coq.ZArith.Znumtheory.
+Require Import Coq.Numbers.Natural.Peano.NPeano.
+Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems Crypto.ModularArithmetic.ModularArithmeticTheorems.
+Require Import Bedrock.Word.
+Require Import Crypto.Tactics.VerdiTactics.
+Require Import Crypto.Spec.Encoding.
+Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Spec.ModularWordEncoding.
+
+
+Local Open Scope F_scope.
+
+Section SignBit.
+ Context {m : Z} {prime_m : prime m} {two_lt_m : (2 < m)%Z} {sz : nat} {bound_check : (Z.to_nat m < 2 ^ sz)%nat}.
+
+ Lemma sign_bit_parity : forall x, @sign_bit m sz x = Z.odd x.
+ Proof.
+ unfold sign_bit, Fm_enc; intros.
+ pose proof (shatter_word (NToWord sz (Z.to_N x))) as shatter.
+ case_eq sz; intros; subst; rewrite shatter.
+ + pose proof (prime_ge_2 m prime_m).
+ simpl in bound_check.
+ assert (m < 1)%Z by (apply Z2Nat.inj_lt; try omega; assumption).
+ omega.
+ + assert (0 < m)%Z as m_pos by (pose proof prime_ge_2 m prime_m; omega).
+ pose proof (FieldToZ_range x m_pos).
+ destruct (FieldToZ x); auto.
+ - destruct p; auto.
+ - pose proof (Pos2Z.neg_is_neg p); omega.
+ Qed.
+
+ Lemma sign_bit_zero : @sign_bit m sz 0 = false.
+ Proof.
+ rewrite sign_bit_parity; auto.
+ Qed.
+
+ Lemma sign_bit_opp : forall (x : F m), x <> 0 -> negb (@sign_bit m sz x) = @sign_bit m sz (opp x).
+ Proof.
+ intros.
+ pose proof sign_bit_zero as sign_zero.
+ rewrite !sign_bit_parity in *.
+ pose proof (F_opp_spec x) as opp_spec_x.
+ apply F_eq in opp_spec_x.
+ rewrite FieldToZ_add in opp_spec_x.
+ rewrite <-opp_spec_x, Z_odd_mod in sign_zero by (pose proof prime_ge_2 m prime_m; omega).
+ replace (Z.odd m) with true in sign_zero by (destruct (ZUtil.prime_odd_or_2 m prime_m); auto || omega).
+ rewrite Z.odd_add, F_FieldToZ_add_opp, Z.div_same, Bool.xorb_true_r in sign_zero by assumption || omega.
+ apply Bool.xorb_eq.
+ rewrite <-Bool.negb_xorb_l.
+ assumption.
+ Qed.
+
+End SignBit. \ No newline at end of file
diff --git a/src/Encoding/PointEncodingPre.v b/src/Encoding/PointEncodingPre.v
new file mode 100644
index 000000000..73ced869b
--- /dev/null
+++ b/src/Encoding/PointEncodingPre.v
@@ -0,0 +1,275 @@
+Require Import Coq.ZArith.ZArith Coq.ZArith.Znumtheory.
+Require Import Coq.Numbers.Natural.Peano.NPeano.
+Require Import Coq.Program.Equality.
+Require Import Crypto.Encoding.EncodingTheorems.
+Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import Bedrock.Word.
+Require Import Crypto.Encoding.ModularWordEncodingTheorems.
+Require Import Crypto.Tactics.VerdiTactics.
+Require Import Crypto.Util.ZUtil.
+
+Require Import Crypto.Spec.Encoding Crypto.Spec.ModularWordEncoding Crypto.Spec.ModularArithmetic.
+
+Local Open Scope F_scope.
+
+Section PointEncoding.
+ Context {prm: TwistedEdwardsParams} {sz : nat} {sz_nonzero : (0 < sz)%nat}
+ {bound_check : (Z.to_nat q < 2 ^ sz)%nat} {q_5mod8 : (q mod 8 = 5)%Z}
+ {sqrt_minus1_valid : (@ZToField q 2 ^ Z.to_N (q / 4)) ^ 2 = opp 1}
+ {FqEncoding : canonical encoding of (F q) as (word sz)}
+ {sign_bit : F q -> bool} {sign_bit_zero : sign_bit 0 = false}
+ {sign_bit_opp : forall x, x <> 0 -> negb (sign_bit x) = sign_bit (opp x)}.
+ Existing Instance prime_q.
+
+ Add Field Ffield : (@Ffield_theory q _)
+ (morphism (@Fring_morph q),
+ preprocess [Fpreprocess],
+ postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption],
+ constants [Fconstant],
+ div (@Fmorph_div_theory q),
+ power_tac (@Fpower_theory q) [Fexp_tac]).
+
+ Definition sqrt_valid (a : F q) := ((sqrt_mod_q a) ^ 2 = a)%F.
+
+ Lemma solve_sqrt_valid : forall p, E.onCurve p ->
+ sqrt_valid (E.solve_for_x2 (snd p)).
+ Proof.
+ intros ? onCurve_xy.
+ destruct p as [x y]; simpl.
+ rewrite (E.solve_correct x y) in onCurve_xy.
+ rewrite <- onCurve_xy.
+ unfold sqrt_valid.
+ eapply sqrt_mod_q_valid; eauto.
+ unfold isSquare; eauto.
+ Grab Existential Variables. eauto.
+ Qed.
+
+ Lemma solve_onCurve: forall (y : F q), sqrt_valid (E.solve_for_x2 y) ->
+ E.onCurve (sqrt_mod_q (E.solve_for_x2 y), y).
+ Proof.
+ intros.
+ unfold sqrt_valid in *.
+ apply E.solve_correct; auto.
+ Qed.
+
+ Lemma solve_opp_onCurve: forall (y : F q), sqrt_valid (E.solve_for_x2 y) ->
+ E.onCurve (opp (sqrt_mod_q (E.solve_for_x2 y)), y).
+ Proof.
+ intros y sqrt_valid_x2.
+ unfold sqrt_valid in *.
+ apply E.solve_correct.
+ rewrite <- sqrt_valid_x2 at 2.
+ ring.
+ Qed.
+
+ Definition point_enc_coordinates (p : (F q * F q)) : Word.word (S sz) := let '(x,y) := p in
+ Word.WS (sign_bit x) (enc y).
+
+ Let point_enc (p : E.point) : Word.word (S sz) := let '(x,y) := proj1_sig p in
+ Word.WS (sign_bit x) (enc y).
+
+ Definition point_dec_coordinates (sign_bit : F q -> bool) (w : Word.word (S sz)) : option (F q * F q) :=
+ match dec (Word.wtl w) with
+ | None => None
+ | Some y => let x2 := E.solve_for_x2 y in
+ let x := sqrt_mod_q x2 in
+ if F_eq_dec (x ^ 2) x2
+ then
+ let p := (if Bool.eqb (whd w) (sign_bit x) then x else opp x, y) in
+ if (andb (F_eqb x 0) (whd w))
+ then None (* special case for 0, since its opposite has the same sign; if the sign bit of 0 is 1, produce None.*)
+ else Some p
+ else None
+ end.
+
+ Ltac inversion_Some_eq := match goal with [H: Some ?x = Some ?y |- _] => inversion H; subst end.
+
+ Lemma point_dec_coordinates_onCurve : forall w p, point_dec_coordinates sign_bit w = Some p -> E.onCurve p.
+ Proof.
+ unfold point_dec_coordinates; intros.
+ edestruct dec; [ | congruence].
+ break_if; [ | congruence].
+ break_if; [ congruence | ].
+ break_if; inversion_Some_eq; auto using solve_onCurve, solve_opp_onCurve.
+ Qed.
+
+ Lemma prod_eq_dec : forall {A} (A_eq_dec : forall a a' : A, {a = a'} + {a <> a'})
+ (x y : (A * A)), {x = y} + {x <> y}.
+ Proof.
+ decide equality.
+ Qed.
+
+ Lemma option_eq_dec : forall {A} (A_eq_dec : forall a a' : A, {a = a'} + {a <> a'})
+ (x y : option A), {x = y} + {x <> y}.
+ Proof.
+ decide equality.
+ Qed.
+
+ Definition point_dec' w p : option E.point :=
+ match (option_eq_dec (prod_eq_dec F_eq_dec) (point_dec_coordinates sign_bit w) (Some p)) with
+ | left EQ => Some (exist _ p (point_dec_coordinates_onCurve w p EQ))
+ | right _ => None (* this case is never reached *)
+ end.
+
+ Definition point_dec (w : word (S sz)) : option E.point :=
+ match (point_dec_coordinates sign_bit w) with
+ | Some p => point_dec' w p
+ | None => None
+ end.
+
+ Lemma point_coordinates_encoding_canonical : forall w p,
+ point_dec_coordinates sign_bit w = Some p -> point_enc_coordinates p = w.
+ Proof.
+ unfold point_dec_coordinates, point_enc_coordinates; intros ? ? coord_dec_Some.
+ case_eq (dec (wtl w)); [ intros ? dec_Some | intros dec_None; rewrite dec_None in *; congruence ].
+ destruct p.
+ rewrite (shatter_word w).
+ f_equal; rewrite dec_Some in *;
+ do 2 (break_if; try congruence); inversion coord_dec_Some; subst.
+ + destruct (F_eq_dec (sqrt_mod_q (E.solve_for_x2 f1)) 0%F) as [sqrt_0 | ?].
+ - rewrite sqrt_0 in *.
+ apply sqrt_mod_q_root_0 in sqrt_0; try assumption.
+ rewrite sqrt_0 in *.
+ break_if; [symmetry; auto using Bool.eqb_prop | ].
+ rewrite sign_bit_zero in *.
+ simpl in Heqb; rewrite Heqb in *.
+ discriminate.
+ - break_if.
+ symmetry; auto using Bool.eqb_prop.
+ rewrite <- sign_bit_opp by assumption.
+ destruct (whd w); inversion Heqb0; break_if; auto.
+ + inversion coord_dec_Some; subst.
+ auto using encoding_canonical.
+Qed.
+
+ Lemma point_encoding_canonical : forall w x, point_dec w = Some x -> point_enc x = w.
+ Proof.
+ (*
+ unfold point_enc; intros.
+ unfold point_dec in *.
+ assert (point_dec_coordinates w = Some (proj1_sig x)). {
+ set (y := point_dec_coordinates w) in *.
+ revert H.
+ dependent destruction y. intros.
+ rewrite H0 in H.
+ *)
+ Admitted.
+
+Lemma point_dec_coordinates_correct w
+ : option_map (@proj1_sig _ _) (point_dec w) = point_dec_coordinates sign_bit w.
+Proof.
+ unfold point_dec, option_map.
+ do 2 break_match; try congruence; unfold point_dec' in *;
+ break_match; try congruence.
+ inversion_Some_eq.
+ reflexivity.
+Qed.
+
+Lemma y_decode : forall p, dec (wtl (point_enc_coordinates p)) = Some (snd p).
+Proof.
+ intros.
+ destruct p as [x y]; simpl.
+ exact (encoding_valid y).
+Qed.
+
+Lemma sign_bit_opp_eq_iff : forall x y, y <> 0 ->
+ (sign_bit x <> sign_bit y <-> sign_bit x = sign_bit (opp y)).
+Proof.
+ split; intro sign_mismatch; case_eq (sign_bit x); case_eq (sign_bit y);
+ try congruence; intros y_sign x_sign; rewrite <- sign_bit_opp in * by auto;
+ rewrite y_sign, x_sign in *; reflexivity || discriminate.
+Qed.
+
+Lemma sign_bit_squares : forall x y, y <> 0 -> x ^ 2 = y ^ 2 ->
+ sign_bit x = sign_bit y -> x = y.
+Proof.
+ intros ? ? y_nonzero squares_eq sign_match.
+ destruct (sqrt_solutions _ _ squares_eq) as [? | eq_opp]; auto.
+ assert (sign_bit x = sign_bit (opp y)) as sign_mismatch by (f_equal; auto).
+ apply sign_bit_opp_eq_iff in sign_mismatch; auto.
+ congruence.
+Qed.
+
+Lemma sign_bit_match : forall x x' y : F q, E.onCurve (x, y) -> E.onCurve (x', y) ->
+ sign_bit x = sign_bit x' -> x = x'.
+Proof.
+ intros ? ? ? onCurve_x onCurve_x' sign_match.
+ apply E.solve_correct in onCurve_x.
+ apply E.solve_correct in onCurve_x'.
+ destruct (F_eq_dec x' 0).
+ + subst.
+ rewrite Fq_pow_zero in onCurve_x' by congruence.
+ rewrite <- onCurve_x' in *.
+ eapply Fq_root_zero; eauto.
+ + apply sign_bit_squares; auto.
+ rewrite onCurve_x, onCurve_x'.
+ reflexivity.
+Qed.
+
+Lemma point_encoding_coordinates_valid : forall p, E.onCurve p ->
+ point_dec_coordinates sign_bit (point_enc_coordinates p) = Some p.
+Proof.
+ intros p onCurve_p.
+ unfold point_dec_coordinates.
+ rewrite y_decode.
+ pose proof (solve_sqrt_valid p onCurve_p) as solve_sqrt_valid_p.
+ destruct p as [x y].
+ unfold sqrt_valid in *.
+ simpl.
+ replace (E.solve_for_x2 y) with (x ^ 2 : F q) in * by (apply E.solve_correct; assumption).
+ case_eq (F_eqb x 0); intro eqb_x_0.
+ + apply F_eqb_eq in eqb_x_0; rewrite eqb_x_0 in *.
+ rewrite !Fq_pow_zero, sqrt_mod_q_of_0, Fq_pow_zero by congruence.
+ rewrite if_F_eq_dec_if_F_eqb, sign_bit_zero.
+ reflexivity.
+ + assert (sqrt_mod_q (x ^ 2) <> 0) by (intro false_eq; apply sqrt_mod_q_root_0 in false_eq; try assumption;
+ apply Fq_root_zero in false_eq; rewrite false_eq, F_eqb_refl in eqb_x_0; congruence).
+ replace (F_eqb (sqrt_mod_q (x ^ 2)) 0) with false by (symmetry;
+ apply F_eqb_neq_complete; assumption).
+ break_if.
+ - simpl.
+ f_equal.
+ break_if.
+ * rewrite Bool.eqb_true_iff in Heqb.
+ pose proof (solve_onCurve y solve_sqrt_valid_p).
+ f_equal.
+ apply (sign_bit_match _ _ y); auto.
+ apply E.solve_correct in onCurve_p; rewrite onCurve_p in *.
+ assumption.
+ * rewrite Bool.eqb_false_iff in Heqb.
+ pose proof (solve_opp_onCurve y solve_sqrt_valid_p).
+ f_equal.
+ apply sign_bit_opp_eq_iff in Heqb; try assumption.
+ apply (sign_bit_match _ _ y); auto.
+ apply E.solve_correct in onCurve_p.
+ rewrite onCurve_p; auto.
+ - simpl in solve_sqrt_valid_p.
+ replace (E.solve_for_x2 y) with (x ^ 2 : F q) in * by (apply E.solve_correct; assumption).
+ congruence.
+Qed.
+
+Lemma point_dec'_valid : forall p,
+ point_dec' (point_enc_coordinates (proj1_sig p)) (proj1_sig p) = Some p.
+Proof.
+ unfold point_dec'; intros.
+ break_match.
+ + f_equal.
+ destruct p.
+ apply E.point_eq.
+ reflexivity.
+ + rewrite point_encoding_coordinates_valid in n by apply (proj2_sig p).
+ congruence.
+Qed.
+
+Lemma point_encoding_valid : forall p, point_dec (point_enc p) = Some p.
+Proof.
+ intros.
+ unfold point_dec.
+ replace (point_enc p) with (point_enc_coordinates (proj1_sig p)) by reflexivity.
+ break_match; rewrite point_encoding_coordinates_valid in * by apply (proj2_sig p); try congruence.
+ inversion_Some_eq.
+ eapply point_dec'_valid.
+Qed.
+
+End PointEncoding.
diff --git a/src/Encoding/PointEncodingTheorems.v b/src/Encoding/PointEncodingTheorems.v
new file mode 100644
index 000000000..ccea1d81b
--- /dev/null
+++ b/src/Encoding/PointEncodingTheorems.v
@@ -0,0 +1,207 @@
+Require Import Coq.ZArith.ZArith Coq.ZArith.Znumtheory.
+Require Import Coq.Numbers.Natural.Peano.NPeano.
+Require Import Coq.Program.Equality.
+Require Import Crypto.Encoding.EncodingTheorems.
+Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import Bedrock.Word.
+Require Import Crypto.Tactics.VerdiTactics.
+
+Require Import Crypto.Spec.Encoding Crypto.Spec.ModularArithmetic Crypto.Spec.CompleteEdwardsCurve.
+
+Local Open Scope F_scope.
+
+Section PointEncoding.
+ Context {prm: CompleteEdwardsCurve.TwistedEdwardsParams} {sz : nat}
+ {FqEncoding : canonical encoding of ModularArithmetic.F (CompleteEdwardsCurve.q) as Word.word sz}
+ {q_5mod8 : (CompleteEdwardsCurve.q mod 8 = 5)%Z}
+ {sqrt_minus1_valid : (@ZToField CompleteEdwardsCurve.q 2 ^ BinInt.Z.to_N (CompleteEdwardsCurve.q / 4)) ^ 2 = opp 1}.
+ Existing Instance CompleteEdwardsCurve.prime_q.
+
+ Add Field Ffield : (@PrimeFieldTheorems.Ffield_theory CompleteEdwardsCurve.q _)
+ (morphism (@ModularArithmeticTheorems.Fring_morph CompleteEdwardsCurve.q),
+ preprocess [ModularArithmeticTheorems.Fpreprocess],
+ postprocess [ModularArithmeticTheorems.Fpostprocess; try exact PrimeFieldTheorems.Fq_1_neq_0; try assumption],
+ constants [ModularArithmeticTheorems.Fconstant],
+ div (@ModularArithmeticTheorems.Fmorph_div_theory CompleteEdwardsCurve.q),
+ power_tac (@ModularArithmeticTheorems.Fpower_theory CompleteEdwardsCurve.q) [ModularArithmeticTheorems.Fexp_tac]).
+
+ Definition sqrt_valid (a : F q) := ((sqrt_mod_q a) ^ 2 = a)%F.
+
+ Lemma solve_sqrt_valid : forall (p : E.point),
+ sqrt_valid (E.solve_for_x2 (snd (proj1_sig p))).
+ Proof.
+ intros.
+ destruct p as [[x y] onCurve_xy]; simpl.
+ rewrite (E.solve_correct x y) in onCurve_xy.
+ rewrite <- onCurve_xy.
+ unfold sqrt_valid.
+ eapply sqrt_mod_q_valid; eauto.
+ unfold isSquare; eauto.
+ Grab Existential Variables. eauto.
+ Qed.
+
+ Lemma solve_onCurve: forall (y : F q), sqrt_valid (E.solve_for_x2 y) ->
+ E.onCurve (sqrt_mod_q (E.solve_for_x2 y), y).
+ Proof.
+ intros.
+ unfold sqrt_valid in *.
+ apply E.solve_correct; auto.
+ Qed.
+
+ Lemma solve_opp_onCurve: forall (y : F q), sqrt_valid (E.solve_for_x2 y) ->
+ E.onCurve (opp (sqrt_mod_q (E.solve_for_x2 y)), y).
+ Proof.
+ intros y sqrt_valid_x2.
+ unfold sqrt_valid in *.
+ apply E.solve_correct.
+ rewrite <- sqrt_valid_x2 at 2.
+ ring.
+ Qed.
+
+Definition sign_bit (x : F q) := (wordToN (enc (opp x)) <? wordToN (enc x))%N.
+Definition point_enc (p : E.point) : word (S sz) := let '(x,y) := proj1_sig p in
+ WS (sign_bit x) (enc y).
+Definition point_dec_coordinates (w : word (S sz)) : option (F q * F q) :=
+ match dec (wtl w) with
+ | None => None
+ | Some y => let x2 := E.solve_for_x2 y in
+ let x := sqrt_mod_q x2 in
+ if F_eq_dec (x ^ 2) x2
+ then
+ let p := (if Bool.eqb (whd w) (sign_bit x) then x else opp x, y) in
+ Some p
+ else None
+ end.
+
+Definition point_dec (w : word (S sz)) : option E.point :=
+ match dec (wtl w) with
+ | None => None
+ | Some y => let x2 := E.solve_for_x2 y in
+ let x := sqrt_mod_q x2 in
+ match (F_eq_dec (x ^ 2) x2) with
+ | right _ => None
+ | left EQ => if Bool.eqb (whd w) (sign_bit x)
+ then Some (exist _ (x, y) (solve_onCurve y EQ))
+ else Some (exist _ (opp x, y) (solve_opp_onCurve y EQ))
+ end
+ end.
+
+Lemma point_dec_coordinates_correct w
+ : option_map (@proj1_sig _ _) (point_dec w) = point_dec_coordinates w.
+Proof.
+ unfold point_dec, point_dec_coordinates.
+ edestruct dec; [ | reflexivity ].
+ edestruct @F_eq_dec; [ | reflexivity ].
+ edestruct @Bool.eqb; reflexivity.
+Qed.
+
+Lemma y_decode : forall p, dec (wtl (point_enc p)) = Some (snd (proj1_sig p)).
+Proof.
+ intros.
+ destruct p as [[x y] onCurve_p]; simpl.
+ exact (encoding_valid y).
+Qed.
+
+
+Lemma wordToN_enc_neq_opp : forall x, x <> 0 -> (wordToN (enc (opp x)) <> wordToN (enc x))%N.
+Proof.
+ intros x x_nonzero.
+ intro false_eq.
+ apply x_nonzero.
+ apply F_eq_opp_zero; try apply two_lt_q.
+ apply wordToN_inj in false_eq.
+ apply encoding_inj in false_eq.
+ auto.
+Qed.
+
+Lemma sign_bit_opp_negb : forall x, x <> 0 -> negb (sign_bit x) = sign_bit (opp x).
+Proof.
+ intros x x_nonzero.
+ unfold sign_bit.
+ rewrite <- N.leb_antisym.
+ rewrite N.ltb_compare, N.leb_compare.
+ rewrite F_opp_involutive.
+ case_eq (wordToN (enc x) ?= wordToN (enc (opp x)))%N; auto.
+ intro wordToN_enc_eq.
+ pose proof (wordToN_enc_neq_opp x x_nonzero).
+ apply N.compare_eq_iff in wordToN_enc_eq.
+ congruence.
+Qed.
+
+Lemma sign_bit_opp : forall x y, y <> 0 ->
+ (sign_bit x <> sign_bit y <-> sign_bit x = sign_bit (opp y)).
+Proof.
+ split; intro sign_mismatch; case_eq (sign_bit x); case_eq (sign_bit y);
+ try congruence; intros y_sign x_sign; rewrite <- sign_bit_opp_negb in * by auto;
+ rewrite y_sign, x_sign in *; reflexivity || discriminate.
+Qed.
+
+Lemma sign_bit_squares : forall x y, y <> 0 -> x ^ 2 = y ^ 2 ->
+ sign_bit x = sign_bit y -> x = y.
+Proof.
+ intros ? ? y_nonzero squares_eq sign_match.
+ destruct (sqrt_solutions _ _ squares_eq) as [? | eq_opp]; auto.
+ assert (sign_bit x = sign_bit (opp y)) as sign_mismatch by (f_equal; auto).
+ apply sign_bit_opp in sign_mismatch; auto.
+ congruence.
+Qed.
+
+Lemma sign_bit_match : forall x x' y : F q, E.onCurve (x, y) -> E.onCurve (x', y) ->
+ sign_bit x = sign_bit x' -> x = x'.
+Proof.
+ intros ? ? ? onCurve_x onCurve_x' sign_match.
+ apply E.solve_correct in onCurve_x.
+ apply E.solve_correct in onCurve_x'.
+ destruct (F_eq_dec x' 0).
+ + subst.
+ rewrite Fq_pow_zero in onCurve_x' by congruence.
+ rewrite <- onCurve_x' in *.
+ eapply Fq_root_zero; eauto.
+ + apply sign_bit_squares; auto.
+ rewrite onCurve_x, onCurve_x'.
+ reflexivity.
+Qed.
+
+Lemma point_encoding_valid : forall p, point_dec (point_enc p) = Some p.
+Proof.
+ intros.
+ unfold point_dec.
+ rewrite y_decode.
+ pose proof solve_sqrt_valid p as solve_sqrt_valid_p.
+ unfold sqrt_valid in *.
+ destruct p as [[x y] onCurve_p].
+ simpl in *.
+ destruct (F_eq_dec ((sqrt_mod_q (E.solve_for_x2 y)) ^ 2) (E.solve_for_x2 y)); intuition.
+ break_if; f_equal; apply E.point_eq.
+ + rewrite Bool.eqb_true_iff in Heqb.
+ pose proof (solve_onCurve y solve_sqrt_valid_p).
+ f_equal.
+ apply (sign_bit_match _ _ y); auto.
+ + rewrite Bool.eqb_false_iff in Heqb.
+ pose proof (solve_opp_onCurve y solve_sqrt_valid_p).
+ f_equal.
+ apply sign_bit_opp in Heqb.
+ apply (sign_bit_match _ _ y); auto.
+ intro eq_zero.
+ apply E.solve_correct in onCurve_p.
+ rewrite eq_zero in *.
+ rewrite Fq_pow_zero in solve_sqrt_valid_p by congruence.
+ rewrite <- solve_sqrt_valid_p in onCurve_p.
+ apply Fq_root_zero in onCurve_p.
+ rewrite onCurve_p in Heqb; auto.
+Qed.
+
+(* Waiting on canonicalization *)
+Lemma point_encoding_canonical : forall (x_enc : word (S sz)) (x : E.point),
+point_dec x_enc = Some x -> point_enc x = x_enc.
+Admitted.
+
+Instance point_encoding : canonical encoding of E.point as (word (S sz)) := {
+ enc := point_enc;
+ dec := point_dec;
+ encoding_valid := point_encoding_valid;
+ encoding_canonical := point_encoding_canonical
+}.
+
+End PointEncoding.
diff --git a/src/ModularArithmetic/ExtendedBaseVector.v b/src/ModularArithmetic/ExtendedBaseVector.v
new file mode 100644
index 000000000..2e65df9bd
--- /dev/null
+++ b/src/ModularArithmetic/ExtendedBaseVector.v
@@ -0,0 +1,162 @@
+Require Import Zpower ZArith.
+Require Import List.
+Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import VerdiTactics.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
+Require Crypto.BaseSystem.
+Local Open Scope Z_scope.
+
+Section ExtendedBaseVector.
+ Context `{prm : PseudoMersenneBaseParams}.
+
+ (* This section defines a new BaseVector that has double the length of the BaseVector
+ * used to construct [params]. The coefficients of the new vector are as follows:
+ *
+ * ext_base[i] = if (i < length base) then base[i] else 2^k * base[i]
+ *
+ * The purpose of this construction is that it allows us to multiply numbers expressed
+ * using [base], obtaining a number expressed using [ext_base]. (Numbers are "expressed" as
+ * vectors of digits; the value of a digit vector is obtained by doing a dot product with
+ * the base vector.) So if x, y are digit vectors:
+ *
+ * (x \dot base) * (y \dot base) = (z \dot ext_base)
+ *
+ * Then we can separate z into its first and second halves:
+ *
+ * (z \dot ext_base) = (z1 \dot base) + (2 ^ k) * (z2 \dot base)
+ *
+ * Now, if we want to reduce the product modulo 2 ^ k - c:
+ *
+ * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + (2 ^ k) * (z2 \dot base) mod (2^k-c)
+ * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + c * (z2 \dot base) mod (2^k-c)
+ *
+ * This sum may be short enough to express using base; if not, we can reduce again.
+ *)
+ Definition ext_base := base ++ (map (Z.mul (2^k)) base).
+
+ Lemma ext_base_positive : forall b, In b ext_base -> b > 0.
+ Proof.
+ unfold ext_base. intros b In_b_base.
+ rewrite in_app_iff in In_b_base.
+ destruct In_b_base as [In_b_base | In_b_extbase].
+ + eapply BaseSystem.base_positive.
+ eapply In_b_base.
+ + eapply in_map_iff in In_b_extbase.
+ destruct In_b_extbase as [b' [b'_2k_b In_b'_base]].
+ subst.
+ specialize (BaseSystem.base_positive b' In_b'_base); intro base_pos.
+ replace 0 with (2 ^ k * 0) by ring.
+ apply (Zmult_gt_compat_l b' 0 (2 ^ k)); [| apply base_pos; intuition].
+ rewrite Z.gt_lt_iff.
+ apply Z.pow_pos_nonneg; intuition.
+ pose proof k_nonneg; omega.
+ Qed.
+
+ Lemma base_length_nonzero : (0 < length base)%nat.
+ Proof.
+ assert (nth_default 0 base 0 = 1) by (apply BaseSystem.b0_1).
+ unfold nth_default in H.
+ case_eq (nth_error base 0); intros;
+ try (rewrite H0 in H; omega).
+ apply (nth_error_value_length _ 0 base z); auto.
+ Qed.
+
+ Lemma b0_1 : forall x, nth_default x ext_base 0 = 1.
+ Proof.
+ intros. unfold ext_base.
+ rewrite nth_default_app.
+ assert (0 < length base)%nat by (apply base_length_nonzero).
+ destruct (lt_dec 0 (length base)); try apply BaseSystem.b0_1; try omega.
+ Qed.
+
+ Lemma two_k_nonzero : 2^k <> 0.
+ Proof.
+ pose proof (Z.pow_eq_0 2 k k_nonneg).
+ intuition.
+ Qed.
+
+ Lemma map_nth_default_base_high : forall n, (n < (length base))%nat ->
+ nth_default 0 (map (Z.mul (2 ^ k)) base) n =
+ (2 ^ k) * (nth_default 0 base n).
+ Proof.
+ intros.
+ erewrite map_nth_default; auto.
+ Qed.
+
+ Lemma base_good_over_boundary : forall
+ (i : nat)
+ (l : (i < length base)%nat)
+ (j' : nat)
+ (Hj': (i + j' < length base)%nat)
+ ,
+ 2 ^ k * (nth_default 0 base i * nth_default 0 base j') =
+ 2 ^ k * (nth_default 0 base i * nth_default 0 base j') /
+ (2 ^ k * nth_default 0 base (i + j')) *
+ (2 ^ k * nth_default 0 base (i + j'))
+ .
+ Proof.
+ intros.
+ remember (nth_default 0 base) as b.
+ rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero).
+ replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat))
+ with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring.
+ rewrite Z.mul_cancel_l by (exact two_k_nonzero).
+ replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat))
+ with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring.
+ subst b.
+ apply (BaseSystem.base_good i j'); omega.
+ Qed.
+
+ Lemma ext_base_good :
+ forall i j, (i+j < length ext_base)%nat ->
+ let b := nth_default 0 ext_base in
+ let r := (b i * b j) / b (i+j)%nat in
+ b i * b j = r * b (i+j)%nat.
+ Proof.
+ intros.
+ subst b. subst r.
+ unfold ext_base in *.
+ rewrite app_length in H; rewrite map_length in H.
+ repeat rewrite nth_default_app.
+ destruct (lt_dec i (length base));
+ destruct (lt_dec j (length base));
+ destruct (lt_dec (i + j) (length base));
+ try omega.
+ { (* i < length base, j < length base, i + j < length base *)
+ apply BaseSystem.base_good; auto.
+ } { (* i < length base, j < length base, i + j >= length base *)
+ rewrite (map_nth_default _ _ _ _ 0) by omega.
+ apply base_matches_modulus; omega.
+ } { (* i < length base, j >= length base, i + j >= length base *)
+ do 2 rewrite map_nth_default_base_high by omega.
+ remember (j - length base)%nat as j'.
+ replace (i + j - length base)%nat with (i + j')%nat by omega.
+ replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j'))
+ with (2 ^ k * (nth_default 0 base i * nth_default 0 base j'))
+ by ring.
+ eapply base_good_over_boundary; eauto; omega.
+ } { (* i >= length base, j < length base, i + j >= length base *)
+ do 2 rewrite map_nth_default_base_high by omega.
+ remember (i - length base)%nat as i'.
+ replace (i + j - length base)%nat with (j + i')%nat by omega.
+ replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j)
+ with (2 ^ k * (nth_default 0 base j * nth_default 0 base i'))
+ by ring.
+ eapply base_good_over_boundary; eauto; omega.
+ }
+ Qed.
+
+ Instance ExtBaseVector : BaseSystem.BaseVector ext_base := {
+ base_positive := ext_base_positive;
+ b0_1 := b0_1;
+ base_good := ext_base_good
+ }.
+
+ Lemma extended_base_length:
+ length ext_base = (length base + length base)%nat.
+ Proof.
+ unfold ext_base; rewrite app_length; rewrite map_length; auto.
+ Qed.
+End ExtendedBaseVector.
+
diff --git a/src/ModularArithmetic/ModularArithmeticTheorems.v b/src/ModularArithmetic/ModularArithmeticTheorems.v
index ddb689547..dabfcf883 100644
--- a/src/ModularArithmetic/ModularArithmeticTheorems.v
+++ b/src/ModularArithmetic/ModularArithmeticTheorems.v
@@ -6,6 +6,7 @@ Require Import Crypto.Tactics.VerdiTactics.
Require Import Coq.ZArith.BinInt Coq.ZArith.Zdiv Coq.ZArith.Znumtheory Coq.NArith.NArith. (* import Zdiv before Znumtheory *)
Require Import Coq.Classes.Morphisms Coq.Setoids.Setoid.
Require Export Coq.setoid_ring.Ring_theory Coq.setoid_ring.Field_theory Coq.setoid_ring.Field_tac.
+Require Export Crypto.Util.IterAssocOp.
Section ModularArithmeticPreliminaries.
Context {m:Z}.
@@ -209,6 +210,14 @@ Section FandZ.
reflexivity.
Qed.
+ Lemma pow_nat_iter_op_correct: forall (x:F m) n, (nat_iter_op mul 1) (N.to_nat n) x = x^n.
+ Proof.
+ induction n using N.peano_ind;
+ destruct (F_pow_spec x) as [pow_0 pow_succ];
+ rewrite ?N2Nat.inj_succ, ?pow_0, <-?N.add_1_l, ?pow_succ;
+ simpl; congruence.
+ Qed.
+
Lemma mod_plus_zero_subproof a b : 0 mod m = (a + b) mod m ->
b mod m = (- a) mod m.
Proof.
@@ -513,6 +522,11 @@ Section VariousModulo.
ring.
Qed.
+ Lemma F_opp_0 : opp (0 : F m) = 0%F.
+ Proof.
+ intros; ring.
+ Qed.
+
Lemma F_opp_swap : forall x y : F m, opp x = y <-> x = opp y.
Proof.
split; intro; subst; ring.
@@ -523,6 +537,23 @@ Section VariousModulo.
intros; ring.
Qed.
+ Lemma F_square_opp : forall x : F m, (opp x ^ 2 = x ^ 2)%F.
+ Proof.
+ intros; ring.
+ Qed.
+
+ Lemma F_mul_opp_r : forall x y : F m, (x * opp y = opp (x * y))%F.
+ intros; ring.
+ Qed.
+
+ Lemma F_mul_opp_l : forall x y : F m, (opp x * y = opp (x * y))%F.
+ intros; ring.
+ Qed.
+
+ Lemma F_mul_opp_both : forall x y : F m, (opp x * opp y = x * y)%F.
+ intros; ring.
+ Qed.
+
Lemma F_add_0_r : forall x : F m, (x + 0)%F = x.
Proof.
intros; ring.
@@ -675,4 +706,17 @@ Section VariousModulo.
replace y with ((y - b) + b) by ring.
rewrite Hxayb; ring.
Qed.
+
+ Lemma F_FieldToZ_add_opp : forall x : F m, x <> 0 -> (FieldToZ x + FieldToZ (opp x) = m)%Z.
+ Proof.
+ intros.
+ rewrite FieldToZ_opp.
+ rewrite Z_mod_nz_opp_full, mod_FieldToZ; try omega.
+ rewrite mod_FieldToZ.
+ replace 0%Z with (@FieldToZ m 0) by auto.
+ intro false_eq.
+ rewrite <-F_eq in false_eq.
+ congruence.
+ Qed.
+
End VariousModulo.
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v
index b2bea21cf..558b9a5a2 100644
--- a/src/ModularArithmetic/ModularBaseSystem.v
+++ b/src/ModularArithmetic/ModularBaseSystem.v
@@ -3,623 +3,115 @@ Require Import Coq.Lists.List.
Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.BaseSystem.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
+Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
Require Import Crypto.Tactics.VerdiTactics.
Local Open Scope Z_scope.
-Module Type PseudoMersenneBaseParams (Import B:BaseCoefs) (Import M:Modulus).
- Parameter k : Z.
- Parameter c : Z.
- Axiom modulus_pseudomersenne : modulus = 2^k - c.
-
- Axiom base_matches_modulus :
- forall i j,
- (i < length base)%nat ->
- (j < length base)%nat ->
- (i+j >= length base)%nat->
- let b := nth_default 0 base in
- let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in
- b i * b j = r * (2^k * b (i+j-length base)%nat).
-
- Axiom base_succ : forall i, ((S i) < length base)%nat ->
- let b := nth_default 0 base in
- b (S i) mod b i = 0.
-
- Axiom base_tail_matches_modulus:
- 2^k mod nth_default 0 base (pred (length base)) = 0.
-
- (* Probably implied by modulus_pseudomersenne. *)
- Axiom k_nonneg : 0 <= k.
-
-End PseudoMersenneBaseParams.
-
-Module Type RepZMod (Import M:Modulus).
- Parameter T : Type.
- Parameter encode : F modulus -> T.
- Parameter decode : T -> F modulus.
-
- Parameter rep : T -> F modulus -> Prop.
- Local Notation "u '~=' x" := (rep u x) (at level 70).
- Axiom encode_rep : forall x, encode x ~= x.
- Axiom rep_decode : forall u x, u ~= x -> decode u = x.
-
- Parameter add : T -> T -> T.
- Axiom add_rep : forall u v x y, u ~= x -> v ~= y -> add u v ~= (x+y)%F.
-
- Parameter sub : T -> T -> T.
- Axiom sub_rep : forall u v x y, u ~= x -> v ~= y -> sub u v ~= (x-y)%F.
-
- Parameter mul : T -> T -> T.
- Axiom mul_rep : forall u v x y, u ~= x -> v ~= y -> mul u v ~= (x*y)%F.
-End RepZMod.
-
-Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersenneBaseParams BC M) <: RepZMod M.
- Module EC <: BaseCoefs.
- Definition base := BC.base ++ (map (Z.mul (2^(P.k))) BC.base).
-
- Lemma base_positive : forall b, In b base -> b > 0.
- Proof.
- unfold base. intros.
- rewrite in_app_iff in H.
- destruct H. {
- apply BC.base_positive; auto.
- } {
- specialize BC.base_positive.
- induction BC.base; intros. {
- simpl in H; intuition.
- } {
- simpl in H.
- destruct H; subst.
- replace 0 with (2 ^ P.k * 0) by auto.
- apply (Zmult_gt_compat_l a 0 (2 ^ P.k)).
- rewrite Z.gt_lt_iff.
- apply Z.pow_pos_nonneg; intuition.
- pose proof P.k_nonneg; omega.
- apply H0; left; auto.
- apply IHl; auto.
- intros. apply H0; auto. right; auto.
- }
- }
- Qed.
-
- Lemma base_length_nonzero : (0 < length BC.base)%nat.
- Proof.
- assert (nth_default 0 BC.base 0 = 1) by (apply BC.b0_1).
- unfold nth_default in H.
- case_eq (nth_error BC.base 0); intros;
- try (rewrite H0 in H; omega).
- apply (nth_error_value_length _ 0 BC.base z); auto.
- Qed.
-
- Lemma b0_1 : forall x, nth_default x base 0 = 1.
- Proof.
- intros. unfold base.
- rewrite nth_default_app.
- assert (0 < length BC.base)%nat by (apply base_length_nonzero).
- destruct (lt_dec 0 (length BC.base)); try apply BC.b0_1; try omega.
- Qed.
-
- Lemma two_k_nonzero : 2^P.k <> 0.
- pose proof (Z.pow_eq_0 2 P.k P.k_nonneg).
- intuition.
- Qed.
-
- Lemma map_nth_default_base_high : forall n, (n < (length BC.base))%nat ->
- nth_default 0 (map (Z.mul (2 ^ P.k)) BC.base) n =
- (2 ^ P.k) * (nth_default 0 BC.base n).
- Proof.
- intros.
- erewrite map_nth_default; auto.
- Qed.
-
- Lemma base_succ : forall i, ((S i) < length base)%nat ->
- let b := nth_default 0 base in
- b (S i) mod b i = 0.
- Proof.
- intros; subst b; unfold base.
- repeat rewrite nth_default_app.
- do 2 break_if; try apply P.base_succ; try omega. {
- destruct (lt_eq_lt_dec (S i) (length BC.base)). {
- destruct s; intuition.
- rewrite map_nth_default_base_high by omega.
- replace i with (pred(length BC.base)) by omega.
- rewrite <- Zmult_mod_idemp_l.
- rewrite P.base_tail_matches_modulus; simpl.
- rewrite Zmod_0_l; auto.
- } {
- assert (length BC.base <= i)%nat by (apply lt_n_Sm_le; auto); omega.
- }
- } {
- unfold base in H; rewrite app_length, map_length in H.
- repeat rewrite map_nth_default_base_high by omega.
- rewrite Zmult_mod_distr_l.
- rewrite <- minus_Sn_m by omega.
- rewrite P.base_succ by omega; auto.
- }
- Qed.
-
- Lemma base_good_over_boundary : forall
- (i : nat)
- (l : (i < length BC.base)%nat)
- (j' : nat)
- (Hj': (i + j' < length BC.base)%nat)
- ,
- 2 ^ P.k * (nth_default 0 BC.base i * nth_default 0 BC.base j') =
- 2 ^ P.k * (nth_default 0 BC.base i * nth_default 0 BC.base j') /
- (2 ^ P.k * nth_default 0 BC.base (i + j')) *
- (2 ^ P.k * nth_default 0 BC.base (i + j'))
- .
- Proof. intros.
- remember (nth_default 0 BC.base) as b.
- rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero).
- replace (b i * b j' / b (i + j')%nat * (2 ^ P.k * b (i + j')%nat))
- with ((2 ^ P.k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring.
- rewrite Z.mul_cancel_l by (exact two_k_nonzero).
- replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat))
- with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring.
- subst b.
- apply (BC.base_good i j'); omega.
- Qed.
-
- Lemma base_good :
- forall i j, (i+j < length base)%nat ->
- let b := nth_default 0 base in
- let r := (b i * b j) / b (i+j)%nat in
- b i * b j = r * b (i+j)%nat.
- Proof.
- intros.
- subst b. subst r.
- unfold base in *.
- rewrite app_length in H; rewrite map_length in H.
- repeat rewrite nth_default_app.
- destruct (lt_dec i (length BC.base));
- destruct (lt_dec j (length BC.base));
- destruct (lt_dec (i + j) (length BC.base));
- try omega.
- { (* i < length BC.base, j < length BC.base, i + j < length BC.base *)
- apply BC.base_good; auto.
- } { (* i < length BC.base, j < length BC.base, i + j >= length BC.base *)
- rewrite (map_nth_default _ _ _ _ 0) by omega.
- apply P.base_matches_modulus; omega.
- } { (* i < length BC.base, j >= length BC.base, i + j >= length BC.base *)
- do 2 rewrite map_nth_default_base_high by omega.
- remember (j - length BC.base)%nat as j'.
- replace (i + j - length BC.base)%nat with (i + j')%nat by omega.
- replace (nth_default 0 BC.base i * (2 ^ P.k * nth_default 0 BC.base j'))
- with (2 ^ P.k * (nth_default 0 BC.base i * nth_default 0 BC.base j'))
- by ring.
- eapply base_good_over_boundary; eauto; omega.
- } { (* i >= length BC.base, j < length BC.base, i + j >= length BC.base *)
- do 2 rewrite map_nth_default_base_high by omega.
- remember (i - length BC.base)%nat as i'.
- replace (i + j - length BC.base)%nat with (j + i')%nat by omega.
- replace (2 ^ P.k * nth_default 0 BC.base i' * nth_default 0 BC.base j)
- with (2 ^ P.k * (nth_default 0 BC.base j * nth_default 0 BC.base i'))
- by ring.
- eapply base_good_over_boundary; eauto; omega.
- }
- Qed.
- End EC.
-
- Module E := BaseSystem EC.
- Module B := BaseSystem BC.
-
- Definition T := B.digits.
- Local Hint Unfold T.
- Definition decode (us : T) : F modulus := ZToField (B.decode us).
- Local Hint Unfold decode.
- Definition rep (us : T) (x : F modulus) := (length us <= length BC.base)%nat /\ decode us = x.
+Section PseudoMersenneBase.
+ Context `{prm :PseudoMersenneBaseParams}.
+
+ Definition decode (us : digits) : F modulus := ZToField (BaseSystem.decode base us).
+
+ Definition rep (us : digits) (x : F modulus) := (length us <= length base)%nat /\ decode us = x.
Local Notation "u '~=' x" := (rep u x) (at level 70).
Local Hint Unfold rep.
- Lemma rep_decode : forall us x, us ~= x -> decode us = x.
- Proof.
- autounfold; intuition.
- Qed.
-
- Definition encode (x : F modulus) := B.encode x.
-
- Lemma encode_rep : forall x : F modulus, encode x ~= x.
- Proof.
- intros. unfold encode, rep.
- split. {
- unfold B.encode; simpl.
- apply EC.base_length_nonzero.
- } {
- unfold decode.
- rewrite B.encode_rep.
- apply ZToField_idempotent. (* TODO: rename this lemma *)
- }
- Qed.
-
- Definition add (us vs : T) := B.add us vs.
- Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> add u v ~= (x+y)%F.
- Proof.
- autounfold; intuition. {
- unfold add.
- rewrite B.add_length_le_max.
- case_max; try rewrite Max.max_r; omega.
- }
- unfold decode in *; unfold B.decode in *.
- rewrite B.add_rep.
- rewrite ZToField_add.
- subst; auto.
- Qed.
+ Definition encode (x : F modulus) := encode x.
- Definition sub (us vs : T) := B.sub us vs.
- Lemma sub_rep : forall u v x y, u ~= x -> v ~= y -> sub u v ~= (x-y)%F.
- Proof.
- autounfold; intuition. {
- rewrite B.sub_length_le_max.
- case_max; try rewrite Max.max_r; omega.
- }
- unfold decode in *; unfold B.decode in *.
- rewrite B.sub_rep.
- rewrite ZToField_sub.
- subst; auto.
- Qed.
+ (* Converts from length of extended base to length of base by reduction modulo M.*)
+ Definition reduce (us : digits) : digits :=
+ let high := skipn (length base) us in
+ let low := firstn (length base) us in
+ let wrap := map (Z.mul c) high in
+ BaseSystem.add low wrap.
- Lemma decode_short : forall (us : B.digits),
- (length us <= length BC.base)%nat -> B.decode us = E.decode us.
- Proof.
- intros.
- unfold B.decode, B.decode', E.decode, E.decode'.
- rewrite combine_truncate_r.
- rewrite (combine_truncate_r us EC.base).
- f_equal; f_equal.
- unfold EC.base.
- rewrite firstn_app_inleft; auto; omega.
- Qed.
+ Definition mul (us vs : digits) := reduce (BaseSystem.mul ext_base us vs).
- Lemma extended_base_length:
- length EC.base = (length BC.base + length BC.base)%nat.
- Proof.
- unfold EC.base; rewrite app_length; rewrite map_length; auto.
- Qed.
+ Definition sub (xs : digits) (xs_0_mod : (BaseSystem.decode base xs) mod modulus = 0) (us vs : digits) :=
+ BaseSystem.sub (add xs us) vs.
- Lemma mul_rep_extended : forall (us vs : B.digits),
- (length us <= length BC.base)%nat ->
- (length vs <= length BC.base)%nat ->
- B.decode us * B.decode vs = E.decode (E.mul us vs).
- Proof.
- intros.
- rewrite E.mul_rep by (unfold EC.base; simpl_list; omega).
- f_equal; rewrite decode_short; auto.
- Qed.
+End PseudoMersenneBase.
- (* Converts from length of E.base to length of B.base by reduction modulo M.*)
- Definition reduce (us : E.digits) : B.digits :=
- let high := skipn (length BC.base) us in
- let low := firstn (length BC.base) us in
- let wrap := map (Z.mul P.c) high in
- B.add low wrap.
+Section CarryBasePow2.
+ Context `{prm :PseudoMersenneBaseParams}.
- Lemma two_k_nonzero : 2^P.k <> 0.
- pose proof (Z.pow_eq_0 2 P.k P.k_nonneg).
- intuition.
- Qed.
-
- Lemma modulus_nonzero : modulus <> 0.
- pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega.
- Qed.
-
- (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *)
- Lemma pseudomersenne_add: forall x y, (x + ((2^P.k) * y)) mod modulus = (x + (P.c * y)) mod modulus.
- Proof.
- intros.
- replace (2^P.k) with (((2^P.k) - P.c) + P.c) by auto.
- rewrite Z.mul_add_distr_r.
- rewrite Zplus_mod.
- rewrite <- P.modulus_pseudomersenne.
- rewrite Z.mul_comm.
- rewrite mod_mult_plus; auto using modulus_nonzero.
- rewrite <- Zplus_mod; auto.
- Qed.
-
- Lemma extended_shiftadd: forall (us : E.digits), E.decode us =
- B.decode (firstn (length BC.base) us) +
- (2^P.k * B.decode (skipn (length BC.base) us)).
- Proof.
- intros.
- unfold B.decode, E.decode; rewrite <- B.mul_each_rep.
- replace B.decode' with E.decode' by auto.
- unfold EC.base.
- replace (map (Z.mul (2 ^ P.k)) BC.base) with (E.mul_each (2 ^ P.k) BC.base) by auto.
- rewrite E.base_mul_app.
- rewrite <- E.mul_each_rep; auto.
- Qed.
-
- Lemma reduce_rep : forall us, B.decode (reduce us) mod modulus = (E.decode us) mod modulus.
- Proof.
- intros.
- rewrite extended_shiftadd.
- rewrite pseudomersenne_add.
- unfold reduce.
- remember (firstn (length BC.base) us) as low.
- remember (skipn (length BC.base) us) as high.
- unfold B.decode.
- rewrite B.add_rep.
- replace (map (Z.mul P.c) high) with (B.mul_each P.c high) by auto.
- rewrite B.mul_each_rep; auto.
- Qed.
-
- Lemma reduce_length : forall us,
- (length us <= length EC.base)%nat ->
- (length (reduce us) <= length (BC.base))%nat.
- Proof.
- intros.
- unfold reduce.
- remember (map (Z.mul P.c) (skipn (length BC.base) us)) as high.
- remember (firstn (length BC.base) us) as low.
- assert (length low >= length high)%nat. {
- subst. rewrite firstn_length.
- rewrite map_length.
- rewrite skipn_length.
- destruct (le_dec (length BC.base) (length us)). {
- rewrite Min.min_l by omega.
- rewrite extended_base_length in H. omega.
- } {
- rewrite Min.min_r by omega. omega.
- }
- }
- assert ((length low <= length BC.base)%nat)
- by (rewrite Heqlow; rewrite firstn_length; apply Min.le_min_l).
- assert (length high <= length BC.base)%nat
- by (rewrite Heqhigh; rewrite map_length; rewrite skipn_length;
- rewrite extended_base_length in H; omega).
- rewrite B.add_trailing_zeros; auto.
- rewrite (B.add_same_length _ _ (length low)); auto.
- rewrite app_length.
- rewrite B.length_zeros; intuition.
- Qed.
-
- Definition mul (us vs : T) := reduce (E.mul us vs).
- Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> mul u v ~= (x*y)%F.
- Proof.
- autounfold; unfold mul; intuition. {
- rewrite reduce_length; try omega.
- rewrite E.mul_length.
- rewrite extended_base_length.
- omega.
- }
- rewrite ZToField_mod, reduce_rep, <-ZToField_mod.
- rewrite E.mul_rep; try (rewrite extended_base_length; omega).
- subst; auto.
- replace (E.decode u) with (B.decode u) by (apply decode_short; omega).
- replace (E.decode v) with (B.decode v) by (apply decode_short; omega).
- apply ZToField_mul.
- Qed.
+ Definition log_cap i := nth_default 0 limb_widths i.
Definition add_to_nth n (x:Z) xs :=
set_nth n (x + nth_default 0 xs n) xs.
- Hint Unfold add_to_nth.
-
- (* i must be in the domain of BC.base *)
- Definition cap i :=
- if eq_nat_dec i (pred (length BC.base))
- then (2^P.k) / nth_default 0 BC.base i
- else nth_default 0 BC.base (S i) / nth_default 0 BC.base i.
+
+ Definition pow2_mod n i := Z.land n (Z.ones i).
Definition carry_simple i := fun us =>
let di := nth_default 0 us i in
- let us' := set_nth i (di mod cap i) us in
- add_to_nth (S i) ( (di / cap i)) us'.
+ let us' := set_nth i (pow2_mod di (log_cap i)) us in
+ add_to_nth (S i) ( (Z.shiftr di (log_cap i))) us'.
Definition carry_and_reduce i := fun us =>
let di := nth_default 0 us i in
- let us' := set_nth i (di mod cap i) us in
- add_to_nth 0 (P.c * (di / cap i)) us'.
+ let us' := set_nth i (pow2_mod di (log_cap i)) us in
+ add_to_nth 0 (c * (Z.shiftr di (log_cap i))) us'.
- Definition carry i : B.digits -> B.digits :=
- if eq_nat_dec i (pred (length BC.base))
+ Definition carry i : digits -> digits :=
+ if eq_nat_dec i (pred (length base))
then carry_and_reduce i
else carry_simple i.
- Lemma decode'_splice : forall xs ys bs,
- B.decode' bs (xs ++ ys) =
- B.decode' (firstn (length xs) bs) xs +
- B.decode' (skipn (length xs) bs) ys.
- Proof.
- induction xs; destruct ys, bs; boring.
- unfold B.decode'.
- rewrite combine_truncate_r.
- ring.
- Qed.
-
- Lemma set_nth_sum : forall n x us, (n < length us)%nat ->
- B.decode (set_nth n x us) =
- (x - nth_default 0 us n) * nth_default 0 BC.base n + B.decode us.
- Proof.
- intros.
- unfold B.decode.
- nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*)
- unfold splice_nth.
- rewrite <- (firstn_skipn n us) at 4.
- do 2 rewrite decode'_splice.
- remember (length (firstn n us)) as n0.
- ring_simplify.
- remember (B.decode' (firstn n0 BC.base) (firstn n us)).
- rewrite (skipn_nth_default n us 0) by omega.
- rewrite firstn_length in Heqn0.
- rewrite Min.min_l in Heqn0 by omega; subst n0.
- destruct (le_lt_dec (length BC.base) n). {
- rewrite nth_default_out_of_bounds by auto.
- rewrite skipn_all by omega.
- do 2 rewrite B.decode_base_nil.
- ring_simplify; auto.
- } {
- rewrite (skipn_nth_default n BC.base 0) by omega.
- do 2 rewrite B.decode'_cons.
- ring_simplify; ring.
- }
- Qed.
-
- Lemma add_to_nth_sum : forall n x us, (n < length us)%nat ->
- B.decode (add_to_nth n x us) =
- x * nth_default 0 BC.base n + B.decode us.
- Proof.
- unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto.
- Qed.
-
- Lemma nth_default_base_positive : forall i, (i < length BC.base)%nat ->
- nth_default 0 BC.base i > 0.
- Proof.
- intros.
- pose proof (nth_error_length_exists_value _ _ H).
- destruct H0.
- pose proof (nth_error_value_In _ _ _ H0).
- pose proof (BC.base_positive _ H1).
- unfold nth_default.
- rewrite H0; auto.
- Qed.
-
- Lemma base_succ_div_mult : forall i, ((S i) < length BC.base)%nat ->
- nth_default 0 BC.base (S i) = nth_default 0 BC.base i *
- (nth_default 0 BC.base (S i) / nth_default 0 BC.base i).
- Proof.
- intros.
- apply Z_div_exact_2; try (apply nth_default_base_positive; omega).
- apply P.base_succ; auto.
- Qed.
-
- Lemma base_length_lt_pred : (pred (length BC.base) < length BC.base)%nat.
- Proof.
- pose proof EC.base_length_nonzero; omega.
- Qed.
- Hint Resolve base_length_lt_pred.
-
- Lemma cap_positive: forall i, (i < length BC.base)%nat -> cap i > 0.
- Proof.
- unfold cap; intros; break_if. {
- apply div_positive_gt_0; try (subst; apply P.base_tail_matches_modulus). {
- rewrite <- two_p_equiv.
- apply two_p_gt_ZERO.
- apply P.k_nonneg.
- } {
- apply nth_default_base_positive; subst; auto.
- }
- } {
- apply div_positive_gt_0; try (apply P.base_succ; omega);
- try (apply nth_default_base_positive; omega).
- }
- Qed.
-
- Lemma cap_div_mod : forall us i, (i < (pred (length BC.base)))%nat ->
- let di := nth_default 0 us i in
- (di - (di mod cap i)) * nth_default 0 BC.base i =
- (di / cap i) * nth_default 0 BC.base (S i).
- Proof.
- intros.
- rewrite (Z_div_mod_eq di (cap i)) at 1 by (apply cap_positive; omega);
- ring_simplify.
- unfold cap; break_if; intuition.
- rewrite base_succ_div_mult at 4 by omega; ring.
- Qed.
-
- Lemma carry_simple_decode_eq : forall i us,
- (length us = length BC.base) ->
- (i < (pred (length BC.base)))%nat ->
- B.decode (carry_simple i us) = B.decode us.
- Proof.
- unfold carry_simple. intros.
- rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
- rewrite set_nth_sum by omega.
- rewrite <- cap_div_mod by auto; ring_simplify; auto.
- Qed.
-
- Lemma two_k_div_mul_last :
- 2 ^ P.k = nth_default 0 BC.base (pred (length BC.base)) *
- (2 ^ P.k / nth_default 0 BC.base (pred (length BC.base))).
- Proof.
- intros.
- pose proof P.base_tail_matches_modulus.
- rewrite (Z_div_mod_eq (2 ^ P.k) (nth_default 0 BC.base (pred (length BC.base)))) at 1 by
- (apply nth_default_base_positive; auto); omega.
- Qed.
-
- Lemma cap_div_mod_reduce : forall us,
- let i := pred (length BC.base) in
- let di := nth_default 0 us i in
- (di - (di mod cap i)) * nth_default 0 BC.base i =
- (di / cap i) * 2 ^ P.k.
- Proof.
- intros.
- rewrite (Z_div_mod_eq di (cap i)) at 1 by
- (apply cap_positive; auto); ring_simplify.
- unfold cap; break_if; intuition.
- rewrite Z.mul_comm, Z.mul_assoc.
- subst i; rewrite <- two_k_div_mul_last; auto.
- Qed.
-
- Lemma carry_decode_eq_reduce : forall us,
- (length us = length BC.base) ->
- B.decode (carry_and_reduce (pred (length BC.base)) us) mod modulus
- = B.decode us mod modulus.
- Proof.
- unfold carry_and_reduce; intros.
- pose proof EC.base_length_nonzero.
- rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
- rewrite set_nth_sum by omega.
- rewrite Zplus_comm; rewrite <- Z.mul_assoc.
- rewrite <- pseudomersenne_add.
- rewrite BC.b0_1.
- rewrite (Z.mul_comm (2 ^ P.k)).
- rewrite <- Zred_factor0.
- rewrite <- cap_div_mod_reduce by auto; auto.
- Qed.
-
- Lemma carry_length : forall i us,
- (length us <= length BC.base)%nat ->
- (length (carry i us) <= length BC.base)%nat.
- Proof.
- unfold carry, carry_simple, carry_and_reduce, add_to_nth.
- intros; break_if; subst; repeat (rewrite length_set_nth); auto.
- Qed.
- Hint Resolve carry_length.
-
- Lemma carry_rep : forall i us x,
- (length us = length BC.base) ->
- (i < length BC.base)%nat ->
- us ~= x -> carry i us ~= x.
- Proof.
- pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq.
- unfold rep, decode, carry in *; intros.
- intuition; break_if; subst; eauto;
- apply F_eq; simpl; intuition.
- Qed.
- Hint Resolve carry_rep.
-
Definition carry_sequence is us := fold_right carry us is.
- Lemma carry_sequence_length: forall is us,
- (length us <= length BC.base)%nat ->
- (length (carry_sequence is us) <= length BC.base)%nat.
- Proof.
- induction is; boring.
- Qed.
- Hint Resolve carry_sequence_length.
-
- Lemma carry_length_exact : forall i us,
- (length us = length BC.base)%nat ->
- (length (carry i us) = length BC.base)%nat.
- Proof.
- unfold carry, carry_simple, carry_and_reduce, add_to_nth.
- intros; break_if; subst; repeat (rewrite length_set_nth); auto.
- Qed.
-
- Lemma carry_sequence_length_exact: forall is us,
- (length us = length BC.base)%nat ->
- (length (carry_sequence is us) = length BC.base)%nat.
- Proof.
- induction is; boring.
- apply carry_length_exact; auto.
- Qed.
- Hint Resolve carry_sequence_length_exact.
-
- Lemma carry_sequence_rep : forall is us x,
- (forall i, In i is -> (i < length BC.base)%nat) ->
- (length us = length BC.base) ->
- us ~= x -> carry_sequence is us ~= x.
- Proof.
- induction is; boring.
- Qed.
-End PseudoMersenneBase. \ No newline at end of file
+End CarryBasePow2.
+
+Section Canonicalization.
+ Context `{prm :PseudoMersenneBaseParams}.
+
+ Fixpoint make_chain i :=
+ match i with
+ | O => nil
+ | S i' => i' :: make_chain i'
+ end.
+
+ (* compute at compile time *)
+ Definition full_carry_chain := make_chain (length limb_widths).
+
+ (* compute at compile time *)
+ Definition max_ones := Z.ones
+ ((fix loop current_max lw :=
+ match lw with
+ | nil => current_max
+ | w :: lw' => loop (Z.max w current_max) lw'
+ end
+ ) 0 limb_widths).
+
+ (* compute at compile time? *)
+ Definition carry_full := carry_sequence full_carry_chain.
+
+ Definition max_bound i := Z.ones (log_cap i).
+
+ Definition isFull us :=
+ (fix loop full i :=
+ match i with
+ | O => full (* don't test 0; the test for 0 is the initial value of [full]. *)
+ | S i' => loop (andb (Z.eqb (max_bound i) (nth_default 0 us i)) full) i'
+ end
+ ) (Z.ltb (max_bound 0 - (c + 1)) (nth_default 0 us 0)) (length us - 1)%nat.
+
+ Fixpoint range' n m :=
+ match m with
+ | O => nil
+ | S m' => (n - m)%nat :: range' n m'
+ end.
+
+ Definition range n := range' n n.
+
+ Definition land_max_bound and_term i := Z.land and_term (max_bound i).
+
+ Definition freeze us :=
+ let us' := carry_full (carry_full (carry_full us)) in
+ let and_term := if isFull us' then max_ones else 0 in
+ (* [and_term] is all ones if us' is full, so the subtractions subtract q overall.
+ Otherwise, it's all zeroes, and the subtractions do nothing. *)
+ map (fun x => (snd x) - land_max_bound and_term (fst x)) (combine (range (length us')) us').
+
+End Canonicalization.
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v
new file mode 100644
index 000000000..981680b4a
--- /dev/null
+++ b/src/ModularArithmetic/ModularBaseSystemOpt.v
@@ -0,0 +1,463 @@
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseRep.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
+Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs.
+Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
+Require Import Crypto.BaseSystem Crypto.ModularArithmetic.ModularBaseSystem.
+Require Import Coq.Lists.List.
+Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil Crypto.Util.CaseUtil.
+Import ListNotations.
+Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory.
+Require Import Coq.QArith.QArith Coq.QArith.Qround.
+Require Import Crypto.Tactics.VerdiTactics.
+Local Open Scope Z.
+
+(* Computed versions of some functions. *)
+
+Definition Z_add_opt := Eval compute in Z.add.
+Definition Z_sub_opt := Eval compute in Z.sub.
+Definition Z_mul_opt := Eval compute in Z.mul.
+Definition Z_div_opt := Eval compute in Z.div.
+Definition Z_pow_opt := Eval compute in Z.pow.
+Definition Z_opp_opt := Eval compute in Z.opp.
+Definition Z_shiftl_opt := Eval compute in Z.shiftl.
+Definition Z_shiftl_by_opt := Eval compute in Z_shiftl_by.
+
+Definition nth_default_opt {A} := Eval compute in @nth_default A.
+Definition set_nth_opt {A} := Eval compute in @set_nth A.
+Definition map_opt {A B} := Eval compute in @map A B.
+Definition base_from_limb_widths_opt := Eval compute in base_from_limb_widths.
+
+Definition Let_In {A P} (x : A) (f : forall y : A, P y)
+ := let y := x in f y.
+
+(* Some automation that comes in handy when constructing base parameters *)
+Ltac opt_step :=
+ match goal with
+ | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ]
+ => refine (_ : match e with nil => _ | _ => _ end = _);
+ destruct e
+ end.
+
+Ltac brute_force_indices limb_widths := intros; unfold sum_firstn, limb_widths; simpl in *;
+ repeat match goal with
+ | _ => progress simpl in *
+ | _ => reflexivity
+ | [H : (S _ < S _)%nat |- _ ] => apply lt_S_n in H
+ | [H : (?x + _ < _)%nat |- _ ] => is_var x; destruct x
+ | [H : (?x < _)%nat |- _ ] => is_var x; destruct x
+ | _ => omega
+ end.
+
+
+Definition limb_widths_from_len len k := Eval compute in
+ (fix loop i prev :=
+ match i with
+ | O => nil
+ | S i' => let x := (if (Z.eq_dec ((k * Z.of_nat (len - i + 1)) mod (Z.of_nat len)) 0)
+ then (k * Z.of_nat (len - i + 1)) / Z.of_nat len
+ else (k * Z.of_nat (len - i + 1)) / Z.of_nat len + 1)in
+ x - prev:: (loop i' x)
+ end) len 0.
+
+Ltac construct_params prime_modulus len k :=
+ let lw := fresh "lw" in set (lw := limb_widths_from_len len k);
+ cbv in lw;
+ eapply Build_PseudoMersenneBaseParams with (limb_widths := lw);
+ [ abstract (apply fold_right_and_True_forall_In_iff; simpl; repeat (split; [omega |]); auto)
+ | abstract (unfold limb_widths; cbv; congruence)
+ | abstract brute_force_indices lw
+ | abstract apply prime_modulus
+ | abstract brute_force_indices lw].
+
+Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits :=
+ match limb_widths with
+ | nil => nil
+ | x :: tail =>
+ 2 ^ (x + 1) - (2 * c) :: map (fun w => 2 ^ (w + 1) - 2) tail
+ end.
+
+Ltac subst_precondition := match goal with
+ | [H : ?P, H' : ?P -> _ |- _] => specialize (H' H); clear H
+end.
+
+Ltac kill_precondition H :=
+ forward H; [abstract (try exact eq_refl; clear; cbv; intros; repeat break_or_hyp; intuition)|];
+ subst_precondition.
+
+Ltac compute_formula :=
+ match goal with
+ | [H : _ -> _ -> PseudoMersenneBaseRep.rep _ ?result |- PseudoMersenneBaseRep.rep _ ?result] => kill_precondition H; compute_formula
+ | [H : _ -> PseudoMersenneBaseRep.rep _ ?result |- PseudoMersenneBaseRep.rep _ ?result] => kill_precondition H; compute_formula
+ | [H : @PseudoMersenneBaseRep.rep ?M ?P _ ?result |- @PseudoMersenneBaseRep.rep ?M ?P _ ?result] =>
+ let m := fresh "m" in set (m := M) in H at 1; change M with m at 1;
+ let p := fresh "p" in set (p := P) in H at 1; change P with p at 1;
+ let r := fresh "r" in set (r := result) in H |- *;
+ cbv -[m p r PseudoMersenneBaseRep.rep] in H;
+ repeat rewrite ?Z.mul_1_r, ?Z.add_0_l, ?Z.add_assoc, ?Z.mul_assoc in H;
+ exact H
+ end.
+
+Section Carries.
+ Context `{prm : PseudoMersenneBaseParams}
+ (* allows caller to precompute k and c *)
+ (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_).
+
+ Definition carry_opt_sig
+ (i : nat) (b : digits)
+ : { d : digits | (i < length limb_widths)%nat -> d = carry i b }.
+ Proof.
+ eexists ; intros.
+ cbv [carry].
+ rewrite <- pull_app_if_sumbool.
+ cbv beta delta
+ [carry carry_and_reduce carry_simple add_to_nth log_cap
+ pow2_mod Z.ones Z.pred base
+ PseudoMersenneBaseParams.limb_widths].
+ change @nth_default with @nth_default_opt in *.
+ change @set_nth with @set_nth_opt in *.
+ lazymatch goal with
+ | [ |- _ = (if ?br then ?c else ?d) ]
+ => let x := fresh "x" in let y := fresh "y" in evar (x:digits); evar (y:digits); transitivity (if br then x else y); subst x; subst y
+ end.
+ 2:cbv zeta.
+ 2:break_if; reflexivity.
+
+ change @nth_default with @nth_default_opt.
+ rewrite c_subst.
+ change @set_nth with @set_nth_opt.
+ change @map with @map_opt.
+ rewrite <- @beq_nat_eq_nat_dec.
+ change base_from_limb_widths with base_from_limb_widths_opt.
+ reflexivity.
+ Defined.
+
+ Definition carry_opt i b
+ := Eval cbv beta iota delta [proj1_sig carry_opt_sig] in proj1_sig (carry_opt_sig i b).
+
+ Definition carry_opt_correct i b : (i < length limb_widths)%nat -> carry_opt i b = carry i b := proj2_sig (carry_opt_sig i b).
+
+ Definition carry_sequence_opt_sig (is : list nat) (us : digits)
+ : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }.
+ Proof.
+ eexists. intros H.
+ cbv [carry_sequence].
+ transitivity (fold_right carry_opt us is).
+ Focus 2.
+ { induction is; [ reflexivity | ].
+ simpl; rewrite IHis, carry_opt_correct.
+ - reflexivity.
+ - rewrite base_length in H.
+ apply H; apply in_eq.
+ - intros. apply H. right. auto.
+ }
+ Unfocus.
+ reflexivity.
+ Defined.
+
+ Definition carry_sequence_opt is us := Eval cbv [proj1_sig carry_sequence_opt_sig] in
+ proj1_sig (carry_sequence_opt_sig is us).
+
+ Definition carry_sequence_opt_correct is us
+ : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt is us = carry_sequence is us
+ := proj2_sig (carry_sequence_opt_sig is us).
+
+ Definition carry_opt_cps_sig
+ {T}
+ (i : nat)
+ (f : digits -> T)
+ (b : digits)
+ : { d : T | (i < length base)%nat -> d = f (carry i b) }.
+ Proof.
+ eexists. intros H.
+ rewrite <- carry_opt_correct by (rewrite base_length in H; assumption).
+ cbv beta iota delta [carry_opt].
+ let LHS := match goal with |- ?LHS = ?RHS => LHS end in
+ let RHS := match goal with |- ?LHS = ?RHS => RHS end in
+ let RHSf := match (eval pattern (nth_default_opt 0%Z b i) in RHS) with ?RHSf _ => RHSf end in
+ change (LHS = Let_In (nth_default_opt 0%Z b i) RHSf).
+ change Z.shiftl with Z_shiftl_opt.
+ change (-1) with (Z_opp_opt 1).
+ change Z.add with Z_add_opt at 8 12 20 24.
+ reflexivity.
+ Defined.
+
+ Definition carry_opt_cps {T} i f b
+ := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b).
+
+ Definition carry_opt_cps_correct {T} i f b :
+ (i < length base)%nat ->
+ @carry_opt_cps T i f b = f (carry i b)
+ := proj2_sig (carry_opt_cps_sig i f b).
+
+ Definition carry_sequence_opt_cps_sig (is : list nat) (us : digits)
+ : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }.
+ Proof.
+ eexists.
+ cbv [carry_sequence].
+ transitivity (fold_right carry_opt_cps id (List.rev is) us).
+ Focus 2.
+ {
+ assert (forall i, In i (rev is) -> i < length base)%nat as Hr. {
+ subst. intros. rewrite <- in_rev in *. auto. }
+ remember (rev is) as ris eqn:Heq.
+ rewrite <- (rev_involutive is), <- Heq.
+ clear H Heq is.
+ rewrite fold_left_rev_right.
+ revert us; induction ris; [ reflexivity | ]; intros.
+ { simpl.
+ rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption].
+ rewrite carry_opt_cps_correct; [reflexivity|].
+ apply Hr; left; reflexivity.
+ } }
+ Unfocus.
+ reflexivity.
+ Defined.
+
+ Definition carry_sequence_opt_cps is us := Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in
+ proj1_sig (carry_sequence_opt_cps_sig is us).
+
+ Definition carry_sequence_opt_cps_correct is us
+ : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps is us = carry_sequence is us
+ := proj2_sig (carry_sequence_opt_cps_sig is us).
+
+
+ Lemma carry_sequence_opt_cps_rep
+ : forall (is : list nat) (us : list Z) (x : F modulus),
+ (forall i : nat, In i is -> i < length base)%nat ->
+ length us = length base ->
+ rep us x -> rep (carry_sequence_opt_cps is us) x.
+ Proof.
+ intros.
+ rewrite carry_sequence_opt_cps_correct by assumption.
+ apply carry_sequence_rep; assumption.
+ Qed.
+
+End Carries.
+
+Section Addition.
+ Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}.
+
+ Definition add_opt_sig (us vs : T) : { b : digits | b = add us vs }.
+ Proof.
+ eexists.
+ cbv [BaseSystem.add].
+ reflexivity.
+ Defined.
+
+ Definition add_opt (us vs : T) : digits
+ := Eval cbv [proj1_sig add_opt_sig] in proj1_sig (add_opt_sig us vs).
+
+ Definition add_opt_correct us vs
+ : add_opt us vs = add us vs
+ := proj2_sig (add_opt_sig us vs).
+
+ Lemma add_opt_rep: forall (u v : T) (x y : F modulus),
+ PseudoMersenneBaseRep.rep u x -> PseudoMersenneBaseRep.rep v y ->
+ PseudoMersenneBaseRep.rep (add_opt u v) (x + y)%F.
+ Proof.
+ intros.
+ rewrite add_opt_correct.
+ auto using PseudoMersenneBaseRep.add_rep.
+ Qed.
+
+End Addition.
+
+Section Subtraction.
+ Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}.
+
+ Definition sub_opt_sig (us vs : T) : { b : digits | b = sub coeff coeff_mod us vs }.
+ Proof.
+ eexists.
+ cbv [BaseSystem.add ModularBaseSystem.sub BaseSystem.sub].
+ reflexivity.
+ Defined.
+
+ Definition sub_opt (us vs : T) : digits
+ := Eval cbv [proj1_sig sub_opt_sig] in proj1_sig (sub_opt_sig us vs).
+
+ Definition sub_opt_correct us vs
+ : sub_opt us vs = sub coeff coeff_mod us vs
+ := proj2_sig (sub_opt_sig us vs).
+
+ Lemma sub_opt_rep: forall (u v : T) (x y : F modulus),
+ PseudoMersenneBaseRep.rep u x -> PseudoMersenneBaseRep.rep v y ->
+ PseudoMersenneBaseRep.rep (sub_opt u v) (x - y)%F.
+ Proof.
+ intros.
+ rewrite sub_opt_correct.
+ change (sub coeff coeff_mod) with PseudoMersenneBaseRep.sub.
+ apply PseudoMersenneBaseRep.sub_rep; auto using coeff_length.
+ Qed.
+
+End Subtraction.
+
+Section Multiplication.
+ Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}
+ (* allows caller to precompute k and c *)
+ (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_).
+ Definition mul_bi'_step
+ (mul_bi' : nat -> digits -> list Z -> list Z)
+ (i : nat) (vsr : digits) (bs : list Z)
+ : list Z
+ := match vsr with
+ | [] => []
+ | v :: vsr' => (v * crosscoef bs i (length vsr'))%Z :: mul_bi' i vsr' bs
+ end.
+
+ Definition mul_bi'_opt_step_sig
+ (mul_bi' : nat -> digits -> list Z -> list Z)
+ (i : nat) (vsr : digits) (bs : list Z)
+ : { l : list Z | l = mul_bi'_step mul_bi' i vsr bs }.
+ Proof.
+ eexists.
+ cbv [mul_bi'_step].
+ opt_step.
+ { reflexivity. }
+ { cbv [crosscoef ext_base base].
+ change Z.div with Z_div_opt.
+ change Z.mul with Z_mul_opt at 2.
+ change @nth_default with @nth_default_opt.
+ reflexivity. }
+ Defined.
+
+ Definition mul_bi'_opt_step
+ (mul_bi' : nat -> digits -> list Z -> list Z)
+ (i : nat) (vsr : digits) (bs : list Z)
+ : list Z
+ := Eval cbv [proj1_sig mul_bi'_opt_step_sig] in
+ proj1_sig (mul_bi'_opt_step_sig mul_bi' i vsr bs).
+
+ Fixpoint mul_bi'_opt
+ (i : nat) (vsr : digits) (bs : list Z) {struct vsr}
+ : list Z
+ := mul_bi'_opt_step mul_bi'_opt i vsr bs.
+
+ Definition mul_bi'_opt_correct
+ (i : nat) (vsr : digits) (bs : list Z)
+ : mul_bi'_opt i vsr bs = mul_bi' bs i vsr.
+ Proof.
+ revert i; induction vsr as [|vsr vsrs IHvsr]; intros.
+ { reflexivity. }
+ { simpl mul_bi'.
+ rewrite <- IHvsr; clear IHvsr.
+ unfold mul_bi'_opt, mul_bi'_opt_step.
+ apply f_equal2; [ | reflexivity ].
+ cbv [crosscoef ext_base base].
+ change Z.div with Z_div_opt.
+ change Z.mul with Z_mul_opt at 2.
+ change @nth_default with @nth_default_opt.
+ reflexivity. }
+ Qed.
+
+ Definition mul'_step
+ (mul' : digits -> digits -> list Z -> digits)
+ (usr vs : digits) (bs : list Z)
+ : digits
+ := match usr with
+ | [] => []
+ | u :: usr' => add (mul_each u (mul_bi bs (length usr') vs)) (mul' usr' vs bs)
+ end.
+
+ Lemma map_zeros : forall a n l,
+ map (Z.mul a) (zeros n ++ l) = zeros n ++ map (Z.mul a) l.
+ Admitted.
+
+ Definition mul'_opt_step_sig
+ (mul' : digits -> digits -> list Z -> digits)
+ (usr vs : digits) (bs : list Z)
+ : { d : digits | d = mul'_step mul' usr vs bs }.
+ Proof.
+ eexists.
+ cbv [mul'_step].
+ match goal with
+ | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ]
+ => refine (_ : match e with nil => _ | _ => _ end = _);
+ destruct e
+ end.
+ { reflexivity. }
+ { cbv [mul_each mul_bi].
+ rewrite <- mul_bi'_opt_correct.
+ rewrite map_zeros.
+ change @map with @map_opt.
+ cbv [zeros].
+ reflexivity. }
+ Defined.
+
+ Definition mul'_opt_step
+ (mul' : digits -> digits -> list Z -> digits)
+ (usr vs : digits) (bs : list Z)
+ : digits
+ := Eval cbv [proj1_sig mul'_opt_step_sig] in proj1_sig (mul'_opt_step_sig mul' usr vs bs).
+
+ Fixpoint mul'_opt
+ (usr vs : digits) (bs : list Z)
+ : digits
+ := mul'_opt_step mul'_opt usr vs bs.
+
+ Definition mul'_opt_correct
+ (usr vs : digits) (bs : list Z)
+ : mul'_opt usr vs bs = mul' bs usr vs.
+ Proof.
+ revert vs; induction usr as [|usr usrs IHusr]; intros.
+ { reflexivity. }
+ { simpl.
+ rewrite <- IHusr; clear IHusr.
+ apply f_equal2; [ | reflexivity ].
+ cbv [mul_each mul_bi].
+ rewrite map_zeros.
+ rewrite <- mul_bi'_opt_correct.
+ reflexivity. }
+ Qed.
+
+ Definition mul_opt_sig (us vs : T) : { b : digits | b = mul us vs }.
+ Proof.
+ eexists.
+ cbv [BaseSystem.mul mul mul_each mul_bi mul_bi' zeros ext_base reduce].
+ rewrite <- mul'_opt_correct.
+ cbv [base PseudoMersenneBaseParams.limb_widths].
+ rewrite map_shiftl by apply k_nonneg.
+ rewrite c_subst.
+ rewrite k_subst.
+ change @map with @map_opt.
+ change base_from_limb_widths with base_from_limb_widths_opt.
+ change @Z_shiftl_by with @Z_shiftl_by_opt.
+ reflexivity.
+ Defined.
+
+ Definition mul_opt (us vs : T) : digits
+ := Eval cbv [proj1_sig mul_opt_sig] in proj1_sig (mul_opt_sig us vs).
+
+ Definition mul_opt_correct us vs
+ : mul_opt us vs = mul us vs
+ := proj2_sig (mul_opt_sig us vs).
+
+ Lemma mul_opt_rep:
+ forall (u v : T) (x y : F modulus), PseudoMersenneBaseRep.rep u x -> PseudoMersenneBaseRep.rep v y ->
+ PseudoMersenneBaseRep.rep (mul_opt u v) (x * y)%F.
+ Proof.
+ intros.
+ rewrite mul_opt_correct.
+ change mul with PseudoMersenneBaseRep.mul.
+ auto using PseudoMersenneBaseRep.mul_rep.
+ Qed.
+
+ Definition carry_mul_opt
+ (is : list nat)
+ (us vs : list Z)
+ : list Z
+ := carry_sequence_opt_cps c_ is (mul_opt us vs).
+
+ Lemma carry_mul_opt_correct
+ : forall (is : list nat) (us vs : list Z) (x y: F modulus),
+ PseudoMersenneBaseRep.rep us x -> PseudoMersenneBaseRep.rep vs y ->
+ (forall i : nat, In i is -> i < length base)%nat ->
+ length (mul_opt us vs) = length base ->
+ PseudoMersenneBaseRep.rep (carry_mul_opt is us vs) (x*y)%F.
+ Proof.
+ intros is us vs x y; intros.
+ change (carry_mul_opt _ _ _) with (carry_sequence_opt_cps c_ is (mul_opt us vs)).
+ apply carry_sequence_opt_cps_rep, mul_opt_rep; auto.
+ Qed.
+End Multiplication. \ No newline at end of file
diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v
new file mode 100644
index 000000000..274acff5a
--- /dev/null
+++ b/src/ModularArithmetic/ModularBaseSystemProofs.v
@@ -0,0 +1,1179 @@
+Require Import Zpower ZArith.
+Require Import Coq.Numbers.Natural.Peano.NPeano.
+Require Import List.
+Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import VerdiTactics.
+Require Crypto.BaseSystem.
+Require Import Crypto.ModularArithmetic.ModularBaseSystem Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import Crypto.BaseSystemProofs Crypto.ModularArithmetic.PseudoMersenneBaseParams Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs Crypto.ModularArithmetic.ExtendedBaseVector.
+Local Open Scope Z_scope.
+
+Section PseudoMersenneProofs.
+ Context `{prm :PseudoMersenneBaseParams}.
+
+ Local Hint Unfold decode.
+ Local Notation "u '~=' x" := (rep u x) (at level 70).
+ Local Notation "u '.+' x" := (add u x) (at level 70).
+ Local Notation "u '.*' x" := (ModularBaseSystem.mul u x) (at level 70).
+ Local Hint Unfold rep.
+
+ Lemma rep_decode : forall us x, us ~= x -> decode us = x.
+ Proof.
+ autounfold; intuition.
+ Qed.
+
+ Lemma encode_rep : forall x : F modulus, encode x ~= x.
+ Proof.
+ intros. unfold encode, rep.
+ split. {
+ unfold encode; simpl.
+ apply base_length_nonzero.
+ } {
+ unfold decode.
+ rewrite encode_rep.
+ apply ZToField_FieldToZ.
+ apply bv.
+ }
+ Qed.
+
+ Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.add u v ~= (x+y)%F.
+ Proof.
+ autounfold; intuition. {
+ unfold add.
+ rewrite add_length_le_max.
+ case_max; try rewrite Max.max_r; omega.
+ }
+ unfold decode in *; unfold decode in *.
+ rewrite add_rep.
+ rewrite ZToField_add.
+ subst; auto.
+ Qed.
+
+ Lemma sub_rep : forall c c_0modq, (length c <= length base)%nat ->
+ forall u v x y, u ~= x -> v ~= y ->
+ ModularBaseSystem.sub c c_0modq u v ~= (x-y)%F.
+ Proof.
+ autounfold; unfold ModularBaseSystem.sub; intuition. {
+ rewrite sub_length_le_max.
+ case_max; try rewrite Max.max_r; try omega.
+ rewrite add_length_le_max.
+ case_max; try rewrite Max.max_r; omega.
+ }
+ unfold decode in *; unfold BaseSystem.decode in *.
+ rewrite BaseSystemProofs.sub_rep, BaseSystemProofs.add_rep.
+ rewrite ZToField_sub, ZToField_add, ZToField_mod.
+ rewrite c_0modq, F_add_0_l.
+ subst; auto.
+ Qed.
+
+ Lemma decode_short : forall (us : BaseSystem.digits),
+ (length us <= length base)%nat ->
+ BaseSystem.decode base us = BaseSystem.decode ext_base us.
+ Proof.
+ intros.
+ unfold BaseSystem.decode, BaseSystem.decode'.
+ rewrite combine_truncate_r.
+ rewrite (combine_truncate_r us ext_base).
+ f_equal; f_equal.
+ unfold ext_base.
+ rewrite firstn_app_inleft; auto; omega.
+ Qed.
+
+ Lemma mul_rep_extended : forall (us vs : BaseSystem.digits),
+ (length us <= length base)%nat ->
+ (length vs <= length base)%nat ->
+ (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode ext_base (BaseSystem.mul ext_base us vs).
+ Proof.
+ intros.
+ rewrite mul_rep by (apply ExtBaseVector || unfold ext_base; simpl_list; omega).
+ f_equal; rewrite decode_short; auto.
+ Qed.
+
+ Lemma modulus_nonzero : modulus <> 0.
+ pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega.
+ Qed.
+
+ (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *)
+ Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus.
+ Proof.
+ intros.
+ replace (2^k) with ((2^k - c) + c) by ring.
+ rewrite Z.mul_add_distr_r.
+ rewrite Zplus_mod.
+ unfold c.
+ rewrite Z.sub_sub_distr, Z.sub_diag.
+ simpl.
+ rewrite Z.mul_comm.
+ rewrite mod_mult_plus; auto using modulus_nonzero.
+ rewrite <- Zplus_mod; auto.
+ Qed.
+
+ Lemma extended_shiftadd: forall (us : BaseSystem.digits),
+ BaseSystem.decode ext_base us =
+ BaseSystem.decode base (firstn (length base) us)
+ + (2^k * BaseSystem.decode base (skipn (length base) us)).
+ Proof.
+ intros.
+ unfold BaseSystem.decode; rewrite <- mul_each_rep.
+ unfold ext_base.
+ replace (map (Z.mul (2 ^ k)) base) with (BaseSystem.mul_each (2 ^ k) base) by auto.
+ rewrite base_mul_app.
+ rewrite <- mul_each_rep; auto.
+ Qed.
+
+ Lemma reduce_rep : forall us,
+ BaseSystem.decode base (reduce us) mod modulus =
+ BaseSystem.decode ext_base us mod modulus.
+ Proof.
+ intros.
+ rewrite extended_shiftadd.
+ rewrite pseudomersenne_add.
+ unfold reduce.
+ remember (firstn (length base) us) as low.
+ remember (skipn (length base) us) as high.
+ unfold BaseSystem.decode.
+ rewrite BaseSystemProofs.add_rep.
+ replace (map (Z.mul c) high) with (BaseSystem.mul_each c high) by auto.
+ rewrite mul_each_rep; auto.
+ Qed.
+
+ Lemma reduce_length : forall us,
+ (length us <= length ext_base)%nat ->
+ (length (reduce us) <= length base)%nat.
+ Proof.
+ intros.
+ unfold reduce.
+ remember (map (Z.mul c) (skipn (length base) us)) as high.
+ remember (firstn (length base) us) as low.
+ assert (length low >= length high)%nat. {
+ subst. rewrite firstn_length.
+ rewrite map_length.
+ rewrite skipn_length.
+ destruct (le_dec (length base) (length us)). {
+ rewrite Min.min_l by omega.
+ rewrite extended_base_length in H. omega.
+ } {
+ rewrite Min.min_r; omega.
+ }
+ }
+ assert ((length low <= length base)%nat)
+ by (rewrite Heqlow; rewrite firstn_length; apply Min.le_min_l).
+ assert (length high <= length base)%nat
+ by (rewrite Heqhigh; rewrite map_length; rewrite skipn_length;
+ rewrite extended_base_length in H; omega).
+ rewrite add_trailing_zeros; auto.
+ rewrite (add_same_length _ _ (length low)); auto.
+ rewrite app_length.
+ rewrite length_zeros; intuition.
+ Qed.
+
+ Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> u .* v ~= (x*y)%F.
+ Proof.
+ autounfold; unfold ModularBaseSystem.mul; intuition.
+ {
+ apply reduce_length.
+ rewrite mul_length, extended_base_length.
+ omega.
+ } {
+ rewrite ZToField_mod, reduce_rep, <-ZToField_mod.
+ rewrite mul_rep by
+ (apply ExtBaseVector || rewrite extended_base_length; omega).
+ subst.
+ do 2 rewrite decode_short by auto.
+ apply ZToField_mul.
+ }
+ Qed.
+
+ Lemma set_nth_sum : forall n x us, (n < length us)%nat ->
+ BaseSystem.decode base (set_nth n x us) =
+ (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us.
+ Proof.
+ intros.
+ unfold BaseSystem.decode.
+ nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*)
+ unfold splice_nth.
+ rewrite <- (firstn_skipn n us) at 4.
+ do 2 rewrite decode'_splice.
+ remember (length (firstn n us)) as n0.
+ ring_simplify.
+ remember (BaseSystem.decode' (firstn n0 base) (firstn n us)).
+ rewrite (skipn_nth_default n us 0) by omega.
+ rewrite firstn_length in Heqn0.
+ rewrite Min.min_l in Heqn0 by omega; subst n0.
+ destruct (le_lt_dec (length base) n). {
+ rewrite nth_default_out_of_bounds by auto.
+ rewrite skipn_all by omega.
+ do 2 rewrite decode_base_nil.
+ ring_simplify; auto.
+ } {
+ rewrite (skipn_nth_default n base 0) by omega.
+ do 2 rewrite decode'_cons.
+ ring_simplify; ring.
+ }
+ Qed.
+
+ Lemma add_to_nth_sum : forall n x us, (n < length us)%nat ->
+ BaseSystem.decode base (add_to_nth n x us) =
+ x * nth_default 0 base n + BaseSystem.decode base us.
+ Proof.
+ unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto.
+ Qed.
+
+ Lemma nth_default_base_positive : forall i, (i < length base)%nat ->
+ nth_default 0 base i > 0.
+ Proof.
+ intros.
+ pose proof (nth_error_length_exists_value _ _ H).
+ destruct H0.
+ pose proof (nth_error_value_In _ _ _ H0).
+ pose proof (BaseSystem.base_positive _ H1).
+ unfold nth_default.
+ rewrite H0; auto.
+ Qed.
+
+ Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat ->
+ nth_default 0 base (S i) = nth_default 0 base i *
+ (nth_default 0 base (S i) / nth_default 0 base i).
+ Proof.
+ intros.
+ apply Z_div_exact_2; try (apply nth_default_base_positive; omega).
+ apply base_succ; auto.
+ Qed.
+
+End PseudoMersenneProofs.
+
+Section CarryProofs.
+ Context `{prm : PseudoMersenneBaseParams}.
+ Local Notation "u '~=' x" := (rep u x) (at level 70).
+ Hint Unfold log_cap.
+
+ Lemma base_length_lt_pred : (pred (length base) < length base)%nat.
+ Proof.
+ pose proof base_length_nonzero; omega.
+ Qed.
+ Hint Resolve base_length_lt_pred.
+
+ Lemma log_cap_nonneg : forall i, 0 <= log_cap i.
+ Proof.
+ unfold log_cap, nth_default; intros.
+ case_eq (nth_error limb_widths i); intros; try omega.
+ apply limb_widths_nonneg.
+ eapply nth_error_value_In; eauto.
+ Qed.
+
+ Lemma nth_default_base_succ : forall i, (S i < length base)%nat ->
+ nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i.
+ Proof.
+ intros.
+ repeat rewrite nth_default_base by omega.
+ rewrite <- Z.pow_add_r by (apply log_cap_nonneg || apply sum_firstn_limb_widths_nonneg).
+ destruct (NPeano.Nat.eq_dec i 0).
+ + subst; f_equal.
+ unfold sum_firstn, log_cap.
+ destruct limb_widths; auto.
+ + erewrite sum_firstn_succ; eauto.
+ unfold log_cap.
+ apply nth_error_Some_nth_default.
+ rewrite <- base_length; omega.
+ Qed.
+
+ Lemma carry_simple_decode_eq : forall i us,
+ (length us = length base) ->
+ (i < (pred (length base)))%nat ->
+ BaseSystem.decode base (carry_simple i us) = BaseSystem.decode base us.
+ Proof.
+ unfold carry_simple. intros.
+ rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
+ rewrite set_nth_sum by omega.
+ unfold pow2_mod.
+ rewrite Z.land_ones by apply log_cap_nonneg.
+ rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg.
+ rewrite nth_default_base_succ by omega.
+ rewrite Z.mul_assoc.
+ rewrite (Z.mul_comm _ (2 ^ log_cap i)).
+ rewrite mul_div_eq; try ring.
+ apply gt_lt_symmetry.
+ apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg.
+ Qed.
+
+ Lemma carry_decode_eq_reduce : forall us,
+ (length us = length base) ->
+ BaseSystem.decode base (carry_and_reduce (pred (length base)) us) mod modulus
+ = BaseSystem.decode base us mod modulus.
+ Proof.
+ unfold carry_and_reduce; intros ? length_eq.
+ pose proof base_length_nonzero.
+ rewrite add_to_nth_sum by (rewrite length_set_nth; omega).
+ rewrite set_nth_sum by omega.
+ rewrite Zplus_comm, <- Z.mul_assoc, <- pseudomersenne_add, BaseSystem.b0_1.
+ rewrite (Z.mul_comm (2 ^ k)), <- Zred_factor0.
+ f_equal.
+ rewrite <- (Z.add_comm (BaseSystem.decode base us)), <- Z.add_assoc, <- Z.add_0_r.
+ f_equal.
+ destruct (NPeano.Nat.eq_dec (length base) 0) as [length_zero | length_nonzero].
+ + apply length0_nil in length_zero.
+ pose proof (base_length) as limbs_length.
+ rewrite length_zero in length_eq, limbs_length.
+ apply length0_nil in length_eq.
+ symmetry in limbs_length.
+ apply length0_nil in limbs_length.
+ unfold log_cap.
+ subst; rewrite length_zero, limbs_length, nth_default_nil.
+ reflexivity.
+ + rewrite nth_default_base by omega.
+ rewrite <- Z.add_opp_l, <- Z.opp_sub_distr.
+ unfold pow2_mod.
+ rewrite Z.land_ones by apply log_cap_nonneg.
+ rewrite <- mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg).
+ rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg.
+ rewrite Zopp_mult_distr_r.
+ rewrite Z.mul_comm.
+ rewrite Z.mul_assoc.
+ rewrite <- Z.pow_add_r by (apply log_cap_nonneg || apply sum_firstn_limb_widths_nonneg).
+ unfold k.
+ replace (length limb_widths) with (S (pred (length base))) by
+ (subst; rewrite <- base_length; apply NPeano.Nat.succ_pred; omega).
+ rewrite sum_firstn_succ with (x:= log_cap (pred (length base))) by
+ (unfold log_cap; apply nth_error_Some_nth_default; rewrite <- base_length; omega).
+ rewrite <- Zopp_mult_distr_r.
+ rewrite Z.mul_comm.
+ rewrite (Z.add_comm (log_cap (pred (length base)))).
+ ring.
+ Qed.
+
+ Lemma carry_length : forall i us,
+ (length us <= length base)%nat ->
+ (length (carry i us) <= length base)%nat.
+ Proof.
+ unfold carry, carry_simple, carry_and_reduce, add_to_nth.
+ intros; break_if; subst; repeat (rewrite length_set_nth); auto.
+ Qed.
+ Hint Resolve carry_length.
+
+ Lemma carry_rep : forall i us x,
+ (length us = length base) ->
+ (i < length base)%nat ->
+ us ~= x -> carry i us ~= x.
+ Proof.
+ pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq.
+ unfold rep, decode, carry in *; intros.
+ intuition; break_if; subst; eauto;
+ apply F_eq; simpl; intuition.
+ Qed.
+ Hint Resolve carry_rep.
+
+ Lemma carry_sequence_length: forall is us,
+ (length us <= length base)%nat ->
+ (length (carry_sequence is us) <= length base)%nat.
+ Proof.
+ induction is; boring.
+ Qed.
+ Hint Resolve carry_sequence_length.
+
+ Lemma carry_length_exact : forall i us,
+ (length us = length base)%nat ->
+ (length (carry i us) = length base)%nat.
+ Proof.
+ unfold carry, carry_simple, carry_and_reduce, add_to_nth.
+ intros; break_if; subst; repeat (rewrite length_set_nth); auto.
+ Qed.
+
+ Lemma carry_sequence_length_exact: forall is us,
+ (length us = length base)%nat ->
+ (length (carry_sequence is us) = length base)%nat.
+ Proof.
+ induction is; boring.
+ apply carry_length_exact; auto.
+ Qed.
+ Hint Resolve carry_sequence_length_exact.
+
+ Lemma carry_sequence_rep : forall is us x,
+ (forall i, In i is -> (i < length base)%nat) ->
+ (length us = length base) ->
+ us ~= x -> carry_sequence is us ~= x.
+ Proof.
+ induction is; boring.
+ Qed.
+
+End CarryProofs.
+
+Section CanonicalizationProofs.
+ Context `{prm : PseudoMersenneBaseParams} (lt_1_length_base : (1 < length base)%nat) (c_pos : 0 < c) {B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B).
+
+ (* TODO : move *)
+ Lemma set_nth_nth_default : forall {A} (d:A) n x l i, (0 <= i < length l)%nat ->
+ nth_default d (set_nth n x l) i =
+ if (eq_nat_dec i n) then x else nth_default d l i.
+ Proof.
+ induction n; (destruct l; [intros; simpl in *; omega | ]); simpl;
+ destruct i; break_if; try omega; intros; try apply nth_default_cons;
+ rewrite !nth_default_cons_S, ?IHn; try break_if; omega || reflexivity.
+ Qed.
+
+ (* TODO : move *)
+ Lemma add_to_nth_nth_default : forall n x l i, (0 <= i < length l)%nat ->
+ nth_default 0 (add_to_nth n x l) i =
+ if (eq_nat_dec i n) then x + nth_default 0 l i else nth_default 0 l i.
+ Proof.
+ intros.
+ unfold add_to_nth.
+ rewrite set_nth_nth_default by assumption.
+ break_if; subst; reflexivity.
+ Qed.
+
+ (* TODO : move *)
+ Lemma length_add_to_nth : forall n x l, length (add_to_nth n x l) = length l.
+ Proof.
+ unfold add_to_nth; intros; apply length_set_nth.
+ Qed.
+
+ (* TODO : move *)
+ Lemma singleton_list : forall {A} (l : list A), length l = 1%nat -> exists x, l = x :: nil.
+ Proof.
+ intros; destruct l; simpl in *; try congruence.
+ eexists; f_equal.
+ apply length0_nil; omega.
+ Qed.
+
+ (* BEGIN groundwork proofs *)
+
+ Lemma pow_2_log_cap_pos : forall i, 0 < 2 ^ log_cap i.
+ Proof.
+ intros; apply Z.pow_pos_nonneg; auto using log_cap_nonneg; omega.
+ Qed.
+ Local Hint Resolve pow_2_log_cap_pos.
+
+ Lemma max_bound_log_cap : forall i, Z.succ (max_bound i) = 2 ^ log_cap i.
+ Proof.
+ intros.
+ unfold max_bound, Z.ones.
+ rewrite Z.shiftl_1_l.
+ omega.
+ Qed.
+
+ Hint Resolve log_cap_nonneg.
+ Lemma pow2_mod_log_cap_range : forall a i, 0 <= pow2_mod a (log_cap i) <= max_bound i.
+ Proof.
+ intros.
+ unfold pow2_mod.
+ rewrite Z.land_ones by apply log_cap_nonneg.
+ unfold max_bound, Z.ones.
+ rewrite Z.shiftl_1_l, <-Z.lt_le_pred.
+ apply Z_mod_lt.
+ pose proof (pow_2_log_cap_pos i).
+ omega.
+ Qed.
+
+ Lemma pow2_mod_log_cap_bounds_lower : forall a i, 0 <= pow2_mod a (log_cap i).
+ Proof.
+ intros.
+ pose proof (pow2_mod_log_cap_range a i); omega.
+ Qed.
+
+ Lemma pow2_mod_log_cap_bounds_upper : forall a i, pow2_mod a (log_cap i) <= max_bound i.
+ Proof.
+ intros.
+ pose proof (pow2_mod_log_cap_range a i); omega.
+ Qed.
+
+ Lemma pow2_mod_log_cap_small : forall a i, 0 <= a <= max_bound i ->
+ pow2_mod a (log_cap i) = a.
+ Proof.
+ intros.
+ unfold pow2_mod.
+ rewrite Z.land_ones by apply log_cap_nonneg.
+ apply Z.mod_small.
+ split; try omega.
+ rewrite <- max_bound_log_cap.
+ omega.
+ Qed.
+
+ Lemma max_bound_nonneg : forall i, 0 <= max_bound i.
+ Proof.
+ unfold max_bound; intros; auto using Z_ones_nonneg.
+ Qed.
+ Local Hint Resolve max_bound_nonneg.
+
+ Lemma pow2_mod_spec : forall a b, (0 <= b) -> pow2_mod a b = a mod (2 ^ b).
+ Proof.
+ intros.
+ unfold pow2_mod.
+ rewrite Z.land_ones; auto.
+ Qed.
+
+ Lemma pow2_mod_upper_bound : forall a b, (0 <= a) -> (0 <= b) -> pow2_mod a b <= a.
+ Proof.
+ intros.
+ unfold pow2_mod.
+ rewrite Z.land_ones; auto.
+ apply Z.mod_le; auto.
+ apply Z.pow_pos_nonneg; omega.
+ Qed.
+
+ Lemma shiftr_eq_0_max_bound : forall i a, Z.shiftr a (log_cap i) = 0 ->
+ a <= max_bound i.
+ Proof.
+ intros ? ? shiftr_0.
+ apply Z.shiftr_eq_0_iff in shiftr_0.
+ intuition; subst; try apply max_bound_nonneg.
+ match goal with H : Z.log2 _ < log_cap _ |- _ => apply Z.log2_lt_pow2 in H;
+ replace (2 ^ log_cap i) with (Z.succ (max_bound i)) in H by
+ (unfold max_bound, Z.ones; rewrite Z.shiftl_1_l; omega)
+ end; auto.
+ omega.
+ Qed.
+
+ Lemma B_compat_log_cap : forall i, 0 <= B - log_cap i.
+ Proof.
+ unfold log_cap; intros.
+ destruct (lt_dec i (length limb_widths)).
+ + apply Z.le_0_sub.
+ apply B_compat.
+ rewrite nth_default_eq.
+ apply nth_In; assumption.
+ + replace (nth_default 0 limb_widths i) with 0; try omega.
+ symmetry; apply nth_default_out_of_bounds.
+ omega.
+ Qed.
+ Local Hint Resolve B_compat_log_cap.
+
+ Lemma max_bound_shiftr_eq_0 : forall i a, 0 <= a -> a <= max_bound i ->
+ Z.shiftr a (log_cap i) = 0.
+ Proof.
+ intros ? ? ? le_max_bound.
+ apply Z.shiftr_eq_0_iff.
+ destruct (Z_eq_dec a 0); auto.
+ right.
+ split; try omega.
+ apply Z.log2_lt_pow2; try omega.
+ rewrite <-max_bound_log_cap.
+ omega.
+ Qed.
+
+ (* END groundwork proofs *)
+ Opaque pow2_mod log_cap max_bound.
+
+ (* automation *)
+ Ltac carry_length_conditions' := unfold carry_full, add_to_nth;
+ rewrite ?length_set_nth, ?carry_length_exact, ?carry_sequence_length_exact, ?carry_sequence_length_exact;
+ try omega; try solve [pose proof base_length; pose proof base_length_nonzero; omega || auto ].
+ Ltac carry_length_conditions := try split; try omega; repeat carry_length_conditions'.
+
+ Ltac add_set_nth := rewrite ?add_to_nth_nth_default; try solve [carry_length_conditions];
+ try break_if; try omega; rewrite ?set_nth_nth_default; try solve [carry_length_conditions];
+ try break_if; try omega.
+
+ (* BEGIN defs *)
+
+ Definition c_carry_constraint : Prop :=
+ (c * (Z.ones (B - log_cap (pred (length base)))) < max_bound 0 + 1)
+ /\ (max_bound 0 + c < 2 ^ (log_cap 0 + 1))
+ /\ (c <= max_bound 0 - c).
+
+ Definition pre_carry_bounds us := forall i, 0 <= nth_default 0 us i <
+ if (eq_nat_dec i 0) then 2 ^ B else 2 ^ B - 2 ^ (B - log_cap (pred i)).
+
+ Lemma pre_carry_bounds_nonzero : forall us, pre_carry_bounds us ->
+ (forall i, 0 <= nth_default 0 us i).
+ Proof.
+ unfold pre_carry_bounds.
+ intros ? PCB i.
+ specialize (PCB i).
+ omega.
+ Qed.
+ Hint Resolve pre_carry_bounds_nonzero.
+
+ Definition carry_done us := forall i, (i < length base)%nat -> Z.shiftr (nth_default 0 us i) (log_cap i) = 0.
+
+ Lemma carry_carry_done_done : forall i us,
+ (length us = length base)%nat ->
+ (i < length base)%nat ->
+ (forall i, 0 <= nth_default 0 us i) ->
+ carry_done us -> carry_done (carry i us).
+ Proof.
+ unfold carry_done; intros until 3. intros Hcarry_done ? ?.
+ unfold carry, carry_simple, carry_and_reduce; break_if; subst.
+ + rewrite Hcarry_done by omega.
+ rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_bound).
+ destruct i0; add_set_nth; rewrite ?Z.mul_0_r, ?Z.add_0_l; auto.
+ match goal with H : S _ = pred (length base) |- _ => rewrite H; auto end.
+ + rewrite Hcarry_done by omega.
+ rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_bound).
+ destruct i0; add_set_nth; subst; rewrite ?Z.add_0_l; auto.
+ Qed.
+
+ (* END defs *)
+
+ (* BEGIN proofs about first carry loop *)
+
+ Lemma nth_default_carry_bound_upper : forall i us, (length us = length base) ->
+ nth_default 0 (carry i us) i <= max_bound i.
+ Proof.
+ unfold carry; intros.
+ break_if.
+ + unfold carry_and_reduce.
+ add_set_nth.
+ apply pow2_mod_log_cap_bounds_upper.
+ + unfold carry_simple.
+ destruct (lt_dec i (length us)).
+ - add_set_nth.
+ apply pow2_mod_log_cap_bounds_upper.
+ - rewrite nth_default_out_of_bounds by carry_length_conditions; auto.
+ Qed.
+
+ Lemma nth_default_carry_bound_lower : forall i us, (length us = length base) ->
+ 0 <= nth_default 0 (carry i us) i.
+ Proof.
+ unfold carry; intros.
+ break_if.
+ + unfold carry_and_reduce.
+ add_set_nth.
+ apply pow2_mod_log_cap_bounds_lower.
+ + unfold carry_simple.
+ destruct (lt_dec i (length us)).
+ - add_set_nth.
+ apply pow2_mod_log_cap_bounds_lower.
+ - rewrite nth_default_out_of_bounds by carry_length_conditions; omega.
+ Qed.
+
+ Lemma nth_default_carry_bound_succ_lower : forall i us, (forall i, 0 <= nth_default 0 us i) ->
+ (length us = length base) ->
+ 0 <= nth_default 0 (carry i us) (S i).
+ Proof.
+ unfold carry; intros ? ? PCB ? .
+ break_if.
+ + subst. replace (S (pred (length base))) with (length base) by omega.
+ rewrite nth_default_out_of_bounds; carry_length_conditions.
+ unfold carry_and_reduce.
+ add_set_nth.
+ + unfold carry_simple.
+ destruct (lt_dec (S i) (length us)).
+ - add_set_nth.
+ apply Z.add_nonneg_nonneg; [ apply Z.shiftr_nonneg | ]; unfold pre_carry_bounds in PCB.
+ * specialize (PCB i). omega.
+ * specialize (PCB (S i)). omega.
+ - rewrite nth_default_out_of_bounds by carry_length_conditions; omega.
+ Qed.
+
+ Lemma carry_unaffected_low : forall i j us, ((0 < i < j)%nat \/ (i = 0 /\ j <> 0 /\ j <> pred (length base))%nat)->
+ (length us = length base) ->
+ nth_default 0 (carry j us) i = nth_default 0 us i.
+ Proof.
+ intros.
+ unfold carry.
+ break_if.
+ + unfold carry_and_reduce.
+ add_set_nth.
+ + unfold carry_simple.
+ destruct (lt_dec i (length us)).
+ - add_set_nth.
+ - rewrite !nth_default_out_of_bounds by
+ (omega || rewrite length_add_to_nth; rewrite length_set_nth; pose proof base_length_nonzero; omega).
+ reflexivity.
+ Qed.
+
+ Lemma carry_unaffected_high : forall i j us, (S j < i)%nat -> (length us = length base) ->
+ nth_default 0 (carry j us) i = nth_default 0 us i.
+ Proof.
+ intros.
+ destruct (lt_dec i (length us));
+ [ | rewrite !nth_default_out_of_bounds by carry_length_conditions; reflexivity].
+ unfold carry, carry_simple.
+ break_if; add_set_nth.
+ Qed.
+
+ Lemma carry_nothing : forall i j us, (i < length base)%nat ->
+ (length us = length base)%nat ->
+ 0 <= nth_default 0 us j <= max_bound j ->
+ nth_default 0 (carry j us) i = nth_default 0 us i.
+ Proof.
+ unfold carry, carry_simple, carry_and_reduce; intros.
+ break_if; (add_set_nth;
+ [ rewrite max_bound_shiftr_eq_0 by omega; ring
+ | subst; apply pow2_mod_log_cap_small; assumption ]).
+ Qed.
+
+ Lemma carry_bounds_0_upper : forall us j, (length us = length base) ->
+ (0 < j < length base)%nat ->
+ nth_default 0 (carry_sequence (make_chain j) us) 0 <= max_bound 0.
+ Proof.
+ unfold carry_sequence; induction j; [simpl; intros; omega | ].
+ intros.
+ simpl in *.
+ destruct (eq_nat_dec 0 j).
+ + subst.
+ apply nth_default_carry_bound_upper; fold (carry_sequence (make_chain 0) us); carry_length_conditions.
+ + rewrite carry_unaffected_low; try omega.
+ fold (carry_sequence (make_chain j) us); carry_length_conditions.
+ Qed.
+
+ Lemma carry_bounds_upper : forall i us j, (0 < i < j)%nat -> (length us = length base) ->
+ nth_default 0 (carry_sequence (make_chain j) us) i <= max_bound i.
+ Proof.
+ unfold carry_sequence;
+ induction j; [simpl; intros; omega | ].
+ intros.
+ simpl in *.
+ assert (i = j \/ i < j)%nat as cases by omega.
+ destruct cases as [eq_j_i | lt_i_j]; subst.
+ + apply nth_default_carry_bound_upper; fold (carry_sequence (make_chain j) us); carry_length_conditions.
+ + rewrite carry_unaffected_low; try omega.
+ fold (carry_sequence (make_chain j) us); carry_length_conditions.
+ Qed.
+
+ Lemma carry_sequence_unaffected : forall i us j, (j < i)%nat -> (length us = length base)%nat ->
+ nth_default 0 (carry_sequence (make_chain j) us) i = nth_default 0 us i.
+ Proof.
+ induction j; [simpl; intros; omega | ].
+ intros.
+ simpl in *.
+ rewrite carry_unaffected_high by carry_length_conditions.
+ apply IHj; omega.
+ Qed.
+
+ Lemma carry_sequence_bounds_lower : forall j i us, (length us = length base) ->
+ (forall i, 0 <= nth_default 0 us i) -> (j <= length base)%nat ->
+ 0 <= nth_default 0 (carry_sequence (make_chain j) us) i.
+ Proof.
+ induction j; intros.
+ + simpl. auto.
+ + simpl.
+ destruct (lt_dec (S j) i).
+ - rewrite carry_unaffected_high by carry_length_conditions.
+ apply IHj; auto; omega.
+ - assert ((i = S j) \/ (i = j) \/ (i < j))%nat as cases by omega.
+ destruct cases as [? | [? | ?]].
+ * subst. apply nth_default_carry_bound_succ_lower; carry_length_conditions.
+ intros.
+ eapply IHj; auto; omega.
+ * subst. apply nth_default_carry_bound_lower; carry_length_conditions.
+ * destruct (eq_nat_dec j (pred (length base)));
+ [ | rewrite carry_unaffected_low by carry_length_conditions; apply IHj; auto; omega ].
+ subst.
+ unfold carry, carry_and_reduce; break_if; try omega.
+ add_set_nth; [ | apply IHj; auto; omega ].
+ apply Z.add_nonneg_nonneg; [ | apply IHj; auto; omega ].
+ apply Z.mul_nonneg_nonneg; try omega.
+ apply Z.shiftr_nonneg.
+ apply IHj; auto; omega.
+ Qed.
+
+ Lemma carry_bounds_lower : forall i us j, (0 < i <= j)%nat -> (length us = length base) ->
+ (forall i, 0 <= nth_default 0 us i) -> (j <= length base)%nat ->
+ 0 <= nth_default 0 (carry_sequence (make_chain j) us) i.
+ Proof.
+ unfold carry_sequence;
+ induction j; [simpl; intros; omega | ].
+ intros.
+ simpl in *.
+ destruct (eq_nat_dec i (S j)).
+ + subst. apply nth_default_carry_bound_succ_lower; auto;
+ fold (carry_sequence (make_chain j) us); carry_length_conditions.
+ intros.
+ apply carry_sequence_bounds_lower; auto; omega.
+ + assert (i = j \/ i < j)%nat as cases by omega.
+ destruct cases as [eq_j_i | lt_i_j]; subst;
+ [apply nth_default_carry_bound_lower| rewrite carry_unaffected_low]; try omega;
+ fold (carry_sequence (make_chain j) us); carry_length_conditions.
+ apply carry_sequence_bounds_lower; auto; omega.
+ Qed.
+
+ Lemma carry_full_bounds : forall us i, (i <> 0)%nat -> (forall i, 0 <= nth_default 0 us i) ->
+ (length us = length base)%nat ->
+ 0 <= nth_default 0 (carry_full us) i <= max_bound i.
+ Proof.
+ unfold carry_full, full_carry_chain; intros.
+ split; (destruct (lt_dec i (length limb_widths));
+ [ | rewrite nth_default_out_of_bounds by carry_length_conditions; omega || auto ]).
+ + apply carry_bounds_lower; carry_length_conditions.
+ + apply carry_bounds_upper; carry_length_conditions.
+ Qed.
+
+ Lemma carry_simple_no_overflow : forall us i, (i < pred (length base))%nat ->
+ length us = length base ->
+ 0 <= nth_default 0 us i < 2 ^ B ->
+ 0 <= nth_default 0 us (S i) < 2 ^ B - 2 ^ (B - log_cap i) ->
+ 0 <= nth_default 0 (carry i us) (S i) < 2 ^ B.
+ Proof.
+ intros.
+ unfold carry, carry_simple; break_if; try omega.
+ add_set_nth.
+ replace (2 ^ B) with (2 ^ (B - log_cap i) + (2 ^ B - 2 ^ (B - log_cap i))) by omega.
+ split.
+ + apply Z.add_nonneg_nonneg; try omega.
+ apply Z.shiftr_nonneg; try omega.
+ + apply Z.add_lt_mono; try omega.
+ rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg.
+ apply Z.div_lt_upper_bound; try apply pow_2_log_cap_pos.
+ rewrite <-Z.pow_add_r by (apply log_cap_nonneg || apply B_compat_log_cap).
+ replace (log_cap i + (B - log_cap i)) with B by ring.
+ omega.
+ Qed.
+
+
+ Lemma carry_sequence_no_overflow : forall i us, pre_carry_bounds us ->
+ (length us = length base) ->
+ nth_default 0 (carry_sequence (make_chain i) us) i < 2 ^ B.
+ Proof.
+ unfold pre_carry_bounds.
+ intros ? ? PCB ?.
+ induction i.
+ + simpl. specialize (PCB 0%nat).
+ intuition.
+ + simpl.
+ destruct (lt_eq_lt_dec i (pred (length base))) as [[? | ? ] | ? ].
+ - apply carry_simple_no_overflow; carry_length_conditions.
+ apply carry_sequence_bounds_lower; carry_length_conditions.
+ apply carry_sequence_bounds_lower; carry_length_conditions.
+ rewrite carry_sequence_unaffected; try omega.
+ specialize (PCB (S i)); rewrite Nat.pred_succ in PCB.
+ break_if; intuition.
+ - unfold carry; break_if; try omega.
+ rewrite nth_default_out_of_bounds; [ apply Z.pow_pos_nonneg; omega | ].
+ subst.
+ unfold carry_and_reduce.
+ carry_length_conditions.
+ - rewrite nth_default_out_of_bounds; [ apply Z.pow_pos_nonneg; omega | ].
+ carry_length_conditions.
+ Qed.
+
+ Lemma carry_full_bounds_0 : forall us, pre_carry_bounds us ->
+ (length us = length base)%nat ->
+ 0 <= nth_default 0 (carry_full us) 0 <= max_bound 0 + c * (Z.ones (B - log_cap (pred (length base)))).
+ Proof.
+ unfold carry_full, full_carry_chain; intros.
+ rewrite <- base_length.
+ replace (length base) with (S (pred (length base))) at 1 2 by omega.
+ simpl.
+ unfold carry, carry_and_reduce; break_if; try omega.
+ add_set_nth.
+ split.
+ + apply Z.add_nonneg_nonneg.
+ - apply Z.mul_nonneg_nonneg; try omega.
+ apply Z.shiftr_nonneg.
+ apply carry_sequence_bounds_lower; auto; omega.
+ - apply carry_sequence_bounds_lower; auto; omega.
+ + rewrite Z.add_comm.
+ apply Z.add_le_mono.
+ - apply carry_bounds_0_upper; auto; omega.
+ - apply Z.mul_le_mono_pos_l; auto.
+ apply Z_shiftr_ones; auto;
+ [ | pose proof (B_compat_log_cap (pred (length base))); omega ].
+ split.
+ * apply carry_bounds_lower; auto; try omega.
+ * apply carry_sequence_no_overflow; auto.
+ Qed.
+
+ Lemma carry_full_bounds_lower : forall i us, pre_carry_bounds us ->
+ (length us = length base)%nat ->
+ 0 <= nth_default 0 (carry_full us) i.
+ Proof.
+ destruct i; intros.
+ + apply carry_full_bounds_0; auto.
+ + destruct (lt_dec (S i) (length base)).
+ - apply carry_bounds_lower; carry_length_conditions.
+ - rewrite nth_default_out_of_bounds; carry_length_conditions.
+ Qed.
+
+ (* END proofs about first carry loop *)
+
+ (* BEGIN proofs about second carry loop *)
+
+ Lemma carry_sequence_carry_full_bounds_same : forall us i, pre_carry_bounds us -> c_carry_constraint ->
+ (length us = length base)%nat -> (0 < i < length base)%nat ->
+ 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full us)) i <= 2 ^ log_cap i.
+ Proof.
+ induction i; intros; try omega.
+ simpl.
+ unfold carry, carry_simple; break_if; try omega.
+ add_set_nth.
+ split.
+ + apply Z.add_nonneg_nonneg.
+ - apply Z.shiftr_nonneg.
+ destruct (eq_nat_dec i 0); subst.
+ * simpl.
+ apply carry_full_bounds_0; auto.
+ * apply IHi; auto; omega.
+ - rewrite carry_sequence_unaffected by carry_length_conditions.
+ apply carry_full_bounds; auto; omega.
+ + rewrite <-max_bound_log_cap, <-Z.add_1_l.
+ apply Z.add_le_mono.
+ - rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg.
+ apply Z_div_floor; auto.
+ destruct i.
+ * simpl.
+ eapply Z.le_lt_trans; [ apply carry_full_bounds_0; auto | ].
+ replace (2 ^ log_cap 0 * 2) with (2 ^ log_cap 0 + 2 ^ log_cap 0) by ring.
+ rewrite <-max_bound_log_cap, <-Z.add_1_l.
+ apply Z.add_lt_le_mono; try omega.
+ unfold c_carry_constraint in *.
+ intuition.
+ * eapply Z.le_lt_trans; [ apply IHi; auto; omega | ].
+ apply Z.lt_mul_diag_r; auto; omega.
+ - rewrite carry_sequence_unaffected by carry_length_conditions.
+ apply carry_full_bounds; auto; omega.
+ Qed.
+
+ Lemma carry_full_2_bounds_0 : forall us, pre_carry_bounds us -> c_carry_constraint ->
+ (length us = length base)%nat -> (1 < length base)%nat ->
+ 0 <= nth_default 0 (carry_full (carry_full us)) 0 <= max_bound 0 + c.
+ Proof.
+ intros.
+ unfold carry_full at 1 3, full_carry_chain.
+ rewrite <-base_length.
+ replace (length base) with (S (pred (length base))) by (pose proof base_length_nonzero; omega).
+ simpl.
+ unfold carry, carry_and_reduce; break_if; try omega.
+ add_set_nth.
+ split.
+ + apply Z.add_nonneg_nonneg.
+ apply Z.mul_nonneg_nonneg; try omega.
+ apply Z.shiftr_nonneg.
+ apply carry_sequence_carry_full_bounds_same; auto; omega.
+ eapply carry_sequence_bounds_lower; eauto; carry_length_conditions.
+ intros.
+ eapply carry_sequence_bounds_lower; eauto; carry_length_conditions.
+ + rewrite Z.add_comm.
+ apply Z.add_le_mono.
+ - apply carry_bounds_0_upper; carry_length_conditions.
+ - replace c with (c * 1) at 2 by ring.
+ apply Z.mul_le_mono_pos_l; try omega.
+ rewrite Z.shiftr_div_pow2 by auto.
+ apply Z.div_le_upper_bound; auto.
+ ring_simplify.
+ apply carry_sequence_carry_full_bounds_same; auto.
+ omega.
+ Qed.
+
+ Lemma carry_full_2_bounds_succ : forall us i, pre_carry_bounds us -> c_carry_constraint ->
+ (length us = length base)%nat -> (0 < i < pred (length base))%nat ->
+ ((0 < i < length base)%nat ->
+ 0 <= nth_default 0
+ (carry_sequence (make_chain i) (carry_full (carry_full us))) i <=
+ 2 ^ log_cap i) ->
+ 0 <= nth_default 0 (carry_simple i
+ (carry_sequence (make_chain i) (carry_full (carry_full us)))) (S i) <= 2 ^ log_cap (S i).
+ Proof.
+ unfold carry_simple; intros ? ? PCB CCC length_eq ? IH.
+ add_set_nth.
+ split.
+ + apply Z.add_nonneg_nonneg.
+ apply Z.shiftr_nonneg.
+ destruct i;
+ [ simpl; pose proof (carry_full_2_bounds_0 us PCB CCC length_eq); omega | ].
+ - assert (0 < S i < length base)%nat as IHpre by omega.
+ specialize (IH IHpre).
+ omega.
+ - rewrite carry_sequence_unaffected by carry_length_conditions.
+ apply carry_full_bounds; carry_length_conditions.
+ intros.
+ apply carry_sequence_bounds_lower; eauto; carry_length_conditions.
+ + rewrite <-max_bound_log_cap, <-Z.add_1_l.
+ rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg.
+ apply Z.add_le_mono.
+ - apply Z.div_le_upper_bound; auto.
+ ring_simplify. apply IH. omega.
+ - rewrite carry_sequence_unaffected by carry_length_conditions.
+ apply carry_full_bounds; carry_length_conditions.
+ intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions.
+ Qed.
+
+ Lemma carry_full_2_bounds_same : forall us i, pre_carry_bounds us -> c_carry_constraint ->
+ (length us = length base)%nat -> (0 < i < length base)%nat ->
+ 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) i <= 2 ^ log_cap i.
+ Proof.
+ intros; induction i; try omega.
+ simpl; unfold carry.
+ break_if; try omega.
+ split; (destruct (eq_nat_dec i 0); subst;
+ [ cbv [make_chain carry_sequence fold_right carry_simple]; add_set_nth
+ | eapply carry_full_2_bounds_succ; eauto; omega]).
+ + apply Z.add_nonneg_nonneg.
+ apply Z.shiftr_nonneg.
+ eapply carry_full_2_bounds_0; eauto.
+ eapply carry_full_bounds; eauto; carry_length_conditions.
+ intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions.
+ + rewrite <-max_bound_log_cap, <-Z.add_1_l.
+ rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg.
+ apply Z.add_le_mono.
+ - apply Z_div_floor; auto.
+ eapply Z.le_lt_trans; [ eapply carry_full_2_bounds_0; eauto | ].
+ replace (Z.succ 1) with (2 ^ 1) by ring.
+ rewrite <-Z.pow_add_r by (omega || auto).
+ unfold c_carry_constraint in *.
+ intuition.
+ - apply carry_full_bounds; carry_length_conditions.
+ intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions.
+ Qed.
+
+ Lemma carry_full_2_bounds' : forall us i j, pre_carry_bounds us -> c_carry_constraint ->
+ (length us = length base)%nat -> (0 < i < length base)%nat -> (i + j < length base)%nat -> (j <> 0)%nat ->
+ 0 <= nth_default 0 (carry_sequence (make_chain (i + j)) (carry_full (carry_full us))) i <= max_bound i.
+ Proof.
+ induction j; intros; try omega.
+ split; (destruct j; [ rewrite Nat.add_1_r; simpl
+ | rewrite <-plus_n_Sm; simpl; rewrite carry_unaffected_low by carry_length_conditions; eapply IHj; eauto; omega ]).
+ + apply nth_default_carry_bound_lower; carry_length_conditions.
+ + apply nth_default_carry_bound_upper; carry_length_conditions.
+ Qed.
+
+ Lemma carry_full_2_bounds : forall us i j, pre_carry_bounds us -> c_carry_constraint ->
+ (length us = length base)%nat -> (0 < i < length base)%nat -> (i < j < length base)%nat ->
+ 0 <= nth_default 0 (carry_sequence (make_chain j) (carry_full (carry_full us))) i <= max_bound i.
+ Proof.
+ intros.
+ replace j with (i + (j - i))%nat by omega.
+ eapply carry_full_2_bounds'; eauto; omega.
+ Qed.
+
+ Lemma carry_carry_full_2_bounds_0_lower : forall us i, pre_carry_bounds us -> c_carry_constraint ->
+ (length us = length base)%nat -> (0 < i < length base)%nat ->
+ (0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) 0).
+ Proof.
+ induction i; try omega.
+ intros ? ? length_eq ?; simpl.
+ destruct i.
+ + unfold carry.
+ break_if;
+ [ pose proof base_length_nonzero; replace (length base) with 1%nat in *; omega | ].
+ simpl.
+ unfold carry_simple.
+ add_set_nth.
+ apply pow2_mod_log_cap_bounds_lower.
+ + rewrite carry_unaffected_low by carry_length_conditions.
+ assert (0 < S i < length base)%nat by omega.
+ intuition.
+ Qed.
+
+ Lemma carry_full_2_bounds_lower :forall us i, pre_carry_bounds us -> c_carry_constraint ->
+ (length us = length base)%nat ->
+ 0 <= nth_default 0 (carry_full (carry_full us)) i.
+ Proof.
+ intros.
+ destruct i.
+ + apply carry_full_2_bounds_0; auto.
+ + apply carry_full_bounds; try solve [carry_length_conditions].
+ intro j.
+ destruct j.
+ - apply carry_full_bounds_0; auto.
+ - apply carry_full_bounds; carry_length_conditions.
+ Qed.
+
+ Lemma carry_carry_full_2_bounds_0_upper : forall us i, pre_carry_bounds us -> c_carry_constraint ->
+ (length us = length base)%nat -> (0 < i < length base)%nat ->
+ (nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) 0 <= max_bound 0 - c)
+ \/ carry_done (carry_sequence (make_chain i) (carry_full (carry_full us))).
+ Proof.
+ induction i; try omega.
+ intros ? ? length_eq ?; simpl.
+ destruct i.
+ + destruct (Z_le_dec (nth_default 0 (carry_full (carry_full us)) 0) (max_bound 0)).
+ - right.
+ unfold carry_done.
+ intros.
+ apply max_bound_shiftr_eq_0; simpl; rewrite carry_nothing; try solve [carry_length_conditions].
+ * apply carry_full_2_bounds_lower; auto.
+ * split; try apply carry_full_2_bounds_lower; auto.
+ * destruct i; auto.
+ apply carry_full_bounds; try solve [carry_length_conditions].
+ auto using carry_full_bounds_lower.
+ * split; auto.
+ apply carry_full_2_bounds_lower; auto.
+ - unfold carry.
+ break_if;
+ [ pose proof base_length_nonzero; replace (length base) with 1%nat in *; omega | ].
+ simpl.
+ unfold carry_simple.
+ add_set_nth. left.
+ remember ((nth_default 0 (carry_full (carry_full us)) 0)) as x.
+ apply Z.le_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)).
+ * replace x with ((x - 2 ^ log_cap 0) + (1 * 2 ^ log_cap 0)) by ring.
+ rewrite pow2_mod_spec by auto.
+ rewrite Z.mod_add by (pose proof (pow_2_log_cap_pos 0); omega).
+ rewrite <-max_bound_log_cap, <-Z.add_1_l, Z.mod_small.
+ apply Z.sub_le_mono_r.
+ subst; apply carry_full_2_bounds_0; auto.
+ split; try omega.
+ pose proof carry_full_2_bounds_0.
+ apply Z.le_lt_trans with (m := (max_bound 0 + c) - (1 + max_bound 0));
+ [ apply Z.sub_le_mono_r; subst x; apply carry_full_2_bounds_0; auto;
+ ring_simplify; unfold c_carry_constraint in *; omega | ].
+ ring_simplify; unfold c_carry_constraint in *; omega.
+ * ring_simplify; unfold c_carry_constraint in *; omega.
+ + rewrite carry_unaffected_low by carry_length_conditions.
+ assert (0 < S i < length base)%nat by omega.
+ intuition.
+ right.
+ apply carry_carry_done_done; try solve [carry_length_conditions].
+ intro j.
+ destruct j.
+ - apply carry_carry_full_2_bounds_0_lower; auto.
+ - destruct (lt_eq_lt_dec j i) as [[? | ?] | ?].
+ * apply carry_full_2_bounds; auto; omega.
+ * subst. apply carry_full_2_bounds_same; auto; omega.
+ * rewrite carry_sequence_unaffected; try solve [carry_length_conditions].
+ apply carry_full_2_bounds_lower; auto; omega.
+ Qed.
+
+ (* END proofs about second carry loop *)
+
+ (* BEGIN proofs about third carry loop *)
+
+ Lemma carry_full_3_bounds : forall us i, pre_carry_bounds us -> c_carry_constraint ->
+ (length us = length base)%nat ->(i < length base)%nat ->
+ 0 <= nth_default 0 (carry_full (carry_full (carry_full us))) i <= max_bound i.
+ Proof.
+ intros.
+ destruct i; [ | apply carry_full_bounds; carry_length_conditions;
+ do 2 (intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions) ].
+ unfold carry_full at 1 4, full_carry_chain.
+ case_eq limb_widths; [intros; pose proof limb_widths_nonnil; congruence | ].
+ simpl.
+ intros ? ? limb_widths_eq.
+ replace (length l) with (pred (length limb_widths)) by (rewrite limb_widths_eq; auto).
+ rewrite <- base_length.
+ unfold carry, carry_and_reduce; break_if; try omega; intros.
+ add_set_nth.
+ split.
+ + apply Z.add_nonneg_nonneg.
+ - apply Z.mul_nonneg_nonneg; auto; try omega.
+ apply Z.shiftr_nonneg.
+ eapply carry_full_2_bounds_same; eauto; omega.
+ - eapply carry_carry_full_2_bounds_0_lower; eauto; omega.
+ + pose proof (carry_carry_full_2_bounds_0_upper us (pred (length base))).
+ assert (0 < pred (length base) < length base)%nat by omega.
+ intuition.
+ - replace (max_bound 0) with (c + (max_bound 0 - c)) by ring.
+ apply Z.add_le_mono; try assumption.
+ replace c with (c * 1) at 2 by ring.
+ apply Z.mul_le_mono_pos_l; try omega.
+ rewrite Z.shiftr_div_pow2 by auto.
+ apply Z.div_le_upper_bound; auto.
+ ring_simplify.
+ apply carry_full_2_bounds_same; auto.
+ - match goal with H : carry_done _ |- _ => unfold carry_done in H; rewrite H by omega end.
+ ring_simplify.
+ apply shiftr_eq_0_max_bound; auto; omega.
+ Qed.
+
+ Lemma nth_error_combine : forall {A B} i (x : A) (x' : B) l l', nth_error l i = Some x ->
+ nth_error l' i = Some x' -> nth_error (combine l l') i = Some (x, x').
+ Admitted.
+
+ Lemma nth_error_range : forall {A} i (l : list A), (i < length l)%nat ->
+ nth_error (range (length l)) i = Some i.
+ Admitted.
+
+ (* END proofs about third carry loop *)
+ Opaque carry_full.
+
+ Lemma freeze_in_bounds : forall us i, (us <> nil)%nat ->
+ 0 <= nth_default 0 (freeze us) i < 2 ^ log_cap i.
+ Proof.
+ Admitted.
+
+ Lemma freeze_canonical : forall us vs x, rep us x -> rep vs x ->
+ freeze us = freeze vs.
+ Admitted.
+
+End CanonicalizationProofs. \ No newline at end of file
diff --git a/src/ModularArithmetic/PrimeFieldTheorems.v b/src/ModularArithmetic/PrimeFieldTheorems.v
index 77d84c455..70a2c4a87 100644
--- a/src/ModularArithmetic/PrimeFieldTheorems.v
+++ b/src/ModularArithmetic/PrimeFieldTheorems.v
@@ -9,6 +9,7 @@ Require Import Coq.Classes.Morphisms Coq.Setoids.Setoid.
Require Import Coq.ZArith.BinInt Coq.NArith.BinNat Coq.ZArith.ZArith Coq.ZArith.Znumtheory Coq.NArith.NArith. (* import Zdiv before Znumtheory *)
Require Import Coq.Logic.Eqdep_dec.
Require Import Crypto.Util.NumTheoryUtil Crypto.Util.ZUtil.
+Require Import Crypto.Util.Tactics.
Existing Class prime.
@@ -67,7 +68,7 @@ Module FieldModulo (Import M : PrimeModulus).
postprocess [Fpostprocess],
constants [Fconstant],
div morph_div_theory_modulo,
- power_tac power_theory_modulo [Fexp_tac]).
+ power_tac power_theory_modulo [Fexp_tac]).
End FieldModulo.
Section VariousModPrime.
@@ -79,8 +80,8 @@ Section VariousModPrime.
postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption],
constants [Fconstant],
div (@Fmorph_div_theory q),
- power_tac (@Fpower_theory q) [Fexp_tac]).
-
+ power_tac (@Fpower_theory q) [Fexp_tac]).
+
Lemma Fq_mul_eq : forall x y z : F q, z <> 0 -> x * z = y * z -> x = y.
Proof.
intros ? ? ? z_nonzero mul_z_eq.
@@ -119,47 +120,26 @@ Section VariousModPrime.
eauto using Fq_inv_unique'.
Qed.
- Let inv_fermat_powmod (x:Z) : Z := powmod q x (Z.to_N (q-2)).
- Lemma FieldToZ_inv_efficient : 2 < q ->
- forall x : F q, FieldToZ (inv x) = inv_fermat_powmod x.
- Proof.
- intros.
- rewrite (fun pf => Fq_inv_unique (fun x : F q => ZToField (inv_fermat_powmod (FieldToZ x))) pf);
- subst inv_fermat_powmod; intuition; rewrite powmod_Zpow_mod;
- replace (Z.of_N (Z.to_N (q - 2))) with (q-2)%Z by (rewrite Z2N.id ; omega).
- - (* inv in range *) rewrite FieldToZ_ZToField, Zmod_mod; reflexivity.
- - (* inv 0 *) replace (FieldToZ 0) with 0%Z by auto.
- rewrite Z.pow_0_l by omega.
- rewrite Zmod_0_l; trivial.
- - (* inv nonzero *) rewrite <- (fermat_inv q _ x0) by
- (rewrite mod_FieldToZ; eauto using FieldToZ_nonzero).
- rewrite <-(ZToField_FieldToZ x0).
- rewrite <-ZToField_mul.
- rewrite ZToField_FieldToZ.
- apply ZToField_eqmod.
- demod; reflexivity.
- Qed.
-
Lemma Fq_mul_zero_why : forall a b : F q, a*b = 0 -> a = 0 \/ b = 0.
intros.
assert (Z := F_eq_dec a 0); destruct Z.
-
+
- left; intuition.
-
+
- assert (a * b / a = 0) by
( rewrite H; field; intuition ).
-
+
replace (a*b/a) with (b) in H0 by (field; trivial).
right; intuition.
Qed.
-
+
Lemma Fq_mul_nonzero_nonzero : forall a b : F q, a <> 0 -> b <> 0 -> a*b <> 0.
intros; intuition; subst.
apply Fq_mul_zero_why in H1.
destruct H1; subst; intuition.
Qed.
Hint Resolve Fq_mul_nonzero_nonzero.
-
+
Lemma Fq_pow_zero : forall (p: N), p <> 0%N -> (0^p = @ZToField q 0)%F.
induction p using N.peano_ind;
rewrite <-?N.add_1_l, ?(proj2 (@F_pow_spec q _) _), ?(proj1 (@F_pow_spec q _)).
@@ -193,6 +173,43 @@ Section VariousModPrime.
+ apply IHp; auto.
Qed.
+ Lemma F_inv_0 : inv 0 = (0 : F q).
+ Proof.
+ destruct (@F_inv_spec q); auto.
+ Qed.
+
+ Definition inv_fermat (x:F q) : F q := x ^ Z.to_N (q - 2)%Z.
+ Lemma Fq_inv_fermat: 2 < q -> forall x : F q, inv x = x ^ Z.to_N (q - 2)%Z.
+ Proof.
+ intros.
+ erewrite (Fq_inv_unique inv_fermat); trivial; split; intros; unfold inv_fermat.
+ { replace 2%N with (Z.to_N (Z.of_N 2))%N by auto.
+ rewrite Fq_pow_zero. reflexivity. intro.
+ assert (Z.of_N (Z.to_N (q-2)) = 0%Z) by (rewrite H0; eauto); rewrite Z2N.id in *; omega. }
+ { clear x. rename x0 into x. pose proof (fermat_inv _ _ x) as Hf; forward Hf.
+ { pose proof @ZToField_FieldToZ; pose proof @ZToField_mod; congruence. }
+ specialize (Hf H1); clear H1.
+ rewrite <-(ZToField_FieldToZ x).
+ rewrite ZToField_pow.
+ replace (Z.of_N (Z.to_N (q - 2))) with (q-2)%Z by (rewrite Z2N.id ; omega).
+ rewrite <-ZToField_mul.
+ apply ZToField_eqmod.
+ rewrite Hf, Zmod_small by omega; reflexivity.
+ }
+ Qed.
+ Lemma Fq_inv_fermat_correct : 2 < q -> forall x : F q, inv_fermat x = inv x.
+ Proof.
+ unfold inv_fermat; intros. rewrite Fq_inv_fermat; auto.
+ Qed.
+
+ Let inv_fermat_powmod (x:Z) : Z := powmod q x (Z.to_N (q-2)).
+ Lemma FieldToZ_inv_efficient : 2 < q ->
+ forall x : F q, FieldToZ (inv x) = inv_fermat_powmod x.
+ Proof.
+ unfold inv_fermat_powmod; intros.
+ rewrite Fq_inv_fermat, powmod_Zpow_mod, <-FieldToZ_pow_Zpow_mod; auto.
+ Qed.
+
Lemma F_minus_swap : forall x y : F q, x - y = 0 -> y - x = 0.
Proof.
intros ? ? eq_zero.
@@ -232,7 +249,7 @@ Section VariousModPrime.
left.
eapply Fq_square_mul; eauto.
instantiate (1 := x).
- assert (x ^ 2 = z * y ^ 2 - x ^ 2 + x ^ 2) as plus_minus_x2 by
+ assert (x ^ 2 = z * y ^ 2 - x ^ 2 + x ^ 2) as plus_minus_x2 by
(rewrite <- eq_zero_sub; ring).
rewrite plus_minus_x2; ring.
}
@@ -248,6 +265,11 @@ Section VariousModPrime.
intros; field. (* TODO: Warning: Collision between bound variables ... *)
Qed.
+ Lemma F_div_opp_1 : forall x y : F q, (opp x / y = opp (x / y))%F.
+ Proof.
+ intros; destruct (F_eq_dec y 0); [subst;unfold div;rewrite !F_inv_0|]; field.
+ Qed.
+
Lemma F_eq_opp_zero : forall x : F q, 2 < q -> (x = opp x <-> x = 0).
Proof.
split; intro A.
@@ -352,6 +374,22 @@ Section VariousModPrime.
Proof.
econstructor; eauto using Fq_mul_zero_why, Fq_1_neq_0.
Qed.
+
+ Lemma add_cancel_mul_r_nonzero {y : F q} (H : y <> 0) (x z : F q)
+ : x * y + z = (x + (z * (inv y))) * y.
+ Proof. field. Qed.
+
+ Lemma sub_cancel_mul_r_nonzero {y : F q} (H : y <> 0) (x z : F q)
+ : x * y - z = (x - (z * (inv y))) * y.
+ Proof. field. Qed.
+
+ Lemma add_cancel_l_nonzero {y : F q} (H : y <> 0) (z : F q)
+ : y + z = (1 + (z * (inv y))) * y.
+ Proof. field. Qed.
+
+ Lemma sub_cancel_l_nonzero {y : F q} (H : y <> 0) (z : F q)
+ : y - z = (1 - (z * (inv y))) * y.
+ Proof. field. Qed.
End VariousModPrime.
Section SquareRootsPrime5Mod8.
@@ -367,7 +405,7 @@ Section SquareRootsPrime5Mod8.
postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption],
constants [Fconstant],
div (@Fmorph_div_theory q),
- power_tac (@Fpower_theory q) [Fexp_tac]).
+ power_tac (@Fpower_theory q) [Fexp_tac]).
(* This is only the square root of -1 if q mod 8 is 3 or 5 *)
Definition sqrt_minus1 : F q := ZToField 2 ^ Z.to_N (q / 4).
@@ -382,7 +420,7 @@ Section SquareRootsPrime5Mod8.
Qed.
(* square root mod q relies on the fact that q is 5 mod 8 *)
- Definition sqrt_mod_q (a : F q) :=
+ Definition sqrt_mod_q (a : F q) :=
let b := a ^ Z.to_N (q / 8 + 1) in
(match (F_eq_dec (b ^ 2) a) with
| left A => b
@@ -445,4 +483,85 @@ Section SquareRootsPrime5Mod8.
field.
Qed.
-End SquareRootsPrime5Mod8. \ No newline at end of file
+ Lemma sqrt_mod_q_of_0 : sqrt_mod_q 0 = 0.
+ Proof.
+ unfold sqrt_mod_q.
+ rewrite !Fq_pow_zero.
+ break_if; ring.
+
+ congruence.
+ intro false_eq.
+ rewrite <-(N2Z.id 0) in false_eq.
+ rewrite N2Z.inj_0 in false_eq.
+ pose proof (prime_ge_2 q prime_q).
+ apply Z2N.inj in false_eq; zero_bounds.
+ assert (0 < q / 8 + 1)%Z.
+ apply Z.add_nonneg_pos; zero_bounds.
+ omega.
+ Qed.
+
+ Lemma sqrt_mod_q_root_0 : forall x : F q, sqrt_mod_q x = 0 -> x = 0.
+ Proof.
+ unfold sqrt_mod_q; intros.
+ break_if.
+ - match goal with [ H : ?sqrt_x ^ 2 = x, H' : ?sqrt_x = 0 |- _ ] => rewrite <-H, H' end.
+ ring.
+ - match goal with
+ | [H : sqrt_minus1 * _ = 0 |- _ ]=>
+ apply Fq_mul_zero_why in H; destruct H as [sqrt_minus1_zero | ? ];
+ [ | eapply Fq_root_zero; eauto ]
+ end.
+ unfold sqrt_minus1 in sqrt_minus1_zero.
+ rewrite sqrt_minus1_zero in sqrt_minus1_valid.
+ exfalso.
+ pose proof (@F_opp_spec q 1) as opp_spec_1.
+ rewrite <-sqrt_minus1_valid in opp_spec_1.
+ assert (((1 + 0 ^ 2) : F q) = (1 : F q)) as ring_subst by ring.
+ rewrite ring_subst in *.
+ apply Fq_1_neq_0; assumption.
+ Qed.
+
+End SquareRootsPrime5Mod8.
+
+Local Open Scope F_scope.
+(** Tactics for solving inequalities. *)
+Ltac solve_cancel_by_field y tnz :=
+ solve [ generalize dependent y; intros;
+ field; tnz ].
+
+Ltac cancel_nonzero_factors' tnz :=
+ idtac;
+ match goal with
+ | [ |- ?x = 0 -> False ]
+ => change (x <> 0)
+ | [ |- ?x * ?y <> 0 ]
+ => apply Fq_mul_nonzero_nonzero
+ | [ H : ?y <> 0 |- _ ]
+ => progress rewrite ?(add_cancel_mul_r_nonzero H), ?(sub_cancel_mul_r_nonzero H), ?(add_cancel_l_nonzero H), ?(sub_cancel_l_nonzero H);
+ apply Fq_mul_nonzero_nonzero; [ | assumption ]
+ | [ |- ?op (?y * (ZToField (m := ?q) ?n)) ?z <> 0 ]
+ => unique assert (ZToField (m := q) n <> 0) by tnz;
+ generalize dependent (ZToField (m := q) n); intros
+ | [ |- ?op (?x * (?y * ?z)) _ <> 0 ]
+ => rewrite F_mul_assoc
+ end.
+Ltac cancel_nonzero_factors tnz := repeat cancel_nonzero_factors' tnz.
+Ltac specialize_quantified_equalities :=
+ repeat match goal with
+ | [ H : forall _ _ _ _, _ = _ -> _, H' : _ = _ |- _ ]
+ => unique pose proof (fun x2 y2 => H _ _ x2 y2 H')
+ | [ H : forall _ _, _ = _ -> _, H' : _ = _ |- _ ]
+ => unique pose proof (H _ _ H')
+ end.
+Ltac finish_inequality tnz :=
+ idtac;
+ match goal with
+ | [ H : ?x <> 0 |- ?y <> 0 ]
+ => replace y with x by (field; tnz); exact H
+ end.
+Ltac field_nonzero tnz :=
+ cancel_nonzero_factors tnz;
+ try (specialize_quantified_equalities;
+ progress cancel_nonzero_factors tnz);
+ try solve [ specialize_quantified_equalities;
+ finish_inequality tnz ].
diff --git a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
new file mode 100644
index 000000000..1a7b3316e
--- /dev/null
+++ b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
@@ -0,0 +1,246 @@
+Require Import Zpower ZArith.
+Require Import List.
+Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import VerdiTactics.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
+Require Crypto.BaseSystem.
+Local Open Scope Z_scope.
+
+Section PseudoMersenneBaseParamProofs.
+ Context `{prm : PseudoMersenneBaseParams}.
+
+ Fixpoint base_from_limb_widths limb_widths :=
+ match limb_widths with
+ | nil => nil
+ | w :: lw => 1 :: map (Z.mul (two_p w)) (base_from_limb_widths lw)
+ end.
+
+ Definition base := base_from_limb_widths limb_widths.
+
+ Lemma base_length : length base = length limb_widths.
+ Proof.
+ unfold base.
+ induction limb_widths; try reflexivity.
+ simpl; rewrite map_length; auto.
+ Qed.
+
+ Lemma nth_error_first : forall {T} (a b : T) l, nth_error (a :: l) 0 = Some b ->
+ a = b.
+ Proof.
+ intros; simpl in *.
+ unfold value in *.
+ congruence.
+ Qed.
+
+ Lemma base_from_limb_widths_step : forall i b w, (S i < length base)%nat ->
+ nth_error base i = Some b ->
+ nth_error limb_widths i = Some w ->
+ nth_error base (S i) = Some (two_p w * b).
+ Proof.
+ unfold base; induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b;
+ unfold base_from_limb_widths in *; fold base_from_limb_widths in *;
+ [rewrite (@nil_length0 Z) in *; omega | ].
+ simpl in *; rewrite map_length in *.
+ case_eq i; intros; subst.
+ + subst; apply nth_error_first in nth_err_w.
+ apply nth_error_first in nth_err_b; subst.
+ apply map_nth_error.
+ case_eq l; intros; subst; [simpl in *; omega | ].
+ unfold base_from_limb_widths; fold base_from_limb_widths.
+ reflexivity.
+ + simpl in nth_err_w.
+ apply nth_error_map in nth_err_w.
+ destruct nth_err_w as [x [A B]].
+ subst.
+ replace (two_p w * (two_p a * x)) with (two_p a * (two_p w * x)) by ring.
+ apply map_nth_error.
+ apply IHl; auto; omega.
+ Qed.
+
+ Lemma nth_error_exists_first : forall {T} l (x : T) (H : nth_error l 0 = Some x),
+ exists l', l = x :: l'.
+ Proof.
+ induction l; try discriminate; eexists.
+ apply nth_error_first in H.
+ subst; eauto.
+ Qed.
+
+ Lemma sum_firstn_succ : forall l i x,
+ nth_error l i = Some x ->
+ sum_firstn l (S i) = x + sum_firstn l i.
+ Proof.
+ unfold sum_firstn; induction l;
+ [intros; rewrite (@nth_error_nil_error Z) in *; congruence | ].
+ intros ? x nth_err_x; destruct (NPeano.Nat.eq_dec i 0).
+ + subst; simpl in *; unfold value in *.
+ congruence.
+ + rewrite <- (NPeano.Nat.succ_pred i) at 2 by auto.
+ rewrite <- (NPeano.Nat.succ_pred i) in nth_err_x by auto.
+ simpl. simpl in nth_err_x.
+ specialize (IHl (pred i) x).
+ rewrite NPeano.Nat.succ_pred in IHl by auto.
+ destruct (NPeano.Nat.eq_dec (pred i) 0).
+ - replace i with 1%nat in * by omega.
+ simpl. replace (pred 1) with 0%nat in * by auto.
+ apply nth_error_exists_first in nth_err_x.
+ destruct nth_err_x as [l' ?].
+ subst; simpl; ring.
+ - rewrite IHl by auto; ring.
+ Qed.
+
+ Lemma sum_firstn_limb_widths_nonneg : forall n, 0 <= sum_firstn limb_widths n.
+ Proof.
+ unfold sum_firstn; intros.
+ apply fold_right_invariant; try omega.
+ intros y In_y_lw ? ?.
+ apply Z.add_nonneg_nonneg; try assumption.
+ apply limb_widths_nonneg.
+ eapply In_firstn; eauto.
+ Qed.
+
+ Lemma k_nonneg : 0 <= k.
+ Proof.
+ apply sum_firstn_limb_widths_nonneg.
+ Qed.
+
+ Lemma nth_error_base : forall i, (i < length base)%nat ->
+ nth_error base i = Some (two_p (sum_firstn limb_widths i)).
+ Proof.
+ induction i; intros.
+ + unfold base, sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity.
+ intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega.
+ +
+ assert (i < length base)%nat as lt_i_length by omega.
+ specialize (IHi lt_i_length).
+ rewrite base_length in lt_i_length.
+ destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w].
+ erewrite base_from_limb_widths_step; eauto.
+ f_equal.
+ simpl.
+ destruct (NPeano.Nat.eq_dec i 0).
+ - subst; unfold sum_firstn; simpl.
+ apply nth_error_exists_first in nth_err_w.
+ destruct nth_err_w as [l' lw_destruct]; subst.
+ rewrite lw_destruct.
+ ring_simplify.
+ f_equal; simpl; ring.
+ - erewrite sum_firstn_succ; eauto.
+ symmetry.
+ apply two_p_is_exp; auto using sum_firstn_limb_widths_nonneg.
+ apply limb_widths_nonneg.
+ eapply nth_error_value_In; eauto.
+ Qed.
+
+ Lemma nth_default_base : forall d i, (i < length base)%nat ->
+ nth_default d base i = 2 ^ (sum_firstn limb_widths i).
+ Proof.
+ intros ? ? i_lt_length.
+ destruct (nth_error_length_exists_value _ _ i_lt_length) as [x nth_err_x].
+ unfold nth_default.
+ rewrite nth_err_x.
+ rewrite nth_error_base in nth_err_x by assumption.
+ rewrite two_p_correct in nth_err_x.
+ congruence.
+ Qed.
+
+ Lemma base_matches_modulus: forall i j,
+ (i < length base)%nat ->
+ (j < length base)%nat ->
+ (i+j >= length base)%nat->
+ let b := nth_default 0 base in
+ let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in
+ b i * b j = r * (2^k * b (i+j-length base)%nat).
+ Proof.
+ intros.
+ rewrite (Z.mul_comm r).
+ subst r.
+ assert (i + j - length base < length base)%nat by omega.
+ rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.mul_pos_pos;
+ [ | subst b; rewrite nth_default_base; try assumption ];
+ apply Z.pow_pos_nonneg; omega || apply k_nonneg || apply sum_firstn_limb_widths_nonneg).
+ rewrite (Zminus_0_l_reverse (b i * b j)) at 1.
+ f_equal.
+ subst b.
+ repeat rewrite nth_default_base by assumption.
+ do 2 rewrite <- Z.pow_add_r by (apply sum_firstn_limb_widths_nonneg || apply k_nonneg).
+ symmetry.
+ apply mod_same_pow.
+ split.
+ + apply Z.add_nonneg_nonneg; apply sum_firstn_limb_widths_nonneg || apply k_nonneg.
+ + rewrite base_length in *; apply limb_widths_match_modulus; assumption.
+ Qed.
+
+ Lemma base_succ : forall i, ((S i) < length base)%nat ->
+ nth_default 0 base (S i) mod nth_default 0 base i = 0.
+ Proof.
+ intros.
+ repeat rewrite nth_default_base by omega.
+ apply mod_same_pow.
+ split; [apply sum_firstn_limb_widths_nonneg | ].
+ destruct (NPeano.Nat.eq_dec i 0); subst.
+ + case_eq limb_widths; intro; unfold sum_firstn; simpl; try omega; intros l' lw_eq.
+ apply Z.add_nonneg_nonneg; try omega.
+ apply limb_widths_nonneg.
+ rewrite lw_eq.
+ apply in_eq.
+ + assert (i < length base)%nat as i_lt_length by omega.
+ rewrite base_length in *.
+ apply nth_error_length_exists_value in i_lt_length.
+ destruct i_lt_length as [x nth_err_x].
+ erewrite sum_firstn_succ; eauto.
+ apply nth_error_value_In in nth_err_x.
+ apply limb_widths_nonneg in nth_err_x.
+ omega.
+ Qed.
+
+ Lemma nth_error_subst : forall i b, nth_error base i = Some b ->
+ b = 2 ^ (sum_firstn limb_widths i).
+ Proof.
+ intros i b nth_err_b.
+ pose proof (nth_error_value_length _ _ _ _ nth_err_b).
+ rewrite nth_error_base in nth_err_b by assumption.
+ rewrite two_p_correct in nth_err_b.
+ congruence.
+ Qed.
+
+ Lemma base_positive : forall b : Z, In b base -> b > 0.
+ Proof.
+ intros b In_b_base.
+ apply In_nth_error_value in In_b_base.
+ destruct In_b_base as [i nth_err_b].
+ apply nth_error_subst in nth_err_b.
+ rewrite nth_err_b.
+ apply gt_lt_symmetry.
+ apply Z.pow_pos_nonneg; omega || apply sum_firstn_limb_widths_nonneg.
+ Qed.
+
+ Lemma b0_1 : forall x : Z, nth_default x base 0 = 1.
+ Proof.
+ unfold base; case_eq limb_widths; intros; [pose proof limb_widths_nonnil; congruence | reflexivity].
+ Qed.
+
+ Lemma base_good : forall i j : nat,
+ (i + j < length base)%nat ->
+ let b := nth_default 0 base in
+ let r := b i * b j / b (i + j)%nat in
+ b i * b j = r * b (i + j)%nat.
+ Proof.
+ intros; subst b r.
+ repeat rewrite nth_default_base by omega.
+ rewrite (Z.mul_comm _ (2 ^ (sum_firstn limb_widths (i+j)))).
+ rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply sum_firstn_limb_widths_nonneg).
+ rewrite <- Z.pow_add_r by apply sum_firstn_limb_widths_nonneg.
+ rewrite mod_same_pow; try ring.
+ split; [ apply sum_firstn_limb_widths_nonneg | ].
+ apply limb_widths_good.
+ rewrite <- base_length; assumption.
+ Qed.
+
+ Global Instance bv : BaseSystem.BaseVector base := {
+ base_positive := base_positive;
+ b0_1 := b0_1;
+ base_good := base_good
+ }.
+
+End PseudoMersenneBaseParamProofs.
diff --git a/src/ModularArithmetic/PseudoMersenneBaseParams.v b/src/ModularArithmetic/PseudoMersenneBaseParams.v
new file mode 100644
index 000000000..3914d6219
--- /dev/null
+++ b/src/ModularArithmetic/PseudoMersenneBaseParams.v
@@ -0,0 +1,24 @@
+Require Import ZArith.
+Require Import List.
+Require Crypto.BaseSystem.
+Local Open Scope Z_scope.
+
+Definition sum_firstn l n := fold_right Z.add 0 (firstn n l).
+
+Class PseudoMersenneBaseParams (modulus : Z) := {
+ limb_widths : list Z;
+ limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w;
+ limb_widths_nonnil : limb_widths <> nil;
+ limb_widths_good : forall i j, (i + j < length limb_widths)%nat ->
+ sum_firstn limb_widths (i + j) <=
+ sum_firstn limb_widths i + sum_firstn limb_widths j;
+ prime_modulus : Znumtheory.prime modulus;
+ k := sum_firstn limb_widths (length limb_widths);
+ c := 2 ^ k - modulus;
+ limb_widths_match_modulus : forall i j,
+ (i < length limb_widths)%nat ->
+ (j < length limb_widths)%nat ->
+ (i + j >= length limb_widths)%nat ->
+ let w_sum := sum_firstn limb_widths in
+ k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j
+}.
diff --git a/src/ModularArithmetic/PseudoMersenneBaseRep.v b/src/ModularArithmetic/PseudoMersenneBaseRep.v
new file mode 100644
index 000000000..c16cc8d38
--- /dev/null
+++ b/src/ModularArithmetic/PseudoMersenneBaseRep.v
@@ -0,0 +1,50 @@
+Require Import ZArith.
+Require Crypto.BaseSystem.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs Crypto.ModularArithmetic.PseudoMersenneBaseParams.
+Local Open Scope Z_scope.
+
+Class RepZMod (modulus : Z) := {
+ T : Type;
+ encode : F modulus -> T;
+ decode : T -> F modulus;
+
+ rep : T -> F modulus -> Prop;
+ encode_rep : forall x, rep (encode x) x;
+ rep_decode : forall u x, rep u x -> decode u = x;
+
+ add : T -> T -> T;
+ add_rep : forall u v x y, rep u x -> rep v y -> rep (add u v) (x+y)%F;
+
+ sub : T -> T -> T;
+ sub_rep : forall u v x y, rep u x -> rep v y -> rep (sub u v) (x-y)%F;
+
+ mul : T -> T -> T;
+ mul_rep : forall u v x y, rep u x -> rep v y -> rep (mul u v) (x*y)%F
+}.
+
+Class SubtractionCoefficient (m : Z) (prm : PseudoMersenneBaseParams m) := {
+ coeff : BaseSystem.digits;
+ coeff_length : (length coeff <= length PseudoMersenneBaseParamProofs.base)%nat;
+ coeff_mod: (BaseSystem.decode PseudoMersenneBaseParamProofs.base coeff) mod m = 0
+}.
+
+Instance PseudoMersenneBase m (prm : PseudoMersenneBaseParams m) (sc : SubtractionCoefficient m prm)
+: RepZMod m := {
+ T := list Z;
+ encode := ModularBaseSystem.encode;
+ decode := ModularBaseSystem.decode;
+
+ rep := ModularBaseSystem.rep;
+ encode_rep := ModularBaseSystemProofs.encode_rep;
+ rep_decode := ModularBaseSystemProofs.rep_decode;
+
+ add := BaseSystem.add;
+ add_rep := ModularBaseSystemProofs.add_rep;
+
+ sub := ModularBaseSystem.sub coeff coeff_mod;
+ sub_rep := ModularBaseSystemProofs.sub_rep coeff coeff_mod coeff_length;
+
+ mul := ModularBaseSystem.mul;
+ mul_rep := ModularBaseSystemProofs.mul_rep
+}.
diff --git a/src/Rep.v b/src/Rep.v
new file mode 100644
index 000000000..b7e7f10c5
--- /dev/null
+++ b/src/Rep.v
@@ -0,0 +1,13 @@
+Class RepConversions (T:Type) (RT:Type) : Type :=
+ {
+ toRep : T -> RT;
+ unRep : RT -> T
+ }.
+
+Definition RepConversionsOK {T RT} (RC:RepConversions T RT) := forall x, unRep (toRep x) = x.
+
+Definition RepFunOK {T RT} `(RC:RepConversions T RT) (f:T->T) (rf : RT -> RT) :=
+ forall x, f (unRep x) = unRep (rf x).
+
+Definition RepBinOpOK {T RT} `(RC:RepConversions T RT) (op:T->T->T) (rop : RT -> RT -> RT) :=
+ forall x y, op (unRep x) (unRep y) = unRep (rop x y).
diff --git a/src/Spec/CompleteEdwardsCurve.v b/src/Spec/CompleteEdwardsCurve.v
index b7d2c0d8e..3348be1d9 100644
--- a/src/Spec/CompleteEdwardsCurve.v
+++ b/src/Spec/CompleteEdwardsCurve.v
@@ -16,41 +16,39 @@ Class TwistedEdwardsParams := {
nonsquare_d : forall x, x^2 <> d
}.
-Section TwistedEdwardsCurves.
- Context {prm:TwistedEdwardsParams}.
-
- (* Twisted Edwards curves with complete addition laws. References:
- * <https://eprint.iacr.org/2008/013.pdf>
- * <http://ed25519.cr.yp.to/ed25519-20110926.pdf>
- * <https://eprint.iacr.org/2015/677.pdf>
- *)
- Definition onCurve P := let '(x,y) := P in a*x^2 + y^2 = 1 + d*x^2*y^2.
- Definition point := { P | onCurve P}.
- Definition mkPoint (xy:F q * F q) (pf:onCurve xy) : point := exist onCurve xy pf.
-
- Definition zero : point := mkPoint (0, 1) (@Pre.zeroOnCurve _ _ _ prime_q).
+Module E.
+ Section TwistedEdwardsCurves.
+ Context {prm:TwistedEdwardsParams}.
+
+ (* Twisted Edwards curves with complete addition laws. References:
+ * <https://eprint.iacr.org/2008/013.pdf>
+ * <http://ed25519.cr.yp.to/ed25519-20110926.pdf>
+ * <https://eprint.iacr.org/2015/677.pdf>
+ *)
+ Definition onCurve P := let '(x,y) := P in a*x^2 + y^2 = 1 + d*x^2*y^2.
+ Definition point := { P | onCurve P}.
+
+ Definition zero : point := exist _ (0, 1) (@Pre.zeroOnCurve _ _ _ prime_q).
+
+ Definition add' P1' P2' :=
+ let '(x1, y1) := P1' in
+ let '(x2, y2) := P2' in
+ (((x1*y2 + y1*x2)/(1 + d*x1*x2*y1*y2)) , ((y1*y2 - a*x1*x2)/(1 - d*x1*x2*y1*y2))).
+
+ Definition add (P1 P2 : point) : point :=
+ let 'exist P1' pf1 := P1 in
+ let 'exist P2' pf2 := P2 in
+ exist _ (add' P1' P2')
+ (@Pre.unifiedAdd'_onCurve _ _ _ prime_q two_lt_q nonzero_a square_a nonsquare_d _ _ pf1 pf2).
+
+ Fixpoint mul (n:nat) (P : point) : point :=
+ match n with
+ | O => zero
+ | S n' => add P (mul n' P)
+ end.
+ End TwistedEdwardsCurves.
+End E.
- Definition unifiedAdd' P1' P2' :=
- let '(x1, y1) := P1' in
- let '(x2, y2) := P2' in
- (((x1*y2 + y1*x2)/(1 + d*x1*x2*y1*y2)) , ((y1*y2 - a*x1*x2)/(1 - d*x1*x2*y1*y2))).
-
- Definition unifiedAdd (P1 P2 : point) : point :=
- let 'exist P1' pf1 := P1 in
- let 'exist P2' pf2 := P2 in
- mkPoint (unifiedAdd' P1' P2')
- (@Pre.unifiedAdd'_onCurve _ _ _ prime_q two_lt_q nonzero_a square_a nonsquare_d _ _ pf1 pf2).
-
- Fixpoint scalarMult (n:nat) (P : point) : point :=
- match n with
- | O => zero
- | S n' => unifiedAdd P (scalarMult n' P)
- end.
-
- Axiom point_eq_dec : forall P Q : point, {P = Q} + {P <> Q}.
-End TwistedEdwardsCurves.
-
Delimit Scope E_scope with E.
-Infix "+" := unifiedAdd : E_scope.
-Infix "*" := scalarMult : E_scope.
-Infix "==" := point_eq_dec (no associativity, at level 70) : E_scope.
+Infix "+" := E.add : E_scope.
+Infix "*" := E.mul : E_scope. \ No newline at end of file
diff --git a/src/Spec/Ed25519.v b/src/Spec/Ed25519.v
index 6ab47e8e5..4876bb8d1 100644
--- a/src/Spec/Ed25519.v
+++ b/src/Spec/Ed25519.v
@@ -1,6 +1,7 @@
Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory.
Require Import Coq.Numbers.Natural.Peano.NPeano Coq.NArith.NArith.
-Require Import Crypto.Spec.Encoding Crypto.Spec.PointEncoding.
+Require Import Crypto.Spec.PointEncoding Crypto.Spec.ModularWordEncoding.
+Require Import Crypto.Encoding.ModularWordEncodingTheorems.
Require Import Crypto.Spec.EdDSA.
Require Import Crypto.Spec.CompleteEdwardsCurve Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems Crypto.ModularArithmetic.ModularArithmeticTheorems.
@@ -114,8 +115,10 @@ Proof.
compute; omega.
Qed.
+Require Import Crypto.Spec.Encoding.
+
Lemma q_pos : (0 < q)%Z. q_bound. Qed.
-Definition FqEncoding : encoding of (F q) as word (b-1) :=
+Definition FqEncoding : canonical encoding of (F q) as word (b-1) :=
@modular_word_encoding q (b - 1) q_pos b_valid.
Lemma l_pos : (0 < Z.of_nat l)%Z. pose proof prime_l; prime_bound. Qed.
@@ -127,7 +130,7 @@ Proof.
unfold l.
apply Z2Nat.inj_lt; compute; congruence.
Qed.
-Definition FlEncoding : encoding of F (Z.of_nat l) as word b :=
+Definition FlEncoding : canonical encoding of F (Z.of_nat l) as word b :=
@modular_word_encoding (Z.of_nat l) b l_pos l_bound.
Lemma q_5mod8 : (q mod 8 = 5)%Z. cbv; reflexivity. Qed.
@@ -140,12 +143,14 @@ Proof.
reflexivity.
Qed.
-Definition PointEncoding := @point_encoding curve25519params (b - 1) FqEncoding q_5mod8 sqrt_minus1_valid.
+Definition PointEncoding : canonical encoding of E.point as (word b) :=
+ (@point_encoding curve25519params (b - 1) q_5mod8 sqrt_minus1_valid FqEncoding sign_bit
+ (@sign_bit_zero _ prime_q two_lt_q _ b_valid) (@sign_bit_opp _ prime_q two_lt_q _ b_valid)).
Definition H : forall n : nat, word n -> word (b + b). Admitted.
-Definition B : point. Admitted. (* TODO: B = decodePoint (y=4/5, x="positive") *)
-Definition B_nonzero : B <> zero. Admitted.
-Definition l_order_B : scalarMult l B = zero. Admitted.
+Definition B : E.point. Admitted. (* TODO: B = decodePoint (y=4/5, x="positive") *)
+Definition B_nonzero : B <> E.zero. Admitted.
+Definition l_order_B : (l * B)%E = E.zero. Admitted.
Local Instance ed25519params : EdDSAParams := {
E := curve25519params;
diff --git a/src/Spec/EdDSA.v b/src/Spec/EdDSA.v
index 3decae6a7..99f0766e0 100644
--- a/src/Spec/EdDSA.v
+++ b/src/Spec/EdDSA.v
@@ -6,6 +6,7 @@ Require Import Crypto.Util.WordUtil.
Require Bedrock.Word.
Require Coq.ZArith.Znumtheory Coq.ZArith.BinInt.
Require Coq.Numbers.Natural.Peano.NPeano.
+Require Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems.
Coercion Word.wordToNat : Word.word >-> nat.
@@ -20,8 +21,8 @@ Section EdDSAParams.
b : nat; (* public keys are k bits, signatures are 2*k bits *)
b_valid : 2^(b - 1) > BinInt.Z.to_nat q;
- FqEncoding : encoding of F q as Word.word (b-1);
- PointEncoding : encoding of point as Word.word b;
+ FqEncoding : canonical encoding of F q as Word.word (b-1);
+ PointEncoding : canonical encoding of E.point as Word.word b;
H : forall {n}, Word.word n -> Word.word (b + b); (* main hash function *)
@@ -32,14 +33,14 @@ Section EdDSAParams.
n_ge_c : n >= c;
n_le_b : n <= b;
- B : point;
- B_not_identity : B <> zero;
+ B : E.point;
+ B_not_identity : B <> E.zero;
l : nat; (* order of the subgroup of E generated by B *)
l_prime : Znumtheory.prime (BinInt.Z.of_nat l);
l_odd : l > 2;
- l_order_B : (l*B)%E = zero;
- FlEncoding : encoding of F (BinInt.Z.of_nat l) as Word.word b
+ l_order_B : (l*B)%E = E.zero;
+ FlEncoding : canonical encoding of F (BinInt.Z.of_nat l) as Word.word b
}.
End EdDSAParams.
@@ -54,6 +55,7 @@ Section EdDSA.
Notation secretkey := (Word.word b) (only parsing).
Notation publickey := (Word.word b) (only parsing).
Notation signature := (Word.word (b + b)) (only parsing).
+ Local Infix "==" := CompleteEdwardsCurveTheorems.E.point_eq_dec (at level 70) : E_scope .
(* TODO: proofread curveKey and definition of n *)
Definition curveKey (sk:secretkey) : nat :=
@@ -65,7 +67,7 @@ Section EdDSA.
Definition sign (A_:publickey) sk {n} (M : Word.word n) :=
let r : nat := H (prngKey sk ++ M) in (* secret nonce *)
- let R : point := (r * B)%E in (* commitment to nonce *)
+ let R : E.point := (r * B)%E in (* commitment to nonce *)
let s : nat := curveKey sk in (* secret scalar *)
let S : F (BinInt.Z.of_nat l) := ZToField (BinInt.Z.of_nat
(r + H (enc R ++ public sk ++ M) * s)) in
@@ -75,11 +77,11 @@ Section EdDSA.
let R_ := Word.split1 b b sig in
let S_ := Word.split2 b b sig in
match dec S_ : option (F (BinInt.Z.of_nat l)) with None => false | Some S' =>
- match dec A_ : option point with None => false | Some A =>
- match dec R_ : option point with None => false | Some R =>
+ match dec A_ : option E.point with None => false | Some A =>
+ match dec R_ : option E.point with None => false | Some R =>
if BinInt.Z.to_nat (FieldToZ S') * B == R + (H (R_ ++ A_ ++ M)) * A
then true else false
end
end
end%E.
-End EdDSA.
+End EdDSA. \ No newline at end of file
diff --git a/src/Spec/Encoding.v b/src/Spec/Encoding.v
index 14cf9d9d9..b063b638f 100644
--- a/src/Spec/Encoding.v
+++ b/src/Spec/Encoding.v
@@ -1,61 +1,8 @@
-Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith.
-Require Import Coq.Numbers.Natural.Peano.NPeano.
-Require Import Crypto.ModularArithmetic.PrimeFieldTheorems Crypto.ModularArithmetic.ModularArithmeticTheorems.
-Require Import Bedrock.Word.
-Require Import Crypto.Tactics.VerdiTactics.
-Require Import Crypto.Util.NatUtil.
-Require Import Crypto.Util.WordUtil.
-
-Class Encoding (T B:Type) := {
+Class CanonicalEncoding (T B:Type) := {
enc : T -> B ;
dec : B -> option T ;
- encoding_valid : forall x, dec (enc x) = Some x
+ encoding_valid : forall x, dec (enc x) = Some x ;
+ encoding_canonical : forall x_enc x, dec x_enc = Some x -> enc x = x_enc
}.
-Notation "'encoding' 'of' T 'as' B" := (Encoding T B) (at level 50).
-
-Local Open Scope nat_scope.
-
-Section ModularWordEncoding.
- Context {m : Z} {sz : nat} {m_pos : (0 < m)%Z} {bound_check : Z.to_nat m < 2 ^ sz}.
-
- Definition Fm_enc (x : F m) : word sz := natToWord sz (Z.to_nat (FieldToZ x)).
-
- Definition Fm_dec (x_ : word sz) : option (F m) :=
- let z := Z.of_nat (wordToNat (x_)) in
- if Z_lt_dec z m
- then Some (ZToField z)
- else None
- .
-
- Lemma bound_check_N : forall x : F m, (N.of_nat (Z.to_nat x) < Npow2 sz)%N.
- Proof.
- intro.
- pose proof (FieldToZ_range x m_pos) as x_range.
- rewrite <- Nnat.N2Nat.id.
- rewrite Npow2_nat.
- apply (Nat2N_inj_lt (Z.to_nat x) (pow2 sz)).
- rewrite Zpow_pow2.
- destruct x_range as [x_low x_high].
- apply Z2Nat.inj_lt in x_high; try omega.
- rewrite <- ZUtil.pow_Z2N_Zpow by omega.
- replace (Z.to_nat 2) with 2%nat by auto.
- omega.
- Qed.
-
- Lemma Fm_encoding_valid : forall x, Fm_dec (Fm_enc x) = Some x.
- Proof.
- unfold Fm_dec, Fm_enc; intros.
- pose proof (FieldToZ_range x m_pos).
- rewrite wordToNat_natToWord_idempotent by apply bound_check_N.
- rewrite Z2Nat.id by omega.
- rewrite ZToField_idempotent.
- break_if; auto; omega.
- Qed.
-
- Instance modular_word_encoding : encoding of F m as word sz := {
- enc := Fm_enc;
- dec := Fm_dec;
- encoding_valid := Fm_encoding_valid
- }.
-End ModularWordEncoding.
+Notation "'canonical' 'encoding' 'of' T 'as' B" := (CanonicalEncoding T B) (at level 50). \ No newline at end of file
diff --git a/src/Spec/ModularWordEncoding.v b/src/Spec/ModularWordEncoding.v
new file mode 100644
index 000000000..d6f6bcb3c
--- /dev/null
+++ b/src/Spec/ModularWordEncoding.v
@@ -0,0 +1,40 @@
+Require Import Coq.ZArith.ZArith.
+Require Import Coq.Numbers.Natural.Peano.NPeano.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import Bedrock.Word.
+Require Import Crypto.Tactics.VerdiTactics.
+Require Import Crypto.Util.NatUtil.
+Require Import Crypto.Util.WordUtil.
+Require Import Crypto.Spec.Encoding.
+Require Crypto.Encoding.ModularWordEncodingPre.
+
+Local Open Scope nat_scope.
+
+Section ModularWordEncoding.
+ Context {m : Z} {sz : nat} {m_pos : (0 < m)%Z} {bound_check : Z.to_nat m < 2 ^ sz}.
+
+ Definition Fm_enc (x : F m) : word sz := NToWord sz (Z.to_N (FieldToZ x)).
+
+ Definition Fm_dec (x_ : word sz) : option (F m) :=
+ let z := Z.of_N (wordToN (x_)) in
+ if Z_lt_dec z m
+ then Some (ZToField z)
+ else None
+ .
+
+ Definition sign_bit (x : F m) :=
+ match (Fm_enc x) with
+ | Word.WO => false
+ | Word.WS b _ w' => b
+ end.
+
+ Instance modular_word_encoding : canonical encoding of F m as word sz := {
+ enc := Fm_enc;
+ dec := Fm_dec;
+ encoding_valid :=
+ @ModularWordEncodingPre.Fm_encoding_valid m sz m_pos bound_check;
+ encoding_canonical :=
+ @ModularWordEncodingPre.Fm_encoding_canonical m sz bound_check
+ }.
+
+End ModularWordEncoding.
diff --git a/src/Spec/PointEncoding.v b/src/Spec/PointEncoding.v
index 4823ef28f..f4634f52f 100644
--- a/src/Spec/PointEncoding.v
+++ b/src/Spec/PointEncoding.v
@@ -1,175 +1,47 @@
-Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory.
-Require Import Coq.Numbers.Natural.Peano.NPeano Coq.NArith.NArith.
-Require Import Crypto.Spec.Encoding Crypto.Encoding.EncodingTheorems.
-Require Import Crypto.Spec.CompleteEdwardsCurve Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems.
-Require Import Crypto.ModularArithmetic.PrimeFieldTheorems Crypto.ModularArithmetic.ModularArithmeticTheorems.
-Require Import Crypto.Util.NatUtil Crypto.Util.ZUtil Crypto.Util.NumTheoryUtil.
-Require Import Bedrock.Word.
-Require Import Crypto.Tactics.VerdiTactics.
+Require Coq.ZArith.ZArith Coq.ZArith.Znumtheory.
+Require Coq.Numbers.Natural.Peano.NPeano.
+Require Crypto.Encoding.EncodingTheorems.
+Require Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems.
+Require Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Bedrock.Word.
+Require Crypto.Tactics.VerdiTactics.
+Require Crypto.Encoding.PointEncodingPre.
+Obligation Tactic := eauto; exact PointEncodingPre.point_encoding_canonical.
+
+Require Import Crypto.Spec.Encoding Crypto.Spec.ModularWordEncoding.
+Require Import Crypto.Spec.CompleteEdwardsCurve Crypto.Spec.ModularArithmetic.
Local Open Scope F_scope.
Section PointEncoding.
- Context {prm:TwistedEdwardsParams} {sz : nat} {FqEncoding : encoding of F q as word sz} {q_5mod8 : q mod 8 = 5} {sqrt_minus1_valid : (@ZToField q 2 ^ Z.to_N (q / 4)) ^ 2 = opp 1}.
+ Context {prm: TwistedEdwardsParams} {sz : nat} {sz_nonzero : (0 < sz)%nat}
+ {bound_check : (BinInt.Z.to_nat q < NPeano.Nat.pow 2 sz)%nat} {q_5mod8 : (q mod 8 = 5)%Z}
+ {sqrt_minus1_valid : (@ZToField q 2 ^ BinInt.Z.to_N (q / 4)) ^ 2 = opp 1}
+ {FqEncoding : canonical encoding of (F q) as (Word.word sz)}
+ {sign_bit : F q -> bool} {sign_bit_zero : sign_bit 0 = false}
+ {sign_bit_opp : forall x, x <> 0 -> negb (sign_bit x) = sign_bit (opp x)}.
Existing Instance prime_q.
- Add Field Ffield : (@Ffield_theory q _)
- (morphism (@Fring_morph q),
- preprocess [Fpreprocess],
- postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption],
- constants [Fconstant],
- div (@Fmorph_div_theory q),
- power_tac (@Fpower_theory q) [Fexp_tac]).
+ Definition point_enc (p : E.point) : Word.word (S sz) := let '(x,y) := proj1_sig p in
+ Word.WS (sign_bit x) (enc y).
- Definition sqrt_valid (a : F q) := ((sqrt_mod_q a) ^ 2 = a)%F.
+ Program Definition point_dec_with_spec :
+ {point_dec : Word.word (S sz) -> option E.point
+ | forall w x, point_dec w = Some x -> (point_enc x = w)
+ } := @PointEncodingPre.point_dec _ _ _ sign_bit.
- Lemma solve_sqrt_valid : forall (p : point),
- sqrt_valid (solve_for_x2 (snd (proj1_sig p))).
- Proof.
- intros.
- destruct p as [[x y] onCurve_xy]; simpl.
- rewrite (solve_correct x y) in onCurve_xy.
- rewrite <- onCurve_xy.
- unfold sqrt_valid.
- eapply sqrt_mod_q_valid; eauto.
- unfold isSquare; eauto.
- Grab Existential Variables. eauto.
- Qed.
+ Definition point_dec := Eval hnf in (proj1_sig point_dec_with_spec).
- Lemma solve_onCurve: forall (y : F q), sqrt_valid (solve_for_x2 y) ->
- onCurve (sqrt_mod_q (solve_for_x2 y), y).
- Proof.
- intros.
- unfold sqrt_valid in *.
- apply solve_correct; auto.
- Qed.
+ Definition point_encoding_valid : forall p : E.point, point_dec (point_enc p) = Some p :=
+ @PointEncodingPre.point_encoding_valid _ _ q_5mod8 sqrt_minus1_valid _ _ sign_bit_zero sign_bit_opp.
- Lemma solve_opp_onCurve: forall (y : F q), sqrt_valid (solve_for_x2 y) ->
- onCurve (opp (sqrt_mod_q (solve_for_x2 y)), y).
- Proof.
- intros y sqrt_valid_x2.
- unfold sqrt_valid in *.
- apply solve_correct.
- rewrite <- sqrt_valid_x2 at 2.
- ring.
- Qed.
+ Definition point_encoding_canonical : forall x_enc x, point_dec x_enc = Some x -> point_enc x = x_enc :=
+ PointEncodingPre.point_encoding_canonical.
-Definition sign_bit (x : F q) := (wordToN (enc (opp x)) <? wordToN (enc x))%N.
-Definition point_enc (p : point) : word (S sz) := let '(x,y) := proj1_sig p in
- WS (sign_bit x) (enc y).
-Definition point_dec (w : word (S sz)) : option point :=
- match dec (wtl w) with
- | None => None
- | Some y => let x2 := solve_for_x2 y in
- let x := sqrt_mod_q x2 in
- match (F_eq_dec (x ^ 2) x2) with
- | right _ => None
- | left EQ => if Bool.eqb (whd w) (sign_bit x)
- then Some (mkPoint (x, y) (solve_onCurve y EQ))
- else Some (mkPoint (opp x, y) (solve_opp_onCurve y EQ))
- end
- end.
-
-Lemma y_decode : forall p, dec (wtl (point_enc p)) = Some (snd (proj1_sig p)).
-Proof.
- intros.
- destruct p as [[x y] onCurve_p]; simpl.
- exact (encoding_valid y).
-Qed.
-
-
-Lemma wordToN_enc_neq_opp : forall x, x <> 0 -> (wordToN (enc (opp x)) <> wordToN (enc x))%N.
-Proof.
- intros x x_nonzero.
- intro false_eq.
- apply x_nonzero.
- apply F_eq_opp_zero; try apply two_lt_q.
- apply wordToN_inj in false_eq.
- apply encoding_inj in false_eq.
- auto.
-Qed.
-
-Lemma sign_bit_opp_negb : forall x, x <> 0 -> negb (sign_bit x) = sign_bit (opp x).
-Proof.
- intros x x_nonzero.
- unfold sign_bit.
- rewrite <- N.leb_antisym.
- rewrite N.ltb_compare, N.leb_compare.
- rewrite F_opp_involutive.
- case_eq (wordToN (enc x) ?= wordToN (enc (opp x)))%N; auto.
- intro wordToN_enc_eq.
- pose proof (wordToN_enc_neq_opp x x_nonzero).
- apply N.compare_eq_iff in wordToN_enc_eq.
- congruence.
-Qed.
-
-Lemma sign_bit_opp : forall x y, y <> 0 ->
- (sign_bit x <> sign_bit y <-> sign_bit x = sign_bit (opp y)).
-Proof.
- split; intro sign_mismatch; case_eq (sign_bit x); case_eq (sign_bit y);
- try congruence; intros y_sign x_sign; rewrite <- sign_bit_opp_negb in * by auto;
- rewrite y_sign, x_sign in *; reflexivity || discriminate.
-Qed.
-
-Lemma sign_bit_squares : forall x y, y <> 0 -> x ^ 2 = y ^ 2 ->
- sign_bit x = sign_bit y -> x = y.
-Proof.
- intros ? ? y_nonzero squares_eq sign_match.
- destruct (sqrt_solutions _ _ squares_eq) as [? | eq_opp]; auto.
- assert (sign_bit x = sign_bit (opp y)) as sign_mismatch by (f_equal; auto).
- apply sign_bit_opp in sign_mismatch; auto.
- congruence.
-Qed.
-
-Lemma sign_bit_match : forall x x' y : F q, onCurve (x, y) -> onCurve (x', y) ->
- sign_bit x = sign_bit x' -> x = x'.
-Proof.
- intros ? ? ? onCurve_x onCurve_x' sign_match.
- apply solve_correct in onCurve_x.
- apply solve_correct in onCurve_x'.
- destruct (F_eq_dec x' 0).
- + subst.
- rewrite Fq_pow_zero in onCurve_x' by congruence.
- rewrite <- onCurve_x' in *.
- eapply Fq_root_zero; eauto.
- + apply sign_bit_squares; auto.
- rewrite onCurve_x, onCurve_x'.
- reflexivity.
-Qed.
-
-Lemma point_encoding_valid : forall p, point_dec (point_enc p) = Some p.
-Proof.
- intros.
- unfold point_dec.
- rewrite y_decode.
- pose proof solve_sqrt_valid p as solve_sqrt_valid_p.
- unfold sqrt_valid in *.
- destruct p as [[x y] onCurve_p].
- simpl in *.
- destruct (F_eq_dec ((sqrt_mod_q (solve_for_x2 y)) ^ 2) (solve_for_x2 y)); intuition.
- break_if; f_equal; apply point_eq.
- + rewrite Bool.eqb_true_iff in Heqb.
- pose proof (solve_onCurve y solve_sqrt_valid_p).
- f_equal.
- apply (sign_bit_match _ _ y); auto.
- + rewrite Bool.eqb_false_iff in Heqb.
- pose proof (solve_opp_onCurve y solve_sqrt_valid_p).
- f_equal.
- apply sign_bit_opp in Heqb.
- apply (sign_bit_match _ _ y); auto.
- intro eq_zero.
- apply solve_correct in onCurve_p.
- rewrite eq_zero in *.
- rewrite Fq_pow_zero in solve_sqrt_valid_p by congruence.
- rewrite <- solve_sqrt_valid_p in onCurve_p.
- apply Fq_root_zero in onCurve_p.
- rewrite onCurve_p in Heqb; auto.
-Qed.
-
-Instance point_encoding : encoding of point as (word (S sz)) := {
- enc := point_enc;
- dec := point_dec;
- encoding_valid := point_encoding_valid
-}.
-
-End PointEncoding.
+ Instance point_encoding : canonical encoding of E.point as (Word.word (S sz)) := {
+ enc := point_enc;
+ dec := point_dec;
+ encoding_valid := point_encoding_valid;
+ encoding_canonical := point_encoding_canonical
+ }.
+End PointEncoding. \ No newline at end of file
diff --git a/src/Specific/Ed25519.v b/src/Specific/Ed25519.v
index a705ceb90..3b90b5cdf 100644
--- a/src/Specific/Ed25519.v
+++ b/src/Specific/Ed25519.v
@@ -2,29 +2,64 @@ Require Import Bedrock.Word.
Require Import Crypto.Spec.Ed25519.
Require Import Crypto.Tactics.VerdiTactics.
Require Import BinNat BinInt NArith Crypto.Spec.ModularArithmetic.
+Require Import ModularArithmetic.ModularArithmeticTheorems.
+Require Import ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.Spec.CompleteEdwardsCurve.
-Require Import Crypto.Spec.Encoding Crypto.Spec.PointEncoding.
+Require Import Crypto.Encoding.PointEncodingPre.
+Require Import Crypto.Spec.Encoding Crypto.Spec.ModularWordEncoding Crypto.Spec.PointEncoding.
Require Import Crypto.CompleteEdwardsCurve.ExtendedCoordinates.
-Require Import Crypto.Util.IterAssocOp.
+Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems.
+Require Import Crypto.Util.IterAssocOp Crypto.Util.WordUtil Crypto.Rep.
Local Infix "++" := Word.combine.
Local Notation " a '[:' i ']' " := (Word.split1 i _ a) (at level 40).
Local Notation " a '[' i ':]' " := (Word.split2 i _ a) (at level 40).
Local Arguments H {_} _.
-Local Arguments scalarMultM1 {_} {_} _ _.
+Local Arguments scalarMultM1 {_} {_} _ _ _.
Local Arguments unifiedAddM1 {_} {_} _ _.
-Ltac set_evars :=
+Local Ltac set_evars :=
repeat match goal with
| [ |- appcontext[?E] ] => is_evar E; let e := fresh "e" in set (e := E)
end.
-Ltac subst_evars :=
+Local Ltac subst_evars :=
repeat match goal with
| [ e := ?E |- _ ] => is_evar E; subst e
end.
-Axiom point_eqb : forall {prm : TwistedEdwardsParams}, point -> point -> bool.
-Axiom point_eqb_correct : forall P Q, point_eqb P Q = if point_eq_dec P Q then true else false.
+Lemma funexp_proj {T T'} (proj : T -> T') (f : T -> T) (f' : T' -> T') x n
+ (f_proj : forall a, proj (f a) = f' (proj a))
+ : proj (funexp f x n) = funexp f' (proj x) n.
+Proof.
+ revert x; induction n as [|n IHn]; simpl; congruence.
+Qed.
+
+Lemma iter_op_proj {T T' S} (proj : T -> T') (op : T -> T -> T) (op' : T' -> T' -> T') x y z
+ (testbit : S -> nat -> bool) (bound : nat)
+ (op_proj : forall a b, proj (op a b) = op' (proj a) (proj b))
+ : proj (iter_op op x testbit y z bound) = iter_op op' (proj x) testbit y (proj z) bound.
+Proof.
+ unfold iter_op.
+ simpl.
+ lazymatch goal with
+ | [ |- ?proj (snd (funexp ?f ?x ?n)) = snd (funexp ?f' _ ?n) ]
+ => pose proof (fun x0 x1 => funexp_proj (fun x => (fst x, proj (snd x))) f f' (x0, x1)) as H'
+ end.
+ simpl in H'.
+ rewrite <- H'.
+ { reflexivity. }
+ { intros [??]; simpl.
+ repeat match goal with
+ | [ |- context[match ?n with _ => _ end] ]
+ => destruct n eqn:?
+ | _ => progress simpl
+ | _ => progress subst
+ | _ => reflexivity
+ | _ => rewrite op_proj
+ end. }
+Qed.
+
+Lemma B_proj : proj1_sig B = (fst(proj1_sig B), snd(proj1_sig B)). destruct B as [[]]; reflexivity. Qed.
Require Import Coq.Setoids.Setoid.
Require Import Coq.Classes.Morphisms.
@@ -52,21 +87,43 @@ Axiom decode_scalar : word b -> option N.
Local Existing Instance Ed25519.FlEncoding.
Axiom decode_scalar_correct : forall x, decode_scalar x = option_map (fun x : F (Z.of_nat Ed25519.l) => Z.to_N x) (dec x).
-Local Infix "==?" := point_eqb (at level 70) : E_scope.
+Local Infix "==?" := E.point_eqb (at level 70) : E_scope.
+Local Infix "==?" := ModularArithmeticTheorems.F_eq_dec (at level 70) : F_scope.
-Axiom negate : point -> point.
-Definition point_sub P Q := (P + negate Q)%E.
-Infix "-" := point_sub : E_scope.
-Axiom solve_for_R : forall A B C, (A ==? B + C)%E = (B ==? A - C)%E.
+Lemma solve_for_R_eq : forall A B C, (A = B + C <-> B = A - C)%E.
+Proof.
+ intros; split; intros; subst; unfold E.sub;
+ rewrite <-E.add_assoc, ?E.add_opp_r, ?E.add_opp_l, E.add_0_r; reflexivity.
+Qed.
-Axiom negateExtended : extendedPoint -> extendedPoint.
-Axiom negateExtended_correct : forall P, negate (unExtendedPoint P) = unExtendedPoint (negateExtended P).
+Lemma solve_for_R : forall A B C, (A ==? B + C)%E = (B ==? A - C)%E.
+Proof.
+ intros.
+ repeat match goal with |- context [(?P ==? ?Q)%E] =>
+ let H := fresh "H" in
+ destruct (E.point_eq_dec P Q) as [H|H];
+ (rewrite (E.point_eqb_complete _ _ H) || rewrite (E.point_eqb_neq_complete _ _ H))
+ end; rewrite solve_for_R_eq in H; congruence.
+Qed.
+
+Local Notation "'(' X ',' Y ',' Z ',' T ')'" := (mkExtended X Y Z T).
+Local Notation "2" := (ZToField 2) : F_scope.
Local Existing Instance PointEncoding.
-Axiom decode_point_eq : forall (P_ Q_ : word (S (b-1))) (P Q:point), dec P_ = Some P -> dec Q_ = Some Q -> weqb P_ Q_ = (P ==? Q)%E.
+Lemma decode_point_eq : forall (P_ Q_ : word (S (b-1))) (P Q:E.point),
+ dec P_ = Some P ->
+ dec Q_ = Some Q ->
+ weqb P_ Q_ = (P ==? Q)%E.
+Proof.
+ intros.
+ replace P_ with (enc P) in * by (auto using encoding_canonical).
+ replace Q_ with (enc Q) in * by (auto using encoding_canonical).
+ rewrite E.point_eqb_correct.
+ edestruct E.point_eq_dec; (apply weqb_true_iff || apply weqb_false_iff); congruence.
+Qed.
-Lemma decode_test_encode_test : forall S_ X, option_rect (fun _ : option point => bool)
- (fun S : point => (S ==? X)%E) false (dec S_) = weqb S_ (enc X).
+Lemma decode_test_encode_test : forall S_ X, option_rect (fun _ : option E.point => bool)
+ (fun S : E.point => (S ==? X)%E) false (dec S_) = weqb S_ (enc X).
Proof.
intros.
destruct (dec S_) eqn:H.
@@ -76,80 +133,449 @@ Proof.
apply weqb_true_iff in Heqb. subst. rewrite encoding_valid in H; discriminate. }
Qed.
-Lemma sharper_verify : forall pk l msg sig, { verify | verify = ed25519_verify pk l msg sig}.
+Definition enc' : F q * F q -> word b.
Proof.
- eexists; intros.
- cbv [ed25519_verify EdDSA.verify
- ed25519params curve25519params
- EdDSA.E EdDSA.B EdDSA.b EdDSA.l EdDSA.H
- EdDSA.PointEncoding EdDSA.FlEncoding EdDSA.FqEncoding].
-
- etransitivity.
- Focus 2.
- { repeat match goal with
- | [ |- ?x = ?x ] => reflexivity
- | [ |- _ = match ?a with None => ?N | Some x => @?S x end :> ?T ]
- => etransitivity;
- [
- | refine (_ : option_rect (fun _ => T) _ N a = _);
- let S' := match goal with |- option_rect _ ?S' _ _ = _ => S' end in
- refine (option_rect (fun a' => option_rect (fun _ => T) S' N a' = match a' with None => N | Some x => S x end)
- (fun x => _) _ a); intros; simpl @option_rect ];
- [ reflexivity | .. ]
- end.
- set_evars.
- rewrite<- point_eqb_correct.
- rename x0 into R. rename x1 into S. rename x into A.
- rewrite solve_for_R.
- let p1 := constr:(scalarMultM1_rep eq_refl) in
- let p2 := constr:(unifiedAddM1_rep eq_refl) in
- repeat match goal with
- | |- context [(_ * ?P)%E] =>
- rewrite <-(unExtendedPoint_mkExtendedPoint P);
- rewrite <-p1
- | |- context [(?P + unExtendedPoint _)%E] =>
- rewrite <-(unExtendedPoint_mkExtendedPoint P);
- rewrite p2
- end;
- rewrite ?Znat.Z_nat_N, <-?Word.wordToN_nat;
- subst_evars;
- reflexivity.
- } Unfocus.
+ intro x.
+ let enc' := (eval hnf in (@enc (@E.point curve25519params) _ _)) in
+ match (eval cbv [proj1_sig] in (fun pf => enc' (exist _ x pf))) with
+ | (fun _ => ?enc') => exact enc'
+ end.
+Defined.
- etransitivity.
- Focus 2.
- { lazymatch goal with |- _ = option_rect _ _ ?false ?dec =>
- symmetry; etransitivity; [|eapply (option_rect_option_map (fun (x:F _) => Z.to_N x) _ false dec)]
- end.
- eapply option_rect_Proper_nd; [intro|reflexivity..].
+Definition enc'_correct : @enc (@E.point curve25519params) _ _ = (fun x => enc' (proj1_sig x))
+ := eq_refl.
+
+Definition Let_In {A P} (x : A) (f : forall a : A, P a) : P x := let y := x in f y.
+Global Instance Let_In_Proper_nd {A P}
+ : Proper (eq ==> pointwise_relation _ eq ==> eq) (@Let_In A (fun _ => P)).
+Proof.
+ lazy; intros; congruence.
+Qed.
+Lemma option_rect_function {A B C S' N' v} f
+ : f (option_rect (fun _ : option A => option B) S' N' v)
+ = option_rect (fun _ : option A => C) (fun x => f (S' x)) (f N') v.
+Proof. destruct v; reflexivity. Qed.
+Local Ltac commute_option_rect_Let_In := (* pull let binders out side of option_rect pattern matching *)
+ idtac;
+ lazymatch goal with
+ | [ |- ?LHS = option_rect ?P ?S ?N (Let_In ?x ?f) ]
+ => (* we want to just do a [change] here, but unification is stupid, so we have to tell it what to unfold in what order *)
+ cut (LHS = Let_In x (fun y => option_rect P S N (f y))); cbv beta;
+ [ set_evars;
+ let H := fresh in
+ intro H;
+ rewrite H;
+ clear;
+ abstract (cbv [Let_In]; reflexivity)
+ | ]
+ end.
+Local Ltac replace_let_in_with_Let_In :=
+ repeat match goal with
+ | [ |- context G[let x := ?y in @?z x] ]
+ => let G' := context G[Let_In y z] in change G'
+ | [ |- _ = Let_In _ _ ]
+ => apply Let_In_Proper_nd; [ reflexivity | cbv beta delta [pointwise_relation]; intro ]
+ end.
+Local Ltac simpl_option_rect := (* deal with [option_rect _ _ _ None] and [option_rect _ _ _ (Some _)] *)
+ repeat match goal with
+ | [ |- context[option_rect ?P ?S ?N None] ]
+ => change (option_rect P S N None) with N
+ | [ |- context[option_rect ?P ?S ?N (Some ?x) ] ]
+ => change (option_rect P S N (Some x)) with (S x); cbv beta
+ end.
+
+Section Ed25519Frep.
+ Generalizable All Variables.
+ Context `(rcS:RepConversions N SRep) (rcSOK:RepConversionsOK rcS).
+ Context `(rcF:RepConversions (F (Ed25519.q)) FRep) (rcFOK:RepConversionsOK rcF).
+ Context (FRepAdd FRepSub FRepMul:FRep->FRep->FRep) (FRepAdd_correct:RepBinOpOK rcF add FRepMul).
+ Context (FRepSub_correct:RepBinOpOK rcF sub FRepSub) (FRepMul_correct:RepBinOpOK rcF mul FRepMul).
+ Local Notation rep2F := (unRep : FRep -> F (Ed25519.q)).
+ Local Notation F2Rep := (toRep : F (Ed25519.q) -> FRep).
+ Local Notation rep2S := (unRep : SRep -> N).
+ Local Notation S2Rep := (toRep : N -> SRep).
+
+ Axiom FRepOpp : FRep -> FRep.
+ Axiom FRepOpp_correct : forall x, opp (rep2F x) = rep2F (FRepOpp x).
+
+ Axiom wltu : forall {b}, word b -> word b -> bool.
+ Axiom wltu_correct : forall {b} (x y:word b), wltu x y = (wordToN x <? wordToN y)%N.
+
+ Axiom compare_enc : forall x y, F_eqb x y = weqb (@enc _ _ FqEncoding x) (@enc _ _ FqEncoding y).
+
+ Axiom wire2FRep : word (b-1) -> option FRep.
+ Axiom wire2FRep_correct : forall x, Fm_dec x = option_map rep2F (wire2FRep x).
+
+ Axiom FRep2wire : FRep -> word (b-1).
+ Axiom FRep2wire_correct : forall x, FRep2wire x = @enc _ _ FqEncoding (rep2F x).
+
+ Axiom SRep_testbit : SRep -> nat -> bool.
+ Axiom SRep_testbit_correct : forall (x0 : SRep) (i : nat), SRep_testbit x0 i = N.testbit_nat (unRep x0) i.
+
+ Definition FSRepPow x n := iter_op FRepMul (toRep 1%F) SRep_testbit n x 255.
+ Lemma FSRepPow_correct : forall x n, (N.size_nat (unRep n) <= 255)%nat -> (unRep x ^ unRep n)%F = unRep (FSRepPow x n).
+ Proof. (* this proof derives the required formula, which I copy-pasted above to be able to reference it without the length precondition *)
+ unfold FSRepPow; intros.
+ erewrite <-pow_nat_iter_op_correct by auto.
+ erewrite <-(fun x => iter_op_spec (scalar := SRep) (mul (m:=Ed25519.q)) F_mul_assoc _ F_mul_1_l _ unRep SRep_testbit_correct n x 255%nat) by auto.
+ rewrite <-(rcFOK 1%F) at 1.
+ erewrite <-iter_op_proj by auto.
+ reflexivity.
+ Qed.
+
+ Definition FRepInv x : FRep := FSRepPow x (S2Rep (Z.to_N (Ed25519.q - 2))).
+ Lemma FRepInv_correct : forall x, inv (rep2F x)%F = rep2F (FRepInv x).
+ unfold FRepInv; intros.
+ rewrite <-FSRepPow_correct; rewrite rcSOK; try reflexivity.
+ pose proof @Fq_inv_fermat_correct as H; unfold inv_fermat in H; rewrite H by
+ auto using Ed25519.prime_q, Ed25519.two_lt_q.
+ reflexivity.
+ Qed.
+
+ Lemma unfoldDiv : forall {m} (x y:F m), (x/y = x * inv y)%F. Proof. unfold div. congruence. Qed.
+
+ Definition rep2E (r:FRep * FRep * FRep * FRep) : extended :=
+ match r with (((x, y), z), t) => mkExtended (rep2F x) (rep2F y) (rep2F z) (rep2F t) end.
+
+ Lemma if_map : forall {T U} (f:T->U) (b:bool) (x y:T), (if b then f x else f y) = f (if b then x else y).
+ Proof.
+ destruct b; trivial.
+ Qed.
+
+ Local Ltac Let_In_unRep :=
match goal with
- | [ |- ?RHS = ?e ?v ]
- => let RHS' := (match eval pattern v in RHS with ?RHS' _ => RHS' end) in
- unify e RHS'
+ | [ |- appcontext G[Let_In (unRep ?x) ?f] ]
+ => change (Let_In (unRep x) f) with (Let_In x (fun y => f (unRep y))); cbv beta
end.
+
+
+ (** TODO: Move me *)
+ Lemma pull_Let_In {B C} (f : B -> C) A (v : A) (b : A -> B)
+ : Let_In v (fun v' => f (b v')) = f (Let_In v b).
+ Proof.
reflexivity.
- } Unfocus.
- rewrite <-decode_scalar_correct.
-
- etransitivity.
- Focus 2.
- { do 2 (eapply option_rect_Proper_nd; [intro|reflexivity..]).
- symmetry; apply decode_test_encode_test.
- } Unfocus.
-
- etransitivity.
- Focus 2.
- { do 2 (eapply option_rect_Proper_nd; [intro|reflexivity..]).
- unfold point_sub. rewrite negateExtended_correct.
- let p := constr:(unifiedAddM1_rep eq_refl) in rewrite p.
+ Qed.
+
+ Lemma Let_app_In {A B T} (g:A->B) (f:B->T) (x:A) :
+ @Let_In _ (fun _ => T) (g x) f =
+ @Let_In _ (fun _ => T) x (fun p => f (g x)).
+ Proof. reflexivity. Qed.
+
+ Lemma Let_app2_In {A B C D T} (g1:A->C) (g2:B->D) (f:C*D->T) (x:A) (y:B) :
+ @Let_In _ (fun _ => T) (g1 x, g2 y) f =
+ @Let_In _ (fun _ => T) (x, y) (fun p => f ((g1 (fst p), g2 (snd p)))).
+ Proof. reflexivity. Qed.
+
+ Create HintDb FRepOperations discriminated.
+ Hint Rewrite FRepMul_correct FRepAdd_correct FRepSub_correct FRepInv_correct FSRepPow_correct FRepOpp_correct : FRepOperations.
+
+ Create HintDb EdDSA_opts discriminated.
+ Hint Rewrite FRepMul_correct FRepAdd_correct FRepSub_correct FRepInv_correct FSRepPow_correct FRepOpp_correct : EdDSA_opts.
+
+ Lemma unifiedAddM1Rep_sig : forall a b : FRep * FRep * FRep * FRep, { unifiedAddM1Rep | rep2E unifiedAddM1Rep = unifiedAddM1' (rep2E a) (rep2E b) }.
+ Proof.
+ destruct a as [[[]]]; destruct b as [[[]]].
+ eexists.
+ lazymatch goal with |- ?LHS = ?RHS :> ?T =>
+ evar (e:T); replace LHS with e; [subst e|]
+ end.
+ unfold rep2E. cbv beta delta [unifiedAddM1'].
+ pose proof (rcFOK twice_d) as H; rewrite <-H; clear H. (* XXX: this is a hack -- rewrite misresolves typeclasses? *)
+
+ { etransitivity; [|replace_let_in_with_Let_In; reflexivity].
+ repeat (
+ autorewrite with FRepOperations;
+ Let_In_unRep;
+ eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [Proper respectful pointwise_relation]; intro]).
+ lazymatch goal with |- ?LHS = (unRep ?x, unRep ?y, unRep ?z, unRep ?t) =>
+ change (LHS = (rep2E (((x, y), z), t)))
+ end.
+ reflexivity. }
+
+ subst e.
+ Local Opaque Let_In.
+ repeat setoid_rewrite (pull_Let_In rep2E).
+ Local Transparent Let_In.
reflexivity.
- } Unfocus.
+ Defined.
+
+ Definition unifiedAddM1Rep (a b:FRep * FRep * FRep * FRep) : FRep * FRep * FRep * FRep := Eval hnf in proj1_sig (unifiedAddM1Rep_sig a b).
+ Definition unifiedAddM1Rep_correct a b : rep2E (unifiedAddM1Rep a b) = unifiedAddM1' (rep2E a) (rep2E b) := Eval hnf in proj2_sig (unifiedAddM1Rep_sig a b).
+
+ Definition rep2T (P:FRep * FRep) := (rep2F (fst P), rep2F (snd P)).
+ Definition erep2trep (P:FRep * FRep * FRep * FRep) := Let_In P (fun P => Let_In (FRepInv (snd (fst P))) (fun iZ => (FRepMul (fst (fst (fst P))) iZ, FRepMul (snd (fst (fst P))) iZ))).
+ Lemma erep2trep_correct : forall P, rep2T (erep2trep P) = extendedToTwisted (rep2E P).
+ Proof.
+ unfold rep2T, rep2E, erep2trep, extendedToTwisted; destruct P as [[[]]]; simpl.
+ rewrite !unfoldDiv, <-!FRepMul_correct, <-FRepInv_correct. reflexivity.
+ Qed.
+
+ (** TODO: possibly move me, remove local *)
+ Local Ltac replace_option_match_with_option_rect :=
+ idtac;
+ lazymatch goal with
+ | [ |- _ = ?RHS :> ?T ]
+ => lazymatch RHS with
+ | match ?a with None => ?N | Some x => @?S x end
+ => replace RHS with (option_rect (fun _ => T) S N a) by (destruct a; reflexivity)
+ end
+ end.
+
+ (** TODO: Move me, remove Local *)
+ Definition proj1_sig_unmatched {A P} := @proj1_sig A P.
+ Definition proj1_sig_nounfold {A P} := @proj1_sig A P.
+ Definition proj1_sig_unfold {A P} := Eval cbv [proj1_sig] in @proj1_sig A P.
+ Local Ltac unfold_proj1_sig_exist :=
+ (** Change the first [proj1_sig] into [proj1_sig_unmatched]; if it's applied to [exist], mark it as unfoldable, otherwise mark it as not unfoldable. Then repeat. Finally, unfold. *)
+ repeat (change @proj1_sig with @proj1_sig_unmatched at 1;
+ match goal with
+ | [ |- context[proj1_sig_unmatched (exist _ _ _)] ]
+ => change @proj1_sig_unmatched with @proj1_sig_unfold
+ | _ => change @proj1_sig_unmatched with @proj1_sig_nounfold
+ end);
+ (* [proj1_sig_nounfold] is a thin wrapper around [proj1_sig]; unfolding it restores [proj1_sig]. Unfolding [proj1_sig_nounfold] exposes the pattern match, which is reduced by ι. *)
+ cbv [proj1_sig_nounfold proj1_sig_unfold].
+
+ (** TODO: possibly move me, remove Local *)
+ Local Ltac reflexivity_when_unification_is_stupid_about_evars
+ := repeat first [ reflexivity
+ | apply f_equal ].
+
+
+ Local Existing Instance eq_Reflexive. (* To get some of the [setoid_rewrite]s below to work, we need to infer [Reflexive eq] before [Reflexive Equivalence.equiv] *)
+
+ (* TODO: move me *)
+ Lemma fold_rep2E x y z t
+ : (rep2F x, rep2F y, rep2F z, rep2F t) = rep2E (((x, y), z), t).
+ Proof. reflexivity. Qed.
+ Lemma commute_negateExtended'_rep2E x y z t
+ : negateExtended' (rep2E (((x, y), z), t))
+ = rep2E (((FRepOpp x, y), z), FRepOpp t).
+ Proof. simpl; autorewrite with FRepOperations; reflexivity. Qed.
+ Lemma fold_rep2E_ffff x y z t
+ : (x, y, z, t) = rep2E (((toRep x, toRep y), toRep z), toRep t).
+ Proof. simpl; rewrite !rcFOK; reflexivity. Qed.
+ Lemma fold_rep2E_rrfr x y z t
+ : (rep2F x, rep2F y, z, rep2F t) = rep2E (((x, y), toRep z), t).
+ Proof. simpl; rewrite !rcFOK; reflexivity. Qed.
+ Lemma fold_rep2E_0fff y z t
+ : (0%F, y, z, t) = rep2E (((toRep 0%F, toRep y), toRep z), toRep t).
+ Proof. apply fold_rep2E_ffff. Qed.
+ Lemma fold_rep2E_ff1f x y t
+ : (x, y, 1%F, t) = rep2E (((toRep x, toRep y), toRep 1%F), toRep t).
+ Proof. apply fold_rep2E_ffff. Qed.
+ Lemma commute_negateExtended'_rep2E_rrfr x y z t
+ : negateExtended' (unRep x, unRep y, z, unRep t)
+ = rep2E (((FRepOpp x, y), toRep z), FRepOpp t).
+ Proof. rewrite <- commute_negateExtended'_rep2E; simpl; rewrite !rcFOK; reflexivity. Qed.
+
+ Hint Rewrite @F_mul_0_l commute_negateExtended'_rep2E_rrfr fold_rep2E_0fff (@fold_rep2E_ff1f (fst (proj1_sig B))) @if_F_eq_dec_if_F_eqb compare_enc (if_map unRep) (fun T => Let_app2_In (T := T) unRep unRep) @F_pow_2_r @unfoldDiv : EdDSA_opts.
+ Hint Rewrite <- unifiedAddM1Rep_correct erep2trep_correct (fun x y z bound => iter_op_proj rep2E unifiedAddM1Rep unifiedAddM1' x y z N.testbit_nat bound unifiedAddM1Rep_correct) FRep2wire_correct: EdDSA_opts.
+
+ Lemma sharper_verify : forall pk l msg sig, { verify | verify = ed25519_verify pk l msg sig}.
+ Proof.
+ eexists; intros.
+ cbv [ed25519_verify EdDSA.verify
+ ed25519params curve25519params
+ EdDSA.E EdDSA.B EdDSA.b EdDSA.l EdDSA.H
+ EdDSA.PointEncoding EdDSA.FlEncoding EdDSA.FqEncoding].
+
+ etransitivity.
+ Focus 2.
+ { repeat match goal with
+ | [ |- ?x = ?x ] => reflexivity
+ | _ => replace_option_match_with_option_rect
+ | [ |- _ = option_rect _ _ _ _ ]
+ => eapply option_rect_Proper_nd; [ intro | reflexivity.. ]
+ end.
+ set_evars.
+ rewrite<- E.point_eqb_correct.
+ rewrite solve_for_R; unfold E.sub.
+ rewrite E.opp_mul.
+ let p1 := constr:(scalarMultM1_rep eq_refl) in
+ let p2 := constr:(unifiedAddM1_rep eq_refl) in
+ repeat match goal with
+ | |- context [(_ * E.opp ?P)%E] =>
+ rewrite <-(unExtendedPoint_mkExtendedPoint P);
+ rewrite negateExtended_correct;
+ rewrite <-p1
+ | |- context [(_ * ?P)%E] =>
+ rewrite <-(unExtendedPoint_mkExtendedPoint P);
+ rewrite <-p1
+ | _ => rewrite p2
+ end;
+ rewrite ?Znat.Z_nat_N, <-?Word.wordToN_nat;
+ subst_evars;
+ reflexivity.
+ } Unfocus.
+
+ etransitivity.
+ Focus 2.
+ { lazymatch goal with |- _ = option_rect _ _ ?false ?dec =>
+ symmetry; etransitivity; [|eapply (option_rect_option_map (fun (x:F _) => Z.to_N x) _ false dec)]
+ end.
+ eapply option_rect_Proper_nd; [intro|reflexivity..].
+ match goal with
+ | [ |- ?RHS = ?e ?v ]
+ => let RHS' := (match eval pattern v in RHS with ?RHS' _ => RHS' end) in
+ unify e RHS'
+ end.
+ reflexivity.
+ } Unfocus.
+ rewrite <-decode_scalar_correct.
+
+ etransitivity.
+ Focus 2.
+ { do 2 (eapply option_rect_Proper_nd; [intro|reflexivity..]).
+ symmetry; apply decode_test_encode_test.
+ } Unfocus.
+
+ rewrite enc'_correct.
+ cbv [unExtendedPoint unifiedAddM1 negateExtended scalarMultM1].
+ unfold_proj1_sig_exist.
+
+ etransitivity.
+ Focus 2.
+ { do 2 (eapply option_rect_Proper_nd; [intro|reflexivity..]).
+ set_evars.
+ repeat match goal with
+ | [ |- appcontext[@proj1_sig ?A ?P (@iter_op ?T ?f ?neutral ?T' ?testbit ?exp ?base ?bound)] ]
+ => erewrite (@iter_op_proj T _ _ (@proj1_sig _ _)) by reflexivity
+ end.
+ subst_evars.
+ reflexivity. }
+ Unfocus.
+
+ cbv [mkExtendedPoint E.zero].
+ unfold_proj1_sig_exist.
+ rewrite B_proj.
+
+ etransitivity.
+ Focus 2.
+ { do 1 (eapply option_rect_Proper_nd; [intro|reflexivity..]).
+ set_evars.
+ lazymatch goal with |- _ = option_rect _ _ ?false ?dec =>
+ symmetry; etransitivity; [|eapply (option_rect_option_map (@proj1_sig _ _) _ false dec)]
+ end.
+ eapply option_rect_Proper_nd; [intro|reflexivity..].
+ match goal with
+ | [ |- ?RHS = ?e ?v ]
+ => let RHS' := (match eval pattern v in RHS with ?RHS' _ => RHS' end) in
+ unify e RHS'
+ end.
+ reflexivity.
+ } Unfocus.
+
+ cbv [dec PointEncoding point_encoding].
+ etransitivity.
+ Focus 2.
+ { do 1 (eapply option_rect_Proper_nd; [intro|reflexivity..]).
+ etransitivity.
+ Focus 2.
+ { apply f_equal.
+ symmetry.
+ apply point_dec_coordinates_correct. }
+ Unfocus.
+ reflexivity. }
+ Unfocus.
+
+ cbv iota beta delta [point_dec_coordinates sign_bit dec FqEncoding modular_word_encoding E.solve_for_x2 sqrt_mod_q].
- cbv [scalarMultM1 iter_op].
- cbv iota zeta delta [test_and_op].
- Local Arguments funexp {_} _ {_} {_}. (* do not display the initializer and iteration bound for now *)
+ etransitivity.
+ Focus 2. {
+ do 1 (eapply option_rect_Proper_nd; [|reflexivity..]). cbv beta delta [pointwise_relation]. intro.
+ etransitivity.
+ Focus 2.
+ { apply f_equal.
+ lazymatch goal with
+ | [ |- _ = ?term :> ?T ]
+ => lazymatch term with (match ?a with None => ?N | Some x => @?S x end)
+ => let term' := constr:((option_rect (fun _ => T) S N) a) in
+ replace term with term' by reflexivity
+ end
+ end.
+ reflexivity. } Unfocus. reflexivity. } Unfocus.
+
+ etransitivity.
+ Focus 2. {
+ do 1 (eapply option_rect_Proper_nd; [cbv beta delta [pointwise_relation]; intro|reflexivity..]).
+ do 1 (eapply option_rect_Proper_nd; [ intro; reflexivity | reflexivity | ]).
+ eapply option_rect_Proper_nd; [ cbv beta delta [pointwise_relation]; intro | reflexivity.. ].
+ replace_let_in_with_Let_In.
+ reflexivity.
+ } Unfocus.
- Axiom rep : forall {m}, list Z -> F m -> Prop.
- Axiom decode_point_limbs : word (S (b-1)) -> option (list Z * list Z).
- Axiom point_dec_rep : forall P_ P lx ly, dec P_ = Some P -> decode_point_limbs P_ = Some (lx, ly) -> rep lx (fst (proj1_sig P)) /\ rep ly (fst (proj1_sig P)).
-Admitted. \ No newline at end of file
+ etransitivity.
+ Focus 2. {
+ do 1 (eapply option_rect_Proper_nd; [cbv beta delta [pointwise_relation]; intro|reflexivity..]).
+ set_evars.
+ rewrite option_rect_function. (* turn the two option_rects into one *)
+ subst_evars.
+ simpl_option_rect.
+ do 1 (eapply option_rect_Proper_nd; [cbv beta delta [pointwise_relation]; intro|reflexivity..]).
+ (* push the [option_rect] inside until it hits a [Some] or a [None] *)
+ repeat match goal with
+ | _ => commute_option_rect_Let_In
+ | [ |- _ = Let_In _ _ ]
+ => apply Let_In_Proper_nd; [ reflexivity | cbv beta delta [pointwise_relation]; intro ]
+ | [ |- ?LHS = option_rect ?P ?S ?N (if ?b then ?t else ?f) ]
+ => transitivity (if b then option_rect P S N t else option_rect P S N f);
+ [
+ | destruct b; reflexivity ]
+ | [ |- _ = if ?b then ?t else ?f ]
+ => apply (f_equal2 (fun x y => if b then x else y))
+ | [ |- _ = false ] => reflexivity
+ | _ => progress simpl_option_rect
+ end.
+ reflexivity.
+ } Unfocus.
+
+ cbv iota beta delta [q d a].
+
+ rewrite wire2FRep_correct.
+
+ etransitivity.
+ Focus 2. {
+ eapply option_rect_Proper_nd; [|reflexivity..]. cbv beta delta [pointwise_relation]. intro.
+ rewrite <-!(option_rect_option_map rep2F).
+ eapply option_rect_Proper_nd; [|reflexivity..]. cbv beta delta [pointwise_relation]. intro.
+ autorewrite with EdDSA_opts.
+ rewrite <-(rcFOK 1%F).
+ pattern Ed25519.d at 1. rewrite <-(rcFOK Ed25519.d) at 1.
+ pattern Ed25519.a at 1. rewrite <-(rcFOK Ed25519.a) at 1.
+ rewrite <- (rcSOK (Z.to_N (Ed25519.q / 8 + 1))).
+ autorewrite with EdDSA_opts.
+ (Let_In_unRep).
+ eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [pointwise_relation]; intro].
+ etransitivity. Focus 2. eapply Let_In_Proper_nd; [|cbv beta delta [pointwise_relation]; intro;reflexivity]. {
+ rewrite FSRepPow_correct by (rewrite rcSOK; cbv; omega).
+ (Let_In_unRep).
+ etransitivity. Focus 2. eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [pointwise_relation]; intro]. {
+ set_evars.
+ rewrite <-(rcFOK sqrt_minus1).
+ autorewrite with EdDSA_opts.
+ subst_evars.
+ reflexivity. } Unfocus.
+ rewrite pull_Let_In.
+ reflexivity. } Unfocus.
+ set_evars.
+ (Let_In_unRep).
+
+ subst_evars. eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [pointwise_relation]; intro]. set_evars.
+
+ autorewrite with EdDSA_opts.
+
+ subst_evars.
+ lazymatch goal with |- _ = if ?b then ?t else ?f => apply (f_equal2 (fun x y => if b then x else y)) end; [|reflexivity].
+ eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [pointwise_relation]; intro].
+ set_evars.
+
+ unfold twistedToExtended.
+ autorewrite with EdDSA_opts.
+ progress cbv beta delta [erep2trep].
+
+ subst_evars.
+ reflexivity. } Unfocus.
+ reflexivity.
+ Defined.
+End Ed25519Frep. \ No newline at end of file
diff --git a/src/Specific/GF1305.v b/src/Specific/GF1305.v
new file mode 100644
index 000000000..b004a60d1
--- /dev/null
+++ b/src/Specific/GF1305.v
@@ -0,0 +1,74 @@
+Require Import Crypto.ModularArithmetic.ModularBaseSystem.
+Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseRep.
+Require Import Coq.Lists.List Crypto.Util.ListUtil.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import Crypto.Tactics.VerdiTactics.
+Require Import Crypto.BaseSystem.
+Import ListNotations.
+Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory.
+Local Open Scope Z.
+
+(* BEGIN PseudoMersenneBaseParams instance construction. *)
+
+Definition modulus : Z := 2^130 - 5.
+Lemma prime_modulus : prime modulus. Admitted.
+
+Instance params1305 : PseudoMersenneBaseParams modulus.
+ construct_params prime_modulus 5%nat 130.
+Defined.
+
+Definition mul2modulus := Eval compute in (construct_mul2modulus params1305).
+
+Instance subCoeff : SubtractionCoefficient modulus params1305.
+ apply Build_SubtractionCoefficient with (coeff := mul2modulus); cbv; auto.
+Defined.
+
+(* END PseudoMersenneBaseParams instance construction. *)
+
+(* Precompute k and c *)
+Definition k_ := Eval compute in k.
+Definition c_ := Eval compute in c.
+
+(* Makes Qed not take forever *)
+Opaque Z.shiftr Pos.iter Z.div2 Pos.div2 Pos.div2_up Pos.succ Z.land
+ Z.of_N Pos.land N.ldiff Pos.pred_N Pos.pred_double Z.opp Z.mul Pos.mul
+ Let_In digits Z.add Pos.add Z.pos_sub.
+
+Local Open Scope nat_scope.
+Lemma GF1305Base26_mul_reduce_formula :
+ forall f0 f1 f2 f3 f4 g0 g1 g2 g3 g4,
+ {ls | forall f g, rep [f0;f1;f2;f3;f4] f
+ -> rep [g0;g1;g2;g3;g4] g
+ -> rep ls (f*g)%F}.
+Proof.
+ eexists; intros ? ? Hf Hg.
+ pose proof (carry_mul_opt_correct k_ c_ (eq_refl k) (eq_refl c_) [0;4;3;2;1;0]_ _ _ _ Hf Hg) as Hfg.
+ compute_formula.
+Defined.
+
+Lemma GF1305Base26_add_formula :
+ forall f0 f1 f2 f3 f4 g0 g1 g2 g3 g4,
+ {ls | forall f g, rep [f0;f1;f2;f3;f4] f
+ -> rep [g0;g1;g2;g3;g4] g
+ -> rep ls (f + g)%F}.
+Proof.
+ eexists; intros ? ? Hf Hg.
+ pose proof (add_opt_rep _ _ _ _ Hf Hg) as Hfg.
+ compute_formula.
+Defined.
+
+Lemma GF25519Base25Point5_sub_formula :
+ forall f0 f1 f2 f3 f4 f5 f6 f7 f8 f9
+ g0 g1 g2 g3 g4 g5 g6 g7 g8 g9,
+ {ls | forall f g, rep [f0;f1;f2;f3;f4;f5;f6;f7;f8;f9] f
+ -> rep [g0;g1;g2;g3;g4;g5;g6;g7;g8;g9] g
+ -> rep ls (f - g)%F}.
+Proof.
+ eexists.
+ intros f g Hf Hg.
+ pose proof (sub_opt_rep _ _ _ _ Hf Hg) as Hfg.
+ compute_formula.
+Defined. \ No newline at end of file
diff --git a/src/Specific/GF25519.v b/src/Specific/GF25519.v
index 0d3923945..8aaf8caf6 100644
--- a/src/Specific/GF25519.v
+++ b/src/Specific/GF25519.v
@@ -1,529 +1,180 @@
-Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
-Require Import Crypto.BaseSystem Crypto.ModularArithmetic.ModularBaseSystem.
+Require Import Crypto.ModularArithmetic.ModularBaseSystem.
+Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
+Require Import Crypto.ModularArithmetic.PseudoMersenneBaseRep.
Require Import Coq.Lists.List Crypto.Util.ListUtil.
+Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
+Require Import Crypto.Tactics.VerdiTactics.
+Require Import Crypto.BaseSystem.
+Require Import Crypto.Rep.
Import ListNotations.
Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory.
-Require Import Coq.QArith.QArith Coq.QArith.Qround.
-Require Import Crypto.Tactics.VerdiTactics.
-Close Scope Q.
-
-Ltac twoIndices i j base :=
- intros;
- assert (In i (seq 0 (length base))) by nth_tac;
- assert (In j (seq 0 (length base))) by nth_tac;
- repeat match goal with [ x := _ |- _ ] => subst x end;
- simpl in *; repeat break_or_hyp; try omega; vm_compute; reflexivity.
-
-Module Base25Point5_10limbs <: BaseCoefs.
- Local Open Scope Z_scope.
- Definition log_base := Eval compute in map (fun i => (Qceiling (Z_of_nat i *255 # 10))) (seq 0 10).
- Definition base := map (fun x => 2 ^ x) log_base.
-
- Lemma base_positive : forall b, In b base -> b > 0.
- Proof.
- compute; intuition; subst; intuition.
- Qed.
-
- Lemma b0_1 : forall x, nth_default x base 0 = 1.
- Proof.
- auto.
- Qed.
-
- Lemma base_good :
- forall i j, (i+j < length base)%nat ->
- let b := nth_default 0 base in
- let r := (b i * b j) / b (i+j)%nat in
- b i * b j = r * b (i+j)%nat.
- Proof.
- twoIndices i j base.
- Qed.
-End Base25Point5_10limbs.
-
-Module Modulus25519 <: PrimeModulus.
- Local Open Scope Z_scope.
- Definition modulus : Z := 2^255 - 19.
- Lemma prime_modulus : prime modulus. Admitted.
-End Modulus25519.
-
-Module F25519Base25Point5Params <: PseudoMersenneBaseParams Base25Point5_10limbs Modulus25519.
- Import Base25Point5_10limbs.
- Import Modulus25519.
- Local Open Scope Z_scope.
- (* TODO: do we actually want B and M "up there" in the type parameters? I was
- * imagining writing something like "Paramter Module M : Modulus". *)
+Local Open Scope Z.
- Definition k := 255.
- Definition c := 19.
- Lemma modulus_pseudomersenne :
- modulus = 2^k - c.
- Proof.
- auto.
- Qed.
+(* BEGIN PseudoMersenneBaseParams instance construction. *)
- Lemma base_matches_modulus :
- forall i j,
- (i < length base)%nat ->
- (j < length base)%nat ->
- (i+j >= length base)%nat ->
- let b := nth_default 0 base in
- let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in
- b i * b j = r * (2^k * b (i+j-length base)%nat).
- Proof.
- twoIndices i j base.
- Qed.
+Definition modulus : Z := 2^255 - 19.
+Lemma prime_modulus : prime modulus. Admitted.
- Lemma base_succ : forall i, ((S i) < length base)%nat ->
- let b := nth_default 0 base in
- b (S i) mod b i = 0.
- Proof.
- intros; twoIndices i (S i) base.
- Qed.
-
- Lemma base_tail_matches_modulus:
- 2^k mod nth_default 0 base (pred (length base)) = 0.
- Proof.
- auto.
- Qed.
-
- Lemma b0_1 : forall x, nth_default x base 0 = 1.
- Proof.
- auto.
- Qed.
-
- Lemma k_nonneg : 0 <= k.
- Proof.
- rewrite Zle_is_le_bool; auto.
- Qed.
-
- Lemma base_range : forall i, 0 <= nth_default 0 log_base i <= k.
- Proof.
- intros i.
- destruct (lt_dec i (length log_base)) as [H|H].
- { repeat (destruct i as [|i]; [cbv; intuition; discriminate|simpl in H; try omega]). }
- { rewrite nth_default_eq, nth_overflow by omega. cbv. intuition; discriminate. }
- Qed.
-
- Lemma base_monotonic: forall i : nat, (i < pred (length log_base))%nat ->
- (0 <= nth_default 0 log_base i <= nth_default 0 log_base (S i)).
- Proof.
- intros i H.
- repeat (destruct i; [cbv; intuition; congruence|]);
- contradict H; cbv; firstorder.
- Qed.
-End F25519Base25Point5Params.
+Instance params25519 : PseudoMersenneBaseParams modulus.
+ construct_params prime_modulus 10%nat 255.
+Defined.
-Module F25519Base25Point5 := PseudoMersenneBase Base25Point5_10limbs Modulus25519 F25519Base25Point5Params.
+Definition mul2modulus := Eval compute in (construct_mul2modulus params25519).
-Section F25519Base25Point5Formula.
- Import F25519Base25Point5 Base25Point5_10limbs F25519Base25Point5 F25519Base25Point5Params.
+Instance subCoeff : SubtractionCoefficient modulus params25519.
+ apply Build_SubtractionCoefficient with (coeff := mul2modulus); cbv; auto.
+Defined.
-Definition Z_add_opt := Eval compute in Z.add.
-Definition Z_sub_opt := Eval compute in Z.sub.
-Definition Z_mul_opt := Eval compute in Z.mul.
-Definition Z_div_opt := Eval compute in Z.div.
-Definition Z_pow_opt := Eval compute in Z.pow.
+(* END PseudoMersenneBaseParams instance construction. *)
-Definition nth_default_opt {A} := Eval compute in @nth_default A.
-Definition map_opt {A B} := Eval compute in @map A B.
+(* Precompute k and c *)
+Definition k_ := Eval compute in k.
+Definition c_ := Eval compute in c.
-Ltac opt_step :=
- match goal with
- | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ]
- => refine (_ : match e with nil => _ | _ => _ end = _);
- destruct e
- end.
+(* Makes Qed not take forever *)
+Opaque Z.shiftr Pos.iter Z.div2 Pos.div2 Pos.div2_up Pos.succ Z.land
+ Z.of_N Pos.land N.ldiff Pos.pred_N Pos.pred_double Z.opp Z.mul Pos.mul
+ Let_In digits Z.add Pos.add Z.pos_sub.
-Definition E_mul_bi'_step
- (mul_bi' : nat -> E.digits -> list Z)
- (i : nat) (vsr : E.digits)
- : list Z
- := match vsr with
- | [] => []
- | v :: vsr' => (v * E.crosscoef i (length vsr'))%Z :: mul_bi' i vsr'
- end.
-
-Definition E_mul_bi'_opt_step_sig
- (mul_bi' : nat -> E.digits -> list Z)
- (i : nat) (vsr : E.digits)
- : { l : list Z | l = E_mul_bi'_step mul_bi' i vsr }.
+Local Open Scope nat_scope.
+Lemma GF25519Base25Point5_mul_reduce_formula :
+ forall f0 f1 f2 f3 f4 f5 f6 f7 f8 f9
+ g0 g1 g2 g3 g4 g5 g6 g7 g8 g9,
+ {ls | forall f g, rep [f0;f1;f2;f3;f4;f5;f6;f7;f8;f9] f
+ -> rep [g0;g1;g2;g3;g4;g5;g6;g7;g8;g9] g
+ -> rep ls (f*g)%F}.
Proof.
- eexists.
- cbv [E_mul_bi'_step].
- opt_step.
- { reflexivity. }
- { cbv [E.crosscoef EC.base Base25Point5_10limbs.base].
- change Z.div with Z_div_opt.
- change Z.pow with Z_pow_opt.
- change Z.mul with Z_mul_opt at 2 3 4 5.
- change @nth_default with @nth_default_opt.
- change @map with @map_opt.
- reflexivity. }
-Defined.
-
-Definition E_mul_bi'_opt_step
- (mul_bi' : nat -> E.digits -> list Z)
- (i : nat) (vsr : E.digits)
- : list Z
- := Eval cbv [proj1_sig E_mul_bi'_opt_step_sig] in
- proj1_sig (E_mul_bi'_opt_step_sig mul_bi' i vsr).
-
-Fixpoint E_mul_bi'_opt
- (i : nat) (vsr : E.digits) {struct vsr}
- : list Z
- := E_mul_bi'_opt_step E_mul_bi'_opt i vsr.
+ eexists; intros ? ? Hf Hg.
+ pose proof (carry_mul_opt_correct k_ c_ (eq_refl k_) (eq_refl c_) [0;9;8;7;6;5;4;3;2;1;0]_ _ _ _ Hf Hg) as Hfg.
+ compute_formula.
+Time Defined.
-Definition E_mul_bi'_opt_correct
- (i : nat) (vsr : E.digits)
- : E_mul_bi'_opt i vsr = E.mul_bi' i vsr.
-Proof.
- revert i; induction vsr as [|vsr vsrs IHvsr]; intros.
- { reflexivity. }
- { simpl E.mul_bi'.
- rewrite <- IHvsr; clear IHvsr.
- unfold E_mul_bi'_opt, E_mul_bi'_opt_step.
- apply f_equal2; [ | reflexivity ].
- cbv [E.crosscoef EC.base Base25Point5_10limbs.base].
- change Z.div with Z_div_opt.
- change Z.pow with Z_pow_opt.
- change Z.mul with Z_mul_opt at 2.
- change @nth_default with @nth_default_opt.
- change @map with @map_opt.
- reflexivity. }
-Qed.
+Extraction "/tmp/test.ml" GF25519Base25Point5_mul_reduce_formula.
+(* It's easy enough to use extraction to get the proper nice-looking formula.
+ * More Ltac acrobatics will be needed to get out that formula for further use in Coq.
+ * The easiest fix will be to make the proof script above fully automated,
+ * using [abstract] to contain the proof part. *)
-Definition E_mul'_step
- (mul' : E.digits -> E.digits -> E.digits)
- (usr vs : E.digits)
- : E.digits
- := match usr with
- | [] => []
- | u :: usr' => E.add (E.mul_each u (E.mul_bi (length usr') vs)) (mul' usr' vs)
- end.
-Definition E_mul'_opt_step_sig
- (mul' : E.digits -> E.digits -> E.digits)
- (usr vs : E.digits)
- : { d : E.digits | d = E_mul'_step mul' usr vs }.
+Lemma GF25519Base25Point5_add_formula :
+ forall f0 f1 f2 f3 f4 f5 f6 f7 f8 f9
+ g0 g1 g2 g3 g4 g5 g6 g7 g8 g9,
+ {ls | forall f g, rep [f0;f1;f2;f3;f4;f5;f6;f7;f8;f9] f
+ -> rep [g0;g1;g2;g3;g4;g5;g6;g7;g8;g9] g
+ -> rep ls (f + g)%F}.
Proof.
eexists.
- cbv [E_mul'_step].
- match goal with
- | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ]
- => refine (_ : match e with nil => _ | _ => _ end = _);
- destruct e
- end.
- { reflexivity. }
- { cbv [E.mul_each E.mul_bi].
- rewrite <- E_mul_bi'_opt_correct.
- reflexivity. }
+ intros f g Hf Hg.
+ pose proof (add_opt_rep _ _ _ _ Hf Hg) as Hfg.
+ compute_formula.
Defined.
-Definition E_mul'_opt_step
- (mul' : E.digits -> E.digits -> E.digits)
- (usr vs : E.digits)
- : E.digits
- := Eval cbv [proj1_sig E_mul'_opt_step_sig] in proj1_sig (E_mul'_opt_step_sig mul' usr vs).
-
-Fixpoint E_mul'_opt
- (usr vs : E.digits)
- : E.digits
- := E_mul'_opt_step E_mul'_opt usr vs.
-
-Definition E_mul'_opt_correct
- (usr vs : E.digits)
- : E_mul'_opt usr vs = E.mul' usr vs.
-Proof.
- revert vs; induction usr as [|usr usrs IHusr]; intros.
- { reflexivity. }
- { simpl.
- rewrite <- IHusr; clear IHusr.
- apply f_equal2; [ | reflexivity ].
- cbv [E.mul_each E.mul_bi].
- rewrite <- E_mul_bi'_opt_correct.
- reflexivity. }
-Qed.
-
-Definition mul_opt_sig (us vs : T) : { b : B.digits | b = mul us vs }.
+Lemma GF25519Base25Point5_sub_formula :
+ forall f0 f1 f2 f3 f4 f5 f6 f7 f8 f9
+ g0 g1 g2 g3 g4 g5 g6 g7 g8 g9,
+ {ls | forall f g, rep [f0;f1;f2;f3;f4;f5;f6;f7;f8;f9] f
+ -> rep [g0;g1;g2;g3;g4;g5;g6;g7;g8;g9] g
+ -> rep ls (f - g)%F}.
Proof.
eexists.
- cbv [mul E.mul E.mul_each E.mul_bi E.mul_bi' E.zeros EC.base reduce].
- rewrite <- E_mul'_opt_correct.
- reflexivity.
+ intros f g Hf Hg.
+ pose proof (sub_opt_rep _ _ _ _ Hf Hg) as Hfg.
+ compute_formula.
Defined.
-Definition mul_opt (us vs : T) : B.digits
- := Eval cbv [proj1_sig mul_opt_sig] in proj1_sig (mul_opt_sig us vs).
+Definition F25519Rep := (Z * Z * Z * Z * Z * Z * Z * Z * Z * Z)%type.
-Definition mul_opt_correct us vs
- : mul_opt us vs = mul us vs
- := proj2_sig (mul_opt_sig us vs).
+Definition F25519toRep (x:F (2^255 - 19)) : F25519Rep := (0, 0, 0, 0, 0, 0, 0, 0, 0, FieldToZ x)%Z.
+Definition F25519unRep (rx:F25519Rep) :=
+ let '(x9, x8, x7, x6, x5, x4, x3, x2, x1, x0) := rx in
+ ModularBaseSystem.decode [x0;x1;x2;x3;x4;x5;x6;x7;x8;x9].
-Lemma beq_nat_eq_nat_dec {R} (x y : nat) (a b : R)
- : (if EqNat.beq_nat x y then a else b) = (if eq_nat_dec x y then a else b).
-Proof.
- destruct (eq_nat_dec x y) as [H|H];
- [ rewrite (proj2 (@beq_nat_true_iff _ _) H); reflexivity
- | rewrite (proj2 (@beq_nat_false_iff _ _) H); reflexivity ].
-Qed.
-
-Lemma pull_app_if_sumbool {A B X Y} (b : sumbool X Y) (f g : A -> B) (x : A)
- : (if b then f x else g x) = (if b then f else g) x.
-Proof.
- destruct b; reflexivity.
-Qed.
+Global Instance F25519RepConversions : RepConversions (F (2^255 - 19)) F25519Rep :=
+ {
+ toRep := F25519toRep;
+ unRep := F25519unRep
+ }.
-Lemma map_nth_default_always {A B} (f : A -> B) (n : nat) (x : A) (l : list A)
- : nth_default (f x) (map f l) n = f (nth_default x l n).
+Lemma F25519RepConversionsOK : RepConversionsOK F25519RepConversions.
Proof.
- revert n; induction l; simpl; intro n; destruct n; [ try reflexivity.. ].
- nth_tac.
+ unfold F25519RepConversions, RepConversionsOK, unRep, toRep, F25519toRep, F25519unRep; intros.
+ change (ModularBaseSystem.decode (ModularBaseSystem.encode x) = x).
+ eauto using ModularBaseSystemProofs.rep_decode, ModularBaseSystemProofs.encode_rep.
Qed.
-Definition log_cap_opt_sig
- (i : nat)
- : { z : Z | i < length (Base25Point5_10limbs.log_base) -> (2^z)%Z = cap i }.
-Proof.
- eexists.
- cbv [cap Base25Point5_10limbs.base].
- intros.
- rewrite map_length in *.
- erewrite map_nth_default; [|assumption].
- instantiate (2 := 0%Z).
- (** For the division of maps of (2 ^ _) over lists, replace it with 2 ^ (_ - _) *)
-
- lazymatch goal with
- | [ |- _ = (if eq_nat_dec ?a ?b then (2^?x/2^?y)%Z else (nth_default 0 (map (fun x => (2^x)%Z) ?ls) ?si / 2^?d)%Z) ]
- => transitivity (2^if eq_nat_dec a b then (x-y)%Z else nth_default 0 ls si - d)%Z;
- [ apply f_equal | break_if ]
+Definition F25519Rep_mul (f g:F25519Rep) : F25519Rep.
+ refine (
+ let '(f9, f8, f7, f6, f5, f4, f3, f2, f1, f0) := f in
+ let '(g9, g8, g7, g6, g5, g4, g3, g2, g1, g0) := g in _).
+ (* FIXME: the r should not be present in generated code *)
+ pose (r := proj1_sig (GF25519Base25Point5_mul_reduce_formula f0 f1 f2 f3 f4 f5 f6 f7 f8 f9
+ g0 g1 g2 g3 g4 g5 g6 g7 g8 g9)).
+ simpl in r.
+ unfold F25519Rep.
+ repeat let t' := (eval cbv beta delta [r] in r) in
+ lazymatch t' with Let_In ?arg ?f =>
+ let x := fresh "x" in
+ refine (let x := arg in _);
+ let t'' := (eval cbv beta in (f x)) in
+ change (Let_In arg f) with t'' in r
+ end.
+ let t' := (eval cbv beta delta [r] in r) in
+ lazymatch t' with [?r0;?r1;?r2;?r3;?r4;?r5;?r6;?r7;?r8;?r9] =>
+ clear r;
+ exact (r9, r8, r7, r6, r5, r4, r3, r2, r1, r0)
end.
-
- Focus 2.
- apply Z.pow_sub_r; [clear;firstorder|apply base_range].
- Focus 2.
- erewrite map_nth_default by (omega); instantiate (1 := 0%Z).
- rewrite <- Z.pow_sub_r; [ reflexivity | .. ].
- { clear; abstract firstorder. }
- { apply base_monotonic. omega. }
- Unfocus.
- rewrite <-beq_nat_eq_nat_dec.
- change Z.sub with Z_sub_opt.
- change @nth_default with @nth_default_opt.
- change @map with @map_opt.
- reflexivity.
-Defined.
-
-Definition log_cap_opt (i : nat)
- := Eval cbv [proj1_sig log_cap_opt_sig] in proj1_sig (log_cap_opt_sig i).
-
-Definition log_cap_opt_correct (i : nat)
- : i < length Base25Point5_10limbs.log_base -> (2^log_cap_opt i = cap i)%Z
- := proj2_sig (log_cap_opt_sig i).
+Time Defined.
+
+Lemma F25519_mul_OK : RepBinOpOK F25519RepConversions ModularArithmetic.mul F25519Rep_mul.
+ cbv iota beta delta [RepBinOpOK F25519RepConversions F25519Rep_mul toRep unRep F25519toRep F25519unRep].
+ destruct x as [[[[[[[[[x9 x8] x7] x6] x5] x4] x3] x2] x1] x0].
+ destruct y as [[[[[[[[[y9 y8] y7] y6] y5] y4] y3] y2] y1] y0].
+ let E := constr:(GF25519Base25Point5_mul_reduce_formula x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 y0 y1 y2 y3 y4 y5 y6 y7 y8 y9) in
+ transitivity (ModularBaseSystem.decode (proj1_sig E)); [|solve[simpl;f_equal]];
+ destruct E as [? r]; cbv [proj1_sig].
+ cbv [rep ModularBaseSystem.rep PseudoMersenneBase modulus] in r; edestruct r; eauto.
+Qed.
-Definition carry_opt_sig
- (i : nat) (b : B.digits)
- : { d : B.digits | i < length Base25Point5_10limbs.log_base -> d = carry i b }.
-Proof.
- eexists ; intros.
- cbv [carry].
- rewrite <- pull_app_if_sumbool.
- cbv beta delta [carry_and_reduce carry_simple add_to_nth Base25Point5_10limbs.base].
- rewrite map_length.
- repeat lazymatch goal with
- | [ |- context[cap ?i] ]
- => replace (cap i) with (2^log_cap_opt i)%Z by (apply log_cap_opt_correct; assumption)
- end.
- lazymatch goal with
- | [ |- _ = (if ?br then ?c else ?d) ]
- => let x := fresh "x" in let y := fresh "y" in evar (x:B.digits); evar (y:B.digits); transitivity (if br then x else y); subst x; subst y
+Definition F25519Rep_add (f g:F25519Rep) : F25519Rep.
+ refine (
+ let '(f9, f8, f7, f6, f5, f4, f3, f2, f1, f0) := f in
+ let '(g9, g8, g7, g6, g5, g4, g3, g2, g1, g0) := g in _).
+ let t' := (eval simpl in (proj1_sig (GF25519Base25Point5_add_formula f0 f1 f2 f3 f4 f5 f6 f7 f8 f9
+ g0 g1 g2 g3 g4 g5 g6 g7 g8 g9))) in
+ lazymatch t' with [?r0;?r1;?r2;?r3;?r4;?r5;?r6;?r7;?r8;?r9] =>
+ exact (r9, r8, r7, r6, r5, r4, r3, r2, r1, r0)
end.
- Focus 2.
- cbv zeta.
- break_if;
- rewrite <- Z.land_ones, <- Z.shiftr_div_pow2 by (
- pose proof (base_range i); pose proof (base_monotonic i);
- change @nth_default with @nth_default_opt in *;
- cbv beta delta [log_cap_opt]; rewrite beq_nat_eq_nat_dec; break_if; change Z_sub_opt with Z.sub; omega);
- reflexivity.
- change @nth_default with @nth_default_opt.
- change @map with @map_opt.
- rewrite <- @beq_nat_eq_nat_dec.
- reflexivity.
Defined.
-Definition carry_opt i b
- := Eval cbv beta iota delta [proj1_sig carry_opt_sig] in proj1_sig (carry_opt_sig i b).
-
-Definition carry_opt_correct i b : i < length Base25Point5_10limbs.log_base -> carry_opt i b = carry i b := proj2_sig (carry_opt_sig i b).
-
-Definition carry_sequence_opt_sig (is : list nat) (us : B.digits)
- : { b : B.digits | (forall i, In i is -> i < length Base25Point5_10limbs.log_base) -> b = carry_sequence is us }.
-Proof.
- eexists. intros H.
- cbv [carry_sequence].
- transitivity (fold_right carry_opt us is).
- Focus 2.
- { induction is; [ reflexivity | ].
- simpl; rewrite IHis, carry_opt_correct.
- - reflexivity.
- - apply H; apply in_eq.
- - intros. apply H. right. auto.
- }
- Unfocus.
- reflexivity.
-Defined.
-
-Definition carry_sequence_opt is us := Eval cbv [proj1_sig carry_sequence_opt_sig] in
- proj1_sig (carry_sequence_opt_sig is us).
-
-Definition carry_sequence_opt_correct is us
- : (forall i, In i is -> i < length Base25Point5_10limbs.log_base) -> carry_sequence_opt is us = carry_sequence is us
- := proj2_sig (carry_sequence_opt_sig is us).
-
-Definition Let_In {A P} (x : A) (f : forall y : A, P y)
- := let y := x in f y.
-
-Definition carry_opt_cps_sig
- {T}
- (i : nat)
- (f : B.digits -> T)
- (b : B.digits)
- : { d : T | i < length Base25Point5_10limbs.log_base -> d = f (carry i b) }.
-Proof.
- eexists. intros H.
- rewrite <- carry_opt_correct by assumption.
- cbv beta iota delta [carry_opt].
- (* TODO: how to match the goal here? Alternatively, rewrite under let binders in carry_opt_sig and remove cbv zeta and restore original match from jgross's commit *)
- lazymatch goal with [ |- ?LHS = _ ] =>
- change (LHS = Let_In (nth_default_opt 0%Z b i) (fun di => (f (if beq_nat i (pred (length Base25Point5_10limbs.log_base))
- then
- set_nth 0
- (c *
- Z.shiftr (di) (log_cap_opt i) +
- nth_default_opt 0
- (set_nth i (Z.land di (Z.ones (log_cap_opt i)))
- b) 0)%Z
- (set_nth i (Z.land (nth_default_opt 0%Z b i) (Z.ones (log_cap_opt i))) b)
- else
- set_nth (S i)
- (Z.shiftr (di) (log_cap_opt i) +
- nth_default_opt 0
- (set_nth i (Z.land (di) (Z.ones (log_cap_opt i)))
- b) (S i))%Z
- (set_nth i (Z.land (nth_default_opt 0%Z b i) (Z.ones (log_cap_opt i))) b)))))
+Definition F25519Rep_sub (f g:F25519Rep) : F25519Rep.
+ refine (
+ let '(f9, f8, f7, f6, f5, f4, f3, f2, f1, f0) := f in
+ let '(g9, g8, g7, g6, g5, g4, g3, g2, g1, g0) := g in _).
+ let t' := (eval simpl in (proj1_sig (GF25519Base25Point5_sub_formula f0 f1 f2 f3 f4 f5 f6 f7 f8 f9
+ g0 g1 g2 g3 g4 g5 g6 g7 g8 g9))) in
+ lazymatch t' with [?r0;?r1;?r2;?r3;?r4;?r5;?r6;?r7;?r8;?r9] =>
+ exact (r9, r8, r7, r6, r5, r4, r3, r2, r1, r0)
end.
- reflexivity.
Defined.
-Definition carry_opt_cps {T} i f b
- := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b).
-
-Definition carry_opt_cps_correct {T} i f b :
- i < length Base25Point5_10limbs.log_base ->
- @carry_opt_cps T i f b = f (carry i b)
- := proj2_sig (carry_opt_cps_sig i f b).
-
-Definition carry_sequence_opt_cps_sig (is : list nat) (us : B.digits)
- : { b : B.digits | (forall i, In i is -> i < length Base25Point5_10limbs.log_base) -> b = carry_sequence is us }.
-Proof.
- eexists.
- cbv [carry_sequence].
- transitivity (fold_right carry_opt_cps id (List.rev is) us).
- Focus 2.
- {
- assert (forall i, In i (rev is) -> i < length Base25Point5_10limbs.log_base) as Hr. {
- subst. intros. rewrite <- in_rev in *. auto. }
- remember (rev is) as ris eqn:Heq.
- rewrite <- (rev_involutive is), <- Heq.
- clear H Heq is.
- rewrite fold_left_rev_right.
- revert us; induction ris; [ reflexivity | ]; intros.
- { simpl.
- rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption].
- rewrite carry_opt_cps_correct; [reflexivity|].
- apply Hr; left; reflexivity.
- } }
- Unfocus.
- reflexivity.
-Defined.
-
-Definition carry_sequence_opt_cps is us := Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in
- proj1_sig (carry_sequence_opt_cps_sig is us).
-
-Definition carry_sequence_opt_cps_correct is us
- : (forall i, In i is -> i < length Base25Point5_10limbs.log_base) -> carry_sequence_opt_cps is us = carry_sequence is us
- := proj2_sig (carry_sequence_opt_cps_sig is us).
-
-Lemma mul_opt_rep:
- forall (u v : T) (x y : F Modulus25519.modulus), rep u x -> rep v y -> rep (mul_opt u v) (x * y)%F.
-Proof.
- intros.
- rewrite mul_opt_correct.
- auto using mul_rep.
-Qed.
-
-Lemma carry_sequence_opt_cps_rep
- : forall (is : list nat) (us : list Z) (x : F Modulus25519.modulus),
- (forall i : nat, In i is -> i < length Base25Point5_10limbs.base) ->
- length us = length Base25Point5_10limbs.base ->
- rep us x -> rep (carry_sequence_opt_cps is us) x.
-Proof.
- intros.
- rewrite carry_sequence_opt_cps_correct by assumption.
- apply carry_sequence_rep; assumption.
-Qed.
-
-Definition carry_mul_opt
- (is : list nat)
- (us vs : list Z)
- : list Z
- := Eval cbv [B.add
- E.add E.mul E.mul' E.mul_bi E.mul_bi' E.mul_each E.zeros EC.base E_mul'_opt
- E_mul'_opt_step E_mul_bi'_opt E_mul_bi'_opt_step
- List.app List.rev Z_div_opt Z_mul_opt Z_pow_opt
- Z_sub_opt app beq_nat log_cap_opt carry_opt_cps carry_sequence_opt_cps error firstn
- fold_left fold_right id length map map_opt mul mul_opt nth_default nth_default_opt
- nth_error plus pred reduce rev seq set_nth skipn value base] in
- carry_sequence_opt_cps is (mul_opt us vs).
-
-Lemma carry_mul_opt_correct
- : forall (is : list nat) (us vs : list Z) (x y: F Modulus25519.modulus),
- rep us x -> rep vs y ->
- (forall i : nat, In i is -> i < length Base25Point5_10limbs.base) ->
- length (mul_opt us vs) = length base ->
- rep (carry_mul_opt is us vs) (x*y)%F.
-Proof.
- intros is us vs x y; intros.
- change (carry_mul_opt _ _ _) with (carry_sequence_opt_cps is (mul_opt us vs)).
- apply carry_sequence_opt_cps_rep, mul_opt_rep; auto.
+Lemma F25519_add_OK : RepBinOpOK F25519RepConversions ModularArithmetic.add F25519Rep_add.
+ cbv iota beta delta [RepBinOpOK F25519RepConversions F25519Rep_add toRep unRep F25519toRep F25519unRep].
+ destruct x as [[[[[[[[[x9 x8] x7] x6] x5] x4] x3] x2] x1] x0].
+ destruct y as [[[[[[[[[y9 y8] y7] y6] y5] y4] y3] y2] y1] y0].
+ let E := constr:(GF25519Base25Point5_add_formula x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 y0 y1 y2 y3 y4 y5 y6 y7 y8 y9) in
+ transitivity (ModularBaseSystem.decode (proj1_sig E)); [|solve[simpl;f_equal]];
+ destruct E as [? r]; cbv [proj1_sig].
+ cbv [rep ModularBaseSystem.rep PseudoMersenneBase modulus] in r; edestruct r; eauto.
Qed.
-
-
- Lemma GF25519Base25Point5_mul_reduce_formula :
- forall f0 f1 f2 f3 f4 f5 f6 f7 f8 f9
- g0 g1 g2 g3 g4 g5 g6 g7 g8 g9,
- {ls | forall f g, rep [f0;f1;f2;f3;f4;f5;f6;f7;f8;f9] f
- -> rep [g0;g1;g2;g3;g4;g5;g6;g7;g8;g9] g
- -> rep ls (f*g)%F}.
- Proof.
- eexists.
- intros f g Hf Hg.
-
- pose proof (carry_mul_opt_correct [0;9;8;7;6;5;4;3;2;1;0]_ _ _ _ Hf Hg) as Hfg.
- forward Hfg; [abstract (clear; cbv; intros; repeat break_or_hyp; intuition)|].
- specialize (Hfg H); clear H.
- forward Hfg; [exact eq_refl|].
- specialize (Hfg H); clear H.
-
- cbv [log_base map k c carry_mul_opt] in Hfg.
- cbv beta iota delta [Let_In] in Hfg.
- rewrite ?Z.mul_0_l, ?Z.mul_0_r, ?Z.mul_1_l, ?Z.mul_1_r, ?Z.add_0_l, ?Z.add_0_r in Hfg.
- rewrite ?Z.mul_assoc, ?Z.add_assoc in Hfg.
- exact Hfg.
- Time Defined.
-End F25519Base25Point5Formula.
-Extraction "/tmp/test.ml" GF25519Base25Point5_mul_reduce_formula.
-(* It's easy enough to use extraction to get the proper nice-looking formula.
- * More Ltac acrobatics will be needed to get out that formula for further use in Coq.
- * The easiest fix will be to make the proof script above fully automated,
- * using [abstract] to contain the proof part. *)
+Lemma F25519_sub_OK : RepBinOpOK F25519RepConversions ModularArithmetic.sub F25519Rep_sub.
+ cbv iota beta delta [RepBinOpOK F25519RepConversions F25519Rep_sub toRep unRep F25519toRep F25519unRep].
+ destruct x as [[[[[[[[[x9 x8] x7] x6] x5] x4] x3] x2] x1] x0].
+ destruct y as [[[[[[[[[y9 y8] y7] y6] y5] y4] y3] y2] y1] y0].
+ let E := constr:(GF25519Base25Point5_sub_formula x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 y0 y1 y2 y3 y4 y5 y6 y7 y8 y9) in
+ transitivity (ModularBaseSystem.decode (proj1_sig E)); [|solve[simpl;f_equal]];
+ destruct E as [? r]; cbv [proj1_sig].
+ cbv [rep ModularBaseSystem.rep PseudoMersenneBase modulus] in r; edestruct r; eauto.
+Qed. \ No newline at end of file
diff --git a/src/Testbit.v b/src/Testbit.v
new file mode 100644
index 000000000..264069587
--- /dev/null
+++ b/src/Testbit.v
@@ -0,0 +1,212 @@
+Require Import Coq.Lists.List.
+Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import Crypto.BaseSystem Crypto.BaseSystemProofs.
+Require Import Coq.ZArith.ZArith Coq.ZArith.Zdiv.
+Require Import Coq.omega.Omega Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith.
+
+Local Open Scope Z.
+
+Definition testbit (limb_width n : nat) (us : list Z) :=
+ Z.testbit (nth_default 0 us (n / limb_width)) (Z.of_nat (n mod limb_width)).
+
+(* identical limb widths *)
+Definition uniform_base (l : list Z) r := forall n d, (n < length l)%nat ->
+ nth n l d = r ^ (Z.of_nat n).
+
+Definition successive_powers (l : list Z) r := forall n d, (S n < length l)%nat ->
+ nth (S n) l d = r * nth n l d.
+
+Fixpoint unfold_bits (limb_width : Z) (us : list Z) :=
+ match us with
+ | nil => 0
+ | u0 :: us' => Z.land u0 (Z.ones limb_width) + Z.shiftl (unfold_bits limb_width us') limb_width
+ end.
+
+Lemma uniform_base_first : forall b0 bs r,
+ uniform_base (b0 :: bs) r -> b0 = 1.
+Proof.
+ boring.
+ match goal with
+ | [ H : uniform_base _ _ |- _ ] => unfold uniform_base in H;
+ specialize (H 0%nat 0); simpl in H; eapply H; omega
+ end.
+Qed.
+
+Lemma uniform_base_second : forall b0 b1 bs r,
+ uniform_base (b0 :: b1 :: bs) r -> b1 = r.
+Proof.
+ boring.
+ match goal with
+ | [ H : uniform_base _ _ |- _ ] => unfold uniform_base in H;
+ specialize (H 1%nat 0); cbv [nth length] in H;
+ rewrite Z.pow_1_r in H; apply H; omega
+ end.
+Qed.
+
+
+Lemma successive_powers_second : forall b0 b1 bs r,
+ successive_powers (b0 :: b1 :: bs) r -> b1 = r * b0.
+Proof.
+ boring.
+ match goal with
+ | [ H : successive_powers _ _ |- _ ] => unfold uniform_base in H;
+ specialize (H 0%nat 0); cbv [nth length] in H; apply H; omega
+ end.
+Qed.
+
+Ltac uniform_base_subst :=
+ match goal with
+ | [H : uniform_base (?b0 :: ?b1 :: _) _ |- _ ]=>
+ erewrite (uniform_base_first b0); eauto;
+ erewrite (uniform_base_second b0 b1); eauto
+ end.
+
+Lemma successive_powers_tail : forall x0 xs r, successive_powers (x0 :: xs) r ->
+ successive_powers xs r.
+Proof.
+ unfold successive_powers; boring.
+Qed.
+
+Lemma decode_uniform_shift : forall us base limb_width, (S (length us) <= length base)%nat ->
+ successive_powers base (2 ^ limb_width) ->
+ decode base (mul_each (2 ^ limb_width) us) = decode base (0 :: us).
+Proof.
+ unfold decode, decode', accumulate, mul_each;
+ induction us; induction base; try solve [boring].
+ intros; simpl in *.
+ destruct base; [ boring; omega | ].
+ simpl; f_equal.
+ + erewrite (successive_powers_second _ z); eauto.
+ ring.
+ + apply IHus; [ omega | ].
+ eapply successive_powers_tail; eassumption.
+Qed.
+
+Lemma uniform_base_successive_powers : forall xs r, uniform_base xs r ->
+ successive_powers xs r.
+Proof.
+ unfold uniform_base, successive_powers; intros ? ? G n ? ?.
+ do 2 rewrite G by omega.
+ rewrite Nat2Z.inj_succ.
+ apply Z.pow_succ_r.
+ apply Nat2Z.is_nonneg.
+Qed.
+
+Lemma uniform_base_BaseVector : forall base r, (r > 0) -> (0 < length base)%nat ->
+ uniform_base base r -> BaseVector base.
+Proof.
+ unfold uniform_base.
+ intros ? ? r_gt_0 base_nonempty uniform.
+ constructor.
+ + intros b In_b_base.
+ apply In_nth_error_value in In_b_base.
+ destruct In_b_base as [x nth_error_x].
+ pose proof (nth_error_value_length _ _ _ _ nth_error_x) as index_bound.
+ specialize (uniform x 0 index_bound).
+ rewrite <- nth_default_eq in uniform.
+ erewrite nth_error_value_eq_nth_default in uniform; eauto.
+ subst.
+ destruct r; [ | apply pos_pow_nat_pos | pose proof (Zlt_neg_0 p) ] ; omega.
+ + intros.
+ rewrite nth_default_eq.
+ rewrite uniform; auto.
+ + intros.
+ subst b.
+ subst r0.
+ repeat rewrite nth_default_eq.
+ repeat rewrite uniform by omega; auto.
+ rewrite <- Z.pow_add_r by apply Nat2Z.is_nonneg.
+ rewrite Nat2Z.inj_add.
+ rewrite <- Z.pow_sub_r, <- Z.pow_add_r by omega.
+ f_equal.
+ omega.
+Qed.
+
+Definition no_overflow us limb_width := forall n,
+ Z.land (nth_default 0 us n) (Z.ones limb_width) = nth_default 0 us n.
+
+Lemma no_overflow_cons : forall u0 us limb_width,
+ no_overflow (u0 :: us) limb_width -> Z.land u0 (Z.ones limb_width) = u0.
+Proof.
+ unfold no_overflow; intros ? ? ? no_overflow_u0_us.
+ specialize (no_overflow_u0_us 0%nat).
+ rewrite nth_default_cons in no_overflow_u0_us.
+ assumption.
+Qed.
+
+Lemma no_overflow_tail : forall u0 us limb_width,
+ no_overflow (u0 :: us) limb_width -> no_overflow us limb_width.
+Proof.
+ unfold no_overflow; intros.
+ erewrite <- nth_default_cons_S; eauto.
+Qed.
+
+Lemma unfold_bits_decode : forall limb_width us base, (0 <= limb_width) ->
+ (length us <= length base)%nat -> (0 < length base)%nat ->
+ no_overflow us limb_width ->
+ uniform_base base (2 ^ limb_width) ->
+ BaseSystem.decode base us = unfold_bits limb_width us.
+Proof.
+ induction us; boring.
+ rewrite <- (IHus base) by (omega || eauto using no_overflow_tail).
+ rewrite decode_cons by (eapply uniform_base_BaseVector; eauto;
+ rewrite gt_lt_symmetry; apply Z_pow_gt0; omega).
+ simpl.
+ f_equal.
+ + symmetry. eapply no_overflow_cons; eauto.
+ + rewrite Z.shiftl_mul_pow2 by assumption.
+ erewrite <- decode_uniform_shift; eauto using uniform_base_successive_powers.
+ rewrite mul_each_rep.
+ unfold decode.
+ apply Z.mul_comm.
+Qed.
+
+
+Lemma unfold_bits_indexing : forall us i limb_width, (0 < limb_width)%nat ->
+ no_overflow us (Z.of_nat limb_width) ->
+ nth_default 0 us i =
+ Z.land (Z.shiftr (unfold_bits (Z.of_nat limb_width) us) (Z.of_nat (i * limb_width))) (Z.ones (Z.of_nat limb_width)).
+Proof.
+ induction us; intros.
+ + rewrite nth_default_nil.
+ rewrite Z.shiftr_0_l.
+ auto using Z.land_0_l.
+ + destruct i; simpl.
+ - rewrite nth_default_cons.
+ rewrite Z.shiftr_0_r, Z_land_add_land by omega.
+ symmetry; eapply no_overflow_cons; eauto.
+ - rewrite nth_default_cons_S.
+ erewrite IHus; eauto using no_overflow_tail.
+ remember (i * limb_width)%nat as k.
+ rewrite Z_shiftr_add_land by omega.
+ replace (limb_width + k - limb_width)%nat with k by omega.
+ reflexivity.
+Qed.
+
+Lemma unfold_bits_testbit : forall limb_width us n, (0 < limb_width)%nat ->
+ no_overflow us (Z.of_nat limb_width) ->
+ testbit limb_width n us = Z.testbit (unfold_bits (Z.of_nat limb_width) us) (Z.of_nat n).
+Proof.
+ unfold testbit; intros.
+ erewrite unfold_bits_indexing; eauto.
+ rewrite <- Z_testbit_low by
+ (split; try apply Nat2Z.inj_lt; pose proof (mod_bound_pos n limb_width); omega).
+ rewrite Z.shiftr_spec by apply Nat2Z.is_nonneg.
+ f_equal.
+ rewrite <- Nat2Z.inj_add.
+ apply Z2Nat.inj; try apply Nat2Z.is_nonneg.
+ rewrite !Nat2Z.id.
+ symmetry.
+ rewrite Nat.add_comm, Nat.mul_comm.
+ apply Nat.div_mod; omega.
+Qed.
+
+Lemma testbit_spec : forall n us base limb_width, (0 < limb_width)%nat ->
+ (0 < length base)%nat -> (length us <= length base)%nat ->
+ no_overflow us (Z.of_nat limb_width) ->
+ uniform_base base (2 ^ (Z.of_nat limb_width)) ->
+ testbit limb_width n us = Z.testbit (BaseSystem.decode base us) (Z.of_nat n).
+Proof.
+ intros.
+ erewrite unfold_bits_testbit, unfold_bits_decode; eauto; omega.
+Qed. \ No newline at end of file
diff --git a/src/Util/CaseUtil.v b/src/Util/CaseUtil.v
index af04a1e49..cf3ebf29c 100644
--- a/src/Util/CaseUtil.v
+++ b/src/Util/CaseUtil.v
@@ -10,3 +10,9 @@ Ltac case_max :=
(exact (le_Sn_le _ _ (not_le _ _ H)))
end
end.
+
+Lemma pull_app_if_sumbool {A B X Y} (b : sumbool X Y) (f g : A -> B) (x : A)
+ : (if b then f x else g x) = (if b then f else g) x.
+Proof.
+ destruct b; reflexivity.
+Qed.
diff --git a/src/Util/IterAssocOp.v b/src/Util/IterAssocOp.v
index 016a4f7bd..6116312e1 100644
--- a/src/Util/IterAssocOp.v
+++ b/src/Util/IterAssocOp.v
@@ -5,11 +5,15 @@ Local Open Scope equiv_scope.
Generalizable All Variables.
Section IterAssocOp.
Context `{Equivalence T}
+ {scalar : Type}
(op:T->T->T) {op_proper:Proper (equiv==>equiv==>equiv) op}
(assoc: forall a b c, op a (op b c) === op (op a b) c)
(neutral:T)
(neutral_l : forall a, op neutral a === a)
- (neutral_r : forall a, op a neutral === a).
+ (neutral_r : forall a, op a neutral === a)
+ (testbit : scalar -> nat -> bool)
+ (scToN : scalar -> N)
+ (testbit_spec : forall x i, testbit x i = N.testbit_nat (scToN x) i).
Existing Instance op_proper.
Fixpoint nat_iter_op n a :=
@@ -51,19 +55,19 @@ Section IterAssocOp.
| S exp' => f (funexp f a exp')
end.
- Definition test_and_op n a (state : nat * T) :=
+ Definition test_and_op sc a (state : nat * T) :=
let '(i, acc) := state in
let acc2 := op acc acc in
match i with
| O => (0, acc)
- | S i' => (i', if N.testbit_nat n i' then op a acc2 else acc2)
+ | S i' => (i', if testbit sc i' then op a acc2 else acc2)
end.
- Definition iter_op n a : T :=
- snd (funexp (test_and_op n a) (N.size_nat n, neutral) (N.size_nat n)).
+ Definition iter_op sc a bound : T :=
+ snd (funexp (test_and_op sc a) (bound, neutral) bound).
- Definition test_and_op_inv n a (s : nat * T) :=
- snd s === nat_iter_op (N.to_nat (N.shiftr_nat n (fst s))) a.
+ Definition test_and_op_inv sc a (s : nat * T) :=
+ snd s === nat_iter_op (N.to_nat (N.shiftr_nat (scToN sc) (fst s))) a.
Hint Rewrite
N.succ_double_spec
@@ -91,7 +95,7 @@ Section IterAssocOp.
reflexivity.
Qed.
- Lemma shiftr_succ : forall n i,
+ Lemma Nshiftr_succ : forall n i,
N.to_nat (N.shiftr_nat n i) =
if N.testbit_nat n i
then S (2 * N.to_nat (N.shiftr_nat n (S i)))
@@ -115,21 +119,23 @@ Section IterAssocOp.
apply N2Nat.id.
Qed.
- Lemma test_and_op_inv_step : forall n a s,
- test_and_op_inv n a s ->
- test_and_op_inv n a (test_and_op n a s).
+ Lemma test_and_op_inv_step : forall sc a s,
+ test_and_op_inv sc a s ->
+ test_and_op_inv sc a (test_and_op sc a s).
Proof.
destruct s as [i acc].
unfold test_and_op_inv, test_and_op; simpl; intro Hpre.
destruct i; [ apply Hpre | ].
simpl.
- rewrite shiftr_succ.
- case_eq (N.testbit_nat n i); intro; simpl; rewrite Hpre, <- plus_n_O, nat_iter_op_plus; reflexivity.
+ rewrite Nshiftr_succ.
+ case_eq (testbit sc i); intro testbit_eq; simpl;
+ rewrite testbit_spec in testbit_eq; rewrite testbit_eq;
+ rewrite Hpre, <- plus_n_O, nat_iter_op_plus; reflexivity.
Qed.
- Lemma test_and_op_inv_holds : forall n a i s,
- test_and_op_inv n a s ->
- test_and_op_inv n a (funexp (test_and_op n a) s i).
+ Lemma test_and_op_inv_holds : forall sc a i s,
+ test_and_op_inv sc a s ->
+ test_and_op_inv sc a (funexp (test_and_op sc a) s i).
Proof.
induction i; intros; auto; simpl; apply test_and_op_inv_step; auto.
Qed.
@@ -144,14 +150,14 @@ Section IterAssocOp.
destruct i; rewrite NPeano.Nat.sub_succ_r; subst; rewrite <- IHy; simpl; reflexivity.
Qed.
- Lemma iter_op_termination : forall n a,
- test_and_op_inv n a
- (funexp (test_and_op n a) (N.size_nat n, neutral) (N.size_nat n)) ->
- iter_op n a === nat_iter_op (N.to_nat n) a.
+ Lemma iter_op_termination : forall sc a bound,
+ N.size_nat (scToN sc) <= bound ->
+ test_and_op_inv sc a
+ (funexp (test_and_op sc a) (bound, neutral) bound) ->
+ iter_op sc a bound === nat_iter_op (N.to_nat (scToN sc)) a.
Proof.
- unfold test_and_op_inv, iter_op; simpl; intros ? ? Hinv.
+ unfold test_and_op_inv, iter_op; simpl; intros ? ? ? ? Hinv.
rewrite Hinv, funexp_test_and_op_index, Minus.minus_diag.
- replace (N.shiftr_nat n 0) with n by auto.
reflexivity.
Qed.
@@ -160,25 +166,33 @@ Section IterAssocOp.
destruct n; auto; simpl; induction p; simpl; auto; rewrite IHp, Pnat.Pos2Nat.inj_succ; reflexivity.
Qed.
- Lemma Nshiftr_size : forall n, N.shiftr_nat n (N.size_nat n) = 0%N.
+ Lemma Nshiftr_size : forall n bound, N.size_nat n <= bound ->
+ N.shiftr_nat n bound = 0%N.
Proof.
intros.
- rewrite Nsize_nat_equiv.
+ rewrite <- (Nat2N.id bound).
rewrite Nshiftr_nat_equiv.
- destruct (N.eq_dec n 0); subst; auto.
+ destruct (N.eq_dec n 0); subst; [apply N.shiftr_0_l|].
apply N.shiftr_eq_0.
- rewrite N.size_log2 by auto.
- apply N.lt_succ_diag_r.
+ rewrite Nsize_nat_equiv in *.
+ rewrite N.size_log2 in * by auto.
+ apply N.le_succ_l.
+ rewrite <- N.compare_le_iff.
+ rewrite N2Nat.inj_compare.
+ rewrite <- Compare_dec.nat_compare_le.
+ rewrite Nat2N.id.
+ auto.
Qed.
-
- Lemma iter_op_spec : forall n a, iter_op n a === nat_iter_op (N.to_nat n) a.
+
+ Lemma iter_op_spec : forall sc a bound, N.size_nat (scToN sc) <= bound ->
+ iter_op sc a bound === nat_iter_op (N.to_nat (scToN sc)) a.
Proof.
intros.
- apply iter_op_termination.
+ apply iter_op_termination; auto.
apply test_and_op_inv_holds.
unfold test_and_op_inv.
simpl.
- rewrite Nshiftr_size.
+ rewrite Nshiftr_size by auto.
reflexivity.
Qed.
diff --git a/src/Util/ListUtil.v b/src/Util/ListUtil.v
index 1f9a62457..36d8a3ad3 100644
--- a/src/Util/ListUtil.v
+++ b/src/Util/ListUtil.v
@@ -433,6 +433,22 @@ Proof.
auto.
Qed.
+Lemma nth_default_cons_S : forall {A} us (u0 : A) n d,
+ nth_default d (u0 :: us) (S n) = nth_default d us n.
+Proof.
+ boring.
+Qed.
+
+Lemma nth_error_Some_nth_default : forall {T} i x (l : list T), (i < length l)%nat ->
+ nth_error l i = Some (nth_default x l i).
+Proof.
+ intros ? ? ? ? i_lt_length.
+ destruct (nth_error_length_exists_value _ _ i_lt_length) as [k nth_err_k].
+ unfold nth_default.
+ rewrite nth_err_k.
+ reflexivity.
+Qed.
+
Lemma set_nth_cons : forall {T} (x u0 : T) us, set_nth 0 x (u0 :: us) = x :: us.
Proof.
auto.
@@ -533,3 +549,36 @@ Lemma cons_eq_tail : forall {T} (x y:T) xs ys, x::xs = y::ys -> xs=ys.
Proof.
intros; solve_by_inversion.
Qed.
+
+Lemma map_nth_default_always {A B} (f : A -> B) (n : nat) (x : A) (l : list A)
+ : nth_default (f x) (map f l) n = f (nth_default x l n).
+Proof.
+ revert n; induction l; simpl; intro n; destruct n; [ try reflexivity.. ].
+ nth_tac.
+Qed.
+
+Lemma fold_right_and_True_forall_In_iff : forall {T} (l : list T) (P : T -> Prop),
+ (forall x, In x l -> P x) <-> fold_right and True (map P l).
+Proof.
+ induction l; intros; simpl; try tauto.
+ rewrite <- IHl.
+ intuition (subst; auto).
+Qed.
+
+Lemma fold_right_invariant : forall {A} P (f: A -> A -> A) l x,
+ P x -> (forall y, In y l -> forall z, P z -> P (f y z)) ->
+ P (fold_right f x l).
+Proof.
+ induction l; intros ? ? step; auto.
+ simpl.
+ apply step; try apply in_eq.
+ apply IHl; auto.
+ intros y in_y_l.
+ apply (in_cons a) in in_y_l.
+ auto.
+Qed.
+
+Lemma In_firstn : forall {T} n l (x : T), In x (firstn n l) -> In x l.
+Proof.
+ induction n; destruct l; boring.
+Qed.
diff --git a/src/Util/NatUtil.v b/src/Util/NatUtil.v
index 6a62d6c22..1f69b04d2 100644
--- a/src/Util/NatUtil.v
+++ b/src/Util/NatUtil.v
@@ -56,3 +56,14 @@ Proof.
rewrite <- nat_compare_lt; auto.
}
Qed.
+
+
+
+Lemma beq_nat_eq_nat_dec {R} (x y : nat) (a b : R)
+ : (if EqNat.beq_nat x y then a else b) = (if eq_nat_dec x y then a else b).
+Proof.
+ destruct (eq_nat_dec x y) as [H|H];
+ [ rewrite (proj2 (@beq_nat_true_iff _ _) H); reflexivity
+ | rewrite (proj2 (@beq_nat_false_iff _ _) H); reflexivity ].
+Qed.
+
diff --git a/src/Util/Tactics.v b/src/Util/Tactics.v
new file mode 100644
index 000000000..08ebfe13e
--- /dev/null
+++ b/src/Util/Tactics.v
@@ -0,0 +1,25 @@
+(** * Generic Tactics *)
+
+(* [pose proof defn], but only if no hypothesis of the same type exists.
+ most useful for proofs of a proposition *)
+Tactic Notation "unique" "pose" "proof" constr(defn) :=
+ let T := type of defn in
+ match goal with
+ | [ H : T |- _ ] => fail 1
+ | _ => pose proof defn
+ end.
+(* [assert T], but only if no hypothesis of the same type exists.
+ most useful for proofs of a proposition *)
+Tactic Notation "unique" "assert" constr(T) :=
+ match goal with
+ | [ H : T |- _ ] => fail 1
+ | _ => assert T
+ end.
+
+(* [assert T], but only if no hypothesis of the same type exists.
+ most useful for proofs of a proposition *)
+Tactic Notation "unique" "assert" constr(T) "by" tactic3(tac) :=
+ match goal with
+ | [ H : T |- _ ] => fail 1
+ | _ => assert T by tac
+ end.
diff --git a/src/Util/WordUtil.v b/src/Util/WordUtil.v
index 17d04c60a..6a8831b14 100644
--- a/src/Util/WordUtil.v
+++ b/src/Util/WordUtil.v
@@ -1,5 +1,6 @@
Require Import Coq.Numbers.Natural.Peano.NPeano.
Require Import Coq.ZArith.ZArith.
+Require Import Crypto.Util.NatUtil.
Require Import Bedrock.Word.
Local Open Scope nat_scope.
@@ -21,6 +22,39 @@ Proof.
auto.
Qed.
+Lemma wordToN_NToWord_idempotent : forall sz n, (n < Npow2 sz)%N ->
+ wordToN (NToWord sz n) = n.
+Proof.
+ intros.
+ rewrite wordToN_nat, NToWord_nat.
+ rewrite wordToNat_natToWord_idempotent; rewrite Nnat.N2Nat.id; auto.
+Qed.
+
+Lemma NToWord_wordToN : forall sz w, NToWord sz (wordToN w) = w.
+Proof.
+ intros.
+ rewrite wordToN_nat, NToWord_nat, Nnat.Nat2N.id.
+ apply natToWord_wordToNat.
+Qed.
+
+Lemma bound_check_nat_N : forall x n, (Z.to_nat x < 2 ^ n)%nat -> (Z.to_N x < Npow2 n)%N.
+Proof.
+ intros x n bound_nat.
+ rewrite <- Nnat.N2Nat.id, Npow2_nat.
+ replace (Z.to_N x) with (N.of_nat (Z.to_nat x)) by apply Z_nat_N.
+ apply (Nat2N_inj_lt _ (pow2 n)).
+ rewrite pow2_id; assumption.
+Qed.
+
+Lemma weqb_false_iff : forall sz (x y : word sz), weqb x y = false <-> x <> y.
+Proof.
+ split; intros.
+ + intro eq_xy; apply weqb_true_iff in eq_xy; congruence.
+ + case_eq (weqb x y); intros weqb_xy; auto.
+ apply weqb_true_iff in weqb_xy.
+ congruence.
+Qed.
+
Definition wfirstn n {m} (w : Word.word m) {H : n <= m} : Word.word n.
refine (Word.split1 n (m - n) (match _ in _ = N return Word.word N with
| eq_refl => w
diff --git a/src/Util/ZUtil.v b/src/Util/ZUtil.v
index db3d84b2d..1b7cfafdc 100644
--- a/src/Util/ZUtil.v
+++ b/src/Util/ZUtil.v
@@ -1,6 +1,7 @@
Require Import Coq.ZArith.Zpower Coq.ZArith.Znumtheory Coq.ZArith.ZArith Coq.ZArith.Zdiv.
Require Import Coq.omega.Omega Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith.
Require Import Crypto.Util.NatUtil.
+Require Import Coq.Lists.List.
Local Open Scope Z.
Lemma gt_lt_symmetry: forall n m, n > m <-> m < n.
@@ -168,6 +169,222 @@ Proof.
intros; omega.
Qed.
+
+Lemma Z_testbit_low : forall n x i, (0 <= i < n) ->
+ Z.testbit x i = Z.testbit (Z.land x (Z.ones n)) i.
+Proof.
+ intros.
+ rewrite Z.land_ones by omega.
+ symmetry.
+ apply Z.mod_pow2_bits_low.
+ omega.
+Qed.
+
+
+Lemma Z_testbit_shiftl : forall i, (0 <= i) -> forall a b n, (i < n) ->
+ Z.testbit (a + Z.shiftl b n) i = Z.testbit a i.
+Proof.
+ intros.
+ erewrite Z_testbit_low; eauto.
+ rewrite Z.land_ones, Z.shiftl_mul_pow2 by omega.
+ rewrite Z.mod_add by (pose proof (Z.pow_pos_nonneg 2 n); omega).
+ auto using Z.mod_pow2_bits_low.
+Qed.
+
+Lemma Z_mod_div_eq0 : forall a b, 0 < b -> (a mod b) / b = 0.
+Proof.
+ intros.
+ apply Z.div_small.
+ auto using Z.mod_pos_bound.
+Qed.
+
+Lemma Z_shiftr_add_land : forall n m a b, (n <= m)%nat ->
+ Z.shiftr ((Z.land a (Z.ones (Z.of_nat n))) + (Z.shiftl b (Z.of_nat n))) (Z.of_nat m) = Z.shiftr b (Z.of_nat (m - n)).
+Proof.
+ intros.
+ rewrite Z.land_ones by apply Nat2Z.is_nonneg.
+ rewrite !Z.shiftr_div_pow2 by apply Nat2Z.is_nonneg.
+ rewrite Z.shiftl_mul_pow2 by apply Nat2Z.is_nonneg.
+ rewrite (le_plus_minus n m) at 1 by assumption.
+ rewrite Nat2Z.inj_add.
+ rewrite Z.pow_add_r by apply Nat2Z.is_nonneg.
+ rewrite <- Z.div_div by first
+ [ pose proof (Z.pow_pos_nonneg 2 (Z.of_nat n)); omega
+ | apply Z.pow_pos_nonneg; omega ].
+ rewrite Z.div_add by (pose proof (Z.pow_pos_nonneg 2 (Z.of_nat n)); omega).
+ rewrite Z_mod_div_eq0 by (pose proof (Z.pow_pos_nonneg 2 (Z.of_nat n)); omega); auto.
+Qed.
+
+Lemma Z_land_add_land : forall n m a b, (m <= n)%nat ->
+ Z.land ((Z.land a (Z.ones (Z.of_nat n))) + (Z.shiftl b (Z.of_nat n))) (Z.ones (Z.of_nat m)) = Z.land a (Z.ones (Z.of_nat m)).
+Proof.
+ intros.
+ rewrite !Z.land_ones by apply Nat2Z.is_nonneg.
+ rewrite Z.shiftl_mul_pow2 by apply Nat2Z.is_nonneg.
+ replace (b * 2 ^ Z.of_nat n) with
+ ((b * 2 ^ Z.of_nat (n - m)) * 2 ^ Z.of_nat m) by
+ (rewrite (le_plus_minus m n) at 2; try assumption;
+ rewrite Nat2Z.inj_add, Z.pow_add_r by apply Nat2Z.is_nonneg; ring).
+ rewrite Z.mod_add by (pose proof (Z.pow_pos_nonneg 2 (Z.of_nat m)); omega).
+ symmetry. apply Znumtheory.Zmod_div_mod; try (apply Z.pow_pos_nonneg; omega).
+ rewrite (le_plus_minus m n) by assumption.
+ rewrite Nat2Z.inj_add, Z.pow_add_r by apply Nat2Z.is_nonneg.
+ apply Z.divide_factor_l.
+Qed.
+
+Lemma Z_pow_gt0 : forall a, 0 < a -> forall b, 0 <= b -> 0 < a ^ b.
+Proof.
+ intros until 1.
+ apply natlike_ind; try (simpl; omega).
+ intros.
+ rewrite Z.pow_succ_r by assumption.
+ apply Z.mul_pos_pos; assumption.
+Qed.
+
+Lemma div_pow2succ : forall n x, (0 <= x) ->
+ n / 2 ^ Z.succ x = Z.div2 (n / 2 ^ x).
+Proof.
+ intros.
+ rewrite Z.pow_succ_r, Z.mul_comm by auto.
+ rewrite <- Z.div_div by (try apply Z.pow_nonzero; omega).
+ rewrite Zdiv2_div.
+ reflexivity.
+Qed.
+
+Lemma shiftr_succ : forall n x,
+ Z.shiftr n (Z.succ x) = Z.shiftr (Z.shiftr n x) 1.
+Proof.
+ intros.
+ rewrite Z.shiftr_shiftr by omega.
+ reflexivity.
+Qed.
+
+
+Definition Z_shiftl_by n a := Z.shiftl a n.
+
+Lemma Z_shiftl_by_mul_pow2 : forall n a, 0 <= n -> Z.mul (2 ^ n) a = Z_shiftl_by n a.
+Proof.
+ intros.
+ unfold Z_shiftl_by.
+ rewrite Z.shiftl_mul_pow2 by assumption.
+ apply Z.mul_comm.
+Qed.
+
+Lemma map_shiftl : forall n l, 0 <= n -> map (Z.mul (2 ^ n)) l = map (Z_shiftl_by n) l.
+Proof.
+ intros; induction l; auto using Z_shiftl_by_mul_pow2.
+ simpl.
+ rewrite IHl.
+ f_equal.
+ apply Z_shiftl_by_mul_pow2.
+ assumption.
+Qed.
+
+Lemma Z_odd_mod : forall a b, (b <> 0)%Z ->
+ Z.odd (a mod b) = if Z.odd b then xorb (Z.odd a) (Z.odd (a / b)) else Z.odd a.
+Proof.
+ intros.
+ rewrite Zmod_eq_full by assumption.
+ rewrite <-Z.add_opp_r, Z.odd_add, Z.odd_opp, Z.odd_mul.
+ case_eq (Z.odd b); intros; rewrite ?Bool.andb_true_r, ?Bool.andb_false_r; auto using Bool.xorb_false_r.
+Qed.
+
+Lemma mod_same_pow : forall a b c, 0 <= c <= b -> a ^ b mod a ^ c = 0.
+Proof.
+ intros.
+ replace b with (b - c + c) by ring.
+ rewrite Z.pow_add_r by omega.
+ apply Z_mod_mult.
+Qed.
+
+ Lemma Z_ones_succ : forall x, (0 <= x) ->
+ Z.ones (Z.succ x) = 2 ^ x + Z.ones x.
+ Proof.
+ unfold Z.ones; intros.
+ rewrite !Z.shiftl_1_l.
+ rewrite Z.add_pred_r.
+ apply Z.succ_inj.
+ rewrite !Z.succ_pred.
+ rewrite Z.pow_succ_r; omega.
+ Qed.
+
+ Lemma Z_div_floor : forall a b c, 0 < b -> a < b * (Z.succ c) -> a / b <= c.
+ Proof.
+ intros.
+ apply Z.lt_succ_r.
+ apply Z.div_lt_upper_bound; try omega.
+ Qed.
+
+ Lemma Z_shiftr_1_r_le : forall a b, a <= b ->
+ Z.shiftr a 1 <= Z.shiftr b 1.
+ Proof.
+ intros.
+ rewrite !Z.shiftr_div_pow2, Z.pow_1_r by omega.
+ apply Z.div_le_mono; omega.
+ Qed.
+
+ Lemma Z_ones_pred : forall i, 0 < i -> Z.ones (Z.pred i) = Z.shiftr (Z.ones i) 1.
+ Proof.
+ induction i; [ | | pose proof (Pos2Z.neg_is_neg p) ]; try omega.
+ intros.
+ unfold Z.ones.
+ rewrite !Z.shiftl_1_l, Z.shiftr_div_pow2, <-!Z.sub_1_r, Z.pow_1_r, <-!Z.add_opp_r by omega.
+ replace (2 ^ (Z.pos p)) with (2 ^ (Z.pos p - 1)* 2).
+ rewrite Z.div_add_l by omega.
+ reflexivity.
+ replace 2 with (2 ^ 1) at 2 by auto.
+ rewrite <-Z.pow_add_r by (pose proof (Pos2Z.is_pos p); omega).
+ f_equal. omega.
+ Qed.
+
+ Lemma Z_shiftr_ones' : forall a n, 0 <= a < 2 ^ n -> forall i, (0 <= i) ->
+ Z.shiftr a i <= Z.ones (n - i) \/ n <= i.
+ Proof.
+ intros until 1.
+ apply natlike_ind.
+ + unfold Z.ones.
+ rewrite Z.shiftr_0_r, Z.shiftl_1_l, Z.sub_0_r.
+ omega.
+ + intros.
+ destruct (Z_lt_le_dec x n); try omega.
+ intuition.
+ left.
+ rewrite shiftr_succ.
+ replace (n - Z.succ x) with (Z.pred (n - x)) by omega.
+ rewrite Z_ones_pred by omega.
+ apply Z_shiftr_1_r_le.
+ assumption.
+ Qed.
+
+ Lemma Z_shiftr_ones : forall a n i, 0 <= a < 2 ^ n -> (0 <= i) -> (i <= n) ->
+ Z.shiftr a i <= Z.ones (n - i) .
+ Proof.
+ intros a n i G G0 G1.
+ destruct (Z_le_lt_eq_dec i n G1).
+ + destruct (Z_shiftr_ones' a n G i G0); omega.
+ + subst; rewrite Z.sub_diag.
+ destruct (Z_eq_dec a 0).
+ - subst; rewrite Z.shiftr_0_l; reflexivity.
+ - rewrite Z.shiftr_eq_0; try omega; try reflexivity.
+ apply Z.log2_lt_pow2; omega.
+ Qed.
+
+ Lemma Z_shiftr_upper_bound : forall a n, 0 <= n -> 0 <= a <= 2 ^ n -> Z.shiftr a n <= 1.
+ Proof.
+ intros a ? ? [a_nonneg a_upper_bound].
+ apply Z_le_lt_eq_dec in a_upper_bound.
+ destruct a_upper_bound.
+ + destruct (Z_eq_dec 0 a).
+ - subst; rewrite Z.shiftr_0_l; omega.
+ - rewrite Z.shiftr_eq_0; auto; try omega.
+ apply Z.log2_lt_pow2; auto; omega.
+ + subst.
+ rewrite Z.shiftr_div_pow2 by assumption.
+ rewrite Z.div_same; try omega.
+ assert (0 < 2 ^ n) by (apply Z.pow_pos_nonneg; omega).
+ omega.
+ Qed.
+
(* prove that known nonnegative numbers are nonnegative *)
Ltac zero_bounds' :=
repeat match goal with
@@ -188,3 +405,14 @@ Ltac zero_bounds' :=
end; try omega; try prime_bound; auto.
Ltac zero_bounds := try omega; try prime_bound; zero_bounds'.
+
+ Lemma Z_ones_nonneg : forall i, (0 <= i) -> 0 <= Z.ones i.
+ Proof.
+ apply natlike_ind.
+ + unfold Z.ones. simpl; omega.
+ + intros.
+ rewrite Z_ones_succ by assumption.
+ zero_bounds.
+ apply Z.pow_nonneg; omega.
+ Qed.
+
diff --git a/to_gallina.md b/to_gallina.md
new file mode 100644
index 000000000..1ac5075ef
--- /dev/null
+++ b/to_gallina.md
@@ -0,0 +1,7 @@
+Remaining work needed for Gallina verification code
+---------------------------------------------------
++ efficient GF exponentiation
++ efficient GF inverse
++ make EdDSA point addition use ModularBaseSystem
++ represent scalars (Fl) in ModularBaseSystem (large c)
++ canonical representations of field elements