diff options
42 files changed, 5414 insertions, 2366 deletions
diff --git a/_CoqProject b/_CoqProject index d4a8857a1..bb92c783a 100644 --- a/_CoqProject +++ b/_CoqProject @@ -1,26 +1,40 @@ -R src Crypto src/BaseSystem.v -src/BoundedIterOp.v +src/BaseSystemProofs.v src/EdDSAProofs.v +src/Rep.v +src/Testbit.v src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v src/CompleteEdwardsCurve/DoubleAndAdd.v src/CompleteEdwardsCurve/ExtendedCoordinates.v src/CompleteEdwardsCurve/Pre.v src/Encoding/EncodingTheorems.v +src/Encoding/ModularWordEncodingPre.v +src/Encoding/ModularWordEncodingTheorems.v +src/Encoding/PointEncodingPre.v +src/Encoding/PointEncodingTheorems.v +src/ModularArithmetic/ExtendedBaseVector.v src/ModularArithmetic/FField.v src/ModularArithmetic/FNsatz.v src/ModularArithmetic/ModularArithmeticTheorems.v src/ModularArithmetic/ModularBaseSystem.v +src/ModularArithmetic/ModularBaseSystemOpt.v +src/ModularArithmetic/ModularBaseSystemProofs.v src/ModularArithmetic/Pre.v src/ModularArithmetic/PrimeFieldTheorems.v +src/ModularArithmetic/PseudoMersenneBaseParamProofs.v +src/ModularArithmetic/PseudoMersenneBaseParams.v +src/ModularArithmetic/PseudoMersenneBaseRep.v src/ModularArithmetic/Tutorial.v src/Spec/CompleteEdwardsCurve.v src/Spec/Ed25519.v src/Spec/EdDSA.v src/Spec/Encoding.v src/Spec/ModularArithmetic.v +src/Spec/ModularWordEncoding.v src/Spec/PointEncoding.v src/Specific/Ed25519.v +src/Specific/GF1305.v src/Specific/GF25519.v src/Tactics/VerdiTactics.v src/Util/CaseUtil.v @@ -28,6 +42,7 @@ src/Util/IterAssocOp.v src/Util/ListUtil.v src/Util/NatUtil.v src/Util/NumTheoryUtil.v +src/Util/Tactics.v src/Util/WordUtil.v src/Util/ZUtil.v src/Assembly/QhasmCommon.v diff --git a/etc/freshen-bedrock-files.sh b/etc/freshen-bedrock-files.sh index 08e0435d7..c2daf7e87 100755 --- a/etc/freshen-bedrock-files.sh +++ b/etc/freshen-bedrock-files.sh @@ -42,10 +42,10 @@ for VFILE in $FILES; do break fi done - ) if [ -z "$FOUND" ]; then echo "WARNING: Could not find $VOFILE, which $VFILE depends on" fi + ) done done diff --git a/src/BaseSystem.v b/src/BaseSystem.v index 4e07c4564..e6ad55f18 100644 --- a/src/BaseSystem.v +++ b/src/BaseSystem.v @@ -5,19 +5,18 @@ Require Import Coq.omega.Omega Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith. Local Open Scope Z. -Module Type BaseCoefs. - (** [BaseCoefs] represent the weights of each digit in a positional number system, with the weight of least significant digit presented first. The following requirements on the base are preconditions for using it with BaseSystem. *) - Parameter base : list Z. - Axiom base_positive : forall b, In b base -> b > 0. (* nonzero would probably work too... *) - Axiom b0_1 : forall x, nth_default x base 0 = 1. - Axiom base_good : +Class BaseVector (base : list Z):= { + base_positive : forall b, In b base -> b > 0; (* nonzero would probably work too... *) + b0_1 : forall x, nth_default x base 0 = 1; + base_good : forall i j, (i+j < length base)%nat -> let b := nth_default 0 base in let r := (b i * b j) / b (i+j)%nat in - b i * b j = r * b (i+j)%nat. -End BaseCoefs. + b i * b j = r * b (i+j)%nat +}. -Module BaseSystem (Import B:BaseCoefs). +Section BaseSystem. + Context (base : list Z). (** [BaseSystem] implements an constrained positional number system. A wide variety of bases are supported: the base coefficients are not required to be powers of 2, and it is NOT necessarily the case that @@ -33,7 +32,6 @@ Module BaseSystem (Import B:BaseCoefs). Definition accumulate p acc := fst p * snd p + acc. Definition decode' bs u := fold_right accumulate 0 (combine u bs). Definition decode := decode' base. - Hint Unfold accumulate. (* Does not carry; z becomes the lowest and only digit. *) Definition encode (z : Z) := z :: nil. @@ -52,50 +50,7 @@ Module BaseSystem (Import B:BaseCoefs). Hint Extern 1 (@eq Z _ _) => ring. - Lemma add_rep : forall bs us vs, decode' bs (add us vs) = decode' bs us + decode' bs vs. - Proof. - unfold decode, decode'; induction bs; destruct us; destruct vs; boring. - Qed. - - Lemma decode_nil : forall bs, decode' bs nil = 0. - auto. - Qed. - Hint Rewrite decode_nil. - - Lemma decode_base_nil : forall us, decode' nil us = 0. - Proof. - intros; rewrite decode'_truncate; auto. - Qed. - Hint Rewrite decode_base_nil. - Definition mul_each u := map (Z.mul u). - Lemma mul_each_rep : forall bs u vs, - decode' bs (mul_each u vs) = u * decode' bs vs. - Proof. - unfold decode'; induction bs; destruct vs; boring. - Qed. - - Lemma base_eq_1cons: base = 1 :: skipn 1 base. - Proof. - pose proof (b0_1 0) as H. - destruct base; compute in H; try discriminate; boring. - Qed. - - Lemma decode'_cons : forall x1 x2 xs1 xs2, - decode' (x1 :: xs1) (x2 :: xs2) = x1 * x2 + decode' xs1 xs2. - Proof. - unfold decode'; boring. - Qed. - Hint Rewrite decode'_cons. - - Lemma decode_cons : forall x us, - decode (x :: us) = x + decode (0 :: us). - Proof. - unfold decode; intros. - rewrite base_eq_1cons. - autorewrite with core; ring_simplify; auto. - Qed. - Fixpoint sub (us vs:digits) : digits := match us,vs with | u::us', v::vs' => u-v :: sub us' vs' @@ -103,76 +58,13 @@ Module BaseSystem (Import B:BaseCoefs). | nil, v::vs' => (0-v)::sub nil vs' end. - Lemma sub_rep : forall bs us vs, decode' bs (sub us vs) = decode' bs us - decode' bs vs. - Proof. - induction bs; destruct us; destruct vs; boring. - Qed. - - Lemma encode_rep : forall z, decode (encode z) = z. - Proof. - pose proof base_eq_1cons. - unfold decode, encode; destruct z; boring. - Qed. - - Lemma mul_each_base : forall us bs c, - decode' bs (mul_each c us) = decode' (mul_each c bs) us. - Proof. - induction us; destruct bs; boring. - Qed. - - Hint Rewrite (@nth_default_nil Z). - Hint Rewrite (@firstn_nil Z). - Hint Rewrite (@skipn_nil Z). - - Lemma base_app : forall us low high, - decode' (low ++ high) us = decode' low (firstn (length low) us) + decode' high (skipn (length low) us). - Proof. - induction us; destruct low; boring. - Qed. - - Lemma base_mul_app : forall low c us, - decode' (low ++ mul_each c low) us = decode' low (firstn (length low) us) + - c * decode' low (skipn (length low) us). - Proof. - intros. - rewrite base_app; f_equal. - rewrite <- mul_each_rep. - rewrite mul_each_base. - reflexivity. - Qed. - Definition crosscoef i j : Z := let b := nth_default 0 base in (b(i) * b(j)) / b(i+j)%nat. Hint Unfold crosscoef. Fixpoint zeros n := match n with O => nil | S n' => 0::zeros n' end. - Lemma zeros_rep : forall bs n, decode' bs (zeros n) = 0. - induction bs; destruct n; boring. - Qed. - Lemma length_zeros : forall n, length (zeros n) = n. - induction n; boring. - Qed. - Hint Rewrite length_zeros. - - Lemma app_zeros_zeros : forall n m, zeros n ++ zeros m = zeros (n + m). - Proof. - induction n; boring. - Qed. - Hint Rewrite app_zeros_zeros. - - Lemma zeros_app0 : forall m, zeros m ++ 0 :: nil = zeros (S m). - Proof. - induction m; boring. - Qed. - Hint Rewrite zeros_app0. - - Lemma rev_zeros : forall n, rev (zeros n) = zeros n. - Proof. - induction n; boring. - Qed. - Hint Rewrite rev_zeros. - + (* mul' is multiplication with the SECOND ARGUMENT REVERSED and OUTPUT REVERSED *) Fixpoint mul_bi' (i:nat) (vsr:digits) := match vsr with @@ -181,311 +73,7 @@ Module BaseSystem (Import B:BaseCoefs). end. Definition mul_bi (i:nat) (vs:digits) : digits := zeros i ++ rev (mul_bi' i (rev vs)). - - Hint Unfold nth_default. - - Lemma decode_single : forall n bs x, - decode' bs (zeros n ++ x :: nil) = nth_default 0 bs n * x. - Proof. - induction n; destruct bs; boring. - Qed. - Hint Rewrite decode_single. - - Lemma peel_decode : forall xs ys x y, decode' (x::xs) (y::ys) = x*y + decode' xs ys. - Proof. - boring. - Qed. - Hint Rewrite zeros_rep peel_decode. - - Lemma decode_highzeros : forall xs bs n, decode' bs (xs ++ zeros n) = decode' bs xs. - Proof. - induction xs; destruct bs; boring. - Qed. - - Lemma mul_bi'_zeros : forall n m, mul_bi' n (zeros m) = zeros m. - induction m; boring. - Qed. - Hint Rewrite mul_bi'_zeros. - - Lemma nth_error_base_nonzero : forall n x, - nth_error base n = Some x -> x <> 0. - Proof. - eauto using (@nth_error_value_In Z), Zgt0_neq0, base_positive. - Qed. - - Hint Rewrite plus_0_r. - - Lemma mul_bi_single : forall m n x, - (n + m < length base)%nat -> - decode (mul_bi n (zeros m ++ x :: nil)) = nth_default 0 base m * x * nth_default 0 base n. - Proof. - unfold mul_bi, decode. - destruct m; simpl; simpl_list; simpl; intros. { - pose proof nth_error_base_nonzero. - boring; destruct base; nth_tac. - rewrite Z_div_mul'; eauto. - } { - ssimpl_list. - autorewrite with core. - rewrite app_assoc. - autorewrite with core. - unfold crosscoef; simpl; ring_simplify. - rewrite Nat.add_1_r. - rewrite base_good by auto. - rewrite Z_div_mult by (apply base_positive; rewrite nth_default_eq; apply nth_In; auto). - rewrite <- Z.mul_assoc. - rewrite <- Z.mul_comm. - rewrite <- Z.mul_assoc. - rewrite <- Z.mul_assoc. - destruct (Z.eq_dec x 0); subst; try ring. - rewrite Z.mul_cancel_l by auto. - rewrite <- base_good by auto. - ring. - } - Qed. - - Lemma set_higher' : forall vs x, vs++x::nil = vs .+ (zeros (length vs) ++ x :: nil). - induction vs; boring; f_equal; ring. - Qed. - - Lemma set_higher : forall bs vs x, - decode' bs (vs++x::nil) = decode' bs vs + nth_default 0 bs (length vs) * x. - Proof. - intros. - rewrite set_higher'. - rewrite add_rep. - f_equal. - apply decode_single. - Qed. - - Lemma zeros_plus_zeros : forall n, zeros n = zeros n .+ zeros n. - induction n; auto. - simpl; f_equal; auto. - Qed. - - Lemma mul_bi'_n_nil : forall n, mul_bi' n nil = nil. - Proof. - unfold mul_bi; auto. - Qed. - Hint Rewrite mul_bi'_n_nil. - - Lemma add_nil_l : forall us, nil .+ us = us. - induction us; auto. - Qed. - Hint Rewrite add_nil_l. - - Lemma add_nil_r : forall us, us .+ nil = us. - induction us; auto. - Qed. - Hint Rewrite add_nil_r. - Lemma add_first_terms : forall us vs a b, - (a :: us) .+ (b :: vs) = (a + b) :: (us .+ vs). - auto. - Qed. - Hint Rewrite add_first_terms. - - Lemma mul_bi'_cons : forall n x us, - mul_bi' n (x :: us) = x * crosscoef n (length us) :: mul_bi' n us. - Proof. - unfold mul_bi'; auto. - Qed. - - Lemma add_same_length : forall us vs l, (length us = l) -> (length vs = l) -> - length (us .+ vs) = l. - Proof. - induction us, vs; boring. - erewrite (IHus vs (pred l)); boring. - Qed. - - Hint Rewrite app_nil_l. - Hint Rewrite app_nil_r. - - Lemma add_snoc_same_length : forall l us vs a b, - (length us = l) -> (length vs = l) -> - (us ++ a :: nil) .+ (vs ++ b :: nil) = (us .+ vs) ++ (a + b) :: nil. - Proof. - induction l, us, vs; boring; discriminate. - Qed. - - Lemma mul_bi'_add : forall us n vs l - (Hlus: length us = l) - (Hlvs: length vs = l), - mul_bi' n (rev (us .+ vs)) = - mul_bi' n (rev us) .+ mul_bi' n (rev vs). - Proof. - (* TODO(adamc): please help prettify this *) - induction us using rev_ind; - try solve [destruct vs; boring; congruence]. - destruct vs using rev_ind; boring; clear IHvs; simpl_list. - erewrite (add_snoc_same_length (pred l) us vs _ _); simpl_list. - repeat rewrite mul_bi'_cons; rewrite add_first_terms; simpl_list. - rewrite (IHus n vs (pred l)). - replace (length us) with (pred l). - replace (length vs) with (pred l). - rewrite (add_same_length us vs (pred l)). - f_equal; ring. - - erewrite length_snoc; eauto. - erewrite length_snoc; eauto. - erewrite length_snoc; eauto. - erewrite length_snoc; eauto. - erewrite length_snoc; eauto. - erewrite length_snoc; eauto. - erewrite length_snoc; eauto. - erewrite length_snoc; eauto. - Qed. - - Lemma zeros_cons0 : forall n, 0 :: zeros n = zeros (S n). - auto. - Qed. - - Lemma add_leading_zeros : forall n us vs, - (zeros n ++ us) .+ (zeros n ++ vs) = zeros n ++ (us .+ vs). - Proof. - induction n; boring. - Qed. - - Lemma rev_add_rev : forall us vs l, (length us = l) -> (length vs = l) -> - (rev us) .+ (rev vs) = rev (us .+ vs). - Proof. - induction us, vs; boring; try solve [subst; discriminate]. - rewrite (add_snoc_same_length (pred l) _ _ _ _) by (subst; simpl_list; omega). - rewrite (IHus vs (pred l)) by omega; auto. - Qed. - Hint Rewrite rev_add_rev. - - Lemma mul_bi'_length : forall us n, length (mul_bi' n us) = length us. - Proof. - induction us, n; boring. - Qed. - Hint Rewrite mul_bi'_length. - - Lemma add_comm : forall us vs, us .+ vs = vs .+ us. - Proof. - induction us, vs; boring; f_equal; auto. - Qed. - - Hint Rewrite rev_length. - - Lemma mul_bi_add_same_length : forall n us vs l, - (length us = l) -> (length vs = l) -> - mul_bi n (us .+ vs) = mul_bi n us .+ mul_bi n vs. - Proof. - unfold mul_bi; boring. - rewrite add_leading_zeros. - erewrite mul_bi'_add; boring. - erewrite rev_add_rev; boring. - Qed. - - Lemma add_zeros_same_length : forall us, us .+ (zeros (length us)) = us. - Proof. - induction us; boring; f_equal; omega. - Qed. - - Hint Rewrite add_zeros_same_length. - Hint Rewrite minus_diag. - - Lemma add_trailing_zeros : forall us vs, (length us >= length vs)%nat -> - us .+ vs = us .+ (vs ++ (zeros (length us - length vs))). - Proof. - induction us, vs; boring; f_equal; boring. - Qed. - - Lemma length_add_ge : forall us vs, (length us >= length vs)%nat -> - (length (us .+ vs) <= length us)%nat. - Proof. - intros. - rewrite add_trailing_zeros by trivial. - erewrite add_same_length by (pose proof app_length; boring); omega. - Qed. - - Lemma add_length_le_max : forall us vs, - (length (us .+ vs) <= max (length us) (length vs))%nat. - Proof. - intros; case_max; (rewrite add_comm; apply length_add_ge; omega) || - (apply length_add_ge; omega) . - Qed. - - Lemma sub_nil_length: forall us : digits, length (sub nil us) = length us. - Proof. - induction us; boring. - Qed. - - Lemma sub_length_le_max : forall us vs, - (length (sub us vs) <= max (length us) (length vs))%nat. - Proof. - induction us, vs; boring. - rewrite sub_nil_length; auto. - Qed. - - Lemma mul_bi_length : forall us n, length (mul_bi n us) = (length us + n)%nat. - Proof. - pose proof mul_bi'_length; unfold mul_bi. - destruct us; repeat progress (simpl_list; boring). - Qed. - Hint Rewrite mul_bi_length. - - Lemma mul_bi_trailing_zeros : forall m n us, - mul_bi n us ++ zeros m = mul_bi n (us ++ zeros m). - Proof. - unfold mul_bi. - induction m; intros; try solve [boring]. - rewrite <- zeros_app0. - rewrite app_assoc. - repeat progress (boring; rewrite rev_app_distr). - Qed. - - Lemma mul_bi_add_longer : forall n us vs, - (length us >= length vs)%nat -> - mul_bi n (us .+ vs) = mul_bi n us .+ mul_bi n vs. - Proof. - boring. - rewrite add_trailing_zeros by auto. - rewrite (add_trailing_zeros (mul_bi n us) (mul_bi n vs)) - by (repeat (rewrite mul_bi_length); omega). - erewrite mul_bi_add_same_length by - (eauto; simpl_list; rewrite length_zeros; omega). - rewrite mul_bi_trailing_zeros. - repeat (f_equal; boring). - Qed. - - Lemma mul_bi_add : forall n us vs, - mul_bi n (us .+ vs) = (mul_bi n us) .+ (mul_bi n vs). - Proof. - intros; pose proof mul_bi_add_longer. - destruct (le_ge_dec (length us) (length vs)). { - replace (mul_bi n us .+ mul_bi n vs) - with (mul_bi n vs .+ mul_bi n us) - by (apply add_comm). - replace (us .+ vs) - with (vs .+ us) - by (apply add_comm). - boring. - } { - boring. - } - Qed. - - Lemma mul_bi_rep : forall i vs, - (i + length vs < length base)%nat -> - decode (mul_bi i vs) = decode vs * nth_default 0 base i. - Proof. - unfold decode. - induction vs using rev_ind; intros; try solve [unfold mul_bi; boring]. - assert (i + length vs < length base)%nat by - (rewrite app_length in *; boring). - - rewrite set_higher. - ring_simplify. - rewrite <- IHvs by auto; clear IHvs. - rewrite <- mul_bi_single by auto. - rewrite <- add_rep. - rewrite <- mul_bi_add. - rewrite set_higher'. - auto. - Qed. - (* mul' is multiplication with the FIRST ARGUMENT REVERSED *) Fixpoint mul' (usr vs:digits) : digits := match usr with @@ -495,66 +83,17 @@ Module BaseSystem (Import B:BaseCoefs). end. Definition mul us := mul' (rev us). - Lemma mul'_rep : forall us vs, - (length us + length vs <= length base)%nat -> - decode (mul' (rev us) vs) = decode us * decode vs. - Proof. - unfold decode. - induction us using rev_ind; boring. - - assert (length us + length vs < length base)%nat by - (rewrite app_length in *; boring). - - ssimpl_list. - rewrite add_rep. - boring. - rewrite set_higher. - rewrite mul_each_rep. - rewrite mul_bi_rep by auto. - unfold decode; ring. - Qed. - - Lemma mul_rep : forall us vs, - (length us + length vs <= length base)%nat -> - decode (mul us vs) = decode us * decode vs. - Proof. - exact mul'_rep. - Qed. - - Lemma mul'_length: forall us vs, - (length (mul' us vs) <= length us + length vs)%nat. - Proof. - pose proof add_length_le_max. - induction us; boring. - unfold mul_each. - simpl_list; case_max; boring; omega. - Qed. - - Lemma mul_length: forall us vs, - (length (mul us vs) <= length us + length vs)%nat. - Proof. - intros; unfold mul. - rewrite mul'_length. - rewrite rev_length; omega. - Qed. - -(* Print Assumptions mul_rep. *) - End BaseSystem. -Module Type PolynomialBaseParams. - Parameter b1 : positive. (* the value at which the polynomial is evaluated *) - Parameter baseLength : nat. (* 1 + degree of the polynomial *) - Axiom baseLengthNonzero : NPeano.ltb 0 baseLength = true. -End PolynomialBaseParams. - -Module PolynomialBaseCoefs (Import P:PolynomialBaseParams) <: BaseCoefs. - (** PolynomialBaseCoeffs generates base vectors for [BaseSystem] using the extra assumption that $b_{i+j} = b_j b_j$. *) +(* Example : polynomial base system *) +Section PolynomialBaseCoefs. + Context (b1 : positive) (baseLength : nat) (baseLengthNonzero : NPeano.ltb 0 baseLength = true). + (** PolynomialBaseCoefs generates base vectors for [BaseSystem]. *) Definition bi i := (Zpos b1)^(Z.of_nat i). - Definition base := map bi (seq 0 baseLength). + Definition poly_base := map bi (seq 0 baseLength). - Lemma b0_1 : forall x, nth_default x base 0 = 1. - unfold base, bi, nth_default. + Lemma poly_b0_1 : forall x, nth_default x poly_base 0 = 1. + unfold poly_base, bi, nth_default. case_eq baseLength; intros. { assert ((0 < baseLength)%nat) by (rewrite <-NPeano.ltb_lt; apply baseLengthNonzero). @@ -563,9 +102,9 @@ Module PolynomialBaseCoefs (Import P:PolynomialBaseParams) <: BaseCoefs. auto. Qed. - Lemma base_positive : forall b, In b base -> b > 0. + Lemma poly_base_positive : forall b, In b poly_base -> b > 0. Proof. - unfold base. + unfold poly_base. intros until 0; intro H. rewrite in_map_iff in *. destruct H; destruct H. @@ -573,20 +112,20 @@ Module PolynomialBaseCoefs (Import P:PolynomialBaseParams) <: BaseCoefs. apply pos_pow_nat_pos. Qed. - Lemma base_defn : forall i, (i < length base)%nat -> - nth_default 0 base i = bi i. + Lemma poly_base_defn : forall i, (i < length poly_base)%nat -> + nth_default 0 poly_base i = bi i. Proof. - unfold base, nth_default; nth_tac. + unfold poly_base, nth_default; nth_tac. Qed. - Lemma base_succ : - forall i, ((S i) < length base)%nat -> - let b := nth_default 0 base in + Lemma poly_base_succ : + forall i, ((S i) < length poly_base)%nat -> + let b := nth_default 0 poly_base in let r := (b (S i) / b i) in b (S i) = r * b i. Proof. intros; subst b; subst r. - repeat rewrite base_defn in * by omega. + repeat rewrite poly_base_defn in * by omega. unfold bi. replace (Z.pos b1 ^ Z.of_nat (S i)) with (Z.pos b1 * (Z.pos b1 ^ Z.of_nat i)) by @@ -598,13 +137,13 @@ Module PolynomialBaseCoefs (Import P:PolynomialBaseParams) <: BaseCoefs. pose proof (Zgt_pos_0 b1); omega. Qed. - Lemma base_good: - forall i j, (i + j < length base)%nat -> - let b := nth_default 0 base in + Lemma poly_base_good: + forall i j, (i + j < length poly_base)%nat -> + let b := nth_default 0 poly_base in let r := (b i * b j) / b (i+j)%nat in b i * b j = r * b (i+j)%nat. Proof. - unfold base, nth_default; nth_tac. + unfold poly_base, nth_default; nth_tac. clear. unfold bi. @@ -614,24 +153,25 @@ Module PolynomialBaseCoefs (Import P:PolynomialBaseParams) <: BaseCoefs. rewrite <- Z.neq_mul_0. split; apply Z.pow_nonzero; try apply Zle_0_nat; try solve [intro H; inversion H]. Qed. + + Instance PolyBaseVector : BaseVector poly_base := { + base_positive := poly_base_positive; + b0_1 := poly_b0_1; + base_good := poly_base_good + }. + End PolynomialBaseCoefs. -Module BasePoly2Degree32Params <: PolynomialBaseParams. - Definition b1 := 2%positive. +Import ListNotations. + +Section BaseSystemExample. Definition baseLength := 32%nat. Lemma baseLengthNonzero : NPeano.ltb 0 baseLength = true. compute; reflexivity. Qed. -End BasePoly2Degree32Params. - -Import ListNotations. - -Module BaseSystemExample. - Module BasePoly2Degree32Coefs := PolynomialBaseCoefs BasePoly2Degree32Params. - Module BasePoly2Degree32 := BaseSystem BasePoly2Degree32Coefs. - Import BasePoly2Degree32. + Definition base2 := poly_base 2 baseLength. - Example three_times_two : mul [1;1;0] [0;1;0] = [0;1;1;0;0]. + Example three_times_two : mul base2 [1;1;0] [0;1;0] = [0;1;1;0;0]. Proof. reflexivity. Qed. @@ -639,20 +179,20 @@ Module BaseSystemExample. (* python -c "e = lambda x: '['+''.join(reversed(bin(x)[2:])).replace('1','1;').replace('0','0;')[:-1]+']'; print(e(19259)); print(e(41781))" *) Definition a := [1;1;0;1;1;1;0;0;1;1;0;1;0;0;1]. Definition b := [1;0;1;0;1;1;0;0;1;1;0;0;0;1;0;1]. - Example da : decode a = 19259. + Example da : decode base2 a = 19259. Proof. reflexivity. Qed. - Example db : decode b = 41781. + Example db : decode base2 b = 41781. Proof. reflexivity. Qed. Example encoded_ab : - mul a b =[1;1;1;2;2;4;2;2;4;5;3;3;3;6;4;2;5;3;4;3;2;1;2;2;2;0;1;1;0;1]. + mul base2 a b =[1;1;1;2;2;4;2;2;4;5;3;3;3;6;4;2;5;3;4;3;2;1;2;2;2;0;1;1;0;1]. Proof. reflexivity. Qed. - Example dab : decode (mul a b) = 804660279. + Example dab : decode base2 (mul base2 a b) = 804660279. Proof. reflexivity. Qed. diff --git a/src/BaseSystemProofs.v b/src/BaseSystemProofs.v new file mode 100644 index 000000000..ab56cb711 --- /dev/null +++ b/src/BaseSystemProofs.v @@ -0,0 +1,503 @@ +Require Import List. +Require Import Util.ListUtil Util.CaseUtil Util.ZUtil. +Require Import ZArith.ZArith ZArith.Zdiv. +Require Import Omega NPeano Arith. +Require Import Crypto.BaseSystem. +Local Open Scope Z. + +Local Infix ".+" := add (at level 50). + +Local Hint Extern 1 (@eq Z _ _) => ring. + +Section BaseSystemProofs. + Context `(base_vector : BaseVector). + + Lemma decode'_truncate : forall bs us, decode' bs us = decode' bs (firstn (length bs) us). + Proof. + unfold decode'; intros; f_equal; apply combine_truncate_l. + Qed. + + Lemma decode'_splice : forall xs ys bs, + decode' bs (xs ++ ys) = + decode' (firstn (length xs) bs) xs + decode' (skipn (length xs) bs) ys. + Proof. + unfold decode'. + induction xs; destruct ys, bs; boring. + + rewrite combine_truncate_r. + do 2 rewrite Z.add_0_r; auto. + + unfold accumulate. + apply Z.add_assoc. + Qed. + + Lemma add_rep : forall bs us vs, decode' bs (add us vs) = decode' bs us + decode' bs vs. + Proof. + unfold decode', accumulate; induction bs; destruct us, vs; boring; ring. + Qed. + + Lemma decode_nil : forall bs, decode' bs nil = 0. + auto. + Qed. + Hint Rewrite decode_nil. + + Lemma decode_base_nil : forall us, decode' nil us = 0. + Proof. + intros; rewrite decode'_truncate; auto. + Qed. + Hint Rewrite decode_base_nil. + + Lemma mul_each_rep : forall bs u vs, + decode' bs (mul_each u vs) = u * decode' bs vs. + Proof. + unfold decode', accumulate; induction bs; destruct vs; boring; ring. + Qed. + + Lemma base_eq_1cons: base = 1 :: skipn 1 base. + Proof. + pose proof (b0_1 0) as H. + destruct base; compute in H; try discriminate; boring. + Qed. + + Lemma decode'_cons : forall x1 x2 xs1 xs2, + decode' (x1 :: xs1) (x2 :: xs2) = x1 * x2 + decode' xs1 xs2. + Proof. + unfold decode', accumulate; boring; ring. + Qed. + Hint Rewrite decode'_cons. + + Lemma decode_cons : forall x us, + decode base (x :: us) = x + decode base (0 :: us). + Proof. + unfold decode; intros. + rewrite base_eq_1cons. + autorewrite with core; ring_simplify; auto. + Qed. + + Lemma sub_rep : forall bs us vs, decode' bs (sub us vs) = decode' bs us - decode' bs vs. + Proof. + induction bs; destruct us; destruct vs; boring; ring. + Qed. + + Lemma encode_rep : forall z, decode base (encode z) = z. + Proof. + pose proof base_eq_1cons. + unfold decode, encode; destruct z; boring. + Qed. + + Lemma mul_each_base : forall us bs c, + decode' bs (mul_each c us) = decode' (mul_each c bs) us. + Proof. + induction us; destruct bs; boring; ring. + Qed. + + Hint Rewrite (@nth_default_nil Z). + Hint Rewrite (@firstn_nil Z). + Hint Rewrite (@skipn_nil Z). + + Lemma base_app : forall us low high, + decode' (low ++ high) us = decode' low (firstn (length low) us) + decode' high (skipn (length low) us). + Proof. + induction us; destruct low; boring. + Qed. + + Lemma base_mul_app : forall low c us, + decode' (low ++ mul_each c low) us = decode' low (firstn (length low) us) + + c * decode' low (skipn (length low) us). + Proof. + intros. + rewrite base_app; f_equal. + rewrite <- mul_each_rep. + rewrite mul_each_base. + reflexivity. + Qed. + + Lemma zeros_rep : forall bs n, decode' bs (zeros n) = 0. + induction bs; destruct n; boring. + Qed. + Lemma length_zeros : forall n, length (zeros n) = n. + induction n; boring. + Qed. + Hint Rewrite length_zeros. + + Lemma app_zeros_zeros : forall n m, zeros n ++ zeros m = zeros (n + m). + Proof. + induction n; boring. + Qed. + Hint Rewrite app_zeros_zeros. + + Lemma zeros_app0 : forall m, zeros m ++ 0 :: nil = zeros (S m). + Proof. + induction m; boring. + Qed. + Hint Rewrite zeros_app0. + + Lemma rev_zeros : forall n, rev (zeros n) = zeros n. + Proof. + induction n; boring. + Qed. + Hint Rewrite rev_zeros. + + Hint Unfold nth_default. + + Lemma decode_single : forall n bs x, + decode' bs (zeros n ++ x :: nil) = nth_default 0 bs n * x. + Proof. + induction n; destruct bs; boring. + Qed. + Hint Rewrite decode_single. + + Lemma peel_decode : forall xs ys x y, decode' (x::xs) (y::ys) = x*y + decode' xs ys. + Proof. + boring. + Qed. + Hint Rewrite zeros_rep peel_decode. + + Lemma decode_highzeros : forall xs bs n, decode' bs (xs ++ zeros n) = decode' bs xs. + Proof. + induction xs; destruct bs; boring. + Qed. + + Lemma mul_bi'_zeros : forall n m, mul_bi' base n (zeros m) = zeros m. + induction m; boring. + Qed. + Hint Rewrite mul_bi'_zeros. + + Lemma nth_error_base_nonzero : forall n x, + nth_error base n = Some x -> x <> 0. + Proof. + eauto using (@nth_error_value_In Z), Zgt0_neq0, base_positive. + Qed. + + Hint Rewrite plus_0_r. + + Lemma mul_bi_single : forall m n x, + (n + m < length base)%nat -> + decode base (mul_bi base n (zeros m ++ x :: nil)) = nth_default 0 base m * x * nth_default 0 base n. + Proof. + unfold mul_bi, decode. + destruct m; simpl; simpl_list; simpl; intros. { + pose proof nth_error_base_nonzero as nth_nonzero. + case_eq base; [intros; boring | intros z l base_eq]. + specialize (b0_1 0); intro b0_1'. + rewrite base_eq in *. + rewrite nth_default_cons in b0_1'. + rewrite b0_1' in *. + unfold crosscoef. + autounfold; autorewrite with core. + unfold nth_default. + nth_tac. + rewrite Z.mul_1_r. + rewrite Z_div_same_full. + destruct x; ring. + eapply nth_nonzero; eauto. + } { + ssimpl_list. + autorewrite with core. + rewrite app_assoc. + autorewrite with core. + unfold crosscoef; simpl; ring_simplify. + rewrite Nat.add_1_r. + rewrite base_good by auto. + rewrite Z_div_mult by (apply base_positive; rewrite nth_default_eq; apply nth_In; auto). + rewrite <- Z.mul_assoc. + rewrite <- Z.mul_comm. + rewrite <- Z.mul_assoc. + rewrite <- Z.mul_assoc. + destruct (Z.eq_dec x 0); subst; try ring. + rewrite Z.mul_cancel_l by auto. + rewrite <- base_good by auto. + ring. + } + Qed. + + Lemma set_higher' : forall vs x, vs++x::nil = vs .+ (zeros (length vs) ++ x :: nil). + induction vs; boring; f_equal; ring. + Qed. + + Lemma set_higher : forall bs vs x, + decode' bs (vs++x::nil) = decode' bs vs + nth_default 0 bs (length vs) * x. + Proof. + intros. + rewrite set_higher'. + rewrite add_rep. + f_equal. + apply decode_single. + Qed. + + Lemma zeros_plus_zeros : forall n, zeros n = zeros n .+ zeros n. + induction n; auto. + simpl; f_equal; auto. + Qed. + + Lemma mul_bi'_n_nil : forall n, mul_bi' base n nil = nil. + Proof. + unfold mul_bi; auto. + Qed. + Hint Rewrite mul_bi'_n_nil. + + Lemma add_nil_l : forall us, nil .+ us = us. + induction us; auto. + Qed. + Hint Rewrite add_nil_l. + + Lemma add_nil_r : forall us, us .+ nil = us. + induction us; auto. + Qed. + Hint Rewrite add_nil_r. + + Lemma add_first_terms : forall us vs a b, + (a :: us) .+ (b :: vs) = (a + b) :: (us .+ vs). + auto. + Qed. + Hint Rewrite add_first_terms. + + Lemma mul_bi'_cons : forall n x us, + mul_bi' base n (x :: us) = x * crosscoef base n (length us) :: mul_bi' base n us. + Proof. + unfold mul_bi'; auto. + Qed. + + Lemma add_same_length : forall us vs l, (length us = l) -> (length vs = l) -> + length (us .+ vs) = l. + Proof. + induction us, vs; boring. + erewrite (IHus vs (pred l)); boring. + Qed. + + Hint Rewrite app_nil_l. + Hint Rewrite app_nil_r. + + Lemma add_snoc_same_length : forall l us vs a b, + (length us = l) -> (length vs = l) -> + (us ++ a :: nil) .+ (vs ++ b :: nil) = (us .+ vs) ++ (a + b) :: nil. + Proof. + induction l, us, vs; boring; discriminate. + Qed. + + Lemma mul_bi'_add : forall us n vs l + (Hlus: length us = l) + (Hlvs: length vs = l), + mul_bi' base n (rev (us .+ vs)) = + mul_bi' base n (rev us) .+ mul_bi' base n (rev vs). + Proof. + (* TODO(adamc): please help prettify this *) + induction us using rev_ind; + try solve [destruct vs; boring; congruence]. + destruct vs using rev_ind; boring; clear IHvs; simpl_list. + erewrite (add_snoc_same_length (pred l) us vs _ _); simpl_list. + repeat rewrite mul_bi'_cons; rewrite add_first_terms; simpl_list. + rewrite (IHus n vs (pred l)). + replace (length us) with (pred l). + replace (length vs) with (pred l). + rewrite (add_same_length us vs (pred l)). + f_equal; ring. + + erewrite length_snoc; eauto. + erewrite length_snoc; eauto. + erewrite length_snoc; eauto. + erewrite length_snoc; eauto. + erewrite length_snoc; eauto. + erewrite length_snoc; eauto. + erewrite length_snoc; eauto. + erewrite length_snoc; eauto. + Qed. + + Lemma zeros_cons0 : forall n, 0 :: zeros n = zeros (S n). + auto. + Qed. + + Lemma add_leading_zeros : forall n us vs, + (zeros n ++ us) .+ (zeros n ++ vs) = zeros n ++ (us .+ vs). + Proof. + induction n; boring. + Qed. + + Lemma rev_add_rev : forall us vs l, (length us = l) -> (length vs = l) -> + (rev us) .+ (rev vs) = rev (us .+ vs). + Proof. + induction us, vs; boring; try solve [subst; discriminate]. + rewrite (add_snoc_same_length (pred l) _ _ _ _) by (subst; simpl_list; omega). + rewrite (IHus vs (pred l)) by omega; auto. + Qed. + Hint Rewrite rev_add_rev. + + Lemma mul_bi'_length : forall us n, length (mul_bi' base n us) = length us. + Proof. + induction us, n; boring. + Qed. + Hint Rewrite mul_bi'_length. + + Lemma add_comm : forall us vs, us .+ vs = vs .+ us. + Proof. + induction us, vs; boring; f_equal; auto. + Qed. + + Hint Rewrite rev_length. + + Lemma mul_bi_add_same_length : forall n us vs l, + (length us = l) -> (length vs = l) -> + mul_bi base n (us .+ vs) = mul_bi base n us .+ mul_bi base n vs. + Proof. + unfold mul_bi; boring. + rewrite add_leading_zeros. + erewrite mul_bi'_add; boring. + erewrite rev_add_rev; boring. + Qed. + + Lemma add_zeros_same_length : forall us, us .+ (zeros (length us)) = us. + Proof. + induction us; boring; f_equal; omega. + Qed. + + Hint Rewrite add_zeros_same_length. + Hint Rewrite minus_diag. + + Lemma add_trailing_zeros : forall us vs, (length us >= length vs)%nat -> + us .+ vs = us .+ (vs ++ (zeros (length us - length vs))). + Proof. + induction us, vs; boring; f_equal; boring. + Qed. + + Lemma length_add_ge : forall us vs, (length us >= length vs)%nat -> + (length (us .+ vs) <= length us)%nat. + Proof. + intros. + rewrite add_trailing_zeros by trivial. + erewrite add_same_length by (pose proof app_length; boring); omega. + Qed. + + Lemma add_length_le_max : forall us vs, + (length (us .+ vs) <= max (length us) (length vs))%nat. + Proof. + intros; case_max; (rewrite add_comm; apply length_add_ge; omega) || + (apply length_add_ge; omega) . + Qed. + + Lemma sub_nil_length: forall us : digits, length (sub nil us) = length us. + Proof. + induction us; boring. + Qed. + + Lemma sub_length_le_max : forall us vs, + (length (sub us vs) <= max (length us) (length vs))%nat. + Proof. + induction us, vs; boring. + rewrite sub_nil_length; auto. + Qed. + + Lemma mul_bi_length : forall us n, length (mul_bi base n us) = (length us + n)%nat. + Proof. + pose proof mul_bi'_length; unfold mul_bi. + destruct us; repeat progress (simpl_list; boring). + Qed. + Hint Rewrite mul_bi_length. + + Lemma mul_bi_trailing_zeros : forall m n us, + mul_bi base n us ++ zeros m = mul_bi base n (us ++ zeros m). + Proof. + unfold mul_bi. + induction m; intros; try solve [boring]. + rewrite <- zeros_app0. + rewrite app_assoc. + repeat progress (boring; rewrite rev_app_distr). + Qed. + + Lemma mul_bi_add_longer : forall n us vs, + (length us >= length vs)%nat -> + mul_bi base n (us .+ vs) = mul_bi base n us .+ mul_bi base n vs. + Proof. + boring. + rewrite add_trailing_zeros by auto. + rewrite (add_trailing_zeros (mul_bi base n us) (mul_bi base n vs)) + by (repeat (rewrite mul_bi_length); omega). + erewrite mul_bi_add_same_length by + (eauto; simpl_list; rewrite length_zeros; omega). + rewrite mul_bi_trailing_zeros. + repeat (f_equal; boring). + Qed. + + Lemma mul_bi_add : forall n us vs, + mul_bi base n (us .+ vs) = (mul_bi base n us) .+ (mul_bi base n vs). + Proof. + intros; pose proof mul_bi_add_longer. + destruct (le_ge_dec (length us) (length vs)). { + rewrite add_comm. + rewrite (add_comm (mul_bi base n us)). + boring. + } { + boring. + } + Qed. + + Lemma mul_bi_rep : forall i vs, + (i + length vs < length base)%nat -> + decode base (mul_bi base i vs) = decode base vs * nth_default 0 base i. + Proof. + unfold decode. + induction vs using rev_ind; intros; try solve [unfold mul_bi; boring]. + assert (i + length vs < length base)%nat by + (rewrite app_length in *; boring). + + rewrite set_higher. + ring_simplify. + rewrite <- IHvs by auto; clear IHvs. + rewrite <- mul_bi_single by auto. + rewrite <- add_rep. + rewrite <- mul_bi_add. + rewrite set_higher'. + auto. + Qed. + + (* mul' is multiplication with the FIRST ARGUMENT REVERSED *) + Fixpoint mul' (usr vs:digits) : digits := + match usr with + | u::usr' => + mul_each u (mul_bi base (length usr') vs) .+ mul' usr' vs + | _ => nil + end. + Definition mul us := mul' (rev us). + + Lemma mul'_rep : forall us vs, + (length us + length vs <= length base)%nat -> + decode base (mul' (rev us) vs) = decode base us * decode base vs. + Proof. + unfold decode. + induction us using rev_ind; boring. + + assert (length us + length vs < length base)%nat by + (rewrite app_length in *; boring). + + ssimpl_list. + rewrite add_rep. + boring. + rewrite set_higher. + rewrite mul_each_rep. + rewrite mul_bi_rep by auto. + unfold decode; ring. + Qed. + + Lemma mul_rep : forall us vs, + (length us + length vs <= length base)%nat -> + decode base (mul us vs) = decode base us * decode base vs. + Proof. + exact mul'_rep. + Qed. + + Lemma mul'_length: forall us vs, + (length (mul' us vs) <= length us + length vs)%nat. + Proof. + pose proof add_length_le_max. + induction us; boring. + unfold mul_each. + simpl_list; case_max; boring; omega. + Qed. + + Lemma mul_length: forall us vs, + (length (mul us vs) <= length us + length vs)%nat. + Proof. + intros; unfold mul. + rewrite mul'_length. + rewrite rev_length; omega. + Qed. + + +End BaseSystemProofs. diff --git a/src/BoundedIterOp.v b/src/BoundedIterOp.v deleted file mode 100644 index bd4f7d66e..000000000 --- a/src/BoundedIterOp.v +++ /dev/null @@ -1,94 +0,0 @@ -Require Import Crypto.Tactics.VerdiTactics. -Require Import Coq.Numbers.BinNums Coq.NArith.NArith Coq.NArith.Nnat Coq.ZArith.ZArith. - -Definition testbit_rev p i bound := Pos.testbit_nat p (bound - i). - -Lemma testbit_rev_succ : forall p i b, (i < S b) -> - testbit_rev p i (S b) = - match p with - | xI p' => testbit_rev p' i b - | xO p' => testbit_rev p' i b - | 1%positive => false - end. -Proof. - unfold testbit_rev; intros; destruct p; rewrite <- minus_Sn_m by omega; auto. -Qed. - -(* implements Pos.iter_op only using testbit, not destructing the positive *) -Definition iter_op {A} (op : A -> A -> A) (zero : A) (bound : nat) (p : positive) := - fix iter (i : nat) (a : A) {struct i} : A := - match i with - | O => zero - | S i' => let ret := iter i' (op a a) in - if testbit_rev p i bound - then op a ret - else ret - end. - -Lemma iter_op_step : forall {A} op z b p i (a : A), (i < S b) -> - iter_op op z (S b) p i a = - match p with - | xI p' => iter_op op z b p' i a - | xO p' => iter_op op z b p' i a - | 1%positive => z - end. -Proof. - destruct p; unfold iter_op; (induction i; [ auto |]); intros; rewrite testbit_rev_succ by omega; rewrite IHi by omega; reflexivity. -Qed. - -Lemma pos_size_gt0 : forall p, 0 < Pos.size_nat p. -Proof. - destruct p; intros; auto; try apply Lt.lt_0_Sn. -Qed. -Hint Resolve pos_size_gt0. - -Lemma iter_op_spec : forall b p {A} op z (a : A) (zero_id : forall x : A, op x z = x), (Pos.size_nat p <= b) -> - iter_op op z b p b a = Pos.iter_op op p a. -Proof. - induction b; intros; [pose proof (pos_size_gt0 p); omega |]. - destruct p; simpl; rewrite iter_op_step by omega; unfold testbit_rev; rewrite Minus.minus_diag; try rewrite IHb; simpl in *; auto; omega. -Qed. - -Lemma xO_neq1 : forall p, (1 < p~0)%positive. -Proof. - induction p; auto; apply Pos.lt_succ_diag_r. -Qed. - -Lemma xI_neq1 : forall p, (1 < p~1)%positive. -Proof. - induction p; auto; eapply Pos.lt_trans; apply Pos.lt_succ_diag_r. -Qed. - -Lemma xI_is_succ : forall n p, Pos.of_succ_nat n = p~1%positive -> - (Pos.to_nat (2 * p))%nat = n. -Proof. - induction n; intros; try discriminate. - rewrite <- Pnat.Nat2Pos.id by apply NPeano.Nat.neq_succ_0. - rewrite Pnat.Pos2Nat.inj_iff. - rewrite <- Pos.of_nat_succ. - apply Pos.succ_inj. - rewrite <- Pos.xI_succ_xO. - auto. -Qed. - -Lemma xO_is_succ : forall n p, Pos.of_succ_nat n = p~0%positive -> - Pos.to_nat (Pos.pred_double p) = n. -Proof. - induction n; intros; try discriminate. - rewrite Pos.pred_double_spec. - rewrite <- Pnat.Pos2Nat.inj_iff in *. - rewrite Pnat.Pos2Nat.inj_xO in *. - rewrite Pnat.SuccNat2Pos.id_succ in *. - rewrite Pnat.Pos2Nat.inj_pred by apply xO_neq1. - rewrite <- NPeano.Nat.succ_inj_wd. - rewrite Pnat.Pos2Nat.inj_xO. - omega. -Qed. - -Lemma size_of_succ : forall n, - Pos.size_nat (Pos.of_nat n) <= Pos.size_nat (Pos.of_nat (S n)). -Proof. - intros; induction n; [simpl; auto|]. - apply Pos.size_nat_monotone. - apply Pos.lt_succ_diag_r. -Qed. diff --git a/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v b/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v index 3740f5a29..f70479c3a 100644 --- a/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v +++ b/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v @@ -7,107 +7,294 @@ Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. Require Import Coq.Logic.Eqdep_dec. Require Import Crypto.Tactics.VerdiTactics. -Section CompleteEdwardsCurveTheorems. - Context {prm:TwistedEdwardsParams}. - Local Opaque q a d prime_q two_lt_q nonzero_a square_a nonsquare_d. (* [F_field] calls [compute] *) - Existing Instance prime_q. +Module E. + Section CompleteEdwardsCurveTheorems. + Context {prm:TwistedEdwardsParams}. + Local Opaque q a d prime_q two_lt_q nonzero_a square_a nonsquare_d. (* [F_field] calls [compute] *) + Existing Instance prime_q. + + Add Field Ffield_p' : (@Ffield_theory q _) + (morphism (@Fring_morph q), + preprocess [Fpreprocess], + postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption], + constants [Fconstant], + div (@Fmorph_div_theory q), + power_tac (@Fpower_theory q) [Fexp_tac]). + + Add Field Ffield_notConstant : (OpaqueFieldTheory q) + (constants [notConstant]). + + Ltac clear_prm := + generalize dependent a; intro a; intros; + generalize dependent d; intro d; intros; + generalize dependent prime_q; intro prime_q; intros; + generalize dependent q; intro q; intros; + clear prm. + + Lemma point_eq : forall xy1 xy2 pf1 pf2, + xy1 = xy2 -> exist E.onCurve xy1 pf1 = exist E.onCurve xy2 pf2. + Proof. + destruct xy1, xy2; intros; find_injection; intros; subst. apply f_equal. + apply UIP_dec, F_eq_dec. (* this is a hack. We actually don't care about the equality of the proofs. However, we *can* prove it, and knowing it lets us use the universal equality instead of a type-specific equivalence, which makes many things nicer. *) + Qed. Hint Resolve point_eq. + + Definition point_eqb (p1 p2:E.point) : bool := andb + (F_eqb (fst (proj1_sig p1)) (fst (proj1_sig p2))) + (F_eqb (snd (proj1_sig p1)) (snd (proj1_sig p2))). + + Local Ltac t := + unfold point_eqb; + repeat match goal with + | _ => progress intros + | _ => progress simpl in * + | _ => progress subst + | [P:E.point |- _ ] => destruct P + | [x: (F q * F q)%type |- _ ] => destruct x + | [H: _ /\ _ |- _ ] => destruct H + | [H: _ |- _ ] => rewrite Bool.andb_true_iff in H + | [H: _ |- _ ] => apply F_eqb_eq in H + | _ => rewrite F_eqb_refl + end; eauto. + + Lemma point_eqb_sound : forall p1 p2, point_eqb p1 p2 = true -> p1 = p2. + Proof. + t. + Qed. + + Lemma point_eqb_complete : forall p1 p2, p1 = p2 -> point_eqb p1 p2 = true. + Proof. + t. + Qed. + + Lemma point_eqb_neq : forall p1 p2, point_eqb p1 p2 = false -> p1 <> p2. + Proof. + intros. destruct (point_eqb p1 p2) eqn:Hneq; intuition. + apply point_eqb_complete in H0; congruence. + Qed. + + Lemma point_eqb_neq_complete : forall p1 p2, p1 <> p2 -> point_eqb p1 p2 = false. + Proof. + intros. destruct (point_eqb p1 p2) eqn:Hneq; intuition. + apply point_eqb_sound in Hneq. congruence. + Qed. + + Lemma point_eqb_refl : forall p, point_eqb p p = true. + Proof. + t. + Qed. + + Definition point_eq_dec (p1 p2:E.point) : {p1 = p2} + {p1 <> p2}. + destruct (point_eqb p1 p2) eqn:H; match goal with + | [ H: _ |- _ ] => apply point_eqb_sound in H + | [ H: _ |- _ ] => apply point_eqb_neq in H + end; eauto. + Qed. + + Lemma point_eqb_correct : forall p1 p2, point_eqb p1 p2 = if point_eq_dec p1 p2 then true else false. + Proof. + intros. destruct (point_eq_dec p1 p2); eauto using point_eqb_complete, point_eqb_neq_complete. + Qed. + + Ltac Edefn := unfold E.add, E.add', E.zero; intros; + repeat match goal with + | [ p : E.point |- _ ] => + let x := fresh "x" p in + let y := fresh "y" p in + let pf := fresh "pf" p in + destruct p as [[x y] pf]; unfold E.onCurve in pf + | _ => eapply point_eq, (f_equal2 pair) + | _ => eapply point_eq + end. + Lemma add_comm : forall A B, (A+B = B+A)%E. + Proof. + Edefn; apply (f_equal2 div); ring. + Qed. + + Ltac unifiedAdd_nonzero := match goal with + | [ |- (?op 1 (d * _ * _ * _ * _ * + inv (1 - d * ?xA * ?xB * ?yA * ?yB) * inv (1 + d * ?xA * ?xB * ?yA * ?yB)))%F <> 0%F] + => let Hadd := fresh "Hadd" in + pose proof (@unifiedAdd'_onCurve _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d (xA, yA) (xB, yB)) as Hadd; + simpl in Hadd; + match goal with + | [H : (1 - d * ?xC * xB * ?yC * yB)%F <> 0%F |- (?op 1 ?other)%F <> 0%F] => + replace other with + (d * xC * ((xA * yB + yA * xB) / (1 + d * xA * xB * yA * yB)) + * yC * ((yA * yB - a * xA * xB) / (1 - d * xA * xB * yA * yB)))%F by (subst; unfold div; ring); + auto + end + end. + + Lemma add_assoc : forall A B C, (A+(B+C) = (A+B)+C)%E. + Proof. + Edefn; F_field_simplify_eq; try abstract (rewrite ?@F_pow_2_r in *; clear_prm; F_nsatz); + pose proof (@edwardsAddCompletePlus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d); + pose proof (@edwardsAddCompleteMinus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d); + cbv beta iota in *; + repeat split; field_nonzero idtac; unifiedAdd_nonzero. + Qed. + + Lemma add_0_r : forall P, (P + E.zero = P)%E. + Proof. + Edefn; repeat rewrite ?F_add_0_r, ?F_add_0_l, ?F_sub_0_l, ?F_sub_0_r, + ?F_mul_0_r, ?F_mul_0_l, ?F_mul_1_l, ?F_mul_1_r, ?F_div_1_r; exact eq_refl. + Qed. - Add Field Ffield_p' : (@Ffield_theory q _) - (morphism (@Fring_morph q), - preprocess [Fpreprocess], - postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption], - constants [Fconstant], - div (@Fmorph_div_theory q), - power_tac (@Fpower_theory q) [Fexp_tac]). + Lemma add_0_l : forall P, (E.zero + P)%E = P. + Proof. + intros; rewrite add_comm. apply add_0_r. + Qed. - Add Field Ffield_notConstant : (OpaqueFieldTheory q) - (constants [notConstant]). + Lemma mul_0_l : forall P, (0 * P = E.zero)%E. + Proof. + auto. + Qed. - Ltac clear_prm := - generalize dependent a; intro a; intros; - generalize dependent d; intro d; intros; - generalize dependent prime_q; intro prime_q; intros; - generalize dependent q; intro q; intros; - clear prm. + Lemma mul_S_l : forall n P, (S n * P)%E = (P + n * P)%E. + Proof. + auto. + Qed. - Lemma point_eq : forall p1 p2, p1 = p2 -> forall pf1 pf2, - mkPoint p1 pf1 = mkPoint p2 pf2. - Proof. - destruct p1, p2; intros; find_injection; intros; subst; apply f_equal. - apply UIP_dec, F_eq_dec. (* this is a hack. We actually don't care about the equality of the proofs. However, we *can* prove it, and knowing it lets us use the universal equality instead of a type-specific equivalence, which makes many things nicer. *) - Qed. - Hint Resolve point_eq. + Lemma mul_add_l : forall a b P, ((a + b)%nat * P)%E = E.add (a * P)%E (b * P)%E. + Proof. + induction a; intros; rewrite ?plus_Sn_m, ?plus_O_n, ?mul_S_l, ?mul_0_l, ?add_0_l, ?mul_S_, ?IHa, ?add_assoc; auto. + Qed. - Ltac Edefn := unfold unifiedAdd, unifiedAdd', zero; intros; - repeat match goal with - | [ p : point |- _ ] => - let x := fresh "x" p in - let y := fresh "y" p in - let pf := fresh "pf" p in - destruct p as [[x y] pf]; unfold onCurve in pf - | _ => eapply point_eq, (f_equal2 pair) - | _ => eapply point_eq - end. - Lemma twistedAddComm : forall A B, (A+B = B+A)%E. - Proof. - Edefn; apply (f_equal2 div); ring. - Qed. + Lemma mul_assoc : forall (n m : nat) P, (n * (m * P) = (n * m)%nat * P)%E. + Proof. + induction n; intros; auto. + rewrite ?mul_S_l, ?Mult.mult_succ_l, ?mul_add_l, ?IHn, add_comm. reflexivity. + Qed. - Lemma twistedAddAssoc : forall A B C, (A+(B+C) = (A+B)+C)%E. - Proof. - (* The Ltac takes ~15s, the Qed no longer takes longer than I have had patience for *) - Edefn; F_field_simplify_eq; try abstract (rewrite ?@F_pow_2_r in *; clear_prm; F_nsatz); - repeat split; match goal with [ |- _ = 0%F -> False ] => admit end; - fail "unreachable". - Qed. - - Lemma zeroIsIdentity : forall P, (P + zero = P)%E. - Proof. - Edefn; repeat rewrite ?F_add_0_r, ?F_add_0_l, ?F_sub_0_l, ?F_sub_0_r, - ?F_mul_0_r, ?F_mul_0_l, ?F_mul_1_l, ?F_mul_1_r, ?F_div_1_r; exact eq_refl. - Qed. - - (* solve for x ^ 2 *) - Definition solve_for_x2 (y : F q) := ((y ^ 2 - 1) / (d * (y ^ 2) - a))%F. - - Lemma d_y2_a_nonzero : (forall y, 0 <> d * y ^ 2 - a)%F. - intros ? eq_zero. - pose proof prime_q. - destruct square_a as [sqrt_a sqrt_a_id]. - rewrite <- sqrt_a_id in eq_zero. - destruct (Fq_square_mul_sub _ _ _ eq_zero) as [ [sqrt_d sqrt_d_id] | a_zero]. - + pose proof (nonsquare_d sqrt_d); auto. - + subst. - rewrite Fq_pow_zero in sqrt_a_id by congruence. - auto using nonzero_a. - Qed. - - Lemma a_d_y2_nonzero : (forall y, a - d * y ^ 2 <> 0)%F. - Proof. - intros y eq_zero. - pose proof prime_q. - eapply F_minus_swap in eq_zero. - eauto using (d_y2_a_nonzero y). - Qed. - - Lemma solve_correct : forall x y, onCurve (x, y) <-> - (x ^ 2 = solve_for_x2 y)%F. - Proof. - split. - + intro onCurve_x_y. + Lemma mul_zero_r : forall m, (m * E.zero = E.zero)%E. + Proof. + induction m; rewrite ?mul_S_l, ?add_0_l; auto. + Qed. + + (* solve for x ^ 2 *) + Definition solve_for_x2 (y : F q) := ((y ^ 2 - 1) / (d * (y ^ 2) - a))%F. + + Lemma d_y2_a_nonzero : (forall y, 0 <> d * y ^ 2 - a)%F. + intros ? eq_zero. pose proof prime_q. - unfold onCurve in onCurve_x_y. - eapply F_div_mul; auto using (d_y2_a_nonzero y). - replace (x ^ 2 * (d * y ^ 2 - a))%F with ((d * x ^ 2 * y ^ 2) - (a * x ^ 2))%F by ring. - rewrite F_sub_add_swap. - replace (y ^ 2 + a * x ^ 2)%F with (a * x ^ 2 + y ^ 2)%F by ring. - rewrite onCurve_x_y. - ring. - + intro x2_eq. - unfold onCurve, solve_for_x2 in *. - rewrite x2_eq. - field. - auto using d_y2_a_nonzero. - Qed. - -End CompleteEdwardsCurveTheorems. + destruct square_a as [sqrt_a sqrt_a_id]. + rewrite <- sqrt_a_id in eq_zero. + destruct (Fq_square_mul_sub _ _ _ eq_zero) as [ [sqrt_d sqrt_d_id] | a_zero]. + + pose proof (nonsquare_d sqrt_d); auto. + + subst. + rewrite Fq_pow_zero in sqrt_a_id by congruence. + auto using nonzero_a. + Qed. + + Lemma a_d_y2_nonzero : (forall y, a - d * y ^ 2 <> 0)%F. + Proof. + intros y eq_zero. + pose proof prime_q. + eapply F_minus_swap in eq_zero. + eauto using (d_y2_a_nonzero y). + Qed. + + Lemma solve_correct : forall x y, E.onCurve (x, y) <-> + (x ^ 2 = solve_for_x2 y)%F. + Proof. + split. + + intro onCurve_x_y. + pose proof prime_q. + unfold E.onCurve in onCurve_x_y. + eapply F_div_mul; auto using (d_y2_a_nonzero y). + replace (x ^ 2 * (d * y ^ 2 - a))%F with ((d * x ^ 2 * y ^ 2) - (a * x ^ 2))%F by ring. + rewrite F_sub_add_swap. + replace (y ^ 2 + a * x ^ 2)%F with (a * x ^ 2 + y ^ 2)%F by ring. + rewrite onCurve_x_y. + ring. + + intro x2_eq. + unfold E.onCurve, solve_for_x2 in *. + rewrite x2_eq. + field. + auto using d_y2_a_nonzero. + Qed. + + + Program Definition opp (P:E.point) : E.point := let '(x, y) := proj1_sig P in (opp x, y). + Next Obligation. Proof. + pose (proj2_sig P) as H; rewrite <-Heq_anonymous in H; simpl in H. + rewrite F_square_opp; trivial. + Qed. + + Definition sub P Q := (P + opp Q)%E. + + Lemma opp_zero : opp E.zero = E.zero. + Proof. + pose proof @F_opp_0. + unfold opp, E.zero; eapply point_eq; congruence. + Qed. + + Lemma add_opp_r : forall P, (P + opp P = E.zero)%E. + Proof. + unfold opp; Edefn; rewrite ?@F_pow_2_r in *; (F_field_simplify_eq; [clear_prm; F_nsatz|..]); + rewrite <-?@F_pow_2_r in *; + pose proof (@edwardsAddCompletePlus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d _ _ _ _ pfP pfP); + pose proof (@edwardsAddCompleteMinus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d _ _ _ _ pfP pfP); + field_nonzero idtac. + Qed. + + Lemma add_opp_l : forall P, (opp P + P = E.zero)%E. + Proof. + intros. rewrite add_comm. eapply add_opp_r. + Qed. + + Lemma add_cancel_r : forall A B C, (B+A = C+A -> B = C)%E. + Proof. + intros. + assert ((B + A) + opp A = (C + A) + opp A)%E as Hc by congruence. + rewrite <-!add_assoc, !add_opp_r, !add_0_r in Hc; exact Hc. + Qed. + + Lemma add_cancel_l : forall A B C, (A+B = A+C -> B = C)%E. + Proof. + intros. + rewrite (add_comm A C) in H. + rewrite (add_comm A B) in H. + eauto using add_cancel_r. + Qed. + + Lemma shuffle_eq_add_opp : forall P Q R, (P + Q = R <-> Q = opp P + R)%E. + Proof. + split; intros. + { assert (opp P + (P + Q) = opp P + R)%E as Hc by congruence. + rewrite add_assoc, add_opp_l, add_comm, add_0_r in Hc; exact Hc. } + { subst. rewrite add_assoc, add_opp_r, add_comm, add_0_r; reflexivity. } + Qed. + + Lemma opp_opp : forall P, opp (opp P) = P. + Proof. + intros. + pose proof (add_opp_r P%E) as H. + rewrite add_comm in H. + rewrite shuffle_eq_add_opp in H. + rewrite add_0_r in H. + congruence. + Qed. + + Lemma opp_add : forall P Q, opp (P + Q)%E = (opp P + opp Q)%E. + Proof. + intros. + pose proof (add_opp_r (P+Q)%E) as H. + rewrite <-!add_assoc in H. + rewrite add_comm in H. + rewrite <-!add_assoc in H. + rewrite shuffle_eq_add_opp in H. + rewrite add_comm in H. + rewrite shuffle_eq_add_opp in H. + rewrite add_0_r in H. + assumption. + Qed. + + Lemma opp_mul : forall n P, opp (E.mul n P) = E.mul n (opp P). + Proof. + pose proof opp_add; pose proof opp_zero. + induction n; simpl; intros; congruence. + Qed. + End CompleteEdwardsCurveTheorems. +End E. +Infix "-" := E.sub : E_scope.
\ No newline at end of file diff --git a/src/CompleteEdwardsCurve/DoubleAndAdd.v b/src/CompleteEdwardsCurve/DoubleAndAdd.v index 84c1289f6..50027349d 100644 --- a/src/CompleteEdwardsCurve/DoubleAndAdd.v +++ b/src/CompleteEdwardsCurve/DoubleAndAdd.v @@ -1,71 +1,30 @@ Require Import Crypto.Tactics.VerdiTactics. -Require Import Crypto.Spec.CompleteEdwardsCurve Crypto.BoundedIterOp. +Require Import Crypto.Spec.CompleteEdwardsCurve Crypto.Util.IterAssocOp. Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. Require Import Coq.Numbers.BinNums Coq.NArith.NArith Coq.NArith.Nnat Coq.ZArith.ZArith. Section EdwardsDoubleAndAdd. Context {prm:TwistedEdwardsParams}. - Definition doubleAndAdd (b n : nat) (P : point) : point := - match N.of_nat n with - | 0%N => zero - | N.pos p => iter_op unifiedAdd zero b p b P - end. + Definition doubleAndAdd (bound n : nat) (P : E.point) : E.point := + iter_op E.add E.zero N.testbit_nat (N.of_nat n) P bound. - Lemma scalarMult_double : forall n P, scalarMult (n + n) P = scalarMult n (P + P)%E. + Lemma scalarMult_double : forall n P, E.mul (n + n) P = E.mul n (P + P)%E. Proof. intros. replace (n + n)%nat with (n * 2)%nat by omega. induction n; simpl; auto. - rewrite twistedAddAssoc. + rewrite E.add_assoc. f_equal; auto. Qed. - Lemma iter_op_double : forall p P, - Pos.iter_op unifiedAdd (p + p) P = Pos.iter_op unifiedAdd p (P + P)%E. + Lemma doubleAndAdd_spec : forall bound n P, N.size_nat (N.of_nat n) <= bound -> + E.mul n P = doubleAndAdd bound n P. Proof. - intros. - rewrite Pos.add_diag. - unfold Pos.iter_op; simpl. - auto. - Qed. - - Lemma doubleAndAdd_spec : forall n b P, (Pos.size_nat (Pos.of_nat n) <= b) -> - scalarMult n P = doubleAndAdd b n P. - Proof. - induction n; auto; intros. - unfold doubleAndAdd; simpl. - rewrite Pos.of_nat_succ. - rewrite iter_op_spec by (auto using zeroIsIdentity). - case_eq (Pos.of_nat (S n)); intros. { - simpl; f_equal. - rewrite (IHn b) by (pose proof (size_of_succ n); omega). - unfold doubleAndAdd. - rewrite H0 in H. - rewrite <- Pos.of_nat_succ in H0. - rewrite <- (xI_is_succ n p) by apply H0. - rewrite Znat.positive_nat_N; simpl. - rewrite iter_op_spec; auto using zeroIsIdentity. - } { - simpl; f_equal. - rewrite (IHn b) by (pose proof (size_of_succ n); omega). - unfold doubleAndAdd. - rewrite <- (xO_is_succ n p) by (rewrite Pos.of_nat_succ; auto). - rewrite Znat.positive_nat_N; simpl. - rewrite <- Pos.succ_pred_double in H0. - rewrite H0 in H. - rewrite iter_op_spec by (auto using zeroIsIdentity; - pose proof (Pos.lt_succ_diag_r (Pos.pred_double p)); - apply Pos.size_nat_monotone in H1; omega; auto). - rewrite <- iter_op_double. - rewrite Pos.add_diag. - rewrite <- Pos.succ_pred_double. - rewrite Pos.iter_op_succ by apply twistedAddAssoc; auto. - } { - rewrite <- Pnat.Pos2Nat.inj_iff in H0. - rewrite Pnat.Nat2Pos.id in H0 by auto. - rewrite Pnat.Pos2Nat.inj_1 in H0. - assert (n = 0)%nat by omega; subst. - auto using zeroIsIdentity. - } + induction n; auto; intros; unfold doubleAndAdd; + rewrite iter_op_spec with (scToN := fun x => x); ( + unfold Morphisms.Proper, Morphisms.respectful, Equivalence.equiv; + intros; subst; try rewrite Nat2N.id; + reflexivity || assumption || apply E.add_assoc + || rewrite E.add_comm; apply E.add_0_r). Qed. End EdwardsDoubleAndAdd.
\ No newline at end of file diff --git a/src/CompleteEdwardsCurve/ExtendedCoordinates.v b/src/CompleteEdwardsCurve/ExtendedCoordinates.v index e918ac128..e91bc084b 100644 --- a/src/CompleteEdwardsCurve/ExtendedCoordinates.v +++ b/src/CompleteEdwardsCurve/ExtendedCoordinates.v @@ -3,7 +3,7 @@ Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. Require Import Crypto.ModularArithmetic.FField. Require Import Crypto.Tactics.VerdiTactics. -Require Import Util.IterAssocOp BinNat NArith. +Require Import Util.IterAssocOp BinNat NArith. Require Import Coq.Setoids.Setoid Coq.Classes.Morphisms Coq.Classes.Equivalence. Local Open Scope equiv_scope. Local Open Scope F_scope. @@ -19,10 +19,10 @@ Section ExtendedCoordinates. postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption], constants [Fconstant], div (@Fmorph_div_theory q), - power_tac (@Fpower_theory q) [Fexp_tac]). + power_tac (@Fpower_theory q) [Fexp_tac]). Add Field Ffield_notConstant : (OpaqueFieldTheory q) - (constants [notConstant]). + (constants [notConstant]). (** [extended] represents a point on an elliptic curve using extended projective * Edwards coordinates with twist a=-1 (see <https://eprint.iacr.org/2008/522.pdf>). *) @@ -44,8 +44,8 @@ Section ExtendedCoordinates. Local Hint Unfold twistedToExtended extendedToTwisted rep. Local Notation "P '~=' rP" := (rep P rP) (at level 70). - Ltac unfoldExtended := - repeat progress (autounfold; unfold onCurve, unifiedAdd, unifiedAdd', rep in *; intros); + Ltac unfoldExtended := + repeat progress (autounfold; unfold E.onCurve, E.add, E.add', rep in *; intros); repeat match goal with | [ p : (F q*F q)%type |- _ ] => let x := fresh "x" p in @@ -61,7 +61,7 @@ Section ExtendedCoordinates. | [ H: @eq (F q * F q)%type _ _ |- _ ] => invcs H | [ H: @eq F q ?x _ |- _ ] => isVar x; rewrite H; clear H end. - + Ltac solveExtended := unfoldExtended; repeat match goal with | [ |- _ /\ _ ] => split @@ -83,14 +83,14 @@ Section ExtendedCoordinates. solveExtended. Qed. - Definition extendedPoint := { P:extended | rep P (extendedToTwisted P) /\ onCurve (extendedToTwisted P) }. + Definition extendedPoint := { P:extended | rep P (extendedToTwisted P) /\ E.onCurve (extendedToTwisted P) }. - Program Definition mkExtendedPoint : point -> extendedPoint := twistedToExtended. + Program Definition mkExtendedPoint : E.point -> extendedPoint := twistedToExtended. Next Obligation. destruct x; erewrite extendedToTwisted_rep; eauto using twistedToExtended_rep. Qed. - Program Definition unExtendedPoint : extendedPoint -> point := extendedToTwisted. + Program Definition unExtendedPoint : extendedPoint -> E.point := extendedToTwisted. Next Obligation. destruct x; simpl; intuition. Qed. @@ -103,7 +103,7 @@ Section ExtendedCoordinates. Lemma unExtendedPoint_mkExtendedPoint : forall P, unExtendedPoint (mkExtendedPoint P) = P. Proof. - destruct P; eapply point_eq; simpl; erewrite extendedToTwisted_rep; eauto using twistedToExtended_rep. + destruct P; eapply E.point_eq; simpl; erewrite extendedToTwisted_rep; eauto using twistedToExtended_rep. Qed. Global Instance Proper_mkExtendedPoint : Proper (eq==>equiv) mkExtendedPoint. @@ -116,6 +116,8 @@ Section ExtendedCoordinates. repeat (econstructor || intro); unfold extendedPoint_eq in *; congruence. Qed. + Definition twice_d := d + d. + Section TwistMinus1. Context (a_eq_minus1 : a = opp 1). (** Second equation from <http://eprint.iacr.org/2008/522.pdf> section 3.1, also <https://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html#addition-add-2008-hwcd-3> and <https://tools.ietf.org/html/draft-josefsson-eddsa-ed25519-03> *) @@ -124,8 +126,8 @@ Section ExtendedCoordinates. let '(X2, Y2, Z2, T2) := P2 in let A := (Y1-X1)*(Y2-X2) in let B := (Y1+X1)*(Y2+X2) in - let C := T1*ZToField 2*d*T2 in - let D := Z1*ZToField 2 *Z2 in + let C := T1*twice_d*T2 in + let D := Z1*(Z2+Z2) in let E := B-A in let F := D-C in let G := D+C in @@ -135,47 +137,30 @@ Section ExtendedCoordinates. let T3 := E*H in let Z3 := F*G in (X3, Y3, Z3, T3). - Local Hint Unfold unifiedAdd. + Local Hint Unfold E.add. + + Local Ltac tnz := repeat apply Fq_mul_nonzero_nonzero; auto using (@char_gt_2 q two_lt_q). - Lemma unifiedAddM1'_rep: forall P Q rP rQ, onCurve rP -> onCurve rQ -> - P ~= rP -> Q ~= rQ -> (unifiedAddM1' P Q) ~= (unifiedAdd' rP rQ). + Lemma F_mul_2_l : forall x : F q, ZToField 2 * x = x + x. + intros. ring. + Qed. + + Lemma unifiedAddM1'_rep: forall P Q rP rQ, E.onCurve rP -> E.onCurve rQ -> + P ~= rP -> Q ~= rQ -> (unifiedAddM1' P Q) ~= (E.add' rP rQ). Proof. intros P Q rP rQ HoP HoQ HrP HrQ. pose proof (@edwardsAddCompletePlus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d). pose proof (@edwardsAddCompleteMinus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d). - unfoldExtended; rewrite a_eq_minus1 in *; simpl in *. + unfoldExtended; unfold twice_d; rewrite a_eq_minus1 in *; simpl in *. repeat rewrite <-F_mul_2_l. repeat split; repeat apply (f_equal2 pair); try F_field; repeat split; auto; repeat rewrite ?F_add_0_r, ?F_add_0_l, ?F_sub_0_l, ?F_sub_0_r, - ?F_mul_0_r, ?F_mul_0_l, ?F_mul_1_l, ?F_mul_1_r, ?F_div_1_r. - - Ltac tnz := repeat apply Fq_mul_nonzero_nonzero; auto using (@char_gt_2 q two_lt_q). - (* If we we had reasoning modulo associativity and commutativity, - * the following tactic would probably solve all remaining goals here: - repeat match goal with [H1: @eq (F p) _ _, H2: @eq (F p) _ _ |- _ ] => - let H := fresh "H" in ( - pose proof (edwardsAddCompletePlus _ _ _ _ H1 H2) as H; - match type of H with ?xs <> 0 => ac_rewrite (eq_refl xs) end - ) || ( - pose proof (edwardsAddCompleteMinus _ _ _ _ H1 H2) as H; - match type of H with ?xs <> 0 => ac_rewrite (eq_refl xs) end - ); tnz - end. *) - - - replace (ZP * ZQ * ZP * ZQ + d * XP * XQ * YP * YQ) with (ZQ*ZQ*ZP*ZP* (1 + d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz. - - replace (ZP * ZToField 2 * ZQ * (ZP * ZQ) + XP * YP * ZToField 2 * d * (XQ * YQ)) with (ZToField 2*ZQ*ZQ*ZP*ZP* (1 + d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz. - - replace (ZP * ZToField 2 * ZQ * (ZP * ZQ) - XP * YP * ZToField 2 * d * (XQ * YQ)) with (ZToField 2*ZQ*ZQ*ZP*ZP* (1 - d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz. - - replace (ZP * ZQ * ZP * ZQ - d * XP * XQ * YP * YQ) with (ZQ*ZQ*ZP*ZP* (1 - d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz. - - replace (ZP * ZToField 2 * ZQ * (ZP * ZQ) + XP * YP * ZToField 2 * d * (XQ * YQ)) with (ZToField 2*ZQ*ZQ*ZP*ZP* (1 + d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz. - - replace (ZP * ZToField 2 * ZQ * (ZP * ZQ) - XP * YP * ZToField 2 * d * (XQ * YQ)) with (ZToField 2*ZQ*ZQ*ZP*ZP* (1 - d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz. - - replace (ZP * ZToField 2 * ZQ - XP * YP / ZP * ZToField 2 * d * (XQ * YQ / ZQ) ) with (ZToField 2*ZQ*ZP* (1 - d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz. - replace (ZP * ZToField 2 * ZQ + XP * YP / ZP * ZToField 2 * d * (XQ * YQ / ZQ) ) with (ZToField 2*ZQ*ZP* (1 + d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz. - - replace (ZP * ZToField 2 * ZQ * (ZP * ZQ) + XP * YP * ZToField 2 * d * (XQ * YQ)) with (ZToField 2*ZQ*ZQ*ZP*ZP* (1 + d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz. - - replace (ZP * ZToField 2 * ZQ * (ZP * ZQ) - XP * YP * ZToField 2 * d * (XQ * YQ)) with (ZToField 2*ZQ*ZQ*ZP*ZP* (1 - d * (XQ / ZQ) * (XP / ZP) * (YQ / ZQ) * (YP / ZP))) by (field; tnz); tnz. + ?F_mul_0_r, ?F_mul_0_l, ?F_mul_1_l, ?F_mul_1_r, ?F_div_1_r; + field_nonzero tnz. Qed. - Lemma unifiedAdd'_onCurve : forall P Q, onCurve P -> onCurve Q -> onCurve (unifiedAdd' P Q). + Lemma unifiedAdd'_onCurve : forall P Q, E.onCurve P -> E.onCurve Q -> E.onCurve (E.add' P Q). Proof. - intros; pose proof (proj2_sig (unifiedAdd (mkPoint _ H) (mkPoint _ H0))); eauto. + intros; pose proof (proj2_sig (E.add (exist _ _ H) (exist _ _ H0))); eauto. Qed. Program Definition unifiedAddM1 : extendedPoint -> extendedPoint -> extendedPoint := unifiedAddM1'. @@ -188,9 +173,9 @@ Section ExtendedCoordinates. eauto using unifiedAdd'_onCurve. Qed. - Lemma unifiedAddM1_rep : forall P Q, unifiedAdd (unExtendedPoint P) (unExtendedPoint Q) = unExtendedPoint (unifiedAddM1 P Q). + Lemma unifiedAddM1_rep : forall P Q, E.add (unExtendedPoint P) (unExtendedPoint Q) = unExtendedPoint (unifiedAddM1 P Q). Proof. - destruct P, Q; unfold unExtendedPoint, unifiedAdd, unifiedAddM1; eapply point_eq; simpl in *; intuition. + destruct P, Q; unfold unExtendedPoint, E.add, unifiedAddM1; eapply E.point_eq; simpl in *; intuition. pose proof (unifiedAddM1'_rep x x0 (extendedToTwisted x) (extendedToTwisted x0)); destruct (unifiedAddM1' x x0); unfold rep in *; intuition. @@ -201,34 +186,54 @@ Section ExtendedCoordinates. repeat (econstructor || intro). repeat match goal with [H: _ === _ |- _ ] => inversion H; clear H end; unfold equiv, extendedPoint_eq. rewrite <-!unifiedAddM1_rep. - destruct x, y, x0, y0; simpl in *; eapply point_eq; congruence. + destruct x, y, x0, y0; simpl in *; eapply E.point_eq; congruence. Qed. - Lemma unifiedAddM1_0_r : forall P, unifiedAddM1 P (mkExtendedPoint zero) === P. + Lemma unifiedAddM1_0_r : forall P, unifiedAddM1 P (mkExtendedPoint E.zero) === P. unfold equiv, extendedPoint_eq; intros. - rewrite <-!unifiedAddM1_rep, unExtendedPoint_mkExtendedPoint, zeroIsIdentity; auto. + rewrite <-!unifiedAddM1_rep, unExtendedPoint_mkExtendedPoint, E.add_0_r; auto. Qed. - Lemma unifiedAddM1_0_l : forall P, unifiedAddM1 (mkExtendedPoint zero) P === P. + Lemma unifiedAddM1_0_l : forall P, unifiedAddM1 (mkExtendedPoint E.zero) P === P. unfold equiv, extendedPoint_eq; intros. - rewrite <-!unifiedAddM1_rep, twistedAddComm, unExtendedPoint_mkExtendedPoint, zeroIsIdentity; auto. + rewrite <-!unifiedAddM1_rep, E.add_comm, unExtendedPoint_mkExtendedPoint, E.add_0_r; auto. Qed. Lemma unifiedAddM1_assoc : forall a b c, unifiedAddM1 a (unifiedAddM1 b c) === unifiedAddM1 (unifiedAddM1 a b) c. Proof. unfold equiv, extendedPoint_eq; intros. - rewrite <-!unifiedAddM1_rep, twistedAddAssoc; auto. + rewrite <-!unifiedAddM1_rep, E.add_assoc; auto. + Qed. + + Lemma testbit_conversion_identity : forall x i, N.testbit_nat x i = N.testbit_nat ((fun a => a) x) i. + Proof. + trivial. Qed. - - Definition scalarMultM1 := iter_op unifiedAddM1 (mkExtendedPoint zero). - Definition scalarMultM1_spec := iter_op_spec unifiedAddM1 unifiedAddM1_assoc (mkExtendedPoint zero) unifiedAddM1_0_l. - Lemma scalarMultM1_rep : forall n P, unExtendedPoint (scalarMultM1 (N.of_nat n) P) = scalarMult n (unExtendedPoint P). - intros; rewrite scalarMultM1_spec, Nat2N.id. + + Definition scalarMultM1 := iter_op unifiedAddM1 (mkExtendedPoint E.zero) N.testbit_nat. + Definition scalarMultM1_spec := + iter_op_spec unifiedAddM1 unifiedAddM1_assoc (mkExtendedPoint E.zero) unifiedAddM1_0_l + N.testbit_nat (fun x => x) testbit_conversion_identity. + Lemma scalarMultM1_rep : forall n P, unExtendedPoint (scalarMultM1 (N.of_nat n) P (N.size_nat (N.of_nat n))) = E.mul n (unExtendedPoint P). + intros; rewrite scalarMultM1_spec, Nat2N.id; auto. induction n; [simpl; rewrite !unExtendedPoint_mkExtendedPoint; reflexivity|]. - unfold scalarMult; fold scalarMult. + unfold E.mul; fold E.mul. rewrite <-IHn, unifiedAddM1_rep; auto. Qed. End TwistMinus1. -End ExtendedCoordinates.
\ No newline at end of file + Definition negateExtended' P := let '(X, Y, Z, T) := P in (opp X, Y, Z, opp T). + Program Definition negateExtended (P:extendedPoint) : extendedPoint := negateExtended' (proj1_sig P). + Next Obligation. + Proof. + unfold negateExtended', rep; destruct P as [[X Y Z T] H]; simpl. destruct H as [[[] []] ?]; subst. + repeat rewrite ?F_div_opp_1, ?F_mul_opp_l, ?F_square_opp; repeat split; trivial. + Qed. + + Lemma negateExtended_correct : forall P, E.opp (unExtendedPoint P) = unExtendedPoint (negateExtended P). + Proof. + unfold E.opp, unExtendedPoint, negateExtended; destruct P as [[]]; simpl; intros. + eapply E.point_eq; repeat rewrite ?F_div_opp_1, ?F_mul_opp_l, ?F_square_opp; trivial. + Qed. +End ExtendedCoordinates. diff --git a/src/EdDSAProofs.v b/src/EdDSAProofs.v index 83467bf6d..dba71b49c 100644 --- a/src/EdDSAProofs.v +++ b/src/EdDSAProofs.v @@ -37,7 +37,7 @@ Section EdDSAProofs. Lemma decode_sign_split2 : forall sk {n} (M : word n), split2 b b (sign (public sk) sk M) = let r : nat := H (prngKey sk ++ M) in (* secret nonce *) - let R : point := (r * B)%E in (* commitment to nonce *) + let R : E.point := (r * B)%E in (* commitment to nonce *) let s : nat := curveKey sk in (* secret scalar *) let S : F (Z.of_nat l) := ZToField (Z.of_nat (r + H (enc R ++ public sk ++ M) * s)) in enc S. @@ -46,62 +46,21 @@ Section EdDSAProofs. Qed. Hint Rewrite decode_sign_split2. - Lemma zero_times : forall P, (0 * P = zero)%E. - Proof. - auto. - Qed. - - Lemma zero_plus : forall P, (zero + P = P)%E. - Proof. - intros; rewrite twistedAddComm; apply zeroIsIdentity. - Qed. - - Lemma times_S : forall n m, S n * m = m + n * m. - Proof. - auto. - Qed. - - Lemma times_S_nat : forall n m, (S n * m = m + n * m)%nat. - Proof. - auto. - Qed. - - Hint Rewrite plus_O_n plus_Sn_m times_S times_S_nat. - Hint Rewrite zeroIsIdentity zero_times zero_plus twistedAddAssoc. - - Lemma scalarMult_distr : forall n0 m, ((n0 + m)%nat * B)%E = unifiedAdd (n0 * B)%E (m * B)%E. - Proof. - unfold scalarMult; induction n0; arith. - Qed. - Hint Rewrite scalarMult_distr. - - Lemma scalarMult_assoc : forall (n0 m : nat), (n0 * (m * B) = (n0 * m)%nat * B)%E. - Proof. - induction n0; arith; simpl; arith. - Qed. - Hint Rewrite scalarMult_assoc. - - Lemma scalarMult_zero : forall m, (m * zero = zero)%E. - Proof. - unfold scalarMult; induction m; arith. - Qed. - Hint Rewrite scalarMult_zero. + Hint Rewrite E.add_0_r E.add_0_l E.add_assoc. + Hint Rewrite E.mul_assoc E.mul_add_l E.mul_0_l E.mul_zero_r. + Hint Rewrite plus_O_n plus_Sn_m mult_0_l mult_succ_l. Hint Rewrite l_order_B. - - Lemma l_order_B' : forall x, (l * x * B = zero)%E. + Lemma l_order_B' : forall x, (l * x * B = E.zero)%E. Proof. - intros; rewrite Mult.mult_comm. rewrite <- scalarMult_assoc. arith. - Qed. - - Hint Rewrite l_order_B'. + intros; rewrite Mult.mult_comm. rewrite <- E.mul_assoc. arith. + Qed. Hint Rewrite l_order_B'. Lemma scalarMult_mod_l : forall n0, (n0 mod l * B = n0 * B)%E. Proof. intros. rewrite (div_mod n0 l) at 2 by (generalize l_odd; omega). arith. - Qed. - Hint Rewrite scalarMult_mod_l. + Qed. Hint Rewrite scalarMult_mod_l. Hint Rewrite @encoding_valid. Hint Rewrite @FieldToZ_ZToField. diff --git a/src/Encoding/EncodingTheorems.v b/src/Encoding/EncodingTheorems.v index f53ad0319..52ac91ada 100644 --- a/src/Encoding/EncodingTheorems.v +++ b/src/Encoding/EncodingTheorems.v @@ -1,7 +1,7 @@ Require Import Crypto.Spec.Encoding. Section EncodingTheorems. - Context {A B : Type} {E : encoding of A as B}. + Context {A B : Type} {E : canonical encoding of A as B}. Lemma encoding_inj : forall x y, enc x = enc y -> x = y. Proof. diff --git a/src/Encoding/ModularWordEncodingPre.v b/src/Encoding/ModularWordEncodingPre.v new file mode 100644 index 000000000..417344b43 --- /dev/null +++ b/src/Encoding/ModularWordEncodingPre.v @@ -0,0 +1,45 @@ +Require Import Coq.ZArith.ZArith. +Require Import Coq.Numbers.Natural.Peano.NPeano. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Bedrock.Word. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.ZUtil Crypto.Util.WordUtil. +Require Import Crypto.Spec.Encoding. + +Local Open Scope nat_scope. + +Section ModularWordEncodingPre. + Context {m : Z} {sz : nat} {m_pos : (0 < m)%Z} {bound_check : Z.to_nat m < 2 ^ sz}. + + Let Fm_enc (x : F m) : word sz := NToWord sz (Z.to_N (FieldToZ x)). + + Let Fm_dec (x_ : word sz) : option (F m) := + let z := Z.of_N (wordToN (x_)) in + if Z_lt_dec z m + then Some (ZToField z) + else None + . + + Lemma Fm_encoding_valid : forall x, Fm_dec (Fm_enc x) = Some x. + Proof. + unfold Fm_dec, Fm_enc; intros. + pose proof (FieldToZ_range x m_pos). + rewrite wordToN_NToWord_idempotent by (apply bound_check_nat_N; + assert (Z.to_nat x < Z.to_nat m) by (apply Z2Nat.inj_lt; omega); omega). + rewrite Z2N.id by omega. + rewrite ZToField_idempotent. + break_if; auto; omega. + Qed. + + Lemma Fm_encoding_canonical : forall w x, Fm_dec w = Some x -> Fm_enc x = w. + Proof. + unfold Fm_dec, Fm_enc; intros ? ? dec_Some. + break_if; [ | congruence ]. + inversion dec_Some. + rewrite FieldToZ_ZToField. + rewrite Z.mod_small by (pose proof (N2Z.is_nonneg (wordToN w)); try omega). + rewrite N2Z.id. + apply NToWord_wordToN. + Qed. + +End ModularWordEncodingPre. diff --git a/src/Encoding/ModularWordEncodingTheorems.v b/src/Encoding/ModularWordEncodingTheorems.v new file mode 100644 index 000000000..7251ac1e6 --- /dev/null +++ b/src/Encoding/ModularWordEncodingTheorems.v @@ -0,0 +1,54 @@ +Require Import Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Require Import Coq.Numbers.Natural.Peano.NPeano. +Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems Crypto.ModularArithmetic.ModularArithmeticTheorems. +Require Import Bedrock.Word. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Spec.Encoding. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Spec.ModularWordEncoding. + + +Local Open Scope F_scope. + +Section SignBit. + Context {m : Z} {prime_m : prime m} {two_lt_m : (2 < m)%Z} {sz : nat} {bound_check : (Z.to_nat m < 2 ^ sz)%nat}. + + Lemma sign_bit_parity : forall x, @sign_bit m sz x = Z.odd x. + Proof. + unfold sign_bit, Fm_enc; intros. + pose proof (shatter_word (NToWord sz (Z.to_N x))) as shatter. + case_eq sz; intros; subst; rewrite shatter. + + pose proof (prime_ge_2 m prime_m). + simpl in bound_check. + assert (m < 1)%Z by (apply Z2Nat.inj_lt; try omega; assumption). + omega. + + assert (0 < m)%Z as m_pos by (pose proof prime_ge_2 m prime_m; omega). + pose proof (FieldToZ_range x m_pos). + destruct (FieldToZ x); auto. + - destruct p; auto. + - pose proof (Pos2Z.neg_is_neg p); omega. + Qed. + + Lemma sign_bit_zero : @sign_bit m sz 0 = false. + Proof. + rewrite sign_bit_parity; auto. + Qed. + + Lemma sign_bit_opp : forall (x : F m), x <> 0 -> negb (@sign_bit m sz x) = @sign_bit m sz (opp x). + Proof. + intros. + pose proof sign_bit_zero as sign_zero. + rewrite !sign_bit_parity in *. + pose proof (F_opp_spec x) as opp_spec_x. + apply F_eq in opp_spec_x. + rewrite FieldToZ_add in opp_spec_x. + rewrite <-opp_spec_x, Z_odd_mod in sign_zero by (pose proof prime_ge_2 m prime_m; omega). + replace (Z.odd m) with true in sign_zero by (destruct (ZUtil.prime_odd_or_2 m prime_m); auto || omega). + rewrite Z.odd_add, F_FieldToZ_add_opp, Z.div_same, Bool.xorb_true_r in sign_zero by assumption || omega. + apply Bool.xorb_eq. + rewrite <-Bool.negb_xorb_l. + assumption. + Qed. + +End SignBit.
\ No newline at end of file diff --git a/src/Encoding/PointEncodingPre.v b/src/Encoding/PointEncodingPre.v new file mode 100644 index 000000000..73ced869b --- /dev/null +++ b/src/Encoding/PointEncodingPre.v @@ -0,0 +1,275 @@ +Require Import Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Require Import Coq.Numbers.Natural.Peano.NPeano. +Require Import Coq.Program.Equality. +Require Import Crypto.Encoding.EncodingTheorems. +Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Bedrock.Word. +Require Import Crypto.Encoding.ModularWordEncodingTheorems. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.ZUtil. + +Require Import Crypto.Spec.Encoding Crypto.Spec.ModularWordEncoding Crypto.Spec.ModularArithmetic. + +Local Open Scope F_scope. + +Section PointEncoding. + Context {prm: TwistedEdwardsParams} {sz : nat} {sz_nonzero : (0 < sz)%nat} + {bound_check : (Z.to_nat q < 2 ^ sz)%nat} {q_5mod8 : (q mod 8 = 5)%Z} + {sqrt_minus1_valid : (@ZToField q 2 ^ Z.to_N (q / 4)) ^ 2 = opp 1} + {FqEncoding : canonical encoding of (F q) as (word sz)} + {sign_bit : F q -> bool} {sign_bit_zero : sign_bit 0 = false} + {sign_bit_opp : forall x, x <> 0 -> negb (sign_bit x) = sign_bit (opp x)}. + Existing Instance prime_q. + + Add Field Ffield : (@Ffield_theory q _) + (morphism (@Fring_morph q), + preprocess [Fpreprocess], + postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption], + constants [Fconstant], + div (@Fmorph_div_theory q), + power_tac (@Fpower_theory q) [Fexp_tac]). + + Definition sqrt_valid (a : F q) := ((sqrt_mod_q a) ^ 2 = a)%F. + + Lemma solve_sqrt_valid : forall p, E.onCurve p -> + sqrt_valid (E.solve_for_x2 (snd p)). + Proof. + intros ? onCurve_xy. + destruct p as [x y]; simpl. + rewrite (E.solve_correct x y) in onCurve_xy. + rewrite <- onCurve_xy. + unfold sqrt_valid. + eapply sqrt_mod_q_valid; eauto. + unfold isSquare; eauto. + Grab Existential Variables. eauto. + Qed. + + Lemma solve_onCurve: forall (y : F q), sqrt_valid (E.solve_for_x2 y) -> + E.onCurve (sqrt_mod_q (E.solve_for_x2 y), y). + Proof. + intros. + unfold sqrt_valid in *. + apply E.solve_correct; auto. + Qed. + + Lemma solve_opp_onCurve: forall (y : F q), sqrt_valid (E.solve_for_x2 y) -> + E.onCurve (opp (sqrt_mod_q (E.solve_for_x2 y)), y). + Proof. + intros y sqrt_valid_x2. + unfold sqrt_valid in *. + apply E.solve_correct. + rewrite <- sqrt_valid_x2 at 2. + ring. + Qed. + + Definition point_enc_coordinates (p : (F q * F q)) : Word.word (S sz) := let '(x,y) := p in + Word.WS (sign_bit x) (enc y). + + Let point_enc (p : E.point) : Word.word (S sz) := let '(x,y) := proj1_sig p in + Word.WS (sign_bit x) (enc y). + + Definition point_dec_coordinates (sign_bit : F q -> bool) (w : Word.word (S sz)) : option (F q * F q) := + match dec (Word.wtl w) with + | None => None + | Some y => let x2 := E.solve_for_x2 y in + let x := sqrt_mod_q x2 in + if F_eq_dec (x ^ 2) x2 + then + let p := (if Bool.eqb (whd w) (sign_bit x) then x else opp x, y) in + if (andb (F_eqb x 0) (whd w)) + then None (* special case for 0, since its opposite has the same sign; if the sign bit of 0 is 1, produce None.*) + else Some p + else None + end. + + Ltac inversion_Some_eq := match goal with [H: Some ?x = Some ?y |- _] => inversion H; subst end. + + Lemma point_dec_coordinates_onCurve : forall w p, point_dec_coordinates sign_bit w = Some p -> E.onCurve p. + Proof. + unfold point_dec_coordinates; intros. + edestruct dec; [ | congruence]. + break_if; [ | congruence]. + break_if; [ congruence | ]. + break_if; inversion_Some_eq; auto using solve_onCurve, solve_opp_onCurve. + Qed. + + Lemma prod_eq_dec : forall {A} (A_eq_dec : forall a a' : A, {a = a'} + {a <> a'}) + (x y : (A * A)), {x = y} + {x <> y}. + Proof. + decide equality. + Qed. + + Lemma option_eq_dec : forall {A} (A_eq_dec : forall a a' : A, {a = a'} + {a <> a'}) + (x y : option A), {x = y} + {x <> y}. + Proof. + decide equality. + Qed. + + Definition point_dec' w p : option E.point := + match (option_eq_dec (prod_eq_dec F_eq_dec) (point_dec_coordinates sign_bit w) (Some p)) with + | left EQ => Some (exist _ p (point_dec_coordinates_onCurve w p EQ)) + | right _ => None (* this case is never reached *) + end. + + Definition point_dec (w : word (S sz)) : option E.point := + match (point_dec_coordinates sign_bit w) with + | Some p => point_dec' w p + | None => None + end. + + Lemma point_coordinates_encoding_canonical : forall w p, + point_dec_coordinates sign_bit w = Some p -> point_enc_coordinates p = w. + Proof. + unfold point_dec_coordinates, point_enc_coordinates; intros ? ? coord_dec_Some. + case_eq (dec (wtl w)); [ intros ? dec_Some | intros dec_None; rewrite dec_None in *; congruence ]. + destruct p. + rewrite (shatter_word w). + f_equal; rewrite dec_Some in *; + do 2 (break_if; try congruence); inversion coord_dec_Some; subst. + + destruct (F_eq_dec (sqrt_mod_q (E.solve_for_x2 f1)) 0%F) as [sqrt_0 | ?]. + - rewrite sqrt_0 in *. + apply sqrt_mod_q_root_0 in sqrt_0; try assumption. + rewrite sqrt_0 in *. + break_if; [symmetry; auto using Bool.eqb_prop | ]. + rewrite sign_bit_zero in *. + simpl in Heqb; rewrite Heqb in *. + discriminate. + - break_if. + symmetry; auto using Bool.eqb_prop. + rewrite <- sign_bit_opp by assumption. + destruct (whd w); inversion Heqb0; break_if; auto. + + inversion coord_dec_Some; subst. + auto using encoding_canonical. +Qed. + + Lemma point_encoding_canonical : forall w x, point_dec w = Some x -> point_enc x = w. + Proof. + (* + unfold point_enc; intros. + unfold point_dec in *. + assert (point_dec_coordinates w = Some (proj1_sig x)). { + set (y := point_dec_coordinates w) in *. + revert H. + dependent destruction y. intros. + rewrite H0 in H. + *) + Admitted. + +Lemma point_dec_coordinates_correct w + : option_map (@proj1_sig _ _) (point_dec w) = point_dec_coordinates sign_bit w. +Proof. + unfold point_dec, option_map. + do 2 break_match; try congruence; unfold point_dec' in *; + break_match; try congruence. + inversion_Some_eq. + reflexivity. +Qed. + +Lemma y_decode : forall p, dec (wtl (point_enc_coordinates p)) = Some (snd p). +Proof. + intros. + destruct p as [x y]; simpl. + exact (encoding_valid y). +Qed. + +Lemma sign_bit_opp_eq_iff : forall x y, y <> 0 -> + (sign_bit x <> sign_bit y <-> sign_bit x = sign_bit (opp y)). +Proof. + split; intro sign_mismatch; case_eq (sign_bit x); case_eq (sign_bit y); + try congruence; intros y_sign x_sign; rewrite <- sign_bit_opp in * by auto; + rewrite y_sign, x_sign in *; reflexivity || discriminate. +Qed. + +Lemma sign_bit_squares : forall x y, y <> 0 -> x ^ 2 = y ^ 2 -> + sign_bit x = sign_bit y -> x = y. +Proof. + intros ? ? y_nonzero squares_eq sign_match. + destruct (sqrt_solutions _ _ squares_eq) as [? | eq_opp]; auto. + assert (sign_bit x = sign_bit (opp y)) as sign_mismatch by (f_equal; auto). + apply sign_bit_opp_eq_iff in sign_mismatch; auto. + congruence. +Qed. + +Lemma sign_bit_match : forall x x' y : F q, E.onCurve (x, y) -> E.onCurve (x', y) -> + sign_bit x = sign_bit x' -> x = x'. +Proof. + intros ? ? ? onCurve_x onCurve_x' sign_match. + apply E.solve_correct in onCurve_x. + apply E.solve_correct in onCurve_x'. + destruct (F_eq_dec x' 0). + + subst. + rewrite Fq_pow_zero in onCurve_x' by congruence. + rewrite <- onCurve_x' in *. + eapply Fq_root_zero; eauto. + + apply sign_bit_squares; auto. + rewrite onCurve_x, onCurve_x'. + reflexivity. +Qed. + +Lemma point_encoding_coordinates_valid : forall p, E.onCurve p -> + point_dec_coordinates sign_bit (point_enc_coordinates p) = Some p. +Proof. + intros p onCurve_p. + unfold point_dec_coordinates. + rewrite y_decode. + pose proof (solve_sqrt_valid p onCurve_p) as solve_sqrt_valid_p. + destruct p as [x y]. + unfold sqrt_valid in *. + simpl. + replace (E.solve_for_x2 y) with (x ^ 2 : F q) in * by (apply E.solve_correct; assumption). + case_eq (F_eqb x 0); intro eqb_x_0. + + apply F_eqb_eq in eqb_x_0; rewrite eqb_x_0 in *. + rewrite !Fq_pow_zero, sqrt_mod_q_of_0, Fq_pow_zero by congruence. + rewrite if_F_eq_dec_if_F_eqb, sign_bit_zero. + reflexivity. + + assert (sqrt_mod_q (x ^ 2) <> 0) by (intro false_eq; apply sqrt_mod_q_root_0 in false_eq; try assumption; + apply Fq_root_zero in false_eq; rewrite false_eq, F_eqb_refl in eqb_x_0; congruence). + replace (F_eqb (sqrt_mod_q (x ^ 2)) 0) with false by (symmetry; + apply F_eqb_neq_complete; assumption). + break_if. + - simpl. + f_equal. + break_if. + * rewrite Bool.eqb_true_iff in Heqb. + pose proof (solve_onCurve y solve_sqrt_valid_p). + f_equal. + apply (sign_bit_match _ _ y); auto. + apply E.solve_correct in onCurve_p; rewrite onCurve_p in *. + assumption. + * rewrite Bool.eqb_false_iff in Heqb. + pose proof (solve_opp_onCurve y solve_sqrt_valid_p). + f_equal. + apply sign_bit_opp_eq_iff in Heqb; try assumption. + apply (sign_bit_match _ _ y); auto. + apply E.solve_correct in onCurve_p. + rewrite onCurve_p; auto. + - simpl in solve_sqrt_valid_p. + replace (E.solve_for_x2 y) with (x ^ 2 : F q) in * by (apply E.solve_correct; assumption). + congruence. +Qed. + +Lemma point_dec'_valid : forall p, + point_dec' (point_enc_coordinates (proj1_sig p)) (proj1_sig p) = Some p. +Proof. + unfold point_dec'; intros. + break_match. + + f_equal. + destruct p. + apply E.point_eq. + reflexivity. + + rewrite point_encoding_coordinates_valid in n by apply (proj2_sig p). + congruence. +Qed. + +Lemma point_encoding_valid : forall p, point_dec (point_enc p) = Some p. +Proof. + intros. + unfold point_dec. + replace (point_enc p) with (point_enc_coordinates (proj1_sig p)) by reflexivity. + break_match; rewrite point_encoding_coordinates_valid in * by apply (proj2_sig p); try congruence. + inversion_Some_eq. + eapply point_dec'_valid. +Qed. + +End PointEncoding. diff --git a/src/Encoding/PointEncodingTheorems.v b/src/Encoding/PointEncodingTheorems.v new file mode 100644 index 000000000..ccea1d81b --- /dev/null +++ b/src/Encoding/PointEncodingTheorems.v @@ -0,0 +1,207 @@ +Require Import Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Require Import Coq.Numbers.Natural.Peano.NPeano. +Require Import Coq.Program.Equality. +Require Import Crypto.Encoding.EncodingTheorems. +Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Bedrock.Word. +Require Import Crypto.Tactics.VerdiTactics. + +Require Import Crypto.Spec.Encoding Crypto.Spec.ModularArithmetic Crypto.Spec.CompleteEdwardsCurve. + +Local Open Scope F_scope. + +Section PointEncoding. + Context {prm: CompleteEdwardsCurve.TwistedEdwardsParams} {sz : nat} + {FqEncoding : canonical encoding of ModularArithmetic.F (CompleteEdwardsCurve.q) as Word.word sz} + {q_5mod8 : (CompleteEdwardsCurve.q mod 8 = 5)%Z} + {sqrt_minus1_valid : (@ZToField CompleteEdwardsCurve.q 2 ^ BinInt.Z.to_N (CompleteEdwardsCurve.q / 4)) ^ 2 = opp 1}. + Existing Instance CompleteEdwardsCurve.prime_q. + + Add Field Ffield : (@PrimeFieldTheorems.Ffield_theory CompleteEdwardsCurve.q _) + (morphism (@ModularArithmeticTheorems.Fring_morph CompleteEdwardsCurve.q), + preprocess [ModularArithmeticTheorems.Fpreprocess], + postprocess [ModularArithmeticTheorems.Fpostprocess; try exact PrimeFieldTheorems.Fq_1_neq_0; try assumption], + constants [ModularArithmeticTheorems.Fconstant], + div (@ModularArithmeticTheorems.Fmorph_div_theory CompleteEdwardsCurve.q), + power_tac (@ModularArithmeticTheorems.Fpower_theory CompleteEdwardsCurve.q) [ModularArithmeticTheorems.Fexp_tac]). + + Definition sqrt_valid (a : F q) := ((sqrt_mod_q a) ^ 2 = a)%F. + + Lemma solve_sqrt_valid : forall (p : E.point), + sqrt_valid (E.solve_for_x2 (snd (proj1_sig p))). + Proof. + intros. + destruct p as [[x y] onCurve_xy]; simpl. + rewrite (E.solve_correct x y) in onCurve_xy. + rewrite <- onCurve_xy. + unfold sqrt_valid. + eapply sqrt_mod_q_valid; eauto. + unfold isSquare; eauto. + Grab Existential Variables. eauto. + Qed. + + Lemma solve_onCurve: forall (y : F q), sqrt_valid (E.solve_for_x2 y) -> + E.onCurve (sqrt_mod_q (E.solve_for_x2 y), y). + Proof. + intros. + unfold sqrt_valid in *. + apply E.solve_correct; auto. + Qed. + + Lemma solve_opp_onCurve: forall (y : F q), sqrt_valid (E.solve_for_x2 y) -> + E.onCurve (opp (sqrt_mod_q (E.solve_for_x2 y)), y). + Proof. + intros y sqrt_valid_x2. + unfold sqrt_valid in *. + apply E.solve_correct. + rewrite <- sqrt_valid_x2 at 2. + ring. + Qed. + +Definition sign_bit (x : F q) := (wordToN (enc (opp x)) <? wordToN (enc x))%N. +Definition point_enc (p : E.point) : word (S sz) := let '(x,y) := proj1_sig p in + WS (sign_bit x) (enc y). +Definition point_dec_coordinates (w : word (S sz)) : option (F q * F q) := + match dec (wtl w) with + | None => None + | Some y => let x2 := E.solve_for_x2 y in + let x := sqrt_mod_q x2 in + if F_eq_dec (x ^ 2) x2 + then + let p := (if Bool.eqb (whd w) (sign_bit x) then x else opp x, y) in + Some p + else None + end. + +Definition point_dec (w : word (S sz)) : option E.point := + match dec (wtl w) with + | None => None + | Some y => let x2 := E.solve_for_x2 y in + let x := sqrt_mod_q x2 in + match (F_eq_dec (x ^ 2) x2) with + | right _ => None + | left EQ => if Bool.eqb (whd w) (sign_bit x) + then Some (exist _ (x, y) (solve_onCurve y EQ)) + else Some (exist _ (opp x, y) (solve_opp_onCurve y EQ)) + end + end. + +Lemma point_dec_coordinates_correct w + : option_map (@proj1_sig _ _) (point_dec w) = point_dec_coordinates w. +Proof. + unfold point_dec, point_dec_coordinates. + edestruct dec; [ | reflexivity ]. + edestruct @F_eq_dec; [ | reflexivity ]. + edestruct @Bool.eqb; reflexivity. +Qed. + +Lemma y_decode : forall p, dec (wtl (point_enc p)) = Some (snd (proj1_sig p)). +Proof. + intros. + destruct p as [[x y] onCurve_p]; simpl. + exact (encoding_valid y). +Qed. + + +Lemma wordToN_enc_neq_opp : forall x, x <> 0 -> (wordToN (enc (opp x)) <> wordToN (enc x))%N. +Proof. + intros x x_nonzero. + intro false_eq. + apply x_nonzero. + apply F_eq_opp_zero; try apply two_lt_q. + apply wordToN_inj in false_eq. + apply encoding_inj in false_eq. + auto. +Qed. + +Lemma sign_bit_opp_negb : forall x, x <> 0 -> negb (sign_bit x) = sign_bit (opp x). +Proof. + intros x x_nonzero. + unfold sign_bit. + rewrite <- N.leb_antisym. + rewrite N.ltb_compare, N.leb_compare. + rewrite F_opp_involutive. + case_eq (wordToN (enc x) ?= wordToN (enc (opp x)))%N; auto. + intro wordToN_enc_eq. + pose proof (wordToN_enc_neq_opp x x_nonzero). + apply N.compare_eq_iff in wordToN_enc_eq. + congruence. +Qed. + +Lemma sign_bit_opp : forall x y, y <> 0 -> + (sign_bit x <> sign_bit y <-> sign_bit x = sign_bit (opp y)). +Proof. + split; intro sign_mismatch; case_eq (sign_bit x); case_eq (sign_bit y); + try congruence; intros y_sign x_sign; rewrite <- sign_bit_opp_negb in * by auto; + rewrite y_sign, x_sign in *; reflexivity || discriminate. +Qed. + +Lemma sign_bit_squares : forall x y, y <> 0 -> x ^ 2 = y ^ 2 -> + sign_bit x = sign_bit y -> x = y. +Proof. + intros ? ? y_nonzero squares_eq sign_match. + destruct (sqrt_solutions _ _ squares_eq) as [? | eq_opp]; auto. + assert (sign_bit x = sign_bit (opp y)) as sign_mismatch by (f_equal; auto). + apply sign_bit_opp in sign_mismatch; auto. + congruence. +Qed. + +Lemma sign_bit_match : forall x x' y : F q, E.onCurve (x, y) -> E.onCurve (x', y) -> + sign_bit x = sign_bit x' -> x = x'. +Proof. + intros ? ? ? onCurve_x onCurve_x' sign_match. + apply E.solve_correct in onCurve_x. + apply E.solve_correct in onCurve_x'. + destruct (F_eq_dec x' 0). + + subst. + rewrite Fq_pow_zero in onCurve_x' by congruence. + rewrite <- onCurve_x' in *. + eapply Fq_root_zero; eauto. + + apply sign_bit_squares; auto. + rewrite onCurve_x, onCurve_x'. + reflexivity. +Qed. + +Lemma point_encoding_valid : forall p, point_dec (point_enc p) = Some p. +Proof. + intros. + unfold point_dec. + rewrite y_decode. + pose proof solve_sqrt_valid p as solve_sqrt_valid_p. + unfold sqrt_valid in *. + destruct p as [[x y] onCurve_p]. + simpl in *. + destruct (F_eq_dec ((sqrt_mod_q (E.solve_for_x2 y)) ^ 2) (E.solve_for_x2 y)); intuition. + break_if; f_equal; apply E.point_eq. + + rewrite Bool.eqb_true_iff in Heqb. + pose proof (solve_onCurve y solve_sqrt_valid_p). + f_equal. + apply (sign_bit_match _ _ y); auto. + + rewrite Bool.eqb_false_iff in Heqb. + pose proof (solve_opp_onCurve y solve_sqrt_valid_p). + f_equal. + apply sign_bit_opp in Heqb. + apply (sign_bit_match _ _ y); auto. + intro eq_zero. + apply E.solve_correct in onCurve_p. + rewrite eq_zero in *. + rewrite Fq_pow_zero in solve_sqrt_valid_p by congruence. + rewrite <- solve_sqrt_valid_p in onCurve_p. + apply Fq_root_zero in onCurve_p. + rewrite onCurve_p in Heqb; auto. +Qed. + +(* Waiting on canonicalization *) +Lemma point_encoding_canonical : forall (x_enc : word (S sz)) (x : E.point), +point_dec x_enc = Some x -> point_enc x = x_enc. +Admitted. + +Instance point_encoding : canonical encoding of E.point as (word (S sz)) := { + enc := point_enc; + dec := point_dec; + encoding_valid := point_encoding_valid; + encoding_canonical := point_encoding_canonical +}. + +End PointEncoding. diff --git a/src/ModularArithmetic/ExtendedBaseVector.v b/src/ModularArithmetic/ExtendedBaseVector.v new file mode 100644 index 000000000..2e65df9bd --- /dev/null +++ b/src/ModularArithmetic/ExtendedBaseVector.v @@ -0,0 +1,162 @@ +Require Import Zpower ZArith. +Require Import List. +Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import VerdiTactics. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Crypto.BaseSystem. +Local Open Scope Z_scope. + +Section ExtendedBaseVector. + Context `{prm : PseudoMersenneBaseParams}. + + (* This section defines a new BaseVector that has double the length of the BaseVector + * used to construct [params]. The coefficients of the new vector are as follows: + * + * ext_base[i] = if (i < length base) then base[i] else 2^k * base[i] + * + * The purpose of this construction is that it allows us to multiply numbers expressed + * using [base], obtaining a number expressed using [ext_base]. (Numbers are "expressed" as + * vectors of digits; the value of a digit vector is obtained by doing a dot product with + * the base vector.) So if x, y are digit vectors: + * + * (x \dot base) * (y \dot base) = (z \dot ext_base) + * + * Then we can separate z into its first and second halves: + * + * (z \dot ext_base) = (z1 \dot base) + (2 ^ k) * (z2 \dot base) + * + * Now, if we want to reduce the product modulo 2 ^ k - c: + * + * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + (2 ^ k) * (z2 \dot base) mod (2^k-c) + * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + c * (z2 \dot base) mod (2^k-c) + * + * This sum may be short enough to express using base; if not, we can reduce again. + *) + Definition ext_base := base ++ (map (Z.mul (2^k)) base). + + Lemma ext_base_positive : forall b, In b ext_base -> b > 0. + Proof. + unfold ext_base. intros b In_b_base. + rewrite in_app_iff in In_b_base. + destruct In_b_base as [In_b_base | In_b_extbase]. + + eapply BaseSystem.base_positive. + eapply In_b_base. + + eapply in_map_iff in In_b_extbase. + destruct In_b_extbase as [b' [b'_2k_b In_b'_base]]. + subst. + specialize (BaseSystem.base_positive b' In_b'_base); intro base_pos. + replace 0 with (2 ^ k * 0) by ring. + apply (Zmult_gt_compat_l b' 0 (2 ^ k)); [| apply base_pos; intuition]. + rewrite Z.gt_lt_iff. + apply Z.pow_pos_nonneg; intuition. + pose proof k_nonneg; omega. + Qed. + + Lemma base_length_nonzero : (0 < length base)%nat. + Proof. + assert (nth_default 0 base 0 = 1) by (apply BaseSystem.b0_1). + unfold nth_default in H. + case_eq (nth_error base 0); intros; + try (rewrite H0 in H; omega). + apply (nth_error_value_length _ 0 base z); auto. + Qed. + + Lemma b0_1 : forall x, nth_default x ext_base 0 = 1. + Proof. + intros. unfold ext_base. + rewrite nth_default_app. + assert (0 < length base)%nat by (apply base_length_nonzero). + destruct (lt_dec 0 (length base)); try apply BaseSystem.b0_1; try omega. + Qed. + + Lemma two_k_nonzero : 2^k <> 0. + Proof. + pose proof (Z.pow_eq_0 2 k k_nonneg). + intuition. + Qed. + + Lemma map_nth_default_base_high : forall n, (n < (length base))%nat -> + nth_default 0 (map (Z.mul (2 ^ k)) base) n = + (2 ^ k) * (nth_default 0 base n). + Proof. + intros. + erewrite map_nth_default; auto. + Qed. + + Lemma base_good_over_boundary : forall + (i : nat) + (l : (i < length base)%nat) + (j' : nat) + (Hj': (i + j' < length base)%nat) + , + 2 ^ k * (nth_default 0 base i * nth_default 0 base j') = + 2 ^ k * (nth_default 0 base i * nth_default 0 base j') / + (2 ^ k * nth_default 0 base (i + j')) * + (2 ^ k * nth_default 0 base (i + j')) + . + Proof. + intros. + remember (nth_default 0 base) as b. + rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero). + replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat)) + with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring. + rewrite Z.mul_cancel_l by (exact two_k_nonzero). + replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat)) + with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring. + subst b. + apply (BaseSystem.base_good i j'); omega. + Qed. + + Lemma ext_base_good : + forall i j, (i+j < length ext_base)%nat -> + let b := nth_default 0 ext_base in + let r := (b i * b j) / b (i+j)%nat in + b i * b j = r * b (i+j)%nat. + Proof. + intros. + subst b. subst r. + unfold ext_base in *. + rewrite app_length in H; rewrite map_length in H. + repeat rewrite nth_default_app. + destruct (lt_dec i (length base)); + destruct (lt_dec j (length base)); + destruct (lt_dec (i + j) (length base)); + try omega. + { (* i < length base, j < length base, i + j < length base *) + apply BaseSystem.base_good; auto. + } { (* i < length base, j < length base, i + j >= length base *) + rewrite (map_nth_default _ _ _ _ 0) by omega. + apply base_matches_modulus; omega. + } { (* i < length base, j >= length base, i + j >= length base *) + do 2 rewrite map_nth_default_base_high by omega. + remember (j - length base)%nat as j'. + replace (i + j - length base)%nat with (i + j')%nat by omega. + replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j')) + with (2 ^ k * (nth_default 0 base i * nth_default 0 base j')) + by ring. + eapply base_good_over_boundary; eauto; omega. + } { (* i >= length base, j < length base, i + j >= length base *) + do 2 rewrite map_nth_default_base_high by omega. + remember (i - length base)%nat as i'. + replace (i + j - length base)%nat with (j + i')%nat by omega. + replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j) + with (2 ^ k * (nth_default 0 base j * nth_default 0 base i')) + by ring. + eapply base_good_over_boundary; eauto; omega. + } + Qed. + + Instance ExtBaseVector : BaseSystem.BaseVector ext_base := { + base_positive := ext_base_positive; + b0_1 := b0_1; + base_good := ext_base_good + }. + + Lemma extended_base_length: + length ext_base = (length base + length base)%nat. + Proof. + unfold ext_base; rewrite app_length; rewrite map_length; auto. + Qed. +End ExtendedBaseVector. + diff --git a/src/ModularArithmetic/ModularArithmeticTheorems.v b/src/ModularArithmetic/ModularArithmeticTheorems.v index ddb689547..dabfcf883 100644 --- a/src/ModularArithmetic/ModularArithmeticTheorems.v +++ b/src/ModularArithmetic/ModularArithmeticTheorems.v @@ -6,6 +6,7 @@ Require Import Crypto.Tactics.VerdiTactics. Require Import Coq.ZArith.BinInt Coq.ZArith.Zdiv Coq.ZArith.Znumtheory Coq.NArith.NArith. (* import Zdiv before Znumtheory *) Require Import Coq.Classes.Morphisms Coq.Setoids.Setoid. Require Export Coq.setoid_ring.Ring_theory Coq.setoid_ring.Field_theory Coq.setoid_ring.Field_tac. +Require Export Crypto.Util.IterAssocOp. Section ModularArithmeticPreliminaries. Context {m:Z}. @@ -209,6 +210,14 @@ Section FandZ. reflexivity. Qed. + Lemma pow_nat_iter_op_correct: forall (x:F m) n, (nat_iter_op mul 1) (N.to_nat n) x = x^n. + Proof. + induction n using N.peano_ind; + destruct (F_pow_spec x) as [pow_0 pow_succ]; + rewrite ?N2Nat.inj_succ, ?pow_0, <-?N.add_1_l, ?pow_succ; + simpl; congruence. + Qed. + Lemma mod_plus_zero_subproof a b : 0 mod m = (a + b) mod m -> b mod m = (- a) mod m. Proof. @@ -513,6 +522,11 @@ Section VariousModulo. ring. Qed. + Lemma F_opp_0 : opp (0 : F m) = 0%F. + Proof. + intros; ring. + Qed. + Lemma F_opp_swap : forall x y : F m, opp x = y <-> x = opp y. Proof. split; intro; subst; ring. @@ -523,6 +537,23 @@ Section VariousModulo. intros; ring. Qed. + Lemma F_square_opp : forall x : F m, (opp x ^ 2 = x ^ 2)%F. + Proof. + intros; ring. + Qed. + + Lemma F_mul_opp_r : forall x y : F m, (x * opp y = opp (x * y))%F. + intros; ring. + Qed. + + Lemma F_mul_opp_l : forall x y : F m, (opp x * y = opp (x * y))%F. + intros; ring. + Qed. + + Lemma F_mul_opp_both : forall x y : F m, (opp x * opp y = x * y)%F. + intros; ring. + Qed. + Lemma F_add_0_r : forall x : F m, (x + 0)%F = x. Proof. intros; ring. @@ -675,4 +706,17 @@ Section VariousModulo. replace y with ((y - b) + b) by ring. rewrite Hxayb; ring. Qed. + + Lemma F_FieldToZ_add_opp : forall x : F m, x <> 0 -> (FieldToZ x + FieldToZ (opp x) = m)%Z. + Proof. + intros. + rewrite FieldToZ_opp. + rewrite Z_mod_nz_opp_full, mod_FieldToZ; try omega. + rewrite mod_FieldToZ. + replace 0%Z with (@FieldToZ m 0) by auto. + intro false_eq. + rewrite <-F_eq in false_eq. + congruence. + Qed. + End VariousModulo. diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index b2bea21cf..558b9a5a2 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -3,623 +3,115 @@ Require Import Coq.Lists.List. Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. Require Import Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.ModularArithmetic.ExtendedBaseVector. Require Import Crypto.Tactics.VerdiTactics. Local Open Scope Z_scope. -Module Type PseudoMersenneBaseParams (Import B:BaseCoefs) (Import M:Modulus). - Parameter k : Z. - Parameter c : Z. - Axiom modulus_pseudomersenne : modulus = 2^k - c. - - Axiom base_matches_modulus : - forall i j, - (i < length base)%nat -> - (j < length base)%nat -> - (i+j >= length base)%nat-> - let b := nth_default 0 base in - let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in - b i * b j = r * (2^k * b (i+j-length base)%nat). - - Axiom base_succ : forall i, ((S i) < length base)%nat -> - let b := nth_default 0 base in - b (S i) mod b i = 0. - - Axiom base_tail_matches_modulus: - 2^k mod nth_default 0 base (pred (length base)) = 0. - - (* Probably implied by modulus_pseudomersenne. *) - Axiom k_nonneg : 0 <= k. - -End PseudoMersenneBaseParams. - -Module Type RepZMod (Import M:Modulus). - Parameter T : Type. - Parameter encode : F modulus -> T. - Parameter decode : T -> F modulus. - - Parameter rep : T -> F modulus -> Prop. - Local Notation "u '~=' x" := (rep u x) (at level 70). - Axiom encode_rep : forall x, encode x ~= x. - Axiom rep_decode : forall u x, u ~= x -> decode u = x. - - Parameter add : T -> T -> T. - Axiom add_rep : forall u v x y, u ~= x -> v ~= y -> add u v ~= (x+y)%F. - - Parameter sub : T -> T -> T. - Axiom sub_rep : forall u v x y, u ~= x -> v ~= y -> sub u v ~= (x-y)%F. - - Parameter mul : T -> T -> T. - Axiom mul_rep : forall u v x y, u ~= x -> v ~= y -> mul u v ~= (x*y)%F. -End RepZMod. - -Module PseudoMersenneBase (BC:BaseCoefs) (Import M:PrimeModulus) (P:PseudoMersenneBaseParams BC M) <: RepZMod M. - Module EC <: BaseCoefs. - Definition base := BC.base ++ (map (Z.mul (2^(P.k))) BC.base). - - Lemma base_positive : forall b, In b base -> b > 0. - Proof. - unfold base. intros. - rewrite in_app_iff in H. - destruct H. { - apply BC.base_positive; auto. - } { - specialize BC.base_positive. - induction BC.base; intros. { - simpl in H; intuition. - } { - simpl in H. - destruct H; subst. - replace 0 with (2 ^ P.k * 0) by auto. - apply (Zmult_gt_compat_l a 0 (2 ^ P.k)). - rewrite Z.gt_lt_iff. - apply Z.pow_pos_nonneg; intuition. - pose proof P.k_nonneg; omega. - apply H0; left; auto. - apply IHl; auto. - intros. apply H0; auto. right; auto. - } - } - Qed. - - Lemma base_length_nonzero : (0 < length BC.base)%nat. - Proof. - assert (nth_default 0 BC.base 0 = 1) by (apply BC.b0_1). - unfold nth_default in H. - case_eq (nth_error BC.base 0); intros; - try (rewrite H0 in H; omega). - apply (nth_error_value_length _ 0 BC.base z); auto. - Qed. - - Lemma b0_1 : forall x, nth_default x base 0 = 1. - Proof. - intros. unfold base. - rewrite nth_default_app. - assert (0 < length BC.base)%nat by (apply base_length_nonzero). - destruct (lt_dec 0 (length BC.base)); try apply BC.b0_1; try omega. - Qed. - - Lemma two_k_nonzero : 2^P.k <> 0. - pose proof (Z.pow_eq_0 2 P.k P.k_nonneg). - intuition. - Qed. - - Lemma map_nth_default_base_high : forall n, (n < (length BC.base))%nat -> - nth_default 0 (map (Z.mul (2 ^ P.k)) BC.base) n = - (2 ^ P.k) * (nth_default 0 BC.base n). - Proof. - intros. - erewrite map_nth_default; auto. - Qed. - - Lemma base_succ : forall i, ((S i) < length base)%nat -> - let b := nth_default 0 base in - b (S i) mod b i = 0. - Proof. - intros; subst b; unfold base. - repeat rewrite nth_default_app. - do 2 break_if; try apply P.base_succ; try omega. { - destruct (lt_eq_lt_dec (S i) (length BC.base)). { - destruct s; intuition. - rewrite map_nth_default_base_high by omega. - replace i with (pred(length BC.base)) by omega. - rewrite <- Zmult_mod_idemp_l. - rewrite P.base_tail_matches_modulus; simpl. - rewrite Zmod_0_l; auto. - } { - assert (length BC.base <= i)%nat by (apply lt_n_Sm_le; auto); omega. - } - } { - unfold base in H; rewrite app_length, map_length in H. - repeat rewrite map_nth_default_base_high by omega. - rewrite Zmult_mod_distr_l. - rewrite <- minus_Sn_m by omega. - rewrite P.base_succ by omega; auto. - } - Qed. - - Lemma base_good_over_boundary : forall - (i : nat) - (l : (i < length BC.base)%nat) - (j' : nat) - (Hj': (i + j' < length BC.base)%nat) - , - 2 ^ P.k * (nth_default 0 BC.base i * nth_default 0 BC.base j') = - 2 ^ P.k * (nth_default 0 BC.base i * nth_default 0 BC.base j') / - (2 ^ P.k * nth_default 0 BC.base (i + j')) * - (2 ^ P.k * nth_default 0 BC.base (i + j')) - . - Proof. intros. - remember (nth_default 0 BC.base) as b. - rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero). - replace (b i * b j' / b (i + j')%nat * (2 ^ P.k * b (i + j')%nat)) - with ((2 ^ P.k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring. - rewrite Z.mul_cancel_l by (exact two_k_nonzero). - replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat)) - with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring. - subst b. - apply (BC.base_good i j'); omega. - Qed. - - Lemma base_good : - forall i j, (i+j < length base)%nat -> - let b := nth_default 0 base in - let r := (b i * b j) / b (i+j)%nat in - b i * b j = r * b (i+j)%nat. - Proof. - intros. - subst b. subst r. - unfold base in *. - rewrite app_length in H; rewrite map_length in H. - repeat rewrite nth_default_app. - destruct (lt_dec i (length BC.base)); - destruct (lt_dec j (length BC.base)); - destruct (lt_dec (i + j) (length BC.base)); - try omega. - { (* i < length BC.base, j < length BC.base, i + j < length BC.base *) - apply BC.base_good; auto. - } { (* i < length BC.base, j < length BC.base, i + j >= length BC.base *) - rewrite (map_nth_default _ _ _ _ 0) by omega. - apply P.base_matches_modulus; omega. - } { (* i < length BC.base, j >= length BC.base, i + j >= length BC.base *) - do 2 rewrite map_nth_default_base_high by omega. - remember (j - length BC.base)%nat as j'. - replace (i + j - length BC.base)%nat with (i + j')%nat by omega. - replace (nth_default 0 BC.base i * (2 ^ P.k * nth_default 0 BC.base j')) - with (2 ^ P.k * (nth_default 0 BC.base i * nth_default 0 BC.base j')) - by ring. - eapply base_good_over_boundary; eauto; omega. - } { (* i >= length BC.base, j < length BC.base, i + j >= length BC.base *) - do 2 rewrite map_nth_default_base_high by omega. - remember (i - length BC.base)%nat as i'. - replace (i + j - length BC.base)%nat with (j + i')%nat by omega. - replace (2 ^ P.k * nth_default 0 BC.base i' * nth_default 0 BC.base j) - with (2 ^ P.k * (nth_default 0 BC.base j * nth_default 0 BC.base i')) - by ring. - eapply base_good_over_boundary; eauto; omega. - } - Qed. - End EC. - - Module E := BaseSystem EC. - Module B := BaseSystem BC. - - Definition T := B.digits. - Local Hint Unfold T. - Definition decode (us : T) : F modulus := ZToField (B.decode us). - Local Hint Unfold decode. - Definition rep (us : T) (x : F modulus) := (length us <= length BC.base)%nat /\ decode us = x. +Section PseudoMersenneBase. + Context `{prm :PseudoMersenneBaseParams}. + + Definition decode (us : digits) : F modulus := ZToField (BaseSystem.decode base us). + + Definition rep (us : digits) (x : F modulus) := (length us <= length base)%nat /\ decode us = x. Local Notation "u '~=' x" := (rep u x) (at level 70). Local Hint Unfold rep. - Lemma rep_decode : forall us x, us ~= x -> decode us = x. - Proof. - autounfold; intuition. - Qed. - - Definition encode (x : F modulus) := B.encode x. - - Lemma encode_rep : forall x : F modulus, encode x ~= x. - Proof. - intros. unfold encode, rep. - split. { - unfold B.encode; simpl. - apply EC.base_length_nonzero. - } { - unfold decode. - rewrite B.encode_rep. - apply ZToField_idempotent. (* TODO: rename this lemma *) - } - Qed. - - Definition add (us vs : T) := B.add us vs. - Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> add u v ~= (x+y)%F. - Proof. - autounfold; intuition. { - unfold add. - rewrite B.add_length_le_max. - case_max; try rewrite Max.max_r; omega. - } - unfold decode in *; unfold B.decode in *. - rewrite B.add_rep. - rewrite ZToField_add. - subst; auto. - Qed. + Definition encode (x : F modulus) := encode x. - Definition sub (us vs : T) := B.sub us vs. - Lemma sub_rep : forall u v x y, u ~= x -> v ~= y -> sub u v ~= (x-y)%F. - Proof. - autounfold; intuition. { - rewrite B.sub_length_le_max. - case_max; try rewrite Max.max_r; omega. - } - unfold decode in *; unfold B.decode in *. - rewrite B.sub_rep. - rewrite ZToField_sub. - subst; auto. - Qed. + (* Converts from length of extended base to length of base by reduction modulo M.*) + Definition reduce (us : digits) : digits := + let high := skipn (length base) us in + let low := firstn (length base) us in + let wrap := map (Z.mul c) high in + BaseSystem.add low wrap. - Lemma decode_short : forall (us : B.digits), - (length us <= length BC.base)%nat -> B.decode us = E.decode us. - Proof. - intros. - unfold B.decode, B.decode', E.decode, E.decode'. - rewrite combine_truncate_r. - rewrite (combine_truncate_r us EC.base). - f_equal; f_equal. - unfold EC.base. - rewrite firstn_app_inleft; auto; omega. - Qed. + Definition mul (us vs : digits) := reduce (BaseSystem.mul ext_base us vs). - Lemma extended_base_length: - length EC.base = (length BC.base + length BC.base)%nat. - Proof. - unfold EC.base; rewrite app_length; rewrite map_length; auto. - Qed. + Definition sub (xs : digits) (xs_0_mod : (BaseSystem.decode base xs) mod modulus = 0) (us vs : digits) := + BaseSystem.sub (add xs us) vs. - Lemma mul_rep_extended : forall (us vs : B.digits), - (length us <= length BC.base)%nat -> - (length vs <= length BC.base)%nat -> - B.decode us * B.decode vs = E.decode (E.mul us vs). - Proof. - intros. - rewrite E.mul_rep by (unfold EC.base; simpl_list; omega). - f_equal; rewrite decode_short; auto. - Qed. +End PseudoMersenneBase. - (* Converts from length of E.base to length of B.base by reduction modulo M.*) - Definition reduce (us : E.digits) : B.digits := - let high := skipn (length BC.base) us in - let low := firstn (length BC.base) us in - let wrap := map (Z.mul P.c) high in - B.add low wrap. +Section CarryBasePow2. + Context `{prm :PseudoMersenneBaseParams}. - Lemma two_k_nonzero : 2^P.k <> 0. - pose proof (Z.pow_eq_0 2 P.k P.k_nonneg). - intuition. - Qed. - - Lemma modulus_nonzero : modulus <> 0. - pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega. - Qed. - - (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *) - Lemma pseudomersenne_add: forall x y, (x + ((2^P.k) * y)) mod modulus = (x + (P.c * y)) mod modulus. - Proof. - intros. - replace (2^P.k) with (((2^P.k) - P.c) + P.c) by auto. - rewrite Z.mul_add_distr_r. - rewrite Zplus_mod. - rewrite <- P.modulus_pseudomersenne. - rewrite Z.mul_comm. - rewrite mod_mult_plus; auto using modulus_nonzero. - rewrite <- Zplus_mod; auto. - Qed. - - Lemma extended_shiftadd: forall (us : E.digits), E.decode us = - B.decode (firstn (length BC.base) us) + - (2^P.k * B.decode (skipn (length BC.base) us)). - Proof. - intros. - unfold B.decode, E.decode; rewrite <- B.mul_each_rep. - replace B.decode' with E.decode' by auto. - unfold EC.base. - replace (map (Z.mul (2 ^ P.k)) BC.base) with (E.mul_each (2 ^ P.k) BC.base) by auto. - rewrite E.base_mul_app. - rewrite <- E.mul_each_rep; auto. - Qed. - - Lemma reduce_rep : forall us, B.decode (reduce us) mod modulus = (E.decode us) mod modulus. - Proof. - intros. - rewrite extended_shiftadd. - rewrite pseudomersenne_add. - unfold reduce. - remember (firstn (length BC.base) us) as low. - remember (skipn (length BC.base) us) as high. - unfold B.decode. - rewrite B.add_rep. - replace (map (Z.mul P.c) high) with (B.mul_each P.c high) by auto. - rewrite B.mul_each_rep; auto. - Qed. - - Lemma reduce_length : forall us, - (length us <= length EC.base)%nat -> - (length (reduce us) <= length (BC.base))%nat. - Proof. - intros. - unfold reduce. - remember (map (Z.mul P.c) (skipn (length BC.base) us)) as high. - remember (firstn (length BC.base) us) as low. - assert (length low >= length high)%nat. { - subst. rewrite firstn_length. - rewrite map_length. - rewrite skipn_length. - destruct (le_dec (length BC.base) (length us)). { - rewrite Min.min_l by omega. - rewrite extended_base_length in H. omega. - } { - rewrite Min.min_r by omega. omega. - } - } - assert ((length low <= length BC.base)%nat) - by (rewrite Heqlow; rewrite firstn_length; apply Min.le_min_l). - assert (length high <= length BC.base)%nat - by (rewrite Heqhigh; rewrite map_length; rewrite skipn_length; - rewrite extended_base_length in H; omega). - rewrite B.add_trailing_zeros; auto. - rewrite (B.add_same_length _ _ (length low)); auto. - rewrite app_length. - rewrite B.length_zeros; intuition. - Qed. - - Definition mul (us vs : T) := reduce (E.mul us vs). - Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> mul u v ~= (x*y)%F. - Proof. - autounfold; unfold mul; intuition. { - rewrite reduce_length; try omega. - rewrite E.mul_length. - rewrite extended_base_length. - omega. - } - rewrite ZToField_mod, reduce_rep, <-ZToField_mod. - rewrite E.mul_rep; try (rewrite extended_base_length; omega). - subst; auto. - replace (E.decode u) with (B.decode u) by (apply decode_short; omega). - replace (E.decode v) with (B.decode v) by (apply decode_short; omega). - apply ZToField_mul. - Qed. + Definition log_cap i := nth_default 0 limb_widths i. Definition add_to_nth n (x:Z) xs := set_nth n (x + nth_default 0 xs n) xs. - Hint Unfold add_to_nth. - - (* i must be in the domain of BC.base *) - Definition cap i := - if eq_nat_dec i (pred (length BC.base)) - then (2^P.k) / nth_default 0 BC.base i - else nth_default 0 BC.base (S i) / nth_default 0 BC.base i. + + Definition pow2_mod n i := Z.land n (Z.ones i). Definition carry_simple i := fun us => let di := nth_default 0 us i in - let us' := set_nth i (di mod cap i) us in - add_to_nth (S i) ( (di / cap i)) us'. + let us' := set_nth i (pow2_mod di (log_cap i)) us in + add_to_nth (S i) ( (Z.shiftr di (log_cap i))) us'. Definition carry_and_reduce i := fun us => let di := nth_default 0 us i in - let us' := set_nth i (di mod cap i) us in - add_to_nth 0 (P.c * (di / cap i)) us'. + let us' := set_nth i (pow2_mod di (log_cap i)) us in + add_to_nth 0 (c * (Z.shiftr di (log_cap i))) us'. - Definition carry i : B.digits -> B.digits := - if eq_nat_dec i (pred (length BC.base)) + Definition carry i : digits -> digits := + if eq_nat_dec i (pred (length base)) then carry_and_reduce i else carry_simple i. - Lemma decode'_splice : forall xs ys bs, - B.decode' bs (xs ++ ys) = - B.decode' (firstn (length xs) bs) xs + - B.decode' (skipn (length xs) bs) ys. - Proof. - induction xs; destruct ys, bs; boring. - unfold B.decode'. - rewrite combine_truncate_r. - ring. - Qed. - - Lemma set_nth_sum : forall n x us, (n < length us)%nat -> - B.decode (set_nth n x us) = - (x - nth_default 0 us n) * nth_default 0 BC.base n + B.decode us. - Proof. - intros. - unfold B.decode. - nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*) - unfold splice_nth. - rewrite <- (firstn_skipn n us) at 4. - do 2 rewrite decode'_splice. - remember (length (firstn n us)) as n0. - ring_simplify. - remember (B.decode' (firstn n0 BC.base) (firstn n us)). - rewrite (skipn_nth_default n us 0) by omega. - rewrite firstn_length in Heqn0. - rewrite Min.min_l in Heqn0 by omega; subst n0. - destruct (le_lt_dec (length BC.base) n). { - rewrite nth_default_out_of_bounds by auto. - rewrite skipn_all by omega. - do 2 rewrite B.decode_base_nil. - ring_simplify; auto. - } { - rewrite (skipn_nth_default n BC.base 0) by omega. - do 2 rewrite B.decode'_cons. - ring_simplify; ring. - } - Qed. - - Lemma add_to_nth_sum : forall n x us, (n < length us)%nat -> - B.decode (add_to_nth n x us) = - x * nth_default 0 BC.base n + B.decode us. - Proof. - unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto. - Qed. - - Lemma nth_default_base_positive : forall i, (i < length BC.base)%nat -> - nth_default 0 BC.base i > 0. - Proof. - intros. - pose proof (nth_error_length_exists_value _ _ H). - destruct H0. - pose proof (nth_error_value_In _ _ _ H0). - pose proof (BC.base_positive _ H1). - unfold nth_default. - rewrite H0; auto. - Qed. - - Lemma base_succ_div_mult : forall i, ((S i) < length BC.base)%nat -> - nth_default 0 BC.base (S i) = nth_default 0 BC.base i * - (nth_default 0 BC.base (S i) / nth_default 0 BC.base i). - Proof. - intros. - apply Z_div_exact_2; try (apply nth_default_base_positive; omega). - apply P.base_succ; auto. - Qed. - - Lemma base_length_lt_pred : (pred (length BC.base) < length BC.base)%nat. - Proof. - pose proof EC.base_length_nonzero; omega. - Qed. - Hint Resolve base_length_lt_pred. - - Lemma cap_positive: forall i, (i < length BC.base)%nat -> cap i > 0. - Proof. - unfold cap; intros; break_if. { - apply div_positive_gt_0; try (subst; apply P.base_tail_matches_modulus). { - rewrite <- two_p_equiv. - apply two_p_gt_ZERO. - apply P.k_nonneg. - } { - apply nth_default_base_positive; subst; auto. - } - } { - apply div_positive_gt_0; try (apply P.base_succ; omega); - try (apply nth_default_base_positive; omega). - } - Qed. - - Lemma cap_div_mod : forall us i, (i < (pred (length BC.base)))%nat -> - let di := nth_default 0 us i in - (di - (di mod cap i)) * nth_default 0 BC.base i = - (di / cap i) * nth_default 0 BC.base (S i). - Proof. - intros. - rewrite (Z_div_mod_eq di (cap i)) at 1 by (apply cap_positive; omega); - ring_simplify. - unfold cap; break_if; intuition. - rewrite base_succ_div_mult at 4 by omega; ring. - Qed. - - Lemma carry_simple_decode_eq : forall i us, - (length us = length BC.base) -> - (i < (pred (length BC.base)))%nat -> - B.decode (carry_simple i us) = B.decode us. - Proof. - unfold carry_simple. intros. - rewrite add_to_nth_sum by (rewrite length_set_nth; omega). - rewrite set_nth_sum by omega. - rewrite <- cap_div_mod by auto; ring_simplify; auto. - Qed. - - Lemma two_k_div_mul_last : - 2 ^ P.k = nth_default 0 BC.base (pred (length BC.base)) * - (2 ^ P.k / nth_default 0 BC.base (pred (length BC.base))). - Proof. - intros. - pose proof P.base_tail_matches_modulus. - rewrite (Z_div_mod_eq (2 ^ P.k) (nth_default 0 BC.base (pred (length BC.base)))) at 1 by - (apply nth_default_base_positive; auto); omega. - Qed. - - Lemma cap_div_mod_reduce : forall us, - let i := pred (length BC.base) in - let di := nth_default 0 us i in - (di - (di mod cap i)) * nth_default 0 BC.base i = - (di / cap i) * 2 ^ P.k. - Proof. - intros. - rewrite (Z_div_mod_eq di (cap i)) at 1 by - (apply cap_positive; auto); ring_simplify. - unfold cap; break_if; intuition. - rewrite Z.mul_comm, Z.mul_assoc. - subst i; rewrite <- two_k_div_mul_last; auto. - Qed. - - Lemma carry_decode_eq_reduce : forall us, - (length us = length BC.base) -> - B.decode (carry_and_reduce (pred (length BC.base)) us) mod modulus - = B.decode us mod modulus. - Proof. - unfold carry_and_reduce; intros. - pose proof EC.base_length_nonzero. - rewrite add_to_nth_sum by (rewrite length_set_nth; omega). - rewrite set_nth_sum by omega. - rewrite Zplus_comm; rewrite <- Z.mul_assoc. - rewrite <- pseudomersenne_add. - rewrite BC.b0_1. - rewrite (Z.mul_comm (2 ^ P.k)). - rewrite <- Zred_factor0. - rewrite <- cap_div_mod_reduce by auto; auto. - Qed. - - Lemma carry_length : forall i us, - (length us <= length BC.base)%nat -> - (length (carry i us) <= length BC.base)%nat. - Proof. - unfold carry, carry_simple, carry_and_reduce, add_to_nth. - intros; break_if; subst; repeat (rewrite length_set_nth); auto. - Qed. - Hint Resolve carry_length. - - Lemma carry_rep : forall i us x, - (length us = length BC.base) -> - (i < length BC.base)%nat -> - us ~= x -> carry i us ~= x. - Proof. - pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq. - unfold rep, decode, carry in *; intros. - intuition; break_if; subst; eauto; - apply F_eq; simpl; intuition. - Qed. - Hint Resolve carry_rep. - Definition carry_sequence is us := fold_right carry us is. - Lemma carry_sequence_length: forall is us, - (length us <= length BC.base)%nat -> - (length (carry_sequence is us) <= length BC.base)%nat. - Proof. - induction is; boring. - Qed. - Hint Resolve carry_sequence_length. - - Lemma carry_length_exact : forall i us, - (length us = length BC.base)%nat -> - (length (carry i us) = length BC.base)%nat. - Proof. - unfold carry, carry_simple, carry_and_reduce, add_to_nth. - intros; break_if; subst; repeat (rewrite length_set_nth); auto. - Qed. - - Lemma carry_sequence_length_exact: forall is us, - (length us = length BC.base)%nat -> - (length (carry_sequence is us) = length BC.base)%nat. - Proof. - induction is; boring. - apply carry_length_exact; auto. - Qed. - Hint Resolve carry_sequence_length_exact. - - Lemma carry_sequence_rep : forall is us x, - (forall i, In i is -> (i < length BC.base)%nat) -> - (length us = length BC.base) -> - us ~= x -> carry_sequence is us ~= x. - Proof. - induction is; boring. - Qed. -End PseudoMersenneBase.
\ No newline at end of file +End CarryBasePow2. + +Section Canonicalization. + Context `{prm :PseudoMersenneBaseParams}. + + Fixpoint make_chain i := + match i with + | O => nil + | S i' => i' :: make_chain i' + end. + + (* compute at compile time *) + Definition full_carry_chain := make_chain (length limb_widths). + + (* compute at compile time *) + Definition max_ones := Z.ones + ((fix loop current_max lw := + match lw with + | nil => current_max + | w :: lw' => loop (Z.max w current_max) lw' + end + ) 0 limb_widths). + + (* compute at compile time? *) + Definition carry_full := carry_sequence full_carry_chain. + + Definition max_bound i := Z.ones (log_cap i). + + Definition isFull us := + (fix loop full i := + match i with + | O => full (* don't test 0; the test for 0 is the initial value of [full]. *) + | S i' => loop (andb (Z.eqb (max_bound i) (nth_default 0 us i)) full) i' + end + ) (Z.ltb (max_bound 0 - (c + 1)) (nth_default 0 us 0)) (length us - 1)%nat. + + Fixpoint range' n m := + match m with + | O => nil + | S m' => (n - m)%nat :: range' n m' + end. + + Definition range n := range' n n. + + Definition land_max_bound and_term i := Z.land and_term (max_bound i). + + Definition freeze us := + let us' := carry_full (carry_full (carry_full us)) in + let and_term := if isFull us' then max_ones else 0 in + (* [and_term] is all ones if us' is full, so the subtractions subtract q overall. + Otherwise, it's all zeroes, and the subtractions do nothing. *) + map (fun x => (snd x) - land_max_bound and_term (fst x)) (combine (range (length us')) us'). + +End Canonicalization. diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v new file mode 100644 index 000000000..981680b4a --- /dev/null +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -0,0 +1,463 @@ +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseRep. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. +Require Import Crypto.ModularArithmetic.ExtendedBaseVector. +Require Import Crypto.BaseSystem Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Coq.Lists.List. +Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil Crypto.Util.CaseUtil. +Import ListNotations. +Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Require Import Coq.QArith.QArith Coq.QArith.Qround. +Require Import Crypto.Tactics.VerdiTactics. +Local Open Scope Z. + +(* Computed versions of some functions. *) + +Definition Z_add_opt := Eval compute in Z.add. +Definition Z_sub_opt := Eval compute in Z.sub. +Definition Z_mul_opt := Eval compute in Z.mul. +Definition Z_div_opt := Eval compute in Z.div. +Definition Z_pow_opt := Eval compute in Z.pow. +Definition Z_opp_opt := Eval compute in Z.opp. +Definition Z_shiftl_opt := Eval compute in Z.shiftl. +Definition Z_shiftl_by_opt := Eval compute in Z_shiftl_by. + +Definition nth_default_opt {A} := Eval compute in @nth_default A. +Definition set_nth_opt {A} := Eval compute in @set_nth A. +Definition map_opt {A B} := Eval compute in @map A B. +Definition base_from_limb_widths_opt := Eval compute in base_from_limb_widths. + +Definition Let_In {A P} (x : A) (f : forall y : A, P y) + := let y := x in f y. + +(* Some automation that comes in handy when constructing base parameters *) +Ltac opt_step := + match goal with + | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ] + => refine (_ : match e with nil => _ | _ => _ end = _); + destruct e + end. + +Ltac brute_force_indices limb_widths := intros; unfold sum_firstn, limb_widths; simpl in *; + repeat match goal with + | _ => progress simpl in * + | _ => reflexivity + | [H : (S _ < S _)%nat |- _ ] => apply lt_S_n in H + | [H : (?x + _ < _)%nat |- _ ] => is_var x; destruct x + | [H : (?x < _)%nat |- _ ] => is_var x; destruct x + | _ => omega + end. + + +Definition limb_widths_from_len len k := Eval compute in + (fix loop i prev := + match i with + | O => nil + | S i' => let x := (if (Z.eq_dec ((k * Z.of_nat (len - i + 1)) mod (Z.of_nat len)) 0) + then (k * Z.of_nat (len - i + 1)) / Z.of_nat len + else (k * Z.of_nat (len - i + 1)) / Z.of_nat len + 1)in + x - prev:: (loop i' x) + end) len 0. + +Ltac construct_params prime_modulus len k := + let lw := fresh "lw" in set (lw := limb_widths_from_len len k); + cbv in lw; + eapply Build_PseudoMersenneBaseParams with (limb_widths := lw); + [ abstract (apply fold_right_and_True_forall_In_iff; simpl; repeat (split; [omega |]); auto) + | abstract (unfold limb_widths; cbv; congruence) + | abstract brute_force_indices lw + | abstract apply prime_modulus + | abstract brute_force_indices lw]. + +Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits := + match limb_widths with + | nil => nil + | x :: tail => + 2 ^ (x + 1) - (2 * c) :: map (fun w => 2 ^ (w + 1) - 2) tail + end. + +Ltac subst_precondition := match goal with + | [H : ?P, H' : ?P -> _ |- _] => specialize (H' H); clear H +end. + +Ltac kill_precondition H := + forward H; [abstract (try exact eq_refl; clear; cbv; intros; repeat break_or_hyp; intuition)|]; + subst_precondition. + +Ltac compute_formula := + match goal with + | [H : _ -> _ -> PseudoMersenneBaseRep.rep _ ?result |- PseudoMersenneBaseRep.rep _ ?result] => kill_precondition H; compute_formula + | [H : _ -> PseudoMersenneBaseRep.rep _ ?result |- PseudoMersenneBaseRep.rep _ ?result] => kill_precondition H; compute_formula + | [H : @PseudoMersenneBaseRep.rep ?M ?P _ ?result |- @PseudoMersenneBaseRep.rep ?M ?P _ ?result] => + let m := fresh "m" in set (m := M) in H at 1; change M with m at 1; + let p := fresh "p" in set (p := P) in H at 1; change P with p at 1; + let r := fresh "r" in set (r := result) in H |- *; + cbv -[m p r PseudoMersenneBaseRep.rep] in H; + repeat rewrite ?Z.mul_1_r, ?Z.add_0_l, ?Z.add_assoc, ?Z.mul_assoc in H; + exact H + end. + +Section Carries. + Context `{prm : PseudoMersenneBaseParams} + (* allows caller to precompute k and c *) + (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_). + + Definition carry_opt_sig + (i : nat) (b : digits) + : { d : digits | (i < length limb_widths)%nat -> d = carry i b }. + Proof. + eexists ; intros. + cbv [carry]. + rewrite <- pull_app_if_sumbool. + cbv beta delta + [carry carry_and_reduce carry_simple add_to_nth log_cap + pow2_mod Z.ones Z.pred base + PseudoMersenneBaseParams.limb_widths]. + change @nth_default with @nth_default_opt in *. + change @set_nth with @set_nth_opt in *. + lazymatch goal with + | [ |- _ = (if ?br then ?c else ?d) ] + => let x := fresh "x" in let y := fresh "y" in evar (x:digits); evar (y:digits); transitivity (if br then x else y); subst x; subst y + end. + 2:cbv zeta. + 2:break_if; reflexivity. + + change @nth_default with @nth_default_opt. + rewrite c_subst. + change @set_nth with @set_nth_opt. + change @map with @map_opt. + rewrite <- @beq_nat_eq_nat_dec. + change base_from_limb_widths with base_from_limb_widths_opt. + reflexivity. + Defined. + + Definition carry_opt i b + := Eval cbv beta iota delta [proj1_sig carry_opt_sig] in proj1_sig (carry_opt_sig i b). + + Definition carry_opt_correct i b : (i < length limb_widths)%nat -> carry_opt i b = carry i b := proj2_sig (carry_opt_sig i b). + + Definition carry_sequence_opt_sig (is : list nat) (us : digits) + : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }. + Proof. + eexists. intros H. + cbv [carry_sequence]. + transitivity (fold_right carry_opt us is). + Focus 2. + { induction is; [ reflexivity | ]. + simpl; rewrite IHis, carry_opt_correct. + - reflexivity. + - rewrite base_length in H. + apply H; apply in_eq. + - intros. apply H. right. auto. + } + Unfocus. + reflexivity. + Defined. + + Definition carry_sequence_opt is us := Eval cbv [proj1_sig carry_sequence_opt_sig] in + proj1_sig (carry_sequence_opt_sig is us). + + Definition carry_sequence_opt_correct is us + : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt is us = carry_sequence is us + := proj2_sig (carry_sequence_opt_sig is us). + + Definition carry_opt_cps_sig + {T} + (i : nat) + (f : digits -> T) + (b : digits) + : { d : T | (i < length base)%nat -> d = f (carry i b) }. + Proof. + eexists. intros H. + rewrite <- carry_opt_correct by (rewrite base_length in H; assumption). + cbv beta iota delta [carry_opt]. + let LHS := match goal with |- ?LHS = ?RHS => LHS end in + let RHS := match goal with |- ?LHS = ?RHS => RHS end in + let RHSf := match (eval pattern (nth_default_opt 0%Z b i) in RHS) with ?RHSf _ => RHSf end in + change (LHS = Let_In (nth_default_opt 0%Z b i) RHSf). + change Z.shiftl with Z_shiftl_opt. + change (-1) with (Z_opp_opt 1). + change Z.add with Z_add_opt at 8 12 20 24. + reflexivity. + Defined. + + Definition carry_opt_cps {T} i f b + := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b). + + Definition carry_opt_cps_correct {T} i f b : + (i < length base)%nat -> + @carry_opt_cps T i f b = f (carry i b) + := proj2_sig (carry_opt_cps_sig i f b). + + Definition carry_sequence_opt_cps_sig (is : list nat) (us : digits) + : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }. + Proof. + eexists. + cbv [carry_sequence]. + transitivity (fold_right carry_opt_cps id (List.rev is) us). + Focus 2. + { + assert (forall i, In i (rev is) -> i < length base)%nat as Hr. { + subst. intros. rewrite <- in_rev in *. auto. } + remember (rev is) as ris eqn:Heq. + rewrite <- (rev_involutive is), <- Heq. + clear H Heq is. + rewrite fold_left_rev_right. + revert us; induction ris; [ reflexivity | ]; intros. + { simpl. + rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption]. + rewrite carry_opt_cps_correct; [reflexivity|]. + apply Hr; left; reflexivity. + } } + Unfocus. + reflexivity. + Defined. + + Definition carry_sequence_opt_cps is us := Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in + proj1_sig (carry_sequence_opt_cps_sig is us). + + Definition carry_sequence_opt_cps_correct is us + : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps is us = carry_sequence is us + := proj2_sig (carry_sequence_opt_cps_sig is us). + + + Lemma carry_sequence_opt_cps_rep + : forall (is : list nat) (us : list Z) (x : F modulus), + (forall i : nat, In i is -> i < length base)%nat -> + length us = length base -> + rep us x -> rep (carry_sequence_opt_cps is us) x. + Proof. + intros. + rewrite carry_sequence_opt_cps_correct by assumption. + apply carry_sequence_rep; assumption. + Qed. + +End Carries. + +Section Addition. + Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}. + + Definition add_opt_sig (us vs : T) : { b : digits | b = add us vs }. + Proof. + eexists. + cbv [BaseSystem.add]. + reflexivity. + Defined. + + Definition add_opt (us vs : T) : digits + := Eval cbv [proj1_sig add_opt_sig] in proj1_sig (add_opt_sig us vs). + + Definition add_opt_correct us vs + : add_opt us vs = add us vs + := proj2_sig (add_opt_sig us vs). + + Lemma add_opt_rep: forall (u v : T) (x y : F modulus), + PseudoMersenneBaseRep.rep u x -> PseudoMersenneBaseRep.rep v y -> + PseudoMersenneBaseRep.rep (add_opt u v) (x + y)%F. + Proof. + intros. + rewrite add_opt_correct. + auto using PseudoMersenneBaseRep.add_rep. + Qed. + +End Addition. + +Section Subtraction. + Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}. + + Definition sub_opt_sig (us vs : T) : { b : digits | b = sub coeff coeff_mod us vs }. + Proof. + eexists. + cbv [BaseSystem.add ModularBaseSystem.sub BaseSystem.sub]. + reflexivity. + Defined. + + Definition sub_opt (us vs : T) : digits + := Eval cbv [proj1_sig sub_opt_sig] in proj1_sig (sub_opt_sig us vs). + + Definition sub_opt_correct us vs + : sub_opt us vs = sub coeff coeff_mod us vs + := proj2_sig (sub_opt_sig us vs). + + Lemma sub_opt_rep: forall (u v : T) (x y : F modulus), + PseudoMersenneBaseRep.rep u x -> PseudoMersenneBaseRep.rep v y -> + PseudoMersenneBaseRep.rep (sub_opt u v) (x - y)%F. + Proof. + intros. + rewrite sub_opt_correct. + change (sub coeff coeff_mod) with PseudoMersenneBaseRep.sub. + apply PseudoMersenneBaseRep.sub_rep; auto using coeff_length. + Qed. + +End Subtraction. + +Section Multiplication. + Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm} + (* allows caller to precompute k and c *) + (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_). + Definition mul_bi'_step + (mul_bi' : nat -> digits -> list Z -> list Z) + (i : nat) (vsr : digits) (bs : list Z) + : list Z + := match vsr with + | [] => [] + | v :: vsr' => (v * crosscoef bs i (length vsr'))%Z :: mul_bi' i vsr' bs + end. + + Definition mul_bi'_opt_step_sig + (mul_bi' : nat -> digits -> list Z -> list Z) + (i : nat) (vsr : digits) (bs : list Z) + : { l : list Z | l = mul_bi'_step mul_bi' i vsr bs }. + Proof. + eexists. + cbv [mul_bi'_step]. + opt_step. + { reflexivity. } + { cbv [crosscoef ext_base base]. + change Z.div with Z_div_opt. + change Z.mul with Z_mul_opt at 2. + change @nth_default with @nth_default_opt. + reflexivity. } + Defined. + + Definition mul_bi'_opt_step + (mul_bi' : nat -> digits -> list Z -> list Z) + (i : nat) (vsr : digits) (bs : list Z) + : list Z + := Eval cbv [proj1_sig mul_bi'_opt_step_sig] in + proj1_sig (mul_bi'_opt_step_sig mul_bi' i vsr bs). + + Fixpoint mul_bi'_opt + (i : nat) (vsr : digits) (bs : list Z) {struct vsr} + : list Z + := mul_bi'_opt_step mul_bi'_opt i vsr bs. + + Definition mul_bi'_opt_correct + (i : nat) (vsr : digits) (bs : list Z) + : mul_bi'_opt i vsr bs = mul_bi' bs i vsr. + Proof. + revert i; induction vsr as [|vsr vsrs IHvsr]; intros. + { reflexivity. } + { simpl mul_bi'. + rewrite <- IHvsr; clear IHvsr. + unfold mul_bi'_opt, mul_bi'_opt_step. + apply f_equal2; [ | reflexivity ]. + cbv [crosscoef ext_base base]. + change Z.div with Z_div_opt. + change Z.mul with Z_mul_opt at 2. + change @nth_default with @nth_default_opt. + reflexivity. } + Qed. + + Definition mul'_step + (mul' : digits -> digits -> list Z -> digits) + (usr vs : digits) (bs : list Z) + : digits + := match usr with + | [] => [] + | u :: usr' => add (mul_each u (mul_bi bs (length usr') vs)) (mul' usr' vs bs) + end. + + Lemma map_zeros : forall a n l, + map (Z.mul a) (zeros n ++ l) = zeros n ++ map (Z.mul a) l. + Admitted. + + Definition mul'_opt_step_sig + (mul' : digits -> digits -> list Z -> digits) + (usr vs : digits) (bs : list Z) + : { d : digits | d = mul'_step mul' usr vs bs }. + Proof. + eexists. + cbv [mul'_step]. + match goal with + | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ] + => refine (_ : match e with nil => _ | _ => _ end = _); + destruct e + end. + { reflexivity. } + { cbv [mul_each mul_bi]. + rewrite <- mul_bi'_opt_correct. + rewrite map_zeros. + change @map with @map_opt. + cbv [zeros]. + reflexivity. } + Defined. + + Definition mul'_opt_step + (mul' : digits -> digits -> list Z -> digits) + (usr vs : digits) (bs : list Z) + : digits + := Eval cbv [proj1_sig mul'_opt_step_sig] in proj1_sig (mul'_opt_step_sig mul' usr vs bs). + + Fixpoint mul'_opt + (usr vs : digits) (bs : list Z) + : digits + := mul'_opt_step mul'_opt usr vs bs. + + Definition mul'_opt_correct + (usr vs : digits) (bs : list Z) + : mul'_opt usr vs bs = mul' bs usr vs. + Proof. + revert vs; induction usr as [|usr usrs IHusr]; intros. + { reflexivity. } + { simpl. + rewrite <- IHusr; clear IHusr. + apply f_equal2; [ | reflexivity ]. + cbv [mul_each mul_bi]. + rewrite map_zeros. + rewrite <- mul_bi'_opt_correct. + reflexivity. } + Qed. + + Definition mul_opt_sig (us vs : T) : { b : digits | b = mul us vs }. + Proof. + eexists. + cbv [BaseSystem.mul mul mul_each mul_bi mul_bi' zeros ext_base reduce]. + rewrite <- mul'_opt_correct. + cbv [base PseudoMersenneBaseParams.limb_widths]. + rewrite map_shiftl by apply k_nonneg. + rewrite c_subst. + rewrite k_subst. + change @map with @map_opt. + change base_from_limb_widths with base_from_limb_widths_opt. + change @Z_shiftl_by with @Z_shiftl_by_opt. + reflexivity. + Defined. + + Definition mul_opt (us vs : T) : digits + := Eval cbv [proj1_sig mul_opt_sig] in proj1_sig (mul_opt_sig us vs). + + Definition mul_opt_correct us vs + : mul_opt us vs = mul us vs + := proj2_sig (mul_opt_sig us vs). + + Lemma mul_opt_rep: + forall (u v : T) (x y : F modulus), PseudoMersenneBaseRep.rep u x -> PseudoMersenneBaseRep.rep v y -> + PseudoMersenneBaseRep.rep (mul_opt u v) (x * y)%F. + Proof. + intros. + rewrite mul_opt_correct. + change mul with PseudoMersenneBaseRep.mul. + auto using PseudoMersenneBaseRep.mul_rep. + Qed. + + Definition carry_mul_opt + (is : list nat) + (us vs : list Z) + : list Z + := carry_sequence_opt_cps c_ is (mul_opt us vs). + + Lemma carry_mul_opt_correct + : forall (is : list nat) (us vs : list Z) (x y: F modulus), + PseudoMersenneBaseRep.rep us x -> PseudoMersenneBaseRep.rep vs y -> + (forall i : nat, In i is -> i < length base)%nat -> + length (mul_opt us vs) = length base -> + PseudoMersenneBaseRep.rep (carry_mul_opt is us vs) (x*y)%F. + Proof. + intros is us vs x y; intros. + change (carry_mul_opt _ _ _) with (carry_sequence_opt_cps c_ is (mul_opt us vs)). + apply carry_sequence_opt_cps_rep, mul_opt_rep; auto. + Qed. +End Multiplication.
\ No newline at end of file diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v new file mode 100644 index 000000000..274acff5a --- /dev/null +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -0,0 +1,1179 @@ +Require Import Zpower ZArith. +Require Import Coq.Numbers.Natural.Peano.NPeano. +Require Import List. +Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. +Require Import VerdiTactics. +Require Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystem Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.BaseSystemProofs Crypto.ModularArithmetic.PseudoMersenneBaseParams Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs Crypto.ModularArithmetic.ExtendedBaseVector. +Local Open Scope Z_scope. + +Section PseudoMersenneProofs. + Context `{prm :PseudoMersenneBaseParams}. + + Local Hint Unfold decode. + Local Notation "u '~=' x" := (rep u x) (at level 70). + Local Notation "u '.+' x" := (add u x) (at level 70). + Local Notation "u '.*' x" := (ModularBaseSystem.mul u x) (at level 70). + Local Hint Unfold rep. + + Lemma rep_decode : forall us x, us ~= x -> decode us = x. + Proof. + autounfold; intuition. + Qed. + + Lemma encode_rep : forall x : F modulus, encode x ~= x. + Proof. + intros. unfold encode, rep. + split. { + unfold encode; simpl. + apply base_length_nonzero. + } { + unfold decode. + rewrite encode_rep. + apply ZToField_FieldToZ. + apply bv. + } + Qed. + + Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> BaseSystem.add u v ~= (x+y)%F. + Proof. + autounfold; intuition. { + unfold add. + rewrite add_length_le_max. + case_max; try rewrite Max.max_r; omega. + } + unfold decode in *; unfold decode in *. + rewrite add_rep. + rewrite ZToField_add. + subst; auto. + Qed. + + Lemma sub_rep : forall c c_0modq, (length c <= length base)%nat -> + forall u v x y, u ~= x -> v ~= y -> + ModularBaseSystem.sub c c_0modq u v ~= (x-y)%F. + Proof. + autounfold; unfold ModularBaseSystem.sub; intuition. { + rewrite sub_length_le_max. + case_max; try rewrite Max.max_r; try omega. + rewrite add_length_le_max. + case_max; try rewrite Max.max_r; omega. + } + unfold decode in *; unfold BaseSystem.decode in *. + rewrite BaseSystemProofs.sub_rep, BaseSystemProofs.add_rep. + rewrite ZToField_sub, ZToField_add, ZToField_mod. + rewrite c_0modq, F_add_0_l. + subst; auto. + Qed. + + Lemma decode_short : forall (us : BaseSystem.digits), + (length us <= length base)%nat -> + BaseSystem.decode base us = BaseSystem.decode ext_base us. + Proof. + intros. + unfold BaseSystem.decode, BaseSystem.decode'. + rewrite combine_truncate_r. + rewrite (combine_truncate_r us ext_base). + f_equal; f_equal. + unfold ext_base. + rewrite firstn_app_inleft; auto; omega. + Qed. + + Lemma mul_rep_extended : forall (us vs : BaseSystem.digits), + (length us <= length base)%nat -> + (length vs <= length base)%nat -> + (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode ext_base (BaseSystem.mul ext_base us vs). + Proof. + intros. + rewrite mul_rep by (apply ExtBaseVector || unfold ext_base; simpl_list; omega). + f_equal; rewrite decode_short; auto. + Qed. + + Lemma modulus_nonzero : modulus <> 0. + pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega. + Qed. + + (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *) + Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus. + Proof. + intros. + replace (2^k) with ((2^k - c) + c) by ring. + rewrite Z.mul_add_distr_r. + rewrite Zplus_mod. + unfold c. + rewrite Z.sub_sub_distr, Z.sub_diag. + simpl. + rewrite Z.mul_comm. + rewrite mod_mult_plus; auto using modulus_nonzero. + rewrite <- Zplus_mod; auto. + Qed. + + Lemma extended_shiftadd: forall (us : BaseSystem.digits), + BaseSystem.decode ext_base us = + BaseSystem.decode base (firstn (length base) us) + + (2^k * BaseSystem.decode base (skipn (length base) us)). + Proof. + intros. + unfold BaseSystem.decode; rewrite <- mul_each_rep. + unfold ext_base. + replace (map (Z.mul (2 ^ k)) base) with (BaseSystem.mul_each (2 ^ k) base) by auto. + rewrite base_mul_app. + rewrite <- mul_each_rep; auto. + Qed. + + Lemma reduce_rep : forall us, + BaseSystem.decode base (reduce us) mod modulus = + BaseSystem.decode ext_base us mod modulus. + Proof. + intros. + rewrite extended_shiftadd. + rewrite pseudomersenne_add. + unfold reduce. + remember (firstn (length base) us) as low. + remember (skipn (length base) us) as high. + unfold BaseSystem.decode. + rewrite BaseSystemProofs.add_rep. + replace (map (Z.mul c) high) with (BaseSystem.mul_each c high) by auto. + rewrite mul_each_rep; auto. + Qed. + + Lemma reduce_length : forall us, + (length us <= length ext_base)%nat -> + (length (reduce us) <= length base)%nat. + Proof. + intros. + unfold reduce. + remember (map (Z.mul c) (skipn (length base) us)) as high. + remember (firstn (length base) us) as low. + assert (length low >= length high)%nat. { + subst. rewrite firstn_length. + rewrite map_length. + rewrite skipn_length. + destruct (le_dec (length base) (length us)). { + rewrite Min.min_l by omega. + rewrite extended_base_length in H. omega. + } { + rewrite Min.min_r; omega. + } + } + assert ((length low <= length base)%nat) + by (rewrite Heqlow; rewrite firstn_length; apply Min.le_min_l). + assert (length high <= length base)%nat + by (rewrite Heqhigh; rewrite map_length; rewrite skipn_length; + rewrite extended_base_length in H; omega). + rewrite add_trailing_zeros; auto. + rewrite (add_same_length _ _ (length low)); auto. + rewrite app_length. + rewrite length_zeros; intuition. + Qed. + + Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> u .* v ~= (x*y)%F. + Proof. + autounfold; unfold ModularBaseSystem.mul; intuition. + { + apply reduce_length. + rewrite mul_length, extended_base_length. + omega. + } { + rewrite ZToField_mod, reduce_rep, <-ZToField_mod. + rewrite mul_rep by + (apply ExtBaseVector || rewrite extended_base_length; omega). + subst. + do 2 rewrite decode_short by auto. + apply ZToField_mul. + } + Qed. + + Lemma set_nth_sum : forall n x us, (n < length us)%nat -> + BaseSystem.decode base (set_nth n x us) = + (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us. + Proof. + intros. + unfold BaseSystem.decode. + nth_inbounds; auto. (* TODO(andreser): nth_inbounds should do this auto*) + unfold splice_nth. + rewrite <- (firstn_skipn n us) at 4. + do 2 rewrite decode'_splice. + remember (length (firstn n us)) as n0. + ring_simplify. + remember (BaseSystem.decode' (firstn n0 base) (firstn n us)). + rewrite (skipn_nth_default n us 0) by omega. + rewrite firstn_length in Heqn0. + rewrite Min.min_l in Heqn0 by omega; subst n0. + destruct (le_lt_dec (length base) n). { + rewrite nth_default_out_of_bounds by auto. + rewrite skipn_all by omega. + do 2 rewrite decode_base_nil. + ring_simplify; auto. + } { + rewrite (skipn_nth_default n base 0) by omega. + do 2 rewrite decode'_cons. + ring_simplify; ring. + } + Qed. + + Lemma add_to_nth_sum : forall n x us, (n < length us)%nat -> + BaseSystem.decode base (add_to_nth n x us) = + x * nth_default 0 base n + BaseSystem.decode base us. + Proof. + unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto. + Qed. + + Lemma nth_default_base_positive : forall i, (i < length base)%nat -> + nth_default 0 base i > 0. + Proof. + intros. + pose proof (nth_error_length_exists_value _ _ H). + destruct H0. + pose proof (nth_error_value_In _ _ _ H0). + pose proof (BaseSystem.base_positive _ H1). + unfold nth_default. + rewrite H0; auto. + Qed. + + Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat -> + nth_default 0 base (S i) = nth_default 0 base i * + (nth_default 0 base (S i) / nth_default 0 base i). + Proof. + intros. + apply Z_div_exact_2; try (apply nth_default_base_positive; omega). + apply base_succ; auto. + Qed. + +End PseudoMersenneProofs. + +Section CarryProofs. + Context `{prm : PseudoMersenneBaseParams}. + Local Notation "u '~=' x" := (rep u x) (at level 70). + Hint Unfold log_cap. + + Lemma base_length_lt_pred : (pred (length base) < length base)%nat. + Proof. + pose proof base_length_nonzero; omega. + Qed. + Hint Resolve base_length_lt_pred. + + Lemma log_cap_nonneg : forall i, 0 <= log_cap i. + Proof. + unfold log_cap, nth_default; intros. + case_eq (nth_error limb_widths i); intros; try omega. + apply limb_widths_nonneg. + eapply nth_error_value_In; eauto. + Qed. + + Lemma nth_default_base_succ : forall i, (S i < length base)%nat -> + nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i. + Proof. + intros. + repeat rewrite nth_default_base by omega. + rewrite <- Z.pow_add_r by (apply log_cap_nonneg || apply sum_firstn_limb_widths_nonneg). + destruct (NPeano.Nat.eq_dec i 0). + + subst; f_equal. + unfold sum_firstn, log_cap. + destruct limb_widths; auto. + + erewrite sum_firstn_succ; eauto. + unfold log_cap. + apply nth_error_Some_nth_default. + rewrite <- base_length; omega. + Qed. + + Lemma carry_simple_decode_eq : forall i us, + (length us = length base) -> + (i < (pred (length base)))%nat -> + BaseSystem.decode base (carry_simple i us) = BaseSystem.decode base us. + Proof. + unfold carry_simple. intros. + rewrite add_to_nth_sum by (rewrite length_set_nth; omega). + rewrite set_nth_sum by omega. + unfold pow2_mod. + rewrite Z.land_ones by apply log_cap_nonneg. + rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. + rewrite nth_default_base_succ by omega. + rewrite Z.mul_assoc. + rewrite (Z.mul_comm _ (2 ^ log_cap i)). + rewrite mul_div_eq; try ring. + apply gt_lt_symmetry. + apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg. + Qed. + + Lemma carry_decode_eq_reduce : forall us, + (length us = length base) -> + BaseSystem.decode base (carry_and_reduce (pred (length base)) us) mod modulus + = BaseSystem.decode base us mod modulus. + Proof. + unfold carry_and_reduce; intros ? length_eq. + pose proof base_length_nonzero. + rewrite add_to_nth_sum by (rewrite length_set_nth; omega). + rewrite set_nth_sum by omega. + rewrite Zplus_comm, <- Z.mul_assoc, <- pseudomersenne_add, BaseSystem.b0_1. + rewrite (Z.mul_comm (2 ^ k)), <- Zred_factor0. + f_equal. + rewrite <- (Z.add_comm (BaseSystem.decode base us)), <- Z.add_assoc, <- Z.add_0_r. + f_equal. + destruct (NPeano.Nat.eq_dec (length base) 0) as [length_zero | length_nonzero]. + + apply length0_nil in length_zero. + pose proof (base_length) as limbs_length. + rewrite length_zero in length_eq, limbs_length. + apply length0_nil in length_eq. + symmetry in limbs_length. + apply length0_nil in limbs_length. + unfold log_cap. + subst; rewrite length_zero, limbs_length, nth_default_nil. + reflexivity. + + rewrite nth_default_base by omega. + rewrite <- Z.add_opp_l, <- Z.opp_sub_distr. + unfold pow2_mod. + rewrite Z.land_ones by apply log_cap_nonneg. + rewrite <- mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg). + rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. + rewrite Zopp_mult_distr_r. + rewrite Z.mul_comm. + rewrite Z.mul_assoc. + rewrite <- Z.pow_add_r by (apply log_cap_nonneg || apply sum_firstn_limb_widths_nonneg). + unfold k. + replace (length limb_widths) with (S (pred (length base))) by + (subst; rewrite <- base_length; apply NPeano.Nat.succ_pred; omega). + rewrite sum_firstn_succ with (x:= log_cap (pred (length base))) by + (unfold log_cap; apply nth_error_Some_nth_default; rewrite <- base_length; omega). + rewrite <- Zopp_mult_distr_r. + rewrite Z.mul_comm. + rewrite (Z.add_comm (log_cap (pred (length base)))). + ring. + Qed. + + Lemma carry_length : forall i us, + (length us <= length base)%nat -> + (length (carry i us) <= length base)%nat. + Proof. + unfold carry, carry_simple, carry_and_reduce, add_to_nth. + intros; break_if; subst; repeat (rewrite length_set_nth); auto. + Qed. + Hint Resolve carry_length. + + Lemma carry_rep : forall i us x, + (length us = length base) -> + (i < length base)%nat -> + us ~= x -> carry i us ~= x. + Proof. + pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq. + unfold rep, decode, carry in *; intros. + intuition; break_if; subst; eauto; + apply F_eq; simpl; intuition. + Qed. + Hint Resolve carry_rep. + + Lemma carry_sequence_length: forall is us, + (length us <= length base)%nat -> + (length (carry_sequence is us) <= length base)%nat. + Proof. + induction is; boring. + Qed. + Hint Resolve carry_sequence_length. + + Lemma carry_length_exact : forall i us, + (length us = length base)%nat -> + (length (carry i us) = length base)%nat. + Proof. + unfold carry, carry_simple, carry_and_reduce, add_to_nth. + intros; break_if; subst; repeat (rewrite length_set_nth); auto. + Qed. + + Lemma carry_sequence_length_exact: forall is us, + (length us = length base)%nat -> + (length (carry_sequence is us) = length base)%nat. + Proof. + induction is; boring. + apply carry_length_exact; auto. + Qed. + Hint Resolve carry_sequence_length_exact. + + Lemma carry_sequence_rep : forall is us x, + (forall i, In i is -> (i < length base)%nat) -> + (length us = length base) -> + us ~= x -> carry_sequence is us ~= x. + Proof. + induction is; boring. + Qed. + +End CarryProofs. + +Section CanonicalizationProofs. + Context `{prm : PseudoMersenneBaseParams} (lt_1_length_base : (1 < length base)%nat) (c_pos : 0 < c) {B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B). + + (* TODO : move *) + Lemma set_nth_nth_default : forall {A} (d:A) n x l i, (0 <= i < length l)%nat -> + nth_default d (set_nth n x l) i = + if (eq_nat_dec i n) then x else nth_default d l i. + Proof. + induction n; (destruct l; [intros; simpl in *; omega | ]); simpl; + destruct i; break_if; try omega; intros; try apply nth_default_cons; + rewrite !nth_default_cons_S, ?IHn; try break_if; omega || reflexivity. + Qed. + + (* TODO : move *) + Lemma add_to_nth_nth_default : forall n x l i, (0 <= i < length l)%nat -> + nth_default 0 (add_to_nth n x l) i = + if (eq_nat_dec i n) then x + nth_default 0 l i else nth_default 0 l i. + Proof. + intros. + unfold add_to_nth. + rewrite set_nth_nth_default by assumption. + break_if; subst; reflexivity. + Qed. + + (* TODO : move *) + Lemma length_add_to_nth : forall n x l, length (add_to_nth n x l) = length l. + Proof. + unfold add_to_nth; intros; apply length_set_nth. + Qed. + + (* TODO : move *) + Lemma singleton_list : forall {A} (l : list A), length l = 1%nat -> exists x, l = x :: nil. + Proof. + intros; destruct l; simpl in *; try congruence. + eexists; f_equal. + apply length0_nil; omega. + Qed. + + (* BEGIN groundwork proofs *) + + Lemma pow_2_log_cap_pos : forall i, 0 < 2 ^ log_cap i. + Proof. + intros; apply Z.pow_pos_nonneg; auto using log_cap_nonneg; omega. + Qed. + Local Hint Resolve pow_2_log_cap_pos. + + Lemma max_bound_log_cap : forall i, Z.succ (max_bound i) = 2 ^ log_cap i. + Proof. + intros. + unfold max_bound, Z.ones. + rewrite Z.shiftl_1_l. + omega. + Qed. + + Hint Resolve log_cap_nonneg. + Lemma pow2_mod_log_cap_range : forall a i, 0 <= pow2_mod a (log_cap i) <= max_bound i. + Proof. + intros. + unfold pow2_mod. + rewrite Z.land_ones by apply log_cap_nonneg. + unfold max_bound, Z.ones. + rewrite Z.shiftl_1_l, <-Z.lt_le_pred. + apply Z_mod_lt. + pose proof (pow_2_log_cap_pos i). + omega. + Qed. + + Lemma pow2_mod_log_cap_bounds_lower : forall a i, 0 <= pow2_mod a (log_cap i). + Proof. + intros. + pose proof (pow2_mod_log_cap_range a i); omega. + Qed. + + Lemma pow2_mod_log_cap_bounds_upper : forall a i, pow2_mod a (log_cap i) <= max_bound i. + Proof. + intros. + pose proof (pow2_mod_log_cap_range a i); omega. + Qed. + + Lemma pow2_mod_log_cap_small : forall a i, 0 <= a <= max_bound i -> + pow2_mod a (log_cap i) = a. + Proof. + intros. + unfold pow2_mod. + rewrite Z.land_ones by apply log_cap_nonneg. + apply Z.mod_small. + split; try omega. + rewrite <- max_bound_log_cap. + omega. + Qed. + + Lemma max_bound_nonneg : forall i, 0 <= max_bound i. + Proof. + unfold max_bound; intros; auto using Z_ones_nonneg. + Qed. + Local Hint Resolve max_bound_nonneg. + + Lemma pow2_mod_spec : forall a b, (0 <= b) -> pow2_mod a b = a mod (2 ^ b). + Proof. + intros. + unfold pow2_mod. + rewrite Z.land_ones; auto. + Qed. + + Lemma pow2_mod_upper_bound : forall a b, (0 <= a) -> (0 <= b) -> pow2_mod a b <= a. + Proof. + intros. + unfold pow2_mod. + rewrite Z.land_ones; auto. + apply Z.mod_le; auto. + apply Z.pow_pos_nonneg; omega. + Qed. + + Lemma shiftr_eq_0_max_bound : forall i a, Z.shiftr a (log_cap i) = 0 -> + a <= max_bound i. + Proof. + intros ? ? shiftr_0. + apply Z.shiftr_eq_0_iff in shiftr_0. + intuition; subst; try apply max_bound_nonneg. + match goal with H : Z.log2 _ < log_cap _ |- _ => apply Z.log2_lt_pow2 in H; + replace (2 ^ log_cap i) with (Z.succ (max_bound i)) in H by + (unfold max_bound, Z.ones; rewrite Z.shiftl_1_l; omega) + end; auto. + omega. + Qed. + + Lemma B_compat_log_cap : forall i, 0 <= B - log_cap i. + Proof. + unfold log_cap; intros. + destruct (lt_dec i (length limb_widths)). + + apply Z.le_0_sub. + apply B_compat. + rewrite nth_default_eq. + apply nth_In; assumption. + + replace (nth_default 0 limb_widths i) with 0; try omega. + symmetry; apply nth_default_out_of_bounds. + omega. + Qed. + Local Hint Resolve B_compat_log_cap. + + Lemma max_bound_shiftr_eq_0 : forall i a, 0 <= a -> a <= max_bound i -> + Z.shiftr a (log_cap i) = 0. + Proof. + intros ? ? ? le_max_bound. + apply Z.shiftr_eq_0_iff. + destruct (Z_eq_dec a 0); auto. + right. + split; try omega. + apply Z.log2_lt_pow2; try omega. + rewrite <-max_bound_log_cap. + omega. + Qed. + + (* END groundwork proofs *) + Opaque pow2_mod log_cap max_bound. + + (* automation *) + Ltac carry_length_conditions' := unfold carry_full, add_to_nth; + rewrite ?length_set_nth, ?carry_length_exact, ?carry_sequence_length_exact, ?carry_sequence_length_exact; + try omega; try solve [pose proof base_length; pose proof base_length_nonzero; omega || auto ]. + Ltac carry_length_conditions := try split; try omega; repeat carry_length_conditions'. + + Ltac add_set_nth := rewrite ?add_to_nth_nth_default; try solve [carry_length_conditions]; + try break_if; try omega; rewrite ?set_nth_nth_default; try solve [carry_length_conditions]; + try break_if; try omega. + + (* BEGIN defs *) + + Definition c_carry_constraint : Prop := + (c * (Z.ones (B - log_cap (pred (length base)))) < max_bound 0 + 1) + /\ (max_bound 0 + c < 2 ^ (log_cap 0 + 1)) + /\ (c <= max_bound 0 - c). + + Definition pre_carry_bounds us := forall i, 0 <= nth_default 0 us i < + if (eq_nat_dec i 0) then 2 ^ B else 2 ^ B - 2 ^ (B - log_cap (pred i)). + + Lemma pre_carry_bounds_nonzero : forall us, pre_carry_bounds us -> + (forall i, 0 <= nth_default 0 us i). + Proof. + unfold pre_carry_bounds. + intros ? PCB i. + specialize (PCB i). + omega. + Qed. + Hint Resolve pre_carry_bounds_nonzero. + + Definition carry_done us := forall i, (i < length base)%nat -> Z.shiftr (nth_default 0 us i) (log_cap i) = 0. + + Lemma carry_carry_done_done : forall i us, + (length us = length base)%nat -> + (i < length base)%nat -> + (forall i, 0 <= nth_default 0 us i) -> + carry_done us -> carry_done (carry i us). + Proof. + unfold carry_done; intros until 3. intros Hcarry_done ? ?. + unfold carry, carry_simple, carry_and_reduce; break_if; subst. + + rewrite Hcarry_done by omega. + rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_bound). + destruct i0; add_set_nth; rewrite ?Z.mul_0_r, ?Z.add_0_l; auto. + match goal with H : S _ = pred (length base) |- _ => rewrite H; auto end. + + rewrite Hcarry_done by omega. + rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_bound). + destruct i0; add_set_nth; subst; rewrite ?Z.add_0_l; auto. + Qed. + + (* END defs *) + + (* BEGIN proofs about first carry loop *) + + Lemma nth_default_carry_bound_upper : forall i us, (length us = length base) -> + nth_default 0 (carry i us) i <= max_bound i. + Proof. + unfold carry; intros. + break_if. + + unfold carry_and_reduce. + add_set_nth. + apply pow2_mod_log_cap_bounds_upper. + + unfold carry_simple. + destruct (lt_dec i (length us)). + - add_set_nth. + apply pow2_mod_log_cap_bounds_upper. + - rewrite nth_default_out_of_bounds by carry_length_conditions; auto. + Qed. + + Lemma nth_default_carry_bound_lower : forall i us, (length us = length base) -> + 0 <= nth_default 0 (carry i us) i. + Proof. + unfold carry; intros. + break_if. + + unfold carry_and_reduce. + add_set_nth. + apply pow2_mod_log_cap_bounds_lower. + + unfold carry_simple. + destruct (lt_dec i (length us)). + - add_set_nth. + apply pow2_mod_log_cap_bounds_lower. + - rewrite nth_default_out_of_bounds by carry_length_conditions; omega. + Qed. + + Lemma nth_default_carry_bound_succ_lower : forall i us, (forall i, 0 <= nth_default 0 us i) -> + (length us = length base) -> + 0 <= nth_default 0 (carry i us) (S i). + Proof. + unfold carry; intros ? ? PCB ? . + break_if. + + subst. replace (S (pred (length base))) with (length base) by omega. + rewrite nth_default_out_of_bounds; carry_length_conditions. + unfold carry_and_reduce. + add_set_nth. + + unfold carry_simple. + destruct (lt_dec (S i) (length us)). + - add_set_nth. + apply Z.add_nonneg_nonneg; [ apply Z.shiftr_nonneg | ]; unfold pre_carry_bounds in PCB. + * specialize (PCB i). omega. + * specialize (PCB (S i)). omega. + - rewrite nth_default_out_of_bounds by carry_length_conditions; omega. + Qed. + + Lemma carry_unaffected_low : forall i j us, ((0 < i < j)%nat \/ (i = 0 /\ j <> 0 /\ j <> pred (length base))%nat)-> + (length us = length base) -> + nth_default 0 (carry j us) i = nth_default 0 us i. + Proof. + intros. + unfold carry. + break_if. + + unfold carry_and_reduce. + add_set_nth. + + unfold carry_simple. + destruct (lt_dec i (length us)). + - add_set_nth. + - rewrite !nth_default_out_of_bounds by + (omega || rewrite length_add_to_nth; rewrite length_set_nth; pose proof base_length_nonzero; omega). + reflexivity. + Qed. + + Lemma carry_unaffected_high : forall i j us, (S j < i)%nat -> (length us = length base) -> + nth_default 0 (carry j us) i = nth_default 0 us i. + Proof. + intros. + destruct (lt_dec i (length us)); + [ | rewrite !nth_default_out_of_bounds by carry_length_conditions; reflexivity]. + unfold carry, carry_simple. + break_if; add_set_nth. + Qed. + + Lemma carry_nothing : forall i j us, (i < length base)%nat -> + (length us = length base)%nat -> + 0 <= nth_default 0 us j <= max_bound j -> + nth_default 0 (carry j us) i = nth_default 0 us i. + Proof. + unfold carry, carry_simple, carry_and_reduce; intros. + break_if; (add_set_nth; + [ rewrite max_bound_shiftr_eq_0 by omega; ring + | subst; apply pow2_mod_log_cap_small; assumption ]). + Qed. + + Lemma carry_bounds_0_upper : forall us j, (length us = length base) -> + (0 < j < length base)%nat -> + nth_default 0 (carry_sequence (make_chain j) us) 0 <= max_bound 0. + Proof. + unfold carry_sequence; induction j; [simpl; intros; omega | ]. + intros. + simpl in *. + destruct (eq_nat_dec 0 j). + + subst. + apply nth_default_carry_bound_upper; fold (carry_sequence (make_chain 0) us); carry_length_conditions. + + rewrite carry_unaffected_low; try omega. + fold (carry_sequence (make_chain j) us); carry_length_conditions. + Qed. + + Lemma carry_bounds_upper : forall i us j, (0 < i < j)%nat -> (length us = length base) -> + nth_default 0 (carry_sequence (make_chain j) us) i <= max_bound i. + Proof. + unfold carry_sequence; + induction j; [simpl; intros; omega | ]. + intros. + simpl in *. + assert (i = j \/ i < j)%nat as cases by omega. + destruct cases as [eq_j_i | lt_i_j]; subst. + + apply nth_default_carry_bound_upper; fold (carry_sequence (make_chain j) us); carry_length_conditions. + + rewrite carry_unaffected_low; try omega. + fold (carry_sequence (make_chain j) us); carry_length_conditions. + Qed. + + Lemma carry_sequence_unaffected : forall i us j, (j < i)%nat -> (length us = length base)%nat -> + nth_default 0 (carry_sequence (make_chain j) us) i = nth_default 0 us i. + Proof. + induction j; [simpl; intros; omega | ]. + intros. + simpl in *. + rewrite carry_unaffected_high by carry_length_conditions. + apply IHj; omega. + Qed. + + Lemma carry_sequence_bounds_lower : forall j i us, (length us = length base) -> + (forall i, 0 <= nth_default 0 us i) -> (j <= length base)%nat -> + 0 <= nth_default 0 (carry_sequence (make_chain j) us) i. + Proof. + induction j; intros. + + simpl. auto. + + simpl. + destruct (lt_dec (S j) i). + - rewrite carry_unaffected_high by carry_length_conditions. + apply IHj; auto; omega. + - assert ((i = S j) \/ (i = j) \/ (i < j))%nat as cases by omega. + destruct cases as [? | [? | ?]]. + * subst. apply nth_default_carry_bound_succ_lower; carry_length_conditions. + intros. + eapply IHj; auto; omega. + * subst. apply nth_default_carry_bound_lower; carry_length_conditions. + * destruct (eq_nat_dec j (pred (length base))); + [ | rewrite carry_unaffected_low by carry_length_conditions; apply IHj; auto; omega ]. + subst. + unfold carry, carry_and_reduce; break_if; try omega. + add_set_nth; [ | apply IHj; auto; omega ]. + apply Z.add_nonneg_nonneg; [ | apply IHj; auto; omega ]. + apply Z.mul_nonneg_nonneg; try omega. + apply Z.shiftr_nonneg. + apply IHj; auto; omega. + Qed. + + Lemma carry_bounds_lower : forall i us j, (0 < i <= j)%nat -> (length us = length base) -> + (forall i, 0 <= nth_default 0 us i) -> (j <= length base)%nat -> + 0 <= nth_default 0 (carry_sequence (make_chain j) us) i. + Proof. + unfold carry_sequence; + induction j; [simpl; intros; omega | ]. + intros. + simpl in *. + destruct (eq_nat_dec i (S j)). + + subst. apply nth_default_carry_bound_succ_lower; auto; + fold (carry_sequence (make_chain j) us); carry_length_conditions. + intros. + apply carry_sequence_bounds_lower; auto; omega. + + assert (i = j \/ i < j)%nat as cases by omega. + destruct cases as [eq_j_i | lt_i_j]; subst; + [apply nth_default_carry_bound_lower| rewrite carry_unaffected_low]; try omega; + fold (carry_sequence (make_chain j) us); carry_length_conditions. + apply carry_sequence_bounds_lower; auto; omega. + Qed. + + Lemma carry_full_bounds : forall us i, (i <> 0)%nat -> (forall i, 0 <= nth_default 0 us i) -> + (length us = length base)%nat -> + 0 <= nth_default 0 (carry_full us) i <= max_bound i. + Proof. + unfold carry_full, full_carry_chain; intros. + split; (destruct (lt_dec i (length limb_widths)); + [ | rewrite nth_default_out_of_bounds by carry_length_conditions; omega || auto ]). + + apply carry_bounds_lower; carry_length_conditions. + + apply carry_bounds_upper; carry_length_conditions. + Qed. + + Lemma carry_simple_no_overflow : forall us i, (i < pred (length base))%nat -> + length us = length base -> + 0 <= nth_default 0 us i < 2 ^ B -> + 0 <= nth_default 0 us (S i) < 2 ^ B - 2 ^ (B - log_cap i) -> + 0 <= nth_default 0 (carry i us) (S i) < 2 ^ B. + Proof. + intros. + unfold carry, carry_simple; break_if; try omega. + add_set_nth. + replace (2 ^ B) with (2 ^ (B - log_cap i) + (2 ^ B - 2 ^ (B - log_cap i))) by omega. + split. + + apply Z.add_nonneg_nonneg; try omega. + apply Z.shiftr_nonneg; try omega. + + apply Z.add_lt_mono; try omega. + rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. + apply Z.div_lt_upper_bound; try apply pow_2_log_cap_pos. + rewrite <-Z.pow_add_r by (apply log_cap_nonneg || apply B_compat_log_cap). + replace (log_cap i + (B - log_cap i)) with B by ring. + omega. + Qed. + + + Lemma carry_sequence_no_overflow : forall i us, pre_carry_bounds us -> + (length us = length base) -> + nth_default 0 (carry_sequence (make_chain i) us) i < 2 ^ B. + Proof. + unfold pre_carry_bounds. + intros ? ? PCB ?. + induction i. + + simpl. specialize (PCB 0%nat). + intuition. + + simpl. + destruct (lt_eq_lt_dec i (pred (length base))) as [[? | ? ] | ? ]. + - apply carry_simple_no_overflow; carry_length_conditions. + apply carry_sequence_bounds_lower; carry_length_conditions. + apply carry_sequence_bounds_lower; carry_length_conditions. + rewrite carry_sequence_unaffected; try omega. + specialize (PCB (S i)); rewrite Nat.pred_succ in PCB. + break_if; intuition. + - unfold carry; break_if; try omega. + rewrite nth_default_out_of_bounds; [ apply Z.pow_pos_nonneg; omega | ]. + subst. + unfold carry_and_reduce. + carry_length_conditions. + - rewrite nth_default_out_of_bounds; [ apply Z.pow_pos_nonneg; omega | ]. + carry_length_conditions. + Qed. + + Lemma carry_full_bounds_0 : forall us, pre_carry_bounds us -> + (length us = length base)%nat -> + 0 <= nth_default 0 (carry_full us) 0 <= max_bound 0 + c * (Z.ones (B - log_cap (pred (length base)))). + Proof. + unfold carry_full, full_carry_chain; intros. + rewrite <- base_length. + replace (length base) with (S (pred (length base))) at 1 2 by omega. + simpl. + unfold carry, carry_and_reduce; break_if; try omega. + add_set_nth. + split. + + apply Z.add_nonneg_nonneg. + - apply Z.mul_nonneg_nonneg; try omega. + apply Z.shiftr_nonneg. + apply carry_sequence_bounds_lower; auto; omega. + - apply carry_sequence_bounds_lower; auto; omega. + + rewrite Z.add_comm. + apply Z.add_le_mono. + - apply carry_bounds_0_upper; auto; omega. + - apply Z.mul_le_mono_pos_l; auto. + apply Z_shiftr_ones; auto; + [ | pose proof (B_compat_log_cap (pred (length base))); omega ]. + split. + * apply carry_bounds_lower; auto; try omega. + * apply carry_sequence_no_overflow; auto. + Qed. + + Lemma carry_full_bounds_lower : forall i us, pre_carry_bounds us -> + (length us = length base)%nat -> + 0 <= nth_default 0 (carry_full us) i. + Proof. + destruct i; intros. + + apply carry_full_bounds_0; auto. + + destruct (lt_dec (S i) (length base)). + - apply carry_bounds_lower; carry_length_conditions. + - rewrite nth_default_out_of_bounds; carry_length_conditions. + Qed. + + (* END proofs about first carry loop *) + + (* BEGIN proofs about second carry loop *) + + Lemma carry_sequence_carry_full_bounds_same : forall us i, pre_carry_bounds us -> c_carry_constraint -> + (length us = length base)%nat -> (0 < i < length base)%nat -> + 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full us)) i <= 2 ^ log_cap i. + Proof. + induction i; intros; try omega. + simpl. + unfold carry, carry_simple; break_if; try omega. + add_set_nth. + split. + + apply Z.add_nonneg_nonneg. + - apply Z.shiftr_nonneg. + destruct (eq_nat_dec i 0); subst. + * simpl. + apply carry_full_bounds_0; auto. + * apply IHi; auto; omega. + - rewrite carry_sequence_unaffected by carry_length_conditions. + apply carry_full_bounds; auto; omega. + + rewrite <-max_bound_log_cap, <-Z.add_1_l. + apply Z.add_le_mono. + - rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. + apply Z_div_floor; auto. + destruct i. + * simpl. + eapply Z.le_lt_trans; [ apply carry_full_bounds_0; auto | ]. + replace (2 ^ log_cap 0 * 2) with (2 ^ log_cap 0 + 2 ^ log_cap 0) by ring. + rewrite <-max_bound_log_cap, <-Z.add_1_l. + apply Z.add_lt_le_mono; try omega. + unfold c_carry_constraint in *. + intuition. + * eapply Z.le_lt_trans; [ apply IHi; auto; omega | ]. + apply Z.lt_mul_diag_r; auto; omega. + - rewrite carry_sequence_unaffected by carry_length_conditions. + apply carry_full_bounds; auto; omega. + Qed. + + Lemma carry_full_2_bounds_0 : forall us, pre_carry_bounds us -> c_carry_constraint -> + (length us = length base)%nat -> (1 < length base)%nat -> + 0 <= nth_default 0 (carry_full (carry_full us)) 0 <= max_bound 0 + c. + Proof. + intros. + unfold carry_full at 1 3, full_carry_chain. + rewrite <-base_length. + replace (length base) with (S (pred (length base))) by (pose proof base_length_nonzero; omega). + simpl. + unfold carry, carry_and_reduce; break_if; try omega. + add_set_nth. + split. + + apply Z.add_nonneg_nonneg. + apply Z.mul_nonneg_nonneg; try omega. + apply Z.shiftr_nonneg. + apply carry_sequence_carry_full_bounds_same; auto; omega. + eapply carry_sequence_bounds_lower; eauto; carry_length_conditions. + intros. + eapply carry_sequence_bounds_lower; eauto; carry_length_conditions. + + rewrite Z.add_comm. + apply Z.add_le_mono. + - apply carry_bounds_0_upper; carry_length_conditions. + - replace c with (c * 1) at 2 by ring. + apply Z.mul_le_mono_pos_l; try omega. + rewrite Z.shiftr_div_pow2 by auto. + apply Z.div_le_upper_bound; auto. + ring_simplify. + apply carry_sequence_carry_full_bounds_same; auto. + omega. + Qed. + + Lemma carry_full_2_bounds_succ : forall us i, pre_carry_bounds us -> c_carry_constraint -> + (length us = length base)%nat -> (0 < i < pred (length base))%nat -> + ((0 < i < length base)%nat -> + 0 <= nth_default 0 + (carry_sequence (make_chain i) (carry_full (carry_full us))) i <= + 2 ^ log_cap i) -> + 0 <= nth_default 0 (carry_simple i + (carry_sequence (make_chain i) (carry_full (carry_full us)))) (S i) <= 2 ^ log_cap (S i). + Proof. + unfold carry_simple; intros ? ? PCB CCC length_eq ? IH. + add_set_nth. + split. + + apply Z.add_nonneg_nonneg. + apply Z.shiftr_nonneg. + destruct i; + [ simpl; pose proof (carry_full_2_bounds_0 us PCB CCC length_eq); omega | ]. + - assert (0 < S i < length base)%nat as IHpre by omega. + specialize (IH IHpre). + omega. + - rewrite carry_sequence_unaffected by carry_length_conditions. + apply carry_full_bounds; carry_length_conditions. + intros. + apply carry_sequence_bounds_lower; eauto; carry_length_conditions. + + rewrite <-max_bound_log_cap, <-Z.add_1_l. + rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. + apply Z.add_le_mono. + - apply Z.div_le_upper_bound; auto. + ring_simplify. apply IH. omega. + - rewrite carry_sequence_unaffected by carry_length_conditions. + apply carry_full_bounds; carry_length_conditions. + intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions. + Qed. + + Lemma carry_full_2_bounds_same : forall us i, pre_carry_bounds us -> c_carry_constraint -> + (length us = length base)%nat -> (0 < i < length base)%nat -> + 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) i <= 2 ^ log_cap i. + Proof. + intros; induction i; try omega. + simpl; unfold carry. + break_if; try omega. + split; (destruct (eq_nat_dec i 0); subst; + [ cbv [make_chain carry_sequence fold_right carry_simple]; add_set_nth + | eapply carry_full_2_bounds_succ; eauto; omega]). + + apply Z.add_nonneg_nonneg. + apply Z.shiftr_nonneg. + eapply carry_full_2_bounds_0; eauto. + eapply carry_full_bounds; eauto; carry_length_conditions. + intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions. + + rewrite <-max_bound_log_cap, <-Z.add_1_l. + rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. + apply Z.add_le_mono. + - apply Z_div_floor; auto. + eapply Z.le_lt_trans; [ eapply carry_full_2_bounds_0; eauto | ]. + replace (Z.succ 1) with (2 ^ 1) by ring. + rewrite <-Z.pow_add_r by (omega || auto). + unfold c_carry_constraint in *. + intuition. + - apply carry_full_bounds; carry_length_conditions. + intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions. + Qed. + + Lemma carry_full_2_bounds' : forall us i j, pre_carry_bounds us -> c_carry_constraint -> + (length us = length base)%nat -> (0 < i < length base)%nat -> (i + j < length base)%nat -> (j <> 0)%nat -> + 0 <= nth_default 0 (carry_sequence (make_chain (i + j)) (carry_full (carry_full us))) i <= max_bound i. + Proof. + induction j; intros; try omega. + split; (destruct j; [ rewrite Nat.add_1_r; simpl + | rewrite <-plus_n_Sm; simpl; rewrite carry_unaffected_low by carry_length_conditions; eapply IHj; eauto; omega ]). + + apply nth_default_carry_bound_lower; carry_length_conditions. + + apply nth_default_carry_bound_upper; carry_length_conditions. + Qed. + + Lemma carry_full_2_bounds : forall us i j, pre_carry_bounds us -> c_carry_constraint -> + (length us = length base)%nat -> (0 < i < length base)%nat -> (i < j < length base)%nat -> + 0 <= nth_default 0 (carry_sequence (make_chain j) (carry_full (carry_full us))) i <= max_bound i. + Proof. + intros. + replace j with (i + (j - i))%nat by omega. + eapply carry_full_2_bounds'; eauto; omega. + Qed. + + Lemma carry_carry_full_2_bounds_0_lower : forall us i, pre_carry_bounds us -> c_carry_constraint -> + (length us = length base)%nat -> (0 < i < length base)%nat -> + (0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) 0). + Proof. + induction i; try omega. + intros ? ? length_eq ?; simpl. + destruct i. + + unfold carry. + break_if; + [ pose proof base_length_nonzero; replace (length base) with 1%nat in *; omega | ]. + simpl. + unfold carry_simple. + add_set_nth. + apply pow2_mod_log_cap_bounds_lower. + + rewrite carry_unaffected_low by carry_length_conditions. + assert (0 < S i < length base)%nat by omega. + intuition. + Qed. + + Lemma carry_full_2_bounds_lower :forall us i, pre_carry_bounds us -> c_carry_constraint -> + (length us = length base)%nat -> + 0 <= nth_default 0 (carry_full (carry_full us)) i. + Proof. + intros. + destruct i. + + apply carry_full_2_bounds_0; auto. + + apply carry_full_bounds; try solve [carry_length_conditions]. + intro j. + destruct j. + - apply carry_full_bounds_0; auto. + - apply carry_full_bounds; carry_length_conditions. + Qed. + + Lemma carry_carry_full_2_bounds_0_upper : forall us i, pre_carry_bounds us -> c_carry_constraint -> + (length us = length base)%nat -> (0 < i < length base)%nat -> + (nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) 0 <= max_bound 0 - c) + \/ carry_done (carry_sequence (make_chain i) (carry_full (carry_full us))). + Proof. + induction i; try omega. + intros ? ? length_eq ?; simpl. + destruct i. + + destruct (Z_le_dec (nth_default 0 (carry_full (carry_full us)) 0) (max_bound 0)). + - right. + unfold carry_done. + intros. + apply max_bound_shiftr_eq_0; simpl; rewrite carry_nothing; try solve [carry_length_conditions]. + * apply carry_full_2_bounds_lower; auto. + * split; try apply carry_full_2_bounds_lower; auto. + * destruct i; auto. + apply carry_full_bounds; try solve [carry_length_conditions]. + auto using carry_full_bounds_lower. + * split; auto. + apply carry_full_2_bounds_lower; auto. + - unfold carry. + break_if; + [ pose proof base_length_nonzero; replace (length base) with 1%nat in *; omega | ]. + simpl. + unfold carry_simple. + add_set_nth. left. + remember ((nth_default 0 (carry_full (carry_full us)) 0)) as x. + apply Z.le_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)). + * replace x with ((x - 2 ^ log_cap 0) + (1 * 2 ^ log_cap 0)) by ring. + rewrite pow2_mod_spec by auto. + rewrite Z.mod_add by (pose proof (pow_2_log_cap_pos 0); omega). + rewrite <-max_bound_log_cap, <-Z.add_1_l, Z.mod_small. + apply Z.sub_le_mono_r. + subst; apply carry_full_2_bounds_0; auto. + split; try omega. + pose proof carry_full_2_bounds_0. + apply Z.le_lt_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)); + [ apply Z.sub_le_mono_r; subst x; apply carry_full_2_bounds_0; auto; + ring_simplify; unfold c_carry_constraint in *; omega | ]. + ring_simplify; unfold c_carry_constraint in *; omega. + * ring_simplify; unfold c_carry_constraint in *; omega. + + rewrite carry_unaffected_low by carry_length_conditions. + assert (0 < S i < length base)%nat by omega. + intuition. + right. + apply carry_carry_done_done; try solve [carry_length_conditions]. + intro j. + destruct j. + - apply carry_carry_full_2_bounds_0_lower; auto. + - destruct (lt_eq_lt_dec j i) as [[? | ?] | ?]. + * apply carry_full_2_bounds; auto; omega. + * subst. apply carry_full_2_bounds_same; auto; omega. + * rewrite carry_sequence_unaffected; try solve [carry_length_conditions]. + apply carry_full_2_bounds_lower; auto; omega. + Qed. + + (* END proofs about second carry loop *) + + (* BEGIN proofs about third carry loop *) + + Lemma carry_full_3_bounds : forall us i, pre_carry_bounds us -> c_carry_constraint -> + (length us = length base)%nat ->(i < length base)%nat -> + 0 <= nth_default 0 (carry_full (carry_full (carry_full us))) i <= max_bound i. + Proof. + intros. + destruct i; [ | apply carry_full_bounds; carry_length_conditions; + do 2 (intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions) ]. + unfold carry_full at 1 4, full_carry_chain. + case_eq limb_widths; [intros; pose proof limb_widths_nonnil; congruence | ]. + simpl. + intros ? ? limb_widths_eq. + replace (length l) with (pred (length limb_widths)) by (rewrite limb_widths_eq; auto). + rewrite <- base_length. + unfold carry, carry_and_reduce; break_if; try omega; intros. + add_set_nth. + split. + + apply Z.add_nonneg_nonneg. + - apply Z.mul_nonneg_nonneg; auto; try omega. + apply Z.shiftr_nonneg. + eapply carry_full_2_bounds_same; eauto; omega. + - eapply carry_carry_full_2_bounds_0_lower; eauto; omega. + + pose proof (carry_carry_full_2_bounds_0_upper us (pred (length base))). + assert (0 < pred (length base) < length base)%nat by omega. + intuition. + - replace (max_bound 0) with (c + (max_bound 0 - c)) by ring. + apply Z.add_le_mono; try assumption. + replace c with (c * 1) at 2 by ring. + apply Z.mul_le_mono_pos_l; try omega. + rewrite Z.shiftr_div_pow2 by auto. + apply Z.div_le_upper_bound; auto. + ring_simplify. + apply carry_full_2_bounds_same; auto. + - match goal with H : carry_done _ |- _ => unfold carry_done in H; rewrite H by omega end. + ring_simplify. + apply shiftr_eq_0_max_bound; auto; omega. + Qed. + + Lemma nth_error_combine : forall {A B} i (x : A) (x' : B) l l', nth_error l i = Some x -> + nth_error l' i = Some x' -> nth_error (combine l l') i = Some (x, x'). + Admitted. + + Lemma nth_error_range : forall {A} i (l : list A), (i < length l)%nat -> + nth_error (range (length l)) i = Some i. + Admitted. + + (* END proofs about third carry loop *) + Opaque carry_full. + + Lemma freeze_in_bounds : forall us i, (us <> nil)%nat -> + 0 <= nth_default 0 (freeze us) i < 2 ^ log_cap i. + Proof. + Admitted. + + Lemma freeze_canonical : forall us vs x, rep us x -> rep vs x -> + freeze us = freeze vs. + Admitted. + +End CanonicalizationProofs.
\ No newline at end of file diff --git a/src/ModularArithmetic/PrimeFieldTheorems.v b/src/ModularArithmetic/PrimeFieldTheorems.v index 77d84c455..70a2c4a87 100644 --- a/src/ModularArithmetic/PrimeFieldTheorems.v +++ b/src/ModularArithmetic/PrimeFieldTheorems.v @@ -9,6 +9,7 @@ Require Import Coq.Classes.Morphisms Coq.Setoids.Setoid. Require Import Coq.ZArith.BinInt Coq.NArith.BinNat Coq.ZArith.ZArith Coq.ZArith.Znumtheory Coq.NArith.NArith. (* import Zdiv before Znumtheory *) Require Import Coq.Logic.Eqdep_dec. Require Import Crypto.Util.NumTheoryUtil Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics. Existing Class prime. @@ -67,7 +68,7 @@ Module FieldModulo (Import M : PrimeModulus). postprocess [Fpostprocess], constants [Fconstant], div morph_div_theory_modulo, - power_tac power_theory_modulo [Fexp_tac]). + power_tac power_theory_modulo [Fexp_tac]). End FieldModulo. Section VariousModPrime. @@ -79,8 +80,8 @@ Section VariousModPrime. postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption], constants [Fconstant], div (@Fmorph_div_theory q), - power_tac (@Fpower_theory q) [Fexp_tac]). - + power_tac (@Fpower_theory q) [Fexp_tac]). + Lemma Fq_mul_eq : forall x y z : F q, z <> 0 -> x * z = y * z -> x = y. Proof. intros ? ? ? z_nonzero mul_z_eq. @@ -119,47 +120,26 @@ Section VariousModPrime. eauto using Fq_inv_unique'. Qed. - Let inv_fermat_powmod (x:Z) : Z := powmod q x (Z.to_N (q-2)). - Lemma FieldToZ_inv_efficient : 2 < q -> - forall x : F q, FieldToZ (inv x) = inv_fermat_powmod x. - Proof. - intros. - rewrite (fun pf => Fq_inv_unique (fun x : F q => ZToField (inv_fermat_powmod (FieldToZ x))) pf); - subst inv_fermat_powmod; intuition; rewrite powmod_Zpow_mod; - replace (Z.of_N (Z.to_N (q - 2))) with (q-2)%Z by (rewrite Z2N.id ; omega). - - (* inv in range *) rewrite FieldToZ_ZToField, Zmod_mod; reflexivity. - - (* inv 0 *) replace (FieldToZ 0) with 0%Z by auto. - rewrite Z.pow_0_l by omega. - rewrite Zmod_0_l; trivial. - - (* inv nonzero *) rewrite <- (fermat_inv q _ x0) by - (rewrite mod_FieldToZ; eauto using FieldToZ_nonzero). - rewrite <-(ZToField_FieldToZ x0). - rewrite <-ZToField_mul. - rewrite ZToField_FieldToZ. - apply ZToField_eqmod. - demod; reflexivity. - Qed. - Lemma Fq_mul_zero_why : forall a b : F q, a*b = 0 -> a = 0 \/ b = 0. intros. assert (Z := F_eq_dec a 0); destruct Z. - + - left; intuition. - + - assert (a * b / a = 0) by ( rewrite H; field; intuition ). - + replace (a*b/a) with (b) in H0 by (field; trivial). right; intuition. Qed. - + Lemma Fq_mul_nonzero_nonzero : forall a b : F q, a <> 0 -> b <> 0 -> a*b <> 0. intros; intuition; subst. apply Fq_mul_zero_why in H1. destruct H1; subst; intuition. Qed. Hint Resolve Fq_mul_nonzero_nonzero. - + Lemma Fq_pow_zero : forall (p: N), p <> 0%N -> (0^p = @ZToField q 0)%F. induction p using N.peano_ind; rewrite <-?N.add_1_l, ?(proj2 (@F_pow_spec q _) _), ?(proj1 (@F_pow_spec q _)). @@ -193,6 +173,43 @@ Section VariousModPrime. + apply IHp; auto. Qed. + Lemma F_inv_0 : inv 0 = (0 : F q). + Proof. + destruct (@F_inv_spec q); auto. + Qed. + + Definition inv_fermat (x:F q) : F q := x ^ Z.to_N (q - 2)%Z. + Lemma Fq_inv_fermat: 2 < q -> forall x : F q, inv x = x ^ Z.to_N (q - 2)%Z. + Proof. + intros. + erewrite (Fq_inv_unique inv_fermat); trivial; split; intros; unfold inv_fermat. + { replace 2%N with (Z.to_N (Z.of_N 2))%N by auto. + rewrite Fq_pow_zero. reflexivity. intro. + assert (Z.of_N (Z.to_N (q-2)) = 0%Z) by (rewrite H0; eauto); rewrite Z2N.id in *; omega. } + { clear x. rename x0 into x. pose proof (fermat_inv _ _ x) as Hf; forward Hf. + { pose proof @ZToField_FieldToZ; pose proof @ZToField_mod; congruence. } + specialize (Hf H1); clear H1. + rewrite <-(ZToField_FieldToZ x). + rewrite ZToField_pow. + replace (Z.of_N (Z.to_N (q - 2))) with (q-2)%Z by (rewrite Z2N.id ; omega). + rewrite <-ZToField_mul. + apply ZToField_eqmod. + rewrite Hf, Zmod_small by omega; reflexivity. + } + Qed. + Lemma Fq_inv_fermat_correct : 2 < q -> forall x : F q, inv_fermat x = inv x. + Proof. + unfold inv_fermat; intros. rewrite Fq_inv_fermat; auto. + Qed. + + Let inv_fermat_powmod (x:Z) : Z := powmod q x (Z.to_N (q-2)). + Lemma FieldToZ_inv_efficient : 2 < q -> + forall x : F q, FieldToZ (inv x) = inv_fermat_powmod x. + Proof. + unfold inv_fermat_powmod; intros. + rewrite Fq_inv_fermat, powmod_Zpow_mod, <-FieldToZ_pow_Zpow_mod; auto. + Qed. + Lemma F_minus_swap : forall x y : F q, x - y = 0 -> y - x = 0. Proof. intros ? ? eq_zero. @@ -232,7 +249,7 @@ Section VariousModPrime. left. eapply Fq_square_mul; eauto. instantiate (1 := x). - assert (x ^ 2 = z * y ^ 2 - x ^ 2 + x ^ 2) as plus_minus_x2 by + assert (x ^ 2 = z * y ^ 2 - x ^ 2 + x ^ 2) as plus_minus_x2 by (rewrite <- eq_zero_sub; ring). rewrite plus_minus_x2; ring. } @@ -248,6 +265,11 @@ Section VariousModPrime. intros; field. (* TODO: Warning: Collision between bound variables ... *) Qed. + Lemma F_div_opp_1 : forall x y : F q, (opp x / y = opp (x / y))%F. + Proof. + intros; destruct (F_eq_dec y 0); [subst;unfold div;rewrite !F_inv_0|]; field. + Qed. + Lemma F_eq_opp_zero : forall x : F q, 2 < q -> (x = opp x <-> x = 0). Proof. split; intro A. @@ -352,6 +374,22 @@ Section VariousModPrime. Proof. econstructor; eauto using Fq_mul_zero_why, Fq_1_neq_0. Qed. + + Lemma add_cancel_mul_r_nonzero {y : F q} (H : y <> 0) (x z : F q) + : x * y + z = (x + (z * (inv y))) * y. + Proof. field. Qed. + + Lemma sub_cancel_mul_r_nonzero {y : F q} (H : y <> 0) (x z : F q) + : x * y - z = (x - (z * (inv y))) * y. + Proof. field. Qed. + + Lemma add_cancel_l_nonzero {y : F q} (H : y <> 0) (z : F q) + : y + z = (1 + (z * (inv y))) * y. + Proof. field. Qed. + + Lemma sub_cancel_l_nonzero {y : F q} (H : y <> 0) (z : F q) + : y - z = (1 - (z * (inv y))) * y. + Proof. field. Qed. End VariousModPrime. Section SquareRootsPrime5Mod8. @@ -367,7 +405,7 @@ Section SquareRootsPrime5Mod8. postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption], constants [Fconstant], div (@Fmorph_div_theory q), - power_tac (@Fpower_theory q) [Fexp_tac]). + power_tac (@Fpower_theory q) [Fexp_tac]). (* This is only the square root of -1 if q mod 8 is 3 or 5 *) Definition sqrt_minus1 : F q := ZToField 2 ^ Z.to_N (q / 4). @@ -382,7 +420,7 @@ Section SquareRootsPrime5Mod8. Qed. (* square root mod q relies on the fact that q is 5 mod 8 *) - Definition sqrt_mod_q (a : F q) := + Definition sqrt_mod_q (a : F q) := let b := a ^ Z.to_N (q / 8 + 1) in (match (F_eq_dec (b ^ 2) a) with | left A => b @@ -445,4 +483,85 @@ Section SquareRootsPrime5Mod8. field. Qed. -End SquareRootsPrime5Mod8.
\ No newline at end of file + Lemma sqrt_mod_q_of_0 : sqrt_mod_q 0 = 0. + Proof. + unfold sqrt_mod_q. + rewrite !Fq_pow_zero. + break_if; ring. + + congruence. + intro false_eq. + rewrite <-(N2Z.id 0) in false_eq. + rewrite N2Z.inj_0 in false_eq. + pose proof (prime_ge_2 q prime_q). + apply Z2N.inj in false_eq; zero_bounds. + assert (0 < q / 8 + 1)%Z. + apply Z.add_nonneg_pos; zero_bounds. + omega. + Qed. + + Lemma sqrt_mod_q_root_0 : forall x : F q, sqrt_mod_q x = 0 -> x = 0. + Proof. + unfold sqrt_mod_q; intros. + break_if. + - match goal with [ H : ?sqrt_x ^ 2 = x, H' : ?sqrt_x = 0 |- _ ] => rewrite <-H, H' end. + ring. + - match goal with + | [H : sqrt_minus1 * _ = 0 |- _ ]=> + apply Fq_mul_zero_why in H; destruct H as [sqrt_minus1_zero | ? ]; + [ | eapply Fq_root_zero; eauto ] + end. + unfold sqrt_minus1 in sqrt_minus1_zero. + rewrite sqrt_minus1_zero in sqrt_minus1_valid. + exfalso. + pose proof (@F_opp_spec q 1) as opp_spec_1. + rewrite <-sqrt_minus1_valid in opp_spec_1. + assert (((1 + 0 ^ 2) : F q) = (1 : F q)) as ring_subst by ring. + rewrite ring_subst in *. + apply Fq_1_neq_0; assumption. + Qed. + +End SquareRootsPrime5Mod8. + +Local Open Scope F_scope. +(** Tactics for solving inequalities. *) +Ltac solve_cancel_by_field y tnz := + solve [ generalize dependent y; intros; + field; tnz ]. + +Ltac cancel_nonzero_factors' tnz := + idtac; + match goal with + | [ |- ?x = 0 -> False ] + => change (x <> 0) + | [ |- ?x * ?y <> 0 ] + => apply Fq_mul_nonzero_nonzero + | [ H : ?y <> 0 |- _ ] + => progress rewrite ?(add_cancel_mul_r_nonzero H), ?(sub_cancel_mul_r_nonzero H), ?(add_cancel_l_nonzero H), ?(sub_cancel_l_nonzero H); + apply Fq_mul_nonzero_nonzero; [ | assumption ] + | [ |- ?op (?y * (ZToField (m := ?q) ?n)) ?z <> 0 ] + => unique assert (ZToField (m := q) n <> 0) by tnz; + generalize dependent (ZToField (m := q) n); intros + | [ |- ?op (?x * (?y * ?z)) _ <> 0 ] + => rewrite F_mul_assoc + end. +Ltac cancel_nonzero_factors tnz := repeat cancel_nonzero_factors' tnz. +Ltac specialize_quantified_equalities := + repeat match goal with + | [ H : forall _ _ _ _, _ = _ -> _, H' : _ = _ |- _ ] + => unique pose proof (fun x2 y2 => H _ _ x2 y2 H') + | [ H : forall _ _, _ = _ -> _, H' : _ = _ |- _ ] + => unique pose proof (H _ _ H') + end. +Ltac finish_inequality tnz := + idtac; + match goal with + | [ H : ?x <> 0 |- ?y <> 0 ] + => replace y with x by (field; tnz); exact H + end. +Ltac field_nonzero tnz := + cancel_nonzero_factors tnz; + try (specialize_quantified_equalities; + progress cancel_nonzero_factors tnz); + try solve [ specialize_quantified_equalities; + finish_inequality tnz ]. diff --git a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v new file mode 100644 index 000000000..1a7b3316e --- /dev/null +++ b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v @@ -0,0 +1,246 @@ +Require Import Zpower ZArith. +Require Import List. +Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import VerdiTactics. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Crypto.BaseSystem. +Local Open Scope Z_scope. + +Section PseudoMersenneBaseParamProofs. + Context `{prm : PseudoMersenneBaseParams}. + + Fixpoint base_from_limb_widths limb_widths := + match limb_widths with + | nil => nil + | w :: lw => 1 :: map (Z.mul (two_p w)) (base_from_limb_widths lw) + end. + + Definition base := base_from_limb_widths limb_widths. + + Lemma base_length : length base = length limb_widths. + Proof. + unfold base. + induction limb_widths; try reflexivity. + simpl; rewrite map_length; auto. + Qed. + + Lemma nth_error_first : forall {T} (a b : T) l, nth_error (a :: l) 0 = Some b -> + a = b. + Proof. + intros; simpl in *. + unfold value in *. + congruence. + Qed. + + Lemma base_from_limb_widths_step : forall i b w, (S i < length base)%nat -> + nth_error base i = Some b -> + nth_error limb_widths i = Some w -> + nth_error base (S i) = Some (two_p w * b). + Proof. + unfold base; induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b; + unfold base_from_limb_widths in *; fold base_from_limb_widths in *; + [rewrite (@nil_length0 Z) in *; omega | ]. + simpl in *; rewrite map_length in *. + case_eq i; intros; subst. + + subst; apply nth_error_first in nth_err_w. + apply nth_error_first in nth_err_b; subst. + apply map_nth_error. + case_eq l; intros; subst; [simpl in *; omega | ]. + unfold base_from_limb_widths; fold base_from_limb_widths. + reflexivity. + + simpl in nth_err_w. + apply nth_error_map in nth_err_w. + destruct nth_err_w as [x [A B]]. + subst. + replace (two_p w * (two_p a * x)) with (two_p a * (two_p w * x)) by ring. + apply map_nth_error. + apply IHl; auto; omega. + Qed. + + Lemma nth_error_exists_first : forall {T} l (x : T) (H : nth_error l 0 = Some x), + exists l', l = x :: l'. + Proof. + induction l; try discriminate; eexists. + apply nth_error_first in H. + subst; eauto. + Qed. + + Lemma sum_firstn_succ : forall l i x, + nth_error l i = Some x -> + sum_firstn l (S i) = x + sum_firstn l i. + Proof. + unfold sum_firstn; induction l; + [intros; rewrite (@nth_error_nil_error Z) in *; congruence | ]. + intros ? x nth_err_x; destruct (NPeano.Nat.eq_dec i 0). + + subst; simpl in *; unfold value in *. + congruence. + + rewrite <- (NPeano.Nat.succ_pred i) at 2 by auto. + rewrite <- (NPeano.Nat.succ_pred i) in nth_err_x by auto. + simpl. simpl in nth_err_x. + specialize (IHl (pred i) x). + rewrite NPeano.Nat.succ_pred in IHl by auto. + destruct (NPeano.Nat.eq_dec (pred i) 0). + - replace i with 1%nat in * by omega. + simpl. replace (pred 1) with 0%nat in * by auto. + apply nth_error_exists_first in nth_err_x. + destruct nth_err_x as [l' ?]. + subst; simpl; ring. + - rewrite IHl by auto; ring. + Qed. + + Lemma sum_firstn_limb_widths_nonneg : forall n, 0 <= sum_firstn limb_widths n. + Proof. + unfold sum_firstn; intros. + apply fold_right_invariant; try omega. + intros y In_y_lw ? ?. + apply Z.add_nonneg_nonneg; try assumption. + apply limb_widths_nonneg. + eapply In_firstn; eauto. + Qed. + + Lemma k_nonneg : 0 <= k. + Proof. + apply sum_firstn_limb_widths_nonneg. + Qed. + + Lemma nth_error_base : forall i, (i < length base)%nat -> + nth_error base i = Some (two_p (sum_firstn limb_widths i)). + Proof. + induction i; intros. + + unfold base, sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity. + intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega. + + + assert (i < length base)%nat as lt_i_length by omega. + specialize (IHi lt_i_length). + rewrite base_length in lt_i_length. + destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w]. + erewrite base_from_limb_widths_step; eauto. + f_equal. + simpl. + destruct (NPeano.Nat.eq_dec i 0). + - subst; unfold sum_firstn; simpl. + apply nth_error_exists_first in nth_err_w. + destruct nth_err_w as [l' lw_destruct]; subst. + rewrite lw_destruct. + ring_simplify. + f_equal; simpl; ring. + - erewrite sum_firstn_succ; eauto. + symmetry. + apply two_p_is_exp; auto using sum_firstn_limb_widths_nonneg. + apply limb_widths_nonneg. + eapply nth_error_value_In; eauto. + Qed. + + Lemma nth_default_base : forall d i, (i < length base)%nat -> + nth_default d base i = 2 ^ (sum_firstn limb_widths i). + Proof. + intros ? ? i_lt_length. + destruct (nth_error_length_exists_value _ _ i_lt_length) as [x nth_err_x]. + unfold nth_default. + rewrite nth_err_x. + rewrite nth_error_base in nth_err_x by assumption. + rewrite two_p_correct in nth_err_x. + congruence. + Qed. + + Lemma base_matches_modulus: forall i j, + (i < length base)%nat -> + (j < length base)%nat -> + (i+j >= length base)%nat-> + let b := nth_default 0 base in + let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in + b i * b j = r * (2^k * b (i+j-length base)%nat). + Proof. + intros. + rewrite (Z.mul_comm r). + subst r. + assert (i + j - length base < length base)%nat by omega. + rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.mul_pos_pos; + [ | subst b; rewrite nth_default_base; try assumption ]; + apply Z.pow_pos_nonneg; omega || apply k_nonneg || apply sum_firstn_limb_widths_nonneg). + rewrite (Zminus_0_l_reverse (b i * b j)) at 1. + f_equal. + subst b. + repeat rewrite nth_default_base by assumption. + do 2 rewrite <- Z.pow_add_r by (apply sum_firstn_limb_widths_nonneg || apply k_nonneg). + symmetry. + apply mod_same_pow. + split. + + apply Z.add_nonneg_nonneg; apply sum_firstn_limb_widths_nonneg || apply k_nonneg. + + rewrite base_length in *; apply limb_widths_match_modulus; assumption. + Qed. + + Lemma base_succ : forall i, ((S i) < length base)%nat -> + nth_default 0 base (S i) mod nth_default 0 base i = 0. + Proof. + intros. + repeat rewrite nth_default_base by omega. + apply mod_same_pow. + split; [apply sum_firstn_limb_widths_nonneg | ]. + destruct (NPeano.Nat.eq_dec i 0); subst. + + case_eq limb_widths; intro; unfold sum_firstn; simpl; try omega; intros l' lw_eq. + apply Z.add_nonneg_nonneg; try omega. + apply limb_widths_nonneg. + rewrite lw_eq. + apply in_eq. + + assert (i < length base)%nat as i_lt_length by omega. + rewrite base_length in *. + apply nth_error_length_exists_value in i_lt_length. + destruct i_lt_length as [x nth_err_x]. + erewrite sum_firstn_succ; eauto. + apply nth_error_value_In in nth_err_x. + apply limb_widths_nonneg in nth_err_x. + omega. + Qed. + + Lemma nth_error_subst : forall i b, nth_error base i = Some b -> + b = 2 ^ (sum_firstn limb_widths i). + Proof. + intros i b nth_err_b. + pose proof (nth_error_value_length _ _ _ _ nth_err_b). + rewrite nth_error_base in nth_err_b by assumption. + rewrite two_p_correct in nth_err_b. + congruence. + Qed. + + Lemma base_positive : forall b : Z, In b base -> b > 0. + Proof. + intros b In_b_base. + apply In_nth_error_value in In_b_base. + destruct In_b_base as [i nth_err_b]. + apply nth_error_subst in nth_err_b. + rewrite nth_err_b. + apply gt_lt_symmetry. + apply Z.pow_pos_nonneg; omega || apply sum_firstn_limb_widths_nonneg. + Qed. + + Lemma b0_1 : forall x : Z, nth_default x base 0 = 1. + Proof. + unfold base; case_eq limb_widths; intros; [pose proof limb_widths_nonnil; congruence | reflexivity]. + Qed. + + Lemma base_good : forall i j : nat, + (i + j < length base)%nat -> + let b := nth_default 0 base in + let r := b i * b j / b (i + j)%nat in + b i * b j = r * b (i + j)%nat. + Proof. + intros; subst b r. + repeat rewrite nth_default_base by omega. + rewrite (Z.mul_comm _ (2 ^ (sum_firstn limb_widths (i+j)))). + rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply sum_firstn_limb_widths_nonneg). + rewrite <- Z.pow_add_r by apply sum_firstn_limb_widths_nonneg. + rewrite mod_same_pow; try ring. + split; [ apply sum_firstn_limb_widths_nonneg | ]. + apply limb_widths_good. + rewrite <- base_length; assumption. + Qed. + + Global Instance bv : BaseSystem.BaseVector base := { + base_positive := base_positive; + b0_1 := b0_1; + base_good := base_good + }. + +End PseudoMersenneBaseParamProofs. diff --git a/src/ModularArithmetic/PseudoMersenneBaseParams.v b/src/ModularArithmetic/PseudoMersenneBaseParams.v new file mode 100644 index 000000000..3914d6219 --- /dev/null +++ b/src/ModularArithmetic/PseudoMersenneBaseParams.v @@ -0,0 +1,24 @@ +Require Import ZArith. +Require Import List. +Require Crypto.BaseSystem. +Local Open Scope Z_scope. + +Definition sum_firstn l n := fold_right Z.add 0 (firstn n l). + +Class PseudoMersenneBaseParams (modulus : Z) := { + limb_widths : list Z; + limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w; + limb_widths_nonnil : limb_widths <> nil; + limb_widths_good : forall i j, (i + j < length limb_widths)%nat -> + sum_firstn limb_widths (i + j) <= + sum_firstn limb_widths i + sum_firstn limb_widths j; + prime_modulus : Znumtheory.prime modulus; + k := sum_firstn limb_widths (length limb_widths); + c := 2 ^ k - modulus; + limb_widths_match_modulus : forall i j, + (i < length limb_widths)%nat -> + (j < length limb_widths)%nat -> + (i + j >= length limb_widths)%nat -> + let w_sum := sum_firstn limb_widths in + k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j +}. diff --git a/src/ModularArithmetic/PseudoMersenneBaseRep.v b/src/ModularArithmetic/PseudoMersenneBaseRep.v new file mode 100644 index 000000000..c16cc8d38 --- /dev/null +++ b/src/ModularArithmetic/PseudoMersenneBaseRep.v @@ -0,0 +1,50 @@ +Require Import ZArith. +Require Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Local Open Scope Z_scope. + +Class RepZMod (modulus : Z) := { + T : Type; + encode : F modulus -> T; + decode : T -> F modulus; + + rep : T -> F modulus -> Prop; + encode_rep : forall x, rep (encode x) x; + rep_decode : forall u x, rep u x -> decode u = x; + + add : T -> T -> T; + add_rep : forall u v x y, rep u x -> rep v y -> rep (add u v) (x+y)%F; + + sub : T -> T -> T; + sub_rep : forall u v x y, rep u x -> rep v y -> rep (sub u v) (x-y)%F; + + mul : T -> T -> T; + mul_rep : forall u v x y, rep u x -> rep v y -> rep (mul u v) (x*y)%F +}. + +Class SubtractionCoefficient (m : Z) (prm : PseudoMersenneBaseParams m) := { + coeff : BaseSystem.digits; + coeff_length : (length coeff <= length PseudoMersenneBaseParamProofs.base)%nat; + coeff_mod: (BaseSystem.decode PseudoMersenneBaseParamProofs.base coeff) mod m = 0 +}. + +Instance PseudoMersenneBase m (prm : PseudoMersenneBaseParams m) (sc : SubtractionCoefficient m prm) +: RepZMod m := { + T := list Z; + encode := ModularBaseSystem.encode; + decode := ModularBaseSystem.decode; + + rep := ModularBaseSystem.rep; + encode_rep := ModularBaseSystemProofs.encode_rep; + rep_decode := ModularBaseSystemProofs.rep_decode; + + add := BaseSystem.add; + add_rep := ModularBaseSystemProofs.add_rep; + + sub := ModularBaseSystem.sub coeff coeff_mod; + sub_rep := ModularBaseSystemProofs.sub_rep coeff coeff_mod coeff_length; + + mul := ModularBaseSystem.mul; + mul_rep := ModularBaseSystemProofs.mul_rep +}. diff --git a/src/Rep.v b/src/Rep.v new file mode 100644 index 000000000..b7e7f10c5 --- /dev/null +++ b/src/Rep.v @@ -0,0 +1,13 @@ +Class RepConversions (T:Type) (RT:Type) : Type := + { + toRep : T -> RT; + unRep : RT -> T + }. + +Definition RepConversionsOK {T RT} (RC:RepConversions T RT) := forall x, unRep (toRep x) = x. + +Definition RepFunOK {T RT} `(RC:RepConversions T RT) (f:T->T) (rf : RT -> RT) := + forall x, f (unRep x) = unRep (rf x). + +Definition RepBinOpOK {T RT} `(RC:RepConversions T RT) (op:T->T->T) (rop : RT -> RT -> RT) := + forall x y, op (unRep x) (unRep y) = unRep (rop x y). diff --git a/src/Spec/CompleteEdwardsCurve.v b/src/Spec/CompleteEdwardsCurve.v index b7d2c0d8e..3348be1d9 100644 --- a/src/Spec/CompleteEdwardsCurve.v +++ b/src/Spec/CompleteEdwardsCurve.v @@ -16,41 +16,39 @@ Class TwistedEdwardsParams := { nonsquare_d : forall x, x^2 <> d }. -Section TwistedEdwardsCurves. - Context {prm:TwistedEdwardsParams}. - - (* Twisted Edwards curves with complete addition laws. References: - * <https://eprint.iacr.org/2008/013.pdf> - * <http://ed25519.cr.yp.to/ed25519-20110926.pdf> - * <https://eprint.iacr.org/2015/677.pdf> - *) - Definition onCurve P := let '(x,y) := P in a*x^2 + y^2 = 1 + d*x^2*y^2. - Definition point := { P | onCurve P}. - Definition mkPoint (xy:F q * F q) (pf:onCurve xy) : point := exist onCurve xy pf. - - Definition zero : point := mkPoint (0, 1) (@Pre.zeroOnCurve _ _ _ prime_q). +Module E. + Section TwistedEdwardsCurves. + Context {prm:TwistedEdwardsParams}. + + (* Twisted Edwards curves with complete addition laws. References: + * <https://eprint.iacr.org/2008/013.pdf> + * <http://ed25519.cr.yp.to/ed25519-20110926.pdf> + * <https://eprint.iacr.org/2015/677.pdf> + *) + Definition onCurve P := let '(x,y) := P in a*x^2 + y^2 = 1 + d*x^2*y^2. + Definition point := { P | onCurve P}. + + Definition zero : point := exist _ (0, 1) (@Pre.zeroOnCurve _ _ _ prime_q). + + Definition add' P1' P2' := + let '(x1, y1) := P1' in + let '(x2, y2) := P2' in + (((x1*y2 + y1*x2)/(1 + d*x1*x2*y1*y2)) , ((y1*y2 - a*x1*x2)/(1 - d*x1*x2*y1*y2))). + + Definition add (P1 P2 : point) : point := + let 'exist P1' pf1 := P1 in + let 'exist P2' pf2 := P2 in + exist _ (add' P1' P2') + (@Pre.unifiedAdd'_onCurve _ _ _ prime_q two_lt_q nonzero_a square_a nonsquare_d _ _ pf1 pf2). + + Fixpoint mul (n:nat) (P : point) : point := + match n with + | O => zero + | S n' => add P (mul n' P) + end. + End TwistedEdwardsCurves. +End E. - Definition unifiedAdd' P1' P2' := - let '(x1, y1) := P1' in - let '(x2, y2) := P2' in - (((x1*y2 + y1*x2)/(1 + d*x1*x2*y1*y2)) , ((y1*y2 - a*x1*x2)/(1 - d*x1*x2*y1*y2))). - - Definition unifiedAdd (P1 P2 : point) : point := - let 'exist P1' pf1 := P1 in - let 'exist P2' pf2 := P2 in - mkPoint (unifiedAdd' P1' P2') - (@Pre.unifiedAdd'_onCurve _ _ _ prime_q two_lt_q nonzero_a square_a nonsquare_d _ _ pf1 pf2). - - Fixpoint scalarMult (n:nat) (P : point) : point := - match n with - | O => zero - | S n' => unifiedAdd P (scalarMult n' P) - end. - - Axiom point_eq_dec : forall P Q : point, {P = Q} + {P <> Q}. -End TwistedEdwardsCurves. - Delimit Scope E_scope with E. -Infix "+" := unifiedAdd : E_scope. -Infix "*" := scalarMult : E_scope. -Infix "==" := point_eq_dec (no associativity, at level 70) : E_scope. +Infix "+" := E.add : E_scope. +Infix "*" := E.mul : E_scope.
\ No newline at end of file diff --git a/src/Spec/Ed25519.v b/src/Spec/Ed25519.v index 6ab47e8e5..4876bb8d1 100644 --- a/src/Spec/Ed25519.v +++ b/src/Spec/Ed25519.v @@ -1,6 +1,7 @@ Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. Require Import Coq.Numbers.Natural.Peano.NPeano Coq.NArith.NArith. -Require Import Crypto.Spec.Encoding Crypto.Spec.PointEncoding. +Require Import Crypto.Spec.PointEncoding Crypto.Spec.ModularWordEncoding. +Require Import Crypto.Encoding.ModularWordEncodingTheorems. Require Import Crypto.Spec.EdDSA. Require Import Crypto.Spec.CompleteEdwardsCurve Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. Require Import Crypto.ModularArithmetic.PrimeFieldTheorems Crypto.ModularArithmetic.ModularArithmeticTheorems. @@ -114,8 +115,10 @@ Proof. compute; omega. Qed. +Require Import Crypto.Spec.Encoding. + Lemma q_pos : (0 < q)%Z. q_bound. Qed. -Definition FqEncoding : encoding of (F q) as word (b-1) := +Definition FqEncoding : canonical encoding of (F q) as word (b-1) := @modular_word_encoding q (b - 1) q_pos b_valid. Lemma l_pos : (0 < Z.of_nat l)%Z. pose proof prime_l; prime_bound. Qed. @@ -127,7 +130,7 @@ Proof. unfold l. apply Z2Nat.inj_lt; compute; congruence. Qed. -Definition FlEncoding : encoding of F (Z.of_nat l) as word b := +Definition FlEncoding : canonical encoding of F (Z.of_nat l) as word b := @modular_word_encoding (Z.of_nat l) b l_pos l_bound. Lemma q_5mod8 : (q mod 8 = 5)%Z. cbv; reflexivity. Qed. @@ -140,12 +143,14 @@ Proof. reflexivity. Qed. -Definition PointEncoding := @point_encoding curve25519params (b - 1) FqEncoding q_5mod8 sqrt_minus1_valid. +Definition PointEncoding : canonical encoding of E.point as (word b) := + (@point_encoding curve25519params (b - 1) q_5mod8 sqrt_minus1_valid FqEncoding sign_bit + (@sign_bit_zero _ prime_q two_lt_q _ b_valid) (@sign_bit_opp _ prime_q two_lt_q _ b_valid)). Definition H : forall n : nat, word n -> word (b + b). Admitted. -Definition B : point. Admitted. (* TODO: B = decodePoint (y=4/5, x="positive") *) -Definition B_nonzero : B <> zero. Admitted. -Definition l_order_B : scalarMult l B = zero. Admitted. +Definition B : E.point. Admitted. (* TODO: B = decodePoint (y=4/5, x="positive") *) +Definition B_nonzero : B <> E.zero. Admitted. +Definition l_order_B : (l * B)%E = E.zero. Admitted. Local Instance ed25519params : EdDSAParams := { E := curve25519params; diff --git a/src/Spec/EdDSA.v b/src/Spec/EdDSA.v index 3decae6a7..99f0766e0 100644 --- a/src/Spec/EdDSA.v +++ b/src/Spec/EdDSA.v @@ -6,6 +6,7 @@ Require Import Crypto.Util.WordUtil. Require Bedrock.Word. Require Coq.ZArith.Znumtheory Coq.ZArith.BinInt. Require Coq.Numbers.Natural.Peano.NPeano. +Require Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. Coercion Word.wordToNat : Word.word >-> nat. @@ -20,8 +21,8 @@ Section EdDSAParams. b : nat; (* public keys are k bits, signatures are 2*k bits *) b_valid : 2^(b - 1) > BinInt.Z.to_nat q; - FqEncoding : encoding of F q as Word.word (b-1); - PointEncoding : encoding of point as Word.word b; + FqEncoding : canonical encoding of F q as Word.word (b-1); + PointEncoding : canonical encoding of E.point as Word.word b; H : forall {n}, Word.word n -> Word.word (b + b); (* main hash function *) @@ -32,14 +33,14 @@ Section EdDSAParams. n_ge_c : n >= c; n_le_b : n <= b; - B : point; - B_not_identity : B <> zero; + B : E.point; + B_not_identity : B <> E.zero; l : nat; (* order of the subgroup of E generated by B *) l_prime : Znumtheory.prime (BinInt.Z.of_nat l); l_odd : l > 2; - l_order_B : (l*B)%E = zero; - FlEncoding : encoding of F (BinInt.Z.of_nat l) as Word.word b + l_order_B : (l*B)%E = E.zero; + FlEncoding : canonical encoding of F (BinInt.Z.of_nat l) as Word.word b }. End EdDSAParams. @@ -54,6 +55,7 @@ Section EdDSA. Notation secretkey := (Word.word b) (only parsing). Notation publickey := (Word.word b) (only parsing). Notation signature := (Word.word (b + b)) (only parsing). + Local Infix "==" := CompleteEdwardsCurveTheorems.E.point_eq_dec (at level 70) : E_scope . (* TODO: proofread curveKey and definition of n *) Definition curveKey (sk:secretkey) : nat := @@ -65,7 +67,7 @@ Section EdDSA. Definition sign (A_:publickey) sk {n} (M : Word.word n) := let r : nat := H (prngKey sk ++ M) in (* secret nonce *) - let R : point := (r * B)%E in (* commitment to nonce *) + let R : E.point := (r * B)%E in (* commitment to nonce *) let s : nat := curveKey sk in (* secret scalar *) let S : F (BinInt.Z.of_nat l) := ZToField (BinInt.Z.of_nat (r + H (enc R ++ public sk ++ M) * s)) in @@ -75,11 +77,11 @@ Section EdDSA. let R_ := Word.split1 b b sig in let S_ := Word.split2 b b sig in match dec S_ : option (F (BinInt.Z.of_nat l)) with None => false | Some S' => - match dec A_ : option point with None => false | Some A => - match dec R_ : option point with None => false | Some R => + match dec A_ : option E.point with None => false | Some A => + match dec R_ : option E.point with None => false | Some R => if BinInt.Z.to_nat (FieldToZ S') * B == R + (H (R_ ++ A_ ++ M)) * A then true else false end end end%E. -End EdDSA. +End EdDSA.
\ No newline at end of file diff --git a/src/Spec/Encoding.v b/src/Spec/Encoding.v index 14cf9d9d9..b063b638f 100644 --- a/src/Spec/Encoding.v +++ b/src/Spec/Encoding.v @@ -1,61 +1,8 @@ -Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith. -Require Import Coq.Numbers.Natural.Peano.NPeano. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems Crypto.ModularArithmetic.ModularArithmeticTheorems. -Require Import Bedrock.Word. -Require Import Crypto.Tactics.VerdiTactics. -Require Import Crypto.Util.NatUtil. -Require Import Crypto.Util.WordUtil. - -Class Encoding (T B:Type) := { +Class CanonicalEncoding (T B:Type) := { enc : T -> B ; dec : B -> option T ; - encoding_valid : forall x, dec (enc x) = Some x + encoding_valid : forall x, dec (enc x) = Some x ; + encoding_canonical : forall x_enc x, dec x_enc = Some x -> enc x = x_enc }. -Notation "'encoding' 'of' T 'as' B" := (Encoding T B) (at level 50). - -Local Open Scope nat_scope. - -Section ModularWordEncoding. - Context {m : Z} {sz : nat} {m_pos : (0 < m)%Z} {bound_check : Z.to_nat m < 2 ^ sz}. - - Definition Fm_enc (x : F m) : word sz := natToWord sz (Z.to_nat (FieldToZ x)). - - Definition Fm_dec (x_ : word sz) : option (F m) := - let z := Z.of_nat (wordToNat (x_)) in - if Z_lt_dec z m - then Some (ZToField z) - else None - . - - Lemma bound_check_N : forall x : F m, (N.of_nat (Z.to_nat x) < Npow2 sz)%N. - Proof. - intro. - pose proof (FieldToZ_range x m_pos) as x_range. - rewrite <- Nnat.N2Nat.id. - rewrite Npow2_nat. - apply (Nat2N_inj_lt (Z.to_nat x) (pow2 sz)). - rewrite Zpow_pow2. - destruct x_range as [x_low x_high]. - apply Z2Nat.inj_lt in x_high; try omega. - rewrite <- ZUtil.pow_Z2N_Zpow by omega. - replace (Z.to_nat 2) with 2%nat by auto. - omega. - Qed. - - Lemma Fm_encoding_valid : forall x, Fm_dec (Fm_enc x) = Some x. - Proof. - unfold Fm_dec, Fm_enc; intros. - pose proof (FieldToZ_range x m_pos). - rewrite wordToNat_natToWord_idempotent by apply bound_check_N. - rewrite Z2Nat.id by omega. - rewrite ZToField_idempotent. - break_if; auto; omega. - Qed. - - Instance modular_word_encoding : encoding of F m as word sz := { - enc := Fm_enc; - dec := Fm_dec; - encoding_valid := Fm_encoding_valid - }. -End ModularWordEncoding. +Notation "'canonical' 'encoding' 'of' T 'as' B" := (CanonicalEncoding T B) (at level 50).
\ No newline at end of file diff --git a/src/Spec/ModularWordEncoding.v b/src/Spec/ModularWordEncoding.v new file mode 100644 index 000000000..d6f6bcb3c --- /dev/null +++ b/src/Spec/ModularWordEncoding.v @@ -0,0 +1,40 @@ +Require Import Coq.ZArith.ZArith. +Require Import Coq.Numbers.Natural.Peano.NPeano. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Bedrock.Word. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.NatUtil. +Require Import Crypto.Util.WordUtil. +Require Import Crypto.Spec.Encoding. +Require Crypto.Encoding.ModularWordEncodingPre. + +Local Open Scope nat_scope. + +Section ModularWordEncoding. + Context {m : Z} {sz : nat} {m_pos : (0 < m)%Z} {bound_check : Z.to_nat m < 2 ^ sz}. + + Definition Fm_enc (x : F m) : word sz := NToWord sz (Z.to_N (FieldToZ x)). + + Definition Fm_dec (x_ : word sz) : option (F m) := + let z := Z.of_N (wordToN (x_)) in + if Z_lt_dec z m + then Some (ZToField z) + else None + . + + Definition sign_bit (x : F m) := + match (Fm_enc x) with + | Word.WO => false + | Word.WS b _ w' => b + end. + + Instance modular_word_encoding : canonical encoding of F m as word sz := { + enc := Fm_enc; + dec := Fm_dec; + encoding_valid := + @ModularWordEncodingPre.Fm_encoding_valid m sz m_pos bound_check; + encoding_canonical := + @ModularWordEncodingPre.Fm_encoding_canonical m sz bound_check + }. + +End ModularWordEncoding. diff --git a/src/Spec/PointEncoding.v b/src/Spec/PointEncoding.v index 4823ef28f..f4634f52f 100644 --- a/src/Spec/PointEncoding.v +++ b/src/Spec/PointEncoding.v @@ -1,175 +1,47 @@ -Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. -Require Import Coq.Numbers.Natural.Peano.NPeano Coq.NArith.NArith. -Require Import Crypto.Spec.Encoding Crypto.Encoding.EncodingTheorems. -Require Import Crypto.Spec.CompleteEdwardsCurve Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems Crypto.ModularArithmetic.ModularArithmeticTheorems. -Require Import Crypto.Util.NatUtil Crypto.Util.ZUtil Crypto.Util.NumTheoryUtil. -Require Import Bedrock.Word. -Require Import Crypto.Tactics.VerdiTactics. +Require Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Require Coq.Numbers.Natural.Peano.NPeano. +Require Crypto.Encoding.EncodingTheorems. +Require Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. +Require Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Bedrock.Word. +Require Crypto.Tactics.VerdiTactics. +Require Crypto.Encoding.PointEncodingPre. +Obligation Tactic := eauto; exact PointEncodingPre.point_encoding_canonical. + +Require Import Crypto.Spec.Encoding Crypto.Spec.ModularWordEncoding. +Require Import Crypto.Spec.CompleteEdwardsCurve Crypto.Spec.ModularArithmetic. Local Open Scope F_scope. Section PointEncoding. - Context {prm:TwistedEdwardsParams} {sz : nat} {FqEncoding : encoding of F q as word sz} {q_5mod8 : q mod 8 = 5} {sqrt_minus1_valid : (@ZToField q 2 ^ Z.to_N (q / 4)) ^ 2 = opp 1}. + Context {prm: TwistedEdwardsParams} {sz : nat} {sz_nonzero : (0 < sz)%nat} + {bound_check : (BinInt.Z.to_nat q < NPeano.Nat.pow 2 sz)%nat} {q_5mod8 : (q mod 8 = 5)%Z} + {sqrt_minus1_valid : (@ZToField q 2 ^ BinInt.Z.to_N (q / 4)) ^ 2 = opp 1} + {FqEncoding : canonical encoding of (F q) as (Word.word sz)} + {sign_bit : F q -> bool} {sign_bit_zero : sign_bit 0 = false} + {sign_bit_opp : forall x, x <> 0 -> negb (sign_bit x) = sign_bit (opp x)}. Existing Instance prime_q. - Add Field Ffield : (@Ffield_theory q _) - (morphism (@Fring_morph q), - preprocess [Fpreprocess], - postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption], - constants [Fconstant], - div (@Fmorph_div_theory q), - power_tac (@Fpower_theory q) [Fexp_tac]). + Definition point_enc (p : E.point) : Word.word (S sz) := let '(x,y) := proj1_sig p in + Word.WS (sign_bit x) (enc y). - Definition sqrt_valid (a : F q) := ((sqrt_mod_q a) ^ 2 = a)%F. + Program Definition point_dec_with_spec : + {point_dec : Word.word (S sz) -> option E.point + | forall w x, point_dec w = Some x -> (point_enc x = w) + } := @PointEncodingPre.point_dec _ _ _ sign_bit. - Lemma solve_sqrt_valid : forall (p : point), - sqrt_valid (solve_for_x2 (snd (proj1_sig p))). - Proof. - intros. - destruct p as [[x y] onCurve_xy]; simpl. - rewrite (solve_correct x y) in onCurve_xy. - rewrite <- onCurve_xy. - unfold sqrt_valid. - eapply sqrt_mod_q_valid; eauto. - unfold isSquare; eauto. - Grab Existential Variables. eauto. - Qed. + Definition point_dec := Eval hnf in (proj1_sig point_dec_with_spec). - Lemma solve_onCurve: forall (y : F q), sqrt_valid (solve_for_x2 y) -> - onCurve (sqrt_mod_q (solve_for_x2 y), y). - Proof. - intros. - unfold sqrt_valid in *. - apply solve_correct; auto. - Qed. + Definition point_encoding_valid : forall p : E.point, point_dec (point_enc p) = Some p := + @PointEncodingPre.point_encoding_valid _ _ q_5mod8 sqrt_minus1_valid _ _ sign_bit_zero sign_bit_opp. - Lemma solve_opp_onCurve: forall (y : F q), sqrt_valid (solve_for_x2 y) -> - onCurve (opp (sqrt_mod_q (solve_for_x2 y)), y). - Proof. - intros y sqrt_valid_x2. - unfold sqrt_valid in *. - apply solve_correct. - rewrite <- sqrt_valid_x2 at 2. - ring. - Qed. + Definition point_encoding_canonical : forall x_enc x, point_dec x_enc = Some x -> point_enc x = x_enc := + PointEncodingPre.point_encoding_canonical. -Definition sign_bit (x : F q) := (wordToN (enc (opp x)) <? wordToN (enc x))%N. -Definition point_enc (p : point) : word (S sz) := let '(x,y) := proj1_sig p in - WS (sign_bit x) (enc y). -Definition point_dec (w : word (S sz)) : option point := - match dec (wtl w) with - | None => None - | Some y => let x2 := solve_for_x2 y in - let x := sqrt_mod_q x2 in - match (F_eq_dec (x ^ 2) x2) with - | right _ => None - | left EQ => if Bool.eqb (whd w) (sign_bit x) - then Some (mkPoint (x, y) (solve_onCurve y EQ)) - else Some (mkPoint (opp x, y) (solve_opp_onCurve y EQ)) - end - end. - -Lemma y_decode : forall p, dec (wtl (point_enc p)) = Some (snd (proj1_sig p)). -Proof. - intros. - destruct p as [[x y] onCurve_p]; simpl. - exact (encoding_valid y). -Qed. - - -Lemma wordToN_enc_neq_opp : forall x, x <> 0 -> (wordToN (enc (opp x)) <> wordToN (enc x))%N. -Proof. - intros x x_nonzero. - intro false_eq. - apply x_nonzero. - apply F_eq_opp_zero; try apply two_lt_q. - apply wordToN_inj in false_eq. - apply encoding_inj in false_eq. - auto. -Qed. - -Lemma sign_bit_opp_negb : forall x, x <> 0 -> negb (sign_bit x) = sign_bit (opp x). -Proof. - intros x x_nonzero. - unfold sign_bit. - rewrite <- N.leb_antisym. - rewrite N.ltb_compare, N.leb_compare. - rewrite F_opp_involutive. - case_eq (wordToN (enc x) ?= wordToN (enc (opp x)))%N; auto. - intro wordToN_enc_eq. - pose proof (wordToN_enc_neq_opp x x_nonzero). - apply N.compare_eq_iff in wordToN_enc_eq. - congruence. -Qed. - -Lemma sign_bit_opp : forall x y, y <> 0 -> - (sign_bit x <> sign_bit y <-> sign_bit x = sign_bit (opp y)). -Proof. - split; intro sign_mismatch; case_eq (sign_bit x); case_eq (sign_bit y); - try congruence; intros y_sign x_sign; rewrite <- sign_bit_opp_negb in * by auto; - rewrite y_sign, x_sign in *; reflexivity || discriminate. -Qed. - -Lemma sign_bit_squares : forall x y, y <> 0 -> x ^ 2 = y ^ 2 -> - sign_bit x = sign_bit y -> x = y. -Proof. - intros ? ? y_nonzero squares_eq sign_match. - destruct (sqrt_solutions _ _ squares_eq) as [? | eq_opp]; auto. - assert (sign_bit x = sign_bit (opp y)) as sign_mismatch by (f_equal; auto). - apply sign_bit_opp in sign_mismatch; auto. - congruence. -Qed. - -Lemma sign_bit_match : forall x x' y : F q, onCurve (x, y) -> onCurve (x', y) -> - sign_bit x = sign_bit x' -> x = x'. -Proof. - intros ? ? ? onCurve_x onCurve_x' sign_match. - apply solve_correct in onCurve_x. - apply solve_correct in onCurve_x'. - destruct (F_eq_dec x' 0). - + subst. - rewrite Fq_pow_zero in onCurve_x' by congruence. - rewrite <- onCurve_x' in *. - eapply Fq_root_zero; eauto. - + apply sign_bit_squares; auto. - rewrite onCurve_x, onCurve_x'. - reflexivity. -Qed. - -Lemma point_encoding_valid : forall p, point_dec (point_enc p) = Some p. -Proof. - intros. - unfold point_dec. - rewrite y_decode. - pose proof solve_sqrt_valid p as solve_sqrt_valid_p. - unfold sqrt_valid in *. - destruct p as [[x y] onCurve_p]. - simpl in *. - destruct (F_eq_dec ((sqrt_mod_q (solve_for_x2 y)) ^ 2) (solve_for_x2 y)); intuition. - break_if; f_equal; apply point_eq. - + rewrite Bool.eqb_true_iff in Heqb. - pose proof (solve_onCurve y solve_sqrt_valid_p). - f_equal. - apply (sign_bit_match _ _ y); auto. - + rewrite Bool.eqb_false_iff in Heqb. - pose proof (solve_opp_onCurve y solve_sqrt_valid_p). - f_equal. - apply sign_bit_opp in Heqb. - apply (sign_bit_match _ _ y); auto. - intro eq_zero. - apply solve_correct in onCurve_p. - rewrite eq_zero in *. - rewrite Fq_pow_zero in solve_sqrt_valid_p by congruence. - rewrite <- solve_sqrt_valid_p in onCurve_p. - apply Fq_root_zero in onCurve_p. - rewrite onCurve_p in Heqb; auto. -Qed. - -Instance point_encoding : encoding of point as (word (S sz)) := { - enc := point_enc; - dec := point_dec; - encoding_valid := point_encoding_valid -}. - -End PointEncoding. + Instance point_encoding : canonical encoding of E.point as (Word.word (S sz)) := { + enc := point_enc; + dec := point_dec; + encoding_valid := point_encoding_valid; + encoding_canonical := point_encoding_canonical + }. +End PointEncoding.
\ No newline at end of file diff --git a/src/Specific/Ed25519.v b/src/Specific/Ed25519.v index a705ceb90..3b90b5cdf 100644 --- a/src/Specific/Ed25519.v +++ b/src/Specific/Ed25519.v @@ -2,29 +2,64 @@ Require Import Bedrock.Word. Require Import Crypto.Spec.Ed25519. Require Import Crypto.Tactics.VerdiTactics. Require Import BinNat BinInt NArith Crypto.Spec.ModularArithmetic. +Require Import ModularArithmetic.ModularArithmeticTheorems. +Require Import ModularArithmetic.PrimeFieldTheorems. Require Import Crypto.Spec.CompleteEdwardsCurve. -Require Import Crypto.Spec.Encoding Crypto.Spec.PointEncoding. +Require Import Crypto.Encoding.PointEncodingPre. +Require Import Crypto.Spec.Encoding Crypto.Spec.ModularWordEncoding Crypto.Spec.PointEncoding. Require Import Crypto.CompleteEdwardsCurve.ExtendedCoordinates. -Require Import Crypto.Util.IterAssocOp. +Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. +Require Import Crypto.Util.IterAssocOp Crypto.Util.WordUtil Crypto.Rep. Local Infix "++" := Word.combine. Local Notation " a '[:' i ']' " := (Word.split1 i _ a) (at level 40). Local Notation " a '[' i ':]' " := (Word.split2 i _ a) (at level 40). Local Arguments H {_} _. -Local Arguments scalarMultM1 {_} {_} _ _. +Local Arguments scalarMultM1 {_} {_} _ _ _. Local Arguments unifiedAddM1 {_} {_} _ _. -Ltac set_evars := +Local Ltac set_evars := repeat match goal with | [ |- appcontext[?E] ] => is_evar E; let e := fresh "e" in set (e := E) end. -Ltac subst_evars := +Local Ltac subst_evars := repeat match goal with | [ e := ?E |- _ ] => is_evar E; subst e end. -Axiom point_eqb : forall {prm : TwistedEdwardsParams}, point -> point -> bool. -Axiom point_eqb_correct : forall P Q, point_eqb P Q = if point_eq_dec P Q then true else false. +Lemma funexp_proj {T T'} (proj : T -> T') (f : T -> T) (f' : T' -> T') x n + (f_proj : forall a, proj (f a) = f' (proj a)) + : proj (funexp f x n) = funexp f' (proj x) n. +Proof. + revert x; induction n as [|n IHn]; simpl; congruence. +Qed. + +Lemma iter_op_proj {T T' S} (proj : T -> T') (op : T -> T -> T) (op' : T' -> T' -> T') x y z + (testbit : S -> nat -> bool) (bound : nat) + (op_proj : forall a b, proj (op a b) = op' (proj a) (proj b)) + : proj (iter_op op x testbit y z bound) = iter_op op' (proj x) testbit y (proj z) bound. +Proof. + unfold iter_op. + simpl. + lazymatch goal with + | [ |- ?proj (snd (funexp ?f ?x ?n)) = snd (funexp ?f' _ ?n) ] + => pose proof (fun x0 x1 => funexp_proj (fun x => (fst x, proj (snd x))) f f' (x0, x1)) as H' + end. + simpl in H'. + rewrite <- H'. + { reflexivity. } + { intros [??]; simpl. + repeat match goal with + | [ |- context[match ?n with _ => _ end] ] + => destruct n eqn:? + | _ => progress simpl + | _ => progress subst + | _ => reflexivity + | _ => rewrite op_proj + end. } +Qed. + +Lemma B_proj : proj1_sig B = (fst(proj1_sig B), snd(proj1_sig B)). destruct B as [[]]; reflexivity. Qed. Require Import Coq.Setoids.Setoid. Require Import Coq.Classes.Morphisms. @@ -52,21 +87,43 @@ Axiom decode_scalar : word b -> option N. Local Existing Instance Ed25519.FlEncoding. Axiom decode_scalar_correct : forall x, decode_scalar x = option_map (fun x : F (Z.of_nat Ed25519.l) => Z.to_N x) (dec x). -Local Infix "==?" := point_eqb (at level 70) : E_scope. +Local Infix "==?" := E.point_eqb (at level 70) : E_scope. +Local Infix "==?" := ModularArithmeticTheorems.F_eq_dec (at level 70) : F_scope. -Axiom negate : point -> point. -Definition point_sub P Q := (P + negate Q)%E. -Infix "-" := point_sub : E_scope. -Axiom solve_for_R : forall A B C, (A ==? B + C)%E = (B ==? A - C)%E. +Lemma solve_for_R_eq : forall A B C, (A = B + C <-> B = A - C)%E. +Proof. + intros; split; intros; subst; unfold E.sub; + rewrite <-E.add_assoc, ?E.add_opp_r, ?E.add_opp_l, E.add_0_r; reflexivity. +Qed. -Axiom negateExtended : extendedPoint -> extendedPoint. -Axiom negateExtended_correct : forall P, negate (unExtendedPoint P) = unExtendedPoint (negateExtended P). +Lemma solve_for_R : forall A B C, (A ==? B + C)%E = (B ==? A - C)%E. +Proof. + intros. + repeat match goal with |- context [(?P ==? ?Q)%E] => + let H := fresh "H" in + destruct (E.point_eq_dec P Q) as [H|H]; + (rewrite (E.point_eqb_complete _ _ H) || rewrite (E.point_eqb_neq_complete _ _ H)) + end; rewrite solve_for_R_eq in H; congruence. +Qed. + +Local Notation "'(' X ',' Y ',' Z ',' T ')'" := (mkExtended X Y Z T). +Local Notation "2" := (ZToField 2) : F_scope. Local Existing Instance PointEncoding. -Axiom decode_point_eq : forall (P_ Q_ : word (S (b-1))) (P Q:point), dec P_ = Some P -> dec Q_ = Some Q -> weqb P_ Q_ = (P ==? Q)%E. +Lemma decode_point_eq : forall (P_ Q_ : word (S (b-1))) (P Q:E.point), + dec P_ = Some P -> + dec Q_ = Some Q -> + weqb P_ Q_ = (P ==? Q)%E. +Proof. + intros. + replace P_ with (enc P) in * by (auto using encoding_canonical). + replace Q_ with (enc Q) in * by (auto using encoding_canonical). + rewrite E.point_eqb_correct. + edestruct E.point_eq_dec; (apply weqb_true_iff || apply weqb_false_iff); congruence. +Qed. -Lemma decode_test_encode_test : forall S_ X, option_rect (fun _ : option point => bool) - (fun S : point => (S ==? X)%E) false (dec S_) = weqb S_ (enc X). +Lemma decode_test_encode_test : forall S_ X, option_rect (fun _ : option E.point => bool) + (fun S : E.point => (S ==? X)%E) false (dec S_) = weqb S_ (enc X). Proof. intros. destruct (dec S_) eqn:H. @@ -76,80 +133,449 @@ Proof. apply weqb_true_iff in Heqb. subst. rewrite encoding_valid in H; discriminate. } Qed. -Lemma sharper_verify : forall pk l msg sig, { verify | verify = ed25519_verify pk l msg sig}. +Definition enc' : F q * F q -> word b. Proof. - eexists; intros. - cbv [ed25519_verify EdDSA.verify - ed25519params curve25519params - EdDSA.E EdDSA.B EdDSA.b EdDSA.l EdDSA.H - EdDSA.PointEncoding EdDSA.FlEncoding EdDSA.FqEncoding]. - - etransitivity. - Focus 2. - { repeat match goal with - | [ |- ?x = ?x ] => reflexivity - | [ |- _ = match ?a with None => ?N | Some x => @?S x end :> ?T ] - => etransitivity; - [ - | refine (_ : option_rect (fun _ => T) _ N a = _); - let S' := match goal with |- option_rect _ ?S' _ _ = _ => S' end in - refine (option_rect (fun a' => option_rect (fun _ => T) S' N a' = match a' with None => N | Some x => S x end) - (fun x => _) _ a); intros; simpl @option_rect ]; - [ reflexivity | .. ] - end. - set_evars. - rewrite<- point_eqb_correct. - rename x0 into R. rename x1 into S. rename x into A. - rewrite solve_for_R. - let p1 := constr:(scalarMultM1_rep eq_refl) in - let p2 := constr:(unifiedAddM1_rep eq_refl) in - repeat match goal with - | |- context [(_ * ?P)%E] => - rewrite <-(unExtendedPoint_mkExtendedPoint P); - rewrite <-p1 - | |- context [(?P + unExtendedPoint _)%E] => - rewrite <-(unExtendedPoint_mkExtendedPoint P); - rewrite p2 - end; - rewrite ?Znat.Z_nat_N, <-?Word.wordToN_nat; - subst_evars; - reflexivity. - } Unfocus. + intro x. + let enc' := (eval hnf in (@enc (@E.point curve25519params) _ _)) in + match (eval cbv [proj1_sig] in (fun pf => enc' (exist _ x pf))) with + | (fun _ => ?enc') => exact enc' + end. +Defined. - etransitivity. - Focus 2. - { lazymatch goal with |- _ = option_rect _ _ ?false ?dec => - symmetry; etransitivity; [|eapply (option_rect_option_map (fun (x:F _) => Z.to_N x) _ false dec)] - end. - eapply option_rect_Proper_nd; [intro|reflexivity..]. +Definition enc'_correct : @enc (@E.point curve25519params) _ _ = (fun x => enc' (proj1_sig x)) + := eq_refl. + +Definition Let_In {A P} (x : A) (f : forall a : A, P a) : P x := let y := x in f y. +Global Instance Let_In_Proper_nd {A P} + : Proper (eq ==> pointwise_relation _ eq ==> eq) (@Let_In A (fun _ => P)). +Proof. + lazy; intros; congruence. +Qed. +Lemma option_rect_function {A B C S' N' v} f + : f (option_rect (fun _ : option A => option B) S' N' v) + = option_rect (fun _ : option A => C) (fun x => f (S' x)) (f N') v. +Proof. destruct v; reflexivity. Qed. +Local Ltac commute_option_rect_Let_In := (* pull let binders out side of option_rect pattern matching *) + idtac; + lazymatch goal with + | [ |- ?LHS = option_rect ?P ?S ?N (Let_In ?x ?f) ] + => (* we want to just do a [change] here, but unification is stupid, so we have to tell it what to unfold in what order *) + cut (LHS = Let_In x (fun y => option_rect P S N (f y))); cbv beta; + [ set_evars; + let H := fresh in + intro H; + rewrite H; + clear; + abstract (cbv [Let_In]; reflexivity) + | ] + end. +Local Ltac replace_let_in_with_Let_In := + repeat match goal with + | [ |- context G[let x := ?y in @?z x] ] + => let G' := context G[Let_In y z] in change G' + | [ |- _ = Let_In _ _ ] + => apply Let_In_Proper_nd; [ reflexivity | cbv beta delta [pointwise_relation]; intro ] + end. +Local Ltac simpl_option_rect := (* deal with [option_rect _ _ _ None] and [option_rect _ _ _ (Some _)] *) + repeat match goal with + | [ |- context[option_rect ?P ?S ?N None] ] + => change (option_rect P S N None) with N + | [ |- context[option_rect ?P ?S ?N (Some ?x) ] ] + => change (option_rect P S N (Some x)) with (S x); cbv beta + end. + +Section Ed25519Frep. + Generalizable All Variables. + Context `(rcS:RepConversions N SRep) (rcSOK:RepConversionsOK rcS). + Context `(rcF:RepConversions (F (Ed25519.q)) FRep) (rcFOK:RepConversionsOK rcF). + Context (FRepAdd FRepSub FRepMul:FRep->FRep->FRep) (FRepAdd_correct:RepBinOpOK rcF add FRepMul). + Context (FRepSub_correct:RepBinOpOK rcF sub FRepSub) (FRepMul_correct:RepBinOpOK rcF mul FRepMul). + Local Notation rep2F := (unRep : FRep -> F (Ed25519.q)). + Local Notation F2Rep := (toRep : F (Ed25519.q) -> FRep). + Local Notation rep2S := (unRep : SRep -> N). + Local Notation S2Rep := (toRep : N -> SRep). + + Axiom FRepOpp : FRep -> FRep. + Axiom FRepOpp_correct : forall x, opp (rep2F x) = rep2F (FRepOpp x). + + Axiom wltu : forall {b}, word b -> word b -> bool. + Axiom wltu_correct : forall {b} (x y:word b), wltu x y = (wordToN x <? wordToN y)%N. + + Axiom compare_enc : forall x y, F_eqb x y = weqb (@enc _ _ FqEncoding x) (@enc _ _ FqEncoding y). + + Axiom wire2FRep : word (b-1) -> option FRep. + Axiom wire2FRep_correct : forall x, Fm_dec x = option_map rep2F (wire2FRep x). + + Axiom FRep2wire : FRep -> word (b-1). + Axiom FRep2wire_correct : forall x, FRep2wire x = @enc _ _ FqEncoding (rep2F x). + + Axiom SRep_testbit : SRep -> nat -> bool. + Axiom SRep_testbit_correct : forall (x0 : SRep) (i : nat), SRep_testbit x0 i = N.testbit_nat (unRep x0) i. + + Definition FSRepPow x n := iter_op FRepMul (toRep 1%F) SRep_testbit n x 255. + Lemma FSRepPow_correct : forall x n, (N.size_nat (unRep n) <= 255)%nat -> (unRep x ^ unRep n)%F = unRep (FSRepPow x n). + Proof. (* this proof derives the required formula, which I copy-pasted above to be able to reference it without the length precondition *) + unfold FSRepPow; intros. + erewrite <-pow_nat_iter_op_correct by auto. + erewrite <-(fun x => iter_op_spec (scalar := SRep) (mul (m:=Ed25519.q)) F_mul_assoc _ F_mul_1_l _ unRep SRep_testbit_correct n x 255%nat) by auto. + rewrite <-(rcFOK 1%F) at 1. + erewrite <-iter_op_proj by auto. + reflexivity. + Qed. + + Definition FRepInv x : FRep := FSRepPow x (S2Rep (Z.to_N (Ed25519.q - 2))). + Lemma FRepInv_correct : forall x, inv (rep2F x)%F = rep2F (FRepInv x). + unfold FRepInv; intros. + rewrite <-FSRepPow_correct; rewrite rcSOK; try reflexivity. + pose proof @Fq_inv_fermat_correct as H; unfold inv_fermat in H; rewrite H by + auto using Ed25519.prime_q, Ed25519.two_lt_q. + reflexivity. + Qed. + + Lemma unfoldDiv : forall {m} (x y:F m), (x/y = x * inv y)%F. Proof. unfold div. congruence. Qed. + + Definition rep2E (r:FRep * FRep * FRep * FRep) : extended := + match r with (((x, y), z), t) => mkExtended (rep2F x) (rep2F y) (rep2F z) (rep2F t) end. + + Lemma if_map : forall {T U} (f:T->U) (b:bool) (x y:T), (if b then f x else f y) = f (if b then x else y). + Proof. + destruct b; trivial. + Qed. + + Local Ltac Let_In_unRep := match goal with - | [ |- ?RHS = ?e ?v ] - => let RHS' := (match eval pattern v in RHS with ?RHS' _ => RHS' end) in - unify e RHS' + | [ |- appcontext G[Let_In (unRep ?x) ?f] ] + => change (Let_In (unRep x) f) with (Let_In x (fun y => f (unRep y))); cbv beta end. + + + (** TODO: Move me *) + Lemma pull_Let_In {B C} (f : B -> C) A (v : A) (b : A -> B) + : Let_In v (fun v' => f (b v')) = f (Let_In v b). + Proof. reflexivity. - } Unfocus. - rewrite <-decode_scalar_correct. - - etransitivity. - Focus 2. - { do 2 (eapply option_rect_Proper_nd; [intro|reflexivity..]). - symmetry; apply decode_test_encode_test. - } Unfocus. - - etransitivity. - Focus 2. - { do 2 (eapply option_rect_Proper_nd; [intro|reflexivity..]). - unfold point_sub. rewrite negateExtended_correct. - let p := constr:(unifiedAddM1_rep eq_refl) in rewrite p. + Qed. + + Lemma Let_app_In {A B T} (g:A->B) (f:B->T) (x:A) : + @Let_In _ (fun _ => T) (g x) f = + @Let_In _ (fun _ => T) x (fun p => f (g x)). + Proof. reflexivity. Qed. + + Lemma Let_app2_In {A B C D T} (g1:A->C) (g2:B->D) (f:C*D->T) (x:A) (y:B) : + @Let_In _ (fun _ => T) (g1 x, g2 y) f = + @Let_In _ (fun _ => T) (x, y) (fun p => f ((g1 (fst p), g2 (snd p)))). + Proof. reflexivity. Qed. + + Create HintDb FRepOperations discriminated. + Hint Rewrite FRepMul_correct FRepAdd_correct FRepSub_correct FRepInv_correct FSRepPow_correct FRepOpp_correct : FRepOperations. + + Create HintDb EdDSA_opts discriminated. + Hint Rewrite FRepMul_correct FRepAdd_correct FRepSub_correct FRepInv_correct FSRepPow_correct FRepOpp_correct : EdDSA_opts. + + Lemma unifiedAddM1Rep_sig : forall a b : FRep * FRep * FRep * FRep, { unifiedAddM1Rep | rep2E unifiedAddM1Rep = unifiedAddM1' (rep2E a) (rep2E b) }. + Proof. + destruct a as [[[]]]; destruct b as [[[]]]. + eexists. + lazymatch goal with |- ?LHS = ?RHS :> ?T => + evar (e:T); replace LHS with e; [subst e|] + end. + unfold rep2E. cbv beta delta [unifiedAddM1']. + pose proof (rcFOK twice_d) as H; rewrite <-H; clear H. (* XXX: this is a hack -- rewrite misresolves typeclasses? *) + + { etransitivity; [|replace_let_in_with_Let_In; reflexivity]. + repeat ( + autorewrite with FRepOperations; + Let_In_unRep; + eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [Proper respectful pointwise_relation]; intro]). + lazymatch goal with |- ?LHS = (unRep ?x, unRep ?y, unRep ?z, unRep ?t) => + change (LHS = (rep2E (((x, y), z), t))) + end. + reflexivity. } + + subst e. + Local Opaque Let_In. + repeat setoid_rewrite (pull_Let_In rep2E). + Local Transparent Let_In. reflexivity. - } Unfocus. + Defined. + + Definition unifiedAddM1Rep (a b:FRep * FRep * FRep * FRep) : FRep * FRep * FRep * FRep := Eval hnf in proj1_sig (unifiedAddM1Rep_sig a b). + Definition unifiedAddM1Rep_correct a b : rep2E (unifiedAddM1Rep a b) = unifiedAddM1' (rep2E a) (rep2E b) := Eval hnf in proj2_sig (unifiedAddM1Rep_sig a b). + + Definition rep2T (P:FRep * FRep) := (rep2F (fst P), rep2F (snd P)). + Definition erep2trep (P:FRep * FRep * FRep * FRep) := Let_In P (fun P => Let_In (FRepInv (snd (fst P))) (fun iZ => (FRepMul (fst (fst (fst P))) iZ, FRepMul (snd (fst (fst P))) iZ))). + Lemma erep2trep_correct : forall P, rep2T (erep2trep P) = extendedToTwisted (rep2E P). + Proof. + unfold rep2T, rep2E, erep2trep, extendedToTwisted; destruct P as [[[]]]; simpl. + rewrite !unfoldDiv, <-!FRepMul_correct, <-FRepInv_correct. reflexivity. + Qed. + + (** TODO: possibly move me, remove local *) + Local Ltac replace_option_match_with_option_rect := + idtac; + lazymatch goal with + | [ |- _ = ?RHS :> ?T ] + => lazymatch RHS with + | match ?a with None => ?N | Some x => @?S x end + => replace RHS with (option_rect (fun _ => T) S N a) by (destruct a; reflexivity) + end + end. + + (** TODO: Move me, remove Local *) + Definition proj1_sig_unmatched {A P} := @proj1_sig A P. + Definition proj1_sig_nounfold {A P} := @proj1_sig A P. + Definition proj1_sig_unfold {A P} := Eval cbv [proj1_sig] in @proj1_sig A P. + Local Ltac unfold_proj1_sig_exist := + (** Change the first [proj1_sig] into [proj1_sig_unmatched]; if it's applied to [exist], mark it as unfoldable, otherwise mark it as not unfoldable. Then repeat. Finally, unfold. *) + repeat (change @proj1_sig with @proj1_sig_unmatched at 1; + match goal with + | [ |- context[proj1_sig_unmatched (exist _ _ _)] ] + => change @proj1_sig_unmatched with @proj1_sig_unfold + | _ => change @proj1_sig_unmatched with @proj1_sig_nounfold + end); + (* [proj1_sig_nounfold] is a thin wrapper around [proj1_sig]; unfolding it restores [proj1_sig]. Unfolding [proj1_sig_nounfold] exposes the pattern match, which is reduced by ι. *) + cbv [proj1_sig_nounfold proj1_sig_unfold]. + + (** TODO: possibly move me, remove Local *) + Local Ltac reflexivity_when_unification_is_stupid_about_evars + := repeat first [ reflexivity + | apply f_equal ]. + + + Local Existing Instance eq_Reflexive. (* To get some of the [setoid_rewrite]s below to work, we need to infer [Reflexive eq] before [Reflexive Equivalence.equiv] *) + + (* TODO: move me *) + Lemma fold_rep2E x y z t + : (rep2F x, rep2F y, rep2F z, rep2F t) = rep2E (((x, y), z), t). + Proof. reflexivity. Qed. + Lemma commute_negateExtended'_rep2E x y z t + : negateExtended' (rep2E (((x, y), z), t)) + = rep2E (((FRepOpp x, y), z), FRepOpp t). + Proof. simpl; autorewrite with FRepOperations; reflexivity. Qed. + Lemma fold_rep2E_ffff x y z t + : (x, y, z, t) = rep2E (((toRep x, toRep y), toRep z), toRep t). + Proof. simpl; rewrite !rcFOK; reflexivity. Qed. + Lemma fold_rep2E_rrfr x y z t + : (rep2F x, rep2F y, z, rep2F t) = rep2E (((x, y), toRep z), t). + Proof. simpl; rewrite !rcFOK; reflexivity. Qed. + Lemma fold_rep2E_0fff y z t + : (0%F, y, z, t) = rep2E (((toRep 0%F, toRep y), toRep z), toRep t). + Proof. apply fold_rep2E_ffff. Qed. + Lemma fold_rep2E_ff1f x y t + : (x, y, 1%F, t) = rep2E (((toRep x, toRep y), toRep 1%F), toRep t). + Proof. apply fold_rep2E_ffff. Qed. + Lemma commute_negateExtended'_rep2E_rrfr x y z t + : negateExtended' (unRep x, unRep y, z, unRep t) + = rep2E (((FRepOpp x, y), toRep z), FRepOpp t). + Proof. rewrite <- commute_negateExtended'_rep2E; simpl; rewrite !rcFOK; reflexivity. Qed. + + Hint Rewrite @F_mul_0_l commute_negateExtended'_rep2E_rrfr fold_rep2E_0fff (@fold_rep2E_ff1f (fst (proj1_sig B))) @if_F_eq_dec_if_F_eqb compare_enc (if_map unRep) (fun T => Let_app2_In (T := T) unRep unRep) @F_pow_2_r @unfoldDiv : EdDSA_opts. + Hint Rewrite <- unifiedAddM1Rep_correct erep2trep_correct (fun x y z bound => iter_op_proj rep2E unifiedAddM1Rep unifiedAddM1' x y z N.testbit_nat bound unifiedAddM1Rep_correct) FRep2wire_correct: EdDSA_opts. + + Lemma sharper_verify : forall pk l msg sig, { verify | verify = ed25519_verify pk l msg sig}. + Proof. + eexists; intros. + cbv [ed25519_verify EdDSA.verify + ed25519params curve25519params + EdDSA.E EdDSA.B EdDSA.b EdDSA.l EdDSA.H + EdDSA.PointEncoding EdDSA.FlEncoding EdDSA.FqEncoding]. + + etransitivity. + Focus 2. + { repeat match goal with + | [ |- ?x = ?x ] => reflexivity + | _ => replace_option_match_with_option_rect + | [ |- _ = option_rect _ _ _ _ ] + => eapply option_rect_Proper_nd; [ intro | reflexivity.. ] + end. + set_evars. + rewrite<- E.point_eqb_correct. + rewrite solve_for_R; unfold E.sub. + rewrite E.opp_mul. + let p1 := constr:(scalarMultM1_rep eq_refl) in + let p2 := constr:(unifiedAddM1_rep eq_refl) in + repeat match goal with + | |- context [(_ * E.opp ?P)%E] => + rewrite <-(unExtendedPoint_mkExtendedPoint P); + rewrite negateExtended_correct; + rewrite <-p1 + | |- context [(_ * ?P)%E] => + rewrite <-(unExtendedPoint_mkExtendedPoint P); + rewrite <-p1 + | _ => rewrite p2 + end; + rewrite ?Znat.Z_nat_N, <-?Word.wordToN_nat; + subst_evars; + reflexivity. + } Unfocus. + + etransitivity. + Focus 2. + { lazymatch goal with |- _ = option_rect _ _ ?false ?dec => + symmetry; etransitivity; [|eapply (option_rect_option_map (fun (x:F _) => Z.to_N x) _ false dec)] + end. + eapply option_rect_Proper_nd; [intro|reflexivity..]. + match goal with + | [ |- ?RHS = ?e ?v ] + => let RHS' := (match eval pattern v in RHS with ?RHS' _ => RHS' end) in + unify e RHS' + end. + reflexivity. + } Unfocus. + rewrite <-decode_scalar_correct. + + etransitivity. + Focus 2. + { do 2 (eapply option_rect_Proper_nd; [intro|reflexivity..]). + symmetry; apply decode_test_encode_test. + } Unfocus. + + rewrite enc'_correct. + cbv [unExtendedPoint unifiedAddM1 negateExtended scalarMultM1]. + unfold_proj1_sig_exist. + + etransitivity. + Focus 2. + { do 2 (eapply option_rect_Proper_nd; [intro|reflexivity..]). + set_evars. + repeat match goal with + | [ |- appcontext[@proj1_sig ?A ?P (@iter_op ?T ?f ?neutral ?T' ?testbit ?exp ?base ?bound)] ] + => erewrite (@iter_op_proj T _ _ (@proj1_sig _ _)) by reflexivity + end. + subst_evars. + reflexivity. } + Unfocus. + + cbv [mkExtendedPoint E.zero]. + unfold_proj1_sig_exist. + rewrite B_proj. + + etransitivity. + Focus 2. + { do 1 (eapply option_rect_Proper_nd; [intro|reflexivity..]). + set_evars. + lazymatch goal with |- _ = option_rect _ _ ?false ?dec => + symmetry; etransitivity; [|eapply (option_rect_option_map (@proj1_sig _ _) _ false dec)] + end. + eapply option_rect_Proper_nd; [intro|reflexivity..]. + match goal with + | [ |- ?RHS = ?e ?v ] + => let RHS' := (match eval pattern v in RHS with ?RHS' _ => RHS' end) in + unify e RHS' + end. + reflexivity. + } Unfocus. + + cbv [dec PointEncoding point_encoding]. + etransitivity. + Focus 2. + { do 1 (eapply option_rect_Proper_nd; [intro|reflexivity..]). + etransitivity. + Focus 2. + { apply f_equal. + symmetry. + apply point_dec_coordinates_correct. } + Unfocus. + reflexivity. } + Unfocus. + + cbv iota beta delta [point_dec_coordinates sign_bit dec FqEncoding modular_word_encoding E.solve_for_x2 sqrt_mod_q]. - cbv [scalarMultM1 iter_op]. - cbv iota zeta delta [test_and_op]. - Local Arguments funexp {_} _ {_} {_}. (* do not display the initializer and iteration bound for now *) + etransitivity. + Focus 2. { + do 1 (eapply option_rect_Proper_nd; [|reflexivity..]). cbv beta delta [pointwise_relation]. intro. + etransitivity. + Focus 2. + { apply f_equal. + lazymatch goal with + | [ |- _ = ?term :> ?T ] + => lazymatch term with (match ?a with None => ?N | Some x => @?S x end) + => let term' := constr:((option_rect (fun _ => T) S N) a) in + replace term with term' by reflexivity + end + end. + reflexivity. } Unfocus. reflexivity. } Unfocus. + + etransitivity. + Focus 2. { + do 1 (eapply option_rect_Proper_nd; [cbv beta delta [pointwise_relation]; intro|reflexivity..]). + do 1 (eapply option_rect_Proper_nd; [ intro; reflexivity | reflexivity | ]). + eapply option_rect_Proper_nd; [ cbv beta delta [pointwise_relation]; intro | reflexivity.. ]. + replace_let_in_with_Let_In. + reflexivity. + } Unfocus. - Axiom rep : forall {m}, list Z -> F m -> Prop. - Axiom decode_point_limbs : word (S (b-1)) -> option (list Z * list Z). - Axiom point_dec_rep : forall P_ P lx ly, dec P_ = Some P -> decode_point_limbs P_ = Some (lx, ly) -> rep lx (fst (proj1_sig P)) /\ rep ly (fst (proj1_sig P)). -Admitted.
\ No newline at end of file + etransitivity. + Focus 2. { + do 1 (eapply option_rect_Proper_nd; [cbv beta delta [pointwise_relation]; intro|reflexivity..]). + set_evars. + rewrite option_rect_function. (* turn the two option_rects into one *) + subst_evars. + simpl_option_rect. + do 1 (eapply option_rect_Proper_nd; [cbv beta delta [pointwise_relation]; intro|reflexivity..]). + (* push the [option_rect] inside until it hits a [Some] or a [None] *) + repeat match goal with + | _ => commute_option_rect_Let_In + | [ |- _ = Let_In _ _ ] + => apply Let_In_Proper_nd; [ reflexivity | cbv beta delta [pointwise_relation]; intro ] + | [ |- ?LHS = option_rect ?P ?S ?N (if ?b then ?t else ?f) ] + => transitivity (if b then option_rect P S N t else option_rect P S N f); + [ + | destruct b; reflexivity ] + | [ |- _ = if ?b then ?t else ?f ] + => apply (f_equal2 (fun x y => if b then x else y)) + | [ |- _ = false ] => reflexivity + | _ => progress simpl_option_rect + end. + reflexivity. + } Unfocus. + + cbv iota beta delta [q d a]. + + rewrite wire2FRep_correct. + + etransitivity. + Focus 2. { + eapply option_rect_Proper_nd; [|reflexivity..]. cbv beta delta [pointwise_relation]. intro. + rewrite <-!(option_rect_option_map rep2F). + eapply option_rect_Proper_nd; [|reflexivity..]. cbv beta delta [pointwise_relation]. intro. + autorewrite with EdDSA_opts. + rewrite <-(rcFOK 1%F). + pattern Ed25519.d at 1. rewrite <-(rcFOK Ed25519.d) at 1. + pattern Ed25519.a at 1. rewrite <-(rcFOK Ed25519.a) at 1. + rewrite <- (rcSOK (Z.to_N (Ed25519.q / 8 + 1))). + autorewrite with EdDSA_opts. + (Let_In_unRep). + eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [pointwise_relation]; intro]. + etransitivity. Focus 2. eapply Let_In_Proper_nd; [|cbv beta delta [pointwise_relation]; intro;reflexivity]. { + rewrite FSRepPow_correct by (rewrite rcSOK; cbv; omega). + (Let_In_unRep). + etransitivity. Focus 2. eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [pointwise_relation]; intro]. { + set_evars. + rewrite <-(rcFOK sqrt_minus1). + autorewrite with EdDSA_opts. + subst_evars. + reflexivity. } Unfocus. + rewrite pull_Let_In. + reflexivity. } Unfocus. + set_evars. + (Let_In_unRep). + + subst_evars. eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [pointwise_relation]; intro]. set_evars. + + autorewrite with EdDSA_opts. + + subst_evars. + lazymatch goal with |- _ = if ?b then ?t else ?f => apply (f_equal2 (fun x y => if b then x else y)) end; [|reflexivity]. + eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [pointwise_relation]; intro]. + set_evars. + + unfold twistedToExtended. + autorewrite with EdDSA_opts. + progress cbv beta delta [erep2trep]. + + subst_evars. + reflexivity. } Unfocus. + reflexivity. + Defined. +End Ed25519Frep.
\ No newline at end of file diff --git a/src/Specific/GF1305.v b/src/Specific/GF1305.v new file mode 100644 index 000000000..b004a60d1 --- /dev/null +++ b/src/Specific/GF1305.v @@ -0,0 +1,74 @@ +Require Import Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseRep. +Require Import Coq.Lists.List Crypto.Util.ListUtil. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.BaseSystem. +Import ListNotations. +Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Local Open Scope Z. + +(* BEGIN PseudoMersenneBaseParams instance construction. *) + +Definition modulus : Z := 2^130 - 5. +Lemma prime_modulus : prime modulus. Admitted. + +Instance params1305 : PseudoMersenneBaseParams modulus. + construct_params prime_modulus 5%nat 130. +Defined. + +Definition mul2modulus := Eval compute in (construct_mul2modulus params1305). + +Instance subCoeff : SubtractionCoefficient modulus params1305. + apply Build_SubtractionCoefficient with (coeff := mul2modulus); cbv; auto. +Defined. + +(* END PseudoMersenneBaseParams instance construction. *) + +(* Precompute k and c *) +Definition k_ := Eval compute in k. +Definition c_ := Eval compute in c. + +(* Makes Qed not take forever *) +Opaque Z.shiftr Pos.iter Z.div2 Pos.div2 Pos.div2_up Pos.succ Z.land + Z.of_N Pos.land N.ldiff Pos.pred_N Pos.pred_double Z.opp Z.mul Pos.mul + Let_In digits Z.add Pos.add Z.pos_sub. + +Local Open Scope nat_scope. +Lemma GF1305Base26_mul_reduce_formula : + forall f0 f1 f2 f3 f4 g0 g1 g2 g3 g4, + {ls | forall f g, rep [f0;f1;f2;f3;f4] f + -> rep [g0;g1;g2;g3;g4] g + -> rep ls (f*g)%F}. +Proof. + eexists; intros ? ? Hf Hg. + pose proof (carry_mul_opt_correct k_ c_ (eq_refl k) (eq_refl c_) [0;4;3;2;1;0]_ _ _ _ Hf Hg) as Hfg. + compute_formula. +Defined. + +Lemma GF1305Base26_add_formula : + forall f0 f1 f2 f3 f4 g0 g1 g2 g3 g4, + {ls | forall f g, rep [f0;f1;f2;f3;f4] f + -> rep [g0;g1;g2;g3;g4] g + -> rep ls (f + g)%F}. +Proof. + eexists; intros ? ? Hf Hg. + pose proof (add_opt_rep _ _ _ _ Hf Hg) as Hfg. + compute_formula. +Defined. + +Lemma GF25519Base25Point5_sub_formula : + forall f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 + g0 g1 g2 g3 g4 g5 g6 g7 g8 g9, + {ls | forall f g, rep [f0;f1;f2;f3;f4;f5;f6;f7;f8;f9] f + -> rep [g0;g1;g2;g3;g4;g5;g6;g7;g8;g9] g + -> rep ls (f - g)%F}. +Proof. + eexists. + intros f g Hf Hg. + pose proof (sub_opt_rep _ _ _ _ Hf Hg) as Hfg. + compute_formula. +Defined.
\ No newline at end of file diff --git a/src/Specific/GF25519.v b/src/Specific/GF25519.v index 0d3923945..8aaf8caf6 100644 --- a/src/Specific/GF25519.v +++ b/src/Specific/GF25519.v @@ -1,529 +1,180 @@ -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.BaseSystem Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseRep. Require Import Coq.Lists.List Crypto.Util.ListUtil. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.BaseSystem. +Require Import Crypto.Rep. Import ListNotations. Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. -Require Import Coq.QArith.QArith Coq.QArith.Qround. -Require Import Crypto.Tactics.VerdiTactics. -Close Scope Q. - -Ltac twoIndices i j base := - intros; - assert (In i (seq 0 (length base))) by nth_tac; - assert (In j (seq 0 (length base))) by nth_tac; - repeat match goal with [ x := _ |- _ ] => subst x end; - simpl in *; repeat break_or_hyp; try omega; vm_compute; reflexivity. - -Module Base25Point5_10limbs <: BaseCoefs. - Local Open Scope Z_scope. - Definition log_base := Eval compute in map (fun i => (Qceiling (Z_of_nat i *255 # 10))) (seq 0 10). - Definition base := map (fun x => 2 ^ x) log_base. - - Lemma base_positive : forall b, In b base -> b > 0. - Proof. - compute; intuition; subst; intuition. - Qed. - - Lemma b0_1 : forall x, nth_default x base 0 = 1. - Proof. - auto. - Qed. - - Lemma base_good : - forall i j, (i+j < length base)%nat -> - let b := nth_default 0 base in - let r := (b i * b j) / b (i+j)%nat in - b i * b j = r * b (i+j)%nat. - Proof. - twoIndices i j base. - Qed. -End Base25Point5_10limbs. - -Module Modulus25519 <: PrimeModulus. - Local Open Scope Z_scope. - Definition modulus : Z := 2^255 - 19. - Lemma prime_modulus : prime modulus. Admitted. -End Modulus25519. - -Module F25519Base25Point5Params <: PseudoMersenneBaseParams Base25Point5_10limbs Modulus25519. - Import Base25Point5_10limbs. - Import Modulus25519. - Local Open Scope Z_scope. - (* TODO: do we actually want B and M "up there" in the type parameters? I was - * imagining writing something like "Paramter Module M : Modulus". *) +Local Open Scope Z. - Definition k := 255. - Definition c := 19. - Lemma modulus_pseudomersenne : - modulus = 2^k - c. - Proof. - auto. - Qed. +(* BEGIN PseudoMersenneBaseParams instance construction. *) - Lemma base_matches_modulus : - forall i j, - (i < length base)%nat -> - (j < length base)%nat -> - (i+j >= length base)%nat -> - let b := nth_default 0 base in - let r := (b i * b j) / (2^k * b (i+j-length base)%nat) in - b i * b j = r * (2^k * b (i+j-length base)%nat). - Proof. - twoIndices i j base. - Qed. +Definition modulus : Z := 2^255 - 19. +Lemma prime_modulus : prime modulus. Admitted. - Lemma base_succ : forall i, ((S i) < length base)%nat -> - let b := nth_default 0 base in - b (S i) mod b i = 0. - Proof. - intros; twoIndices i (S i) base. - Qed. - - Lemma base_tail_matches_modulus: - 2^k mod nth_default 0 base (pred (length base)) = 0. - Proof. - auto. - Qed. - - Lemma b0_1 : forall x, nth_default x base 0 = 1. - Proof. - auto. - Qed. - - Lemma k_nonneg : 0 <= k. - Proof. - rewrite Zle_is_le_bool; auto. - Qed. - - Lemma base_range : forall i, 0 <= nth_default 0 log_base i <= k. - Proof. - intros i. - destruct (lt_dec i (length log_base)) as [H|H]. - { repeat (destruct i as [|i]; [cbv; intuition; discriminate|simpl in H; try omega]). } - { rewrite nth_default_eq, nth_overflow by omega. cbv. intuition; discriminate. } - Qed. - - Lemma base_monotonic: forall i : nat, (i < pred (length log_base))%nat -> - (0 <= nth_default 0 log_base i <= nth_default 0 log_base (S i)). - Proof. - intros i H. - repeat (destruct i; [cbv; intuition; congruence|]); - contradict H; cbv; firstorder. - Qed. -End F25519Base25Point5Params. +Instance params25519 : PseudoMersenneBaseParams modulus. + construct_params prime_modulus 10%nat 255. +Defined. -Module F25519Base25Point5 := PseudoMersenneBase Base25Point5_10limbs Modulus25519 F25519Base25Point5Params. +Definition mul2modulus := Eval compute in (construct_mul2modulus params25519). -Section F25519Base25Point5Formula. - Import F25519Base25Point5 Base25Point5_10limbs F25519Base25Point5 F25519Base25Point5Params. +Instance subCoeff : SubtractionCoefficient modulus params25519. + apply Build_SubtractionCoefficient with (coeff := mul2modulus); cbv; auto. +Defined. -Definition Z_add_opt := Eval compute in Z.add. -Definition Z_sub_opt := Eval compute in Z.sub. -Definition Z_mul_opt := Eval compute in Z.mul. -Definition Z_div_opt := Eval compute in Z.div. -Definition Z_pow_opt := Eval compute in Z.pow. +(* END PseudoMersenneBaseParams instance construction. *) -Definition nth_default_opt {A} := Eval compute in @nth_default A. -Definition map_opt {A B} := Eval compute in @map A B. +(* Precompute k and c *) +Definition k_ := Eval compute in k. +Definition c_ := Eval compute in c. -Ltac opt_step := - match goal with - | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ] - => refine (_ : match e with nil => _ | _ => _ end = _); - destruct e - end. +(* Makes Qed not take forever *) +Opaque Z.shiftr Pos.iter Z.div2 Pos.div2 Pos.div2_up Pos.succ Z.land + Z.of_N Pos.land N.ldiff Pos.pred_N Pos.pred_double Z.opp Z.mul Pos.mul + Let_In digits Z.add Pos.add Z.pos_sub. -Definition E_mul_bi'_step - (mul_bi' : nat -> E.digits -> list Z) - (i : nat) (vsr : E.digits) - : list Z - := match vsr with - | [] => [] - | v :: vsr' => (v * E.crosscoef i (length vsr'))%Z :: mul_bi' i vsr' - end. - -Definition E_mul_bi'_opt_step_sig - (mul_bi' : nat -> E.digits -> list Z) - (i : nat) (vsr : E.digits) - : { l : list Z | l = E_mul_bi'_step mul_bi' i vsr }. +Local Open Scope nat_scope. +Lemma GF25519Base25Point5_mul_reduce_formula : + forall f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 + g0 g1 g2 g3 g4 g5 g6 g7 g8 g9, + {ls | forall f g, rep [f0;f1;f2;f3;f4;f5;f6;f7;f8;f9] f + -> rep [g0;g1;g2;g3;g4;g5;g6;g7;g8;g9] g + -> rep ls (f*g)%F}. Proof. - eexists. - cbv [E_mul_bi'_step]. - opt_step. - { reflexivity. } - { cbv [E.crosscoef EC.base Base25Point5_10limbs.base]. - change Z.div with Z_div_opt. - change Z.pow with Z_pow_opt. - change Z.mul with Z_mul_opt at 2 3 4 5. - change @nth_default with @nth_default_opt. - change @map with @map_opt. - reflexivity. } -Defined. - -Definition E_mul_bi'_opt_step - (mul_bi' : nat -> E.digits -> list Z) - (i : nat) (vsr : E.digits) - : list Z - := Eval cbv [proj1_sig E_mul_bi'_opt_step_sig] in - proj1_sig (E_mul_bi'_opt_step_sig mul_bi' i vsr). - -Fixpoint E_mul_bi'_opt - (i : nat) (vsr : E.digits) {struct vsr} - : list Z - := E_mul_bi'_opt_step E_mul_bi'_opt i vsr. + eexists; intros ? ? Hf Hg. + pose proof (carry_mul_opt_correct k_ c_ (eq_refl k_) (eq_refl c_) [0;9;8;7;6;5;4;3;2;1;0]_ _ _ _ Hf Hg) as Hfg. + compute_formula. +Time Defined. -Definition E_mul_bi'_opt_correct - (i : nat) (vsr : E.digits) - : E_mul_bi'_opt i vsr = E.mul_bi' i vsr. -Proof. - revert i; induction vsr as [|vsr vsrs IHvsr]; intros. - { reflexivity. } - { simpl E.mul_bi'. - rewrite <- IHvsr; clear IHvsr. - unfold E_mul_bi'_opt, E_mul_bi'_opt_step. - apply f_equal2; [ | reflexivity ]. - cbv [E.crosscoef EC.base Base25Point5_10limbs.base]. - change Z.div with Z_div_opt. - change Z.pow with Z_pow_opt. - change Z.mul with Z_mul_opt at 2. - change @nth_default with @nth_default_opt. - change @map with @map_opt. - reflexivity. } -Qed. +Extraction "/tmp/test.ml" GF25519Base25Point5_mul_reduce_formula. +(* It's easy enough to use extraction to get the proper nice-looking formula. + * More Ltac acrobatics will be needed to get out that formula for further use in Coq. + * The easiest fix will be to make the proof script above fully automated, + * using [abstract] to contain the proof part. *) -Definition E_mul'_step - (mul' : E.digits -> E.digits -> E.digits) - (usr vs : E.digits) - : E.digits - := match usr with - | [] => [] - | u :: usr' => E.add (E.mul_each u (E.mul_bi (length usr') vs)) (mul' usr' vs) - end. -Definition E_mul'_opt_step_sig - (mul' : E.digits -> E.digits -> E.digits) - (usr vs : E.digits) - : { d : E.digits | d = E_mul'_step mul' usr vs }. +Lemma GF25519Base25Point5_add_formula : + forall f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 + g0 g1 g2 g3 g4 g5 g6 g7 g8 g9, + {ls | forall f g, rep [f0;f1;f2;f3;f4;f5;f6;f7;f8;f9] f + -> rep [g0;g1;g2;g3;g4;g5;g6;g7;g8;g9] g + -> rep ls (f + g)%F}. Proof. eexists. - cbv [E_mul'_step]. - match goal with - | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ] - => refine (_ : match e with nil => _ | _ => _ end = _); - destruct e - end. - { reflexivity. } - { cbv [E.mul_each E.mul_bi]. - rewrite <- E_mul_bi'_opt_correct. - reflexivity. } + intros f g Hf Hg. + pose proof (add_opt_rep _ _ _ _ Hf Hg) as Hfg. + compute_formula. Defined. -Definition E_mul'_opt_step - (mul' : E.digits -> E.digits -> E.digits) - (usr vs : E.digits) - : E.digits - := Eval cbv [proj1_sig E_mul'_opt_step_sig] in proj1_sig (E_mul'_opt_step_sig mul' usr vs). - -Fixpoint E_mul'_opt - (usr vs : E.digits) - : E.digits - := E_mul'_opt_step E_mul'_opt usr vs. - -Definition E_mul'_opt_correct - (usr vs : E.digits) - : E_mul'_opt usr vs = E.mul' usr vs. -Proof. - revert vs; induction usr as [|usr usrs IHusr]; intros. - { reflexivity. } - { simpl. - rewrite <- IHusr; clear IHusr. - apply f_equal2; [ | reflexivity ]. - cbv [E.mul_each E.mul_bi]. - rewrite <- E_mul_bi'_opt_correct. - reflexivity. } -Qed. - -Definition mul_opt_sig (us vs : T) : { b : B.digits | b = mul us vs }. +Lemma GF25519Base25Point5_sub_formula : + forall f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 + g0 g1 g2 g3 g4 g5 g6 g7 g8 g9, + {ls | forall f g, rep [f0;f1;f2;f3;f4;f5;f6;f7;f8;f9] f + -> rep [g0;g1;g2;g3;g4;g5;g6;g7;g8;g9] g + -> rep ls (f - g)%F}. Proof. eexists. - cbv [mul E.mul E.mul_each E.mul_bi E.mul_bi' E.zeros EC.base reduce]. - rewrite <- E_mul'_opt_correct. - reflexivity. + intros f g Hf Hg. + pose proof (sub_opt_rep _ _ _ _ Hf Hg) as Hfg. + compute_formula. Defined. -Definition mul_opt (us vs : T) : B.digits - := Eval cbv [proj1_sig mul_opt_sig] in proj1_sig (mul_opt_sig us vs). +Definition F25519Rep := (Z * Z * Z * Z * Z * Z * Z * Z * Z * Z)%type. -Definition mul_opt_correct us vs - : mul_opt us vs = mul us vs - := proj2_sig (mul_opt_sig us vs). +Definition F25519toRep (x:F (2^255 - 19)) : F25519Rep := (0, 0, 0, 0, 0, 0, 0, 0, 0, FieldToZ x)%Z. +Definition F25519unRep (rx:F25519Rep) := + let '(x9, x8, x7, x6, x5, x4, x3, x2, x1, x0) := rx in + ModularBaseSystem.decode [x0;x1;x2;x3;x4;x5;x6;x7;x8;x9]. -Lemma beq_nat_eq_nat_dec {R} (x y : nat) (a b : R) - : (if EqNat.beq_nat x y then a else b) = (if eq_nat_dec x y then a else b). -Proof. - destruct (eq_nat_dec x y) as [H|H]; - [ rewrite (proj2 (@beq_nat_true_iff _ _) H); reflexivity - | rewrite (proj2 (@beq_nat_false_iff _ _) H); reflexivity ]. -Qed. - -Lemma pull_app_if_sumbool {A B X Y} (b : sumbool X Y) (f g : A -> B) (x : A) - : (if b then f x else g x) = (if b then f else g) x. -Proof. - destruct b; reflexivity. -Qed. +Global Instance F25519RepConversions : RepConversions (F (2^255 - 19)) F25519Rep := + { + toRep := F25519toRep; + unRep := F25519unRep + }. -Lemma map_nth_default_always {A B} (f : A -> B) (n : nat) (x : A) (l : list A) - : nth_default (f x) (map f l) n = f (nth_default x l n). +Lemma F25519RepConversionsOK : RepConversionsOK F25519RepConversions. Proof. - revert n; induction l; simpl; intro n; destruct n; [ try reflexivity.. ]. - nth_tac. + unfold F25519RepConversions, RepConversionsOK, unRep, toRep, F25519toRep, F25519unRep; intros. + change (ModularBaseSystem.decode (ModularBaseSystem.encode x) = x). + eauto using ModularBaseSystemProofs.rep_decode, ModularBaseSystemProofs.encode_rep. Qed. -Definition log_cap_opt_sig - (i : nat) - : { z : Z | i < length (Base25Point5_10limbs.log_base) -> (2^z)%Z = cap i }. -Proof. - eexists. - cbv [cap Base25Point5_10limbs.base]. - intros. - rewrite map_length in *. - erewrite map_nth_default; [|assumption]. - instantiate (2 := 0%Z). - (** For the division of maps of (2 ^ _) over lists, replace it with 2 ^ (_ - _) *) - - lazymatch goal with - | [ |- _ = (if eq_nat_dec ?a ?b then (2^?x/2^?y)%Z else (nth_default 0 (map (fun x => (2^x)%Z) ?ls) ?si / 2^?d)%Z) ] - => transitivity (2^if eq_nat_dec a b then (x-y)%Z else nth_default 0 ls si - d)%Z; - [ apply f_equal | break_if ] +Definition F25519Rep_mul (f g:F25519Rep) : F25519Rep. + refine ( + let '(f9, f8, f7, f6, f5, f4, f3, f2, f1, f0) := f in + let '(g9, g8, g7, g6, g5, g4, g3, g2, g1, g0) := g in _). + (* FIXME: the r should not be present in generated code *) + pose (r := proj1_sig (GF25519Base25Point5_mul_reduce_formula f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 + g0 g1 g2 g3 g4 g5 g6 g7 g8 g9)). + simpl in r. + unfold F25519Rep. + repeat let t' := (eval cbv beta delta [r] in r) in + lazymatch t' with Let_In ?arg ?f => + let x := fresh "x" in + refine (let x := arg in _); + let t'' := (eval cbv beta in (f x)) in + change (Let_In arg f) with t'' in r + end. + let t' := (eval cbv beta delta [r] in r) in + lazymatch t' with [?r0;?r1;?r2;?r3;?r4;?r5;?r6;?r7;?r8;?r9] => + clear r; + exact (r9, r8, r7, r6, r5, r4, r3, r2, r1, r0) end. - - Focus 2. - apply Z.pow_sub_r; [clear;firstorder|apply base_range]. - Focus 2. - erewrite map_nth_default by (omega); instantiate (1 := 0%Z). - rewrite <- Z.pow_sub_r; [ reflexivity | .. ]. - { clear; abstract firstorder. } - { apply base_monotonic. omega. } - Unfocus. - rewrite <-beq_nat_eq_nat_dec. - change Z.sub with Z_sub_opt. - change @nth_default with @nth_default_opt. - change @map with @map_opt. - reflexivity. -Defined. - -Definition log_cap_opt (i : nat) - := Eval cbv [proj1_sig log_cap_opt_sig] in proj1_sig (log_cap_opt_sig i). - -Definition log_cap_opt_correct (i : nat) - : i < length Base25Point5_10limbs.log_base -> (2^log_cap_opt i = cap i)%Z - := proj2_sig (log_cap_opt_sig i). +Time Defined. + +Lemma F25519_mul_OK : RepBinOpOK F25519RepConversions ModularArithmetic.mul F25519Rep_mul. + cbv iota beta delta [RepBinOpOK F25519RepConversions F25519Rep_mul toRep unRep F25519toRep F25519unRep]. + destruct x as [[[[[[[[[x9 x8] x7] x6] x5] x4] x3] x2] x1] x0]. + destruct y as [[[[[[[[[y9 y8] y7] y6] y5] y4] y3] y2] y1] y0]. + let E := constr:(GF25519Base25Point5_mul_reduce_formula x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 y0 y1 y2 y3 y4 y5 y6 y7 y8 y9) in + transitivity (ModularBaseSystem.decode (proj1_sig E)); [|solve[simpl;f_equal]]; + destruct E as [? r]; cbv [proj1_sig]. + cbv [rep ModularBaseSystem.rep PseudoMersenneBase modulus] in r; edestruct r; eauto. +Qed. -Definition carry_opt_sig - (i : nat) (b : B.digits) - : { d : B.digits | i < length Base25Point5_10limbs.log_base -> d = carry i b }. -Proof. - eexists ; intros. - cbv [carry]. - rewrite <- pull_app_if_sumbool. - cbv beta delta [carry_and_reduce carry_simple add_to_nth Base25Point5_10limbs.base]. - rewrite map_length. - repeat lazymatch goal with - | [ |- context[cap ?i] ] - => replace (cap i) with (2^log_cap_opt i)%Z by (apply log_cap_opt_correct; assumption) - end. - lazymatch goal with - | [ |- _ = (if ?br then ?c else ?d) ] - => let x := fresh "x" in let y := fresh "y" in evar (x:B.digits); evar (y:B.digits); transitivity (if br then x else y); subst x; subst y +Definition F25519Rep_add (f g:F25519Rep) : F25519Rep. + refine ( + let '(f9, f8, f7, f6, f5, f4, f3, f2, f1, f0) := f in + let '(g9, g8, g7, g6, g5, g4, g3, g2, g1, g0) := g in _). + let t' := (eval simpl in (proj1_sig (GF25519Base25Point5_add_formula f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 + g0 g1 g2 g3 g4 g5 g6 g7 g8 g9))) in + lazymatch t' with [?r0;?r1;?r2;?r3;?r4;?r5;?r6;?r7;?r8;?r9] => + exact (r9, r8, r7, r6, r5, r4, r3, r2, r1, r0) end. - Focus 2. - cbv zeta. - break_if; - rewrite <- Z.land_ones, <- Z.shiftr_div_pow2 by ( - pose proof (base_range i); pose proof (base_monotonic i); - change @nth_default with @nth_default_opt in *; - cbv beta delta [log_cap_opt]; rewrite beq_nat_eq_nat_dec; break_if; change Z_sub_opt with Z.sub; omega); - reflexivity. - change @nth_default with @nth_default_opt. - change @map with @map_opt. - rewrite <- @beq_nat_eq_nat_dec. - reflexivity. Defined. -Definition carry_opt i b - := Eval cbv beta iota delta [proj1_sig carry_opt_sig] in proj1_sig (carry_opt_sig i b). - -Definition carry_opt_correct i b : i < length Base25Point5_10limbs.log_base -> carry_opt i b = carry i b := proj2_sig (carry_opt_sig i b). - -Definition carry_sequence_opt_sig (is : list nat) (us : B.digits) - : { b : B.digits | (forall i, In i is -> i < length Base25Point5_10limbs.log_base) -> b = carry_sequence is us }. -Proof. - eexists. intros H. - cbv [carry_sequence]. - transitivity (fold_right carry_opt us is). - Focus 2. - { induction is; [ reflexivity | ]. - simpl; rewrite IHis, carry_opt_correct. - - reflexivity. - - apply H; apply in_eq. - - intros. apply H. right. auto. - } - Unfocus. - reflexivity. -Defined. - -Definition carry_sequence_opt is us := Eval cbv [proj1_sig carry_sequence_opt_sig] in - proj1_sig (carry_sequence_opt_sig is us). - -Definition carry_sequence_opt_correct is us - : (forall i, In i is -> i < length Base25Point5_10limbs.log_base) -> carry_sequence_opt is us = carry_sequence is us - := proj2_sig (carry_sequence_opt_sig is us). - -Definition Let_In {A P} (x : A) (f : forall y : A, P y) - := let y := x in f y. - -Definition carry_opt_cps_sig - {T} - (i : nat) - (f : B.digits -> T) - (b : B.digits) - : { d : T | i < length Base25Point5_10limbs.log_base -> d = f (carry i b) }. -Proof. - eexists. intros H. - rewrite <- carry_opt_correct by assumption. - cbv beta iota delta [carry_opt]. - (* TODO: how to match the goal here? Alternatively, rewrite under let binders in carry_opt_sig and remove cbv zeta and restore original match from jgross's commit *) - lazymatch goal with [ |- ?LHS = _ ] => - change (LHS = Let_In (nth_default_opt 0%Z b i) (fun di => (f (if beq_nat i (pred (length Base25Point5_10limbs.log_base)) - then - set_nth 0 - (c * - Z.shiftr (di) (log_cap_opt i) + - nth_default_opt 0 - (set_nth i (Z.land di (Z.ones (log_cap_opt i))) - b) 0)%Z - (set_nth i (Z.land (nth_default_opt 0%Z b i) (Z.ones (log_cap_opt i))) b) - else - set_nth (S i) - (Z.shiftr (di) (log_cap_opt i) + - nth_default_opt 0 - (set_nth i (Z.land (di) (Z.ones (log_cap_opt i))) - b) (S i))%Z - (set_nth i (Z.land (nth_default_opt 0%Z b i) (Z.ones (log_cap_opt i))) b))))) +Definition F25519Rep_sub (f g:F25519Rep) : F25519Rep. + refine ( + let '(f9, f8, f7, f6, f5, f4, f3, f2, f1, f0) := f in + let '(g9, g8, g7, g6, g5, g4, g3, g2, g1, g0) := g in _). + let t' := (eval simpl in (proj1_sig (GF25519Base25Point5_sub_formula f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 + g0 g1 g2 g3 g4 g5 g6 g7 g8 g9))) in + lazymatch t' with [?r0;?r1;?r2;?r3;?r4;?r5;?r6;?r7;?r8;?r9] => + exact (r9, r8, r7, r6, r5, r4, r3, r2, r1, r0) end. - reflexivity. Defined. -Definition carry_opt_cps {T} i f b - := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b). - -Definition carry_opt_cps_correct {T} i f b : - i < length Base25Point5_10limbs.log_base -> - @carry_opt_cps T i f b = f (carry i b) - := proj2_sig (carry_opt_cps_sig i f b). - -Definition carry_sequence_opt_cps_sig (is : list nat) (us : B.digits) - : { b : B.digits | (forall i, In i is -> i < length Base25Point5_10limbs.log_base) -> b = carry_sequence is us }. -Proof. - eexists. - cbv [carry_sequence]. - transitivity (fold_right carry_opt_cps id (List.rev is) us). - Focus 2. - { - assert (forall i, In i (rev is) -> i < length Base25Point5_10limbs.log_base) as Hr. { - subst. intros. rewrite <- in_rev in *. auto. } - remember (rev is) as ris eqn:Heq. - rewrite <- (rev_involutive is), <- Heq. - clear H Heq is. - rewrite fold_left_rev_right. - revert us; induction ris; [ reflexivity | ]; intros. - { simpl. - rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption]. - rewrite carry_opt_cps_correct; [reflexivity|]. - apply Hr; left; reflexivity. - } } - Unfocus. - reflexivity. -Defined. - -Definition carry_sequence_opt_cps is us := Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in - proj1_sig (carry_sequence_opt_cps_sig is us). - -Definition carry_sequence_opt_cps_correct is us - : (forall i, In i is -> i < length Base25Point5_10limbs.log_base) -> carry_sequence_opt_cps is us = carry_sequence is us - := proj2_sig (carry_sequence_opt_cps_sig is us). - -Lemma mul_opt_rep: - forall (u v : T) (x y : F Modulus25519.modulus), rep u x -> rep v y -> rep (mul_opt u v) (x * y)%F. -Proof. - intros. - rewrite mul_opt_correct. - auto using mul_rep. -Qed. - -Lemma carry_sequence_opt_cps_rep - : forall (is : list nat) (us : list Z) (x : F Modulus25519.modulus), - (forall i : nat, In i is -> i < length Base25Point5_10limbs.base) -> - length us = length Base25Point5_10limbs.base -> - rep us x -> rep (carry_sequence_opt_cps is us) x. -Proof. - intros. - rewrite carry_sequence_opt_cps_correct by assumption. - apply carry_sequence_rep; assumption. -Qed. - -Definition carry_mul_opt - (is : list nat) - (us vs : list Z) - : list Z - := Eval cbv [B.add - E.add E.mul E.mul' E.mul_bi E.mul_bi' E.mul_each E.zeros EC.base E_mul'_opt - E_mul'_opt_step E_mul_bi'_opt E_mul_bi'_opt_step - List.app List.rev Z_div_opt Z_mul_opt Z_pow_opt - Z_sub_opt app beq_nat log_cap_opt carry_opt_cps carry_sequence_opt_cps error firstn - fold_left fold_right id length map map_opt mul mul_opt nth_default nth_default_opt - nth_error plus pred reduce rev seq set_nth skipn value base] in - carry_sequence_opt_cps is (mul_opt us vs). - -Lemma carry_mul_opt_correct - : forall (is : list nat) (us vs : list Z) (x y: F Modulus25519.modulus), - rep us x -> rep vs y -> - (forall i : nat, In i is -> i < length Base25Point5_10limbs.base) -> - length (mul_opt us vs) = length base -> - rep (carry_mul_opt is us vs) (x*y)%F. -Proof. - intros is us vs x y; intros. - change (carry_mul_opt _ _ _) with (carry_sequence_opt_cps is (mul_opt us vs)). - apply carry_sequence_opt_cps_rep, mul_opt_rep; auto. +Lemma F25519_add_OK : RepBinOpOK F25519RepConversions ModularArithmetic.add F25519Rep_add. + cbv iota beta delta [RepBinOpOK F25519RepConversions F25519Rep_add toRep unRep F25519toRep F25519unRep]. + destruct x as [[[[[[[[[x9 x8] x7] x6] x5] x4] x3] x2] x1] x0]. + destruct y as [[[[[[[[[y9 y8] y7] y6] y5] y4] y3] y2] y1] y0]. + let E := constr:(GF25519Base25Point5_add_formula x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 y0 y1 y2 y3 y4 y5 y6 y7 y8 y9) in + transitivity (ModularBaseSystem.decode (proj1_sig E)); [|solve[simpl;f_equal]]; + destruct E as [? r]; cbv [proj1_sig]. + cbv [rep ModularBaseSystem.rep PseudoMersenneBase modulus] in r; edestruct r; eauto. Qed. - - - Lemma GF25519Base25Point5_mul_reduce_formula : - forall f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 - g0 g1 g2 g3 g4 g5 g6 g7 g8 g9, - {ls | forall f g, rep [f0;f1;f2;f3;f4;f5;f6;f7;f8;f9] f - -> rep [g0;g1;g2;g3;g4;g5;g6;g7;g8;g9] g - -> rep ls (f*g)%F}. - Proof. - eexists. - intros f g Hf Hg. - - pose proof (carry_mul_opt_correct [0;9;8;7;6;5;4;3;2;1;0]_ _ _ _ Hf Hg) as Hfg. - forward Hfg; [abstract (clear; cbv; intros; repeat break_or_hyp; intuition)|]. - specialize (Hfg H); clear H. - forward Hfg; [exact eq_refl|]. - specialize (Hfg H); clear H. - - cbv [log_base map k c carry_mul_opt] in Hfg. - cbv beta iota delta [Let_In] in Hfg. - rewrite ?Z.mul_0_l, ?Z.mul_0_r, ?Z.mul_1_l, ?Z.mul_1_r, ?Z.add_0_l, ?Z.add_0_r in Hfg. - rewrite ?Z.mul_assoc, ?Z.add_assoc in Hfg. - exact Hfg. - Time Defined. -End F25519Base25Point5Formula. -Extraction "/tmp/test.ml" GF25519Base25Point5_mul_reduce_formula. -(* It's easy enough to use extraction to get the proper nice-looking formula. - * More Ltac acrobatics will be needed to get out that formula for further use in Coq. - * The easiest fix will be to make the proof script above fully automated, - * using [abstract] to contain the proof part. *) +Lemma F25519_sub_OK : RepBinOpOK F25519RepConversions ModularArithmetic.sub F25519Rep_sub. + cbv iota beta delta [RepBinOpOK F25519RepConversions F25519Rep_sub toRep unRep F25519toRep F25519unRep]. + destruct x as [[[[[[[[[x9 x8] x7] x6] x5] x4] x3] x2] x1] x0]. + destruct y as [[[[[[[[[y9 y8] y7] y6] y5] y4] y3] y2] y1] y0]. + let E := constr:(GF25519Base25Point5_sub_formula x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 y0 y1 y2 y3 y4 y5 y6 y7 y8 y9) in + transitivity (ModularBaseSystem.decode (proj1_sig E)); [|solve[simpl;f_equal]]; + destruct E as [? r]; cbv [proj1_sig]. + cbv [rep ModularBaseSystem.rep PseudoMersenneBase modulus] in r; edestruct r; eauto. +Qed.
\ No newline at end of file diff --git a/src/Testbit.v b/src/Testbit.v new file mode 100644 index 000000000..264069587 --- /dev/null +++ b/src/Testbit.v @@ -0,0 +1,212 @@ +Require Import Coq.Lists.List. +Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. +Require Import Crypto.BaseSystem Crypto.BaseSystemProofs. +Require Import Coq.ZArith.ZArith Coq.ZArith.Zdiv. +Require Import Coq.omega.Omega Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith. + +Local Open Scope Z. + +Definition testbit (limb_width n : nat) (us : list Z) := + Z.testbit (nth_default 0 us (n / limb_width)) (Z.of_nat (n mod limb_width)). + +(* identical limb widths *) +Definition uniform_base (l : list Z) r := forall n d, (n < length l)%nat -> + nth n l d = r ^ (Z.of_nat n). + +Definition successive_powers (l : list Z) r := forall n d, (S n < length l)%nat -> + nth (S n) l d = r * nth n l d. + +Fixpoint unfold_bits (limb_width : Z) (us : list Z) := + match us with + | nil => 0 + | u0 :: us' => Z.land u0 (Z.ones limb_width) + Z.shiftl (unfold_bits limb_width us') limb_width + end. + +Lemma uniform_base_first : forall b0 bs r, + uniform_base (b0 :: bs) r -> b0 = 1. +Proof. + boring. + match goal with + | [ H : uniform_base _ _ |- _ ] => unfold uniform_base in H; + specialize (H 0%nat 0); simpl in H; eapply H; omega + end. +Qed. + +Lemma uniform_base_second : forall b0 b1 bs r, + uniform_base (b0 :: b1 :: bs) r -> b1 = r. +Proof. + boring. + match goal with + | [ H : uniform_base _ _ |- _ ] => unfold uniform_base in H; + specialize (H 1%nat 0); cbv [nth length] in H; + rewrite Z.pow_1_r in H; apply H; omega + end. +Qed. + + +Lemma successive_powers_second : forall b0 b1 bs r, + successive_powers (b0 :: b1 :: bs) r -> b1 = r * b0. +Proof. + boring. + match goal with + | [ H : successive_powers _ _ |- _ ] => unfold uniform_base in H; + specialize (H 0%nat 0); cbv [nth length] in H; apply H; omega + end. +Qed. + +Ltac uniform_base_subst := + match goal with + | [H : uniform_base (?b0 :: ?b1 :: _) _ |- _ ]=> + erewrite (uniform_base_first b0); eauto; + erewrite (uniform_base_second b0 b1); eauto + end. + +Lemma successive_powers_tail : forall x0 xs r, successive_powers (x0 :: xs) r -> + successive_powers xs r. +Proof. + unfold successive_powers; boring. +Qed. + +Lemma decode_uniform_shift : forall us base limb_width, (S (length us) <= length base)%nat -> + successive_powers base (2 ^ limb_width) -> + decode base (mul_each (2 ^ limb_width) us) = decode base (0 :: us). +Proof. + unfold decode, decode', accumulate, mul_each; + induction us; induction base; try solve [boring]. + intros; simpl in *. + destruct base; [ boring; omega | ]. + simpl; f_equal. + + erewrite (successive_powers_second _ z); eauto. + ring. + + apply IHus; [ omega | ]. + eapply successive_powers_tail; eassumption. +Qed. + +Lemma uniform_base_successive_powers : forall xs r, uniform_base xs r -> + successive_powers xs r. +Proof. + unfold uniform_base, successive_powers; intros ? ? G n ? ?. + do 2 rewrite G by omega. + rewrite Nat2Z.inj_succ. + apply Z.pow_succ_r. + apply Nat2Z.is_nonneg. +Qed. + +Lemma uniform_base_BaseVector : forall base r, (r > 0) -> (0 < length base)%nat -> + uniform_base base r -> BaseVector base. +Proof. + unfold uniform_base. + intros ? ? r_gt_0 base_nonempty uniform. + constructor. + + intros b In_b_base. + apply In_nth_error_value in In_b_base. + destruct In_b_base as [x nth_error_x]. + pose proof (nth_error_value_length _ _ _ _ nth_error_x) as index_bound. + specialize (uniform x 0 index_bound). + rewrite <- nth_default_eq in uniform. + erewrite nth_error_value_eq_nth_default in uniform; eauto. + subst. + destruct r; [ | apply pos_pow_nat_pos | pose proof (Zlt_neg_0 p) ] ; omega. + + intros. + rewrite nth_default_eq. + rewrite uniform; auto. + + intros. + subst b. + subst r0. + repeat rewrite nth_default_eq. + repeat rewrite uniform by omega; auto. + rewrite <- Z.pow_add_r by apply Nat2Z.is_nonneg. + rewrite Nat2Z.inj_add. + rewrite <- Z.pow_sub_r, <- Z.pow_add_r by omega. + f_equal. + omega. +Qed. + +Definition no_overflow us limb_width := forall n, + Z.land (nth_default 0 us n) (Z.ones limb_width) = nth_default 0 us n. + +Lemma no_overflow_cons : forall u0 us limb_width, + no_overflow (u0 :: us) limb_width -> Z.land u0 (Z.ones limb_width) = u0. +Proof. + unfold no_overflow; intros ? ? ? no_overflow_u0_us. + specialize (no_overflow_u0_us 0%nat). + rewrite nth_default_cons in no_overflow_u0_us. + assumption. +Qed. + +Lemma no_overflow_tail : forall u0 us limb_width, + no_overflow (u0 :: us) limb_width -> no_overflow us limb_width. +Proof. + unfold no_overflow; intros. + erewrite <- nth_default_cons_S; eauto. +Qed. + +Lemma unfold_bits_decode : forall limb_width us base, (0 <= limb_width) -> + (length us <= length base)%nat -> (0 < length base)%nat -> + no_overflow us limb_width -> + uniform_base base (2 ^ limb_width) -> + BaseSystem.decode base us = unfold_bits limb_width us. +Proof. + induction us; boring. + rewrite <- (IHus base) by (omega || eauto using no_overflow_tail). + rewrite decode_cons by (eapply uniform_base_BaseVector; eauto; + rewrite gt_lt_symmetry; apply Z_pow_gt0; omega). + simpl. + f_equal. + + symmetry. eapply no_overflow_cons; eauto. + + rewrite Z.shiftl_mul_pow2 by assumption. + erewrite <- decode_uniform_shift; eauto using uniform_base_successive_powers. + rewrite mul_each_rep. + unfold decode. + apply Z.mul_comm. +Qed. + + +Lemma unfold_bits_indexing : forall us i limb_width, (0 < limb_width)%nat -> + no_overflow us (Z.of_nat limb_width) -> + nth_default 0 us i = + Z.land (Z.shiftr (unfold_bits (Z.of_nat limb_width) us) (Z.of_nat (i * limb_width))) (Z.ones (Z.of_nat limb_width)). +Proof. + induction us; intros. + + rewrite nth_default_nil. + rewrite Z.shiftr_0_l. + auto using Z.land_0_l. + + destruct i; simpl. + - rewrite nth_default_cons. + rewrite Z.shiftr_0_r, Z_land_add_land by omega. + symmetry; eapply no_overflow_cons; eauto. + - rewrite nth_default_cons_S. + erewrite IHus; eauto using no_overflow_tail. + remember (i * limb_width)%nat as k. + rewrite Z_shiftr_add_land by omega. + replace (limb_width + k - limb_width)%nat with k by omega. + reflexivity. +Qed. + +Lemma unfold_bits_testbit : forall limb_width us n, (0 < limb_width)%nat -> + no_overflow us (Z.of_nat limb_width) -> + testbit limb_width n us = Z.testbit (unfold_bits (Z.of_nat limb_width) us) (Z.of_nat n). +Proof. + unfold testbit; intros. + erewrite unfold_bits_indexing; eauto. + rewrite <- Z_testbit_low by + (split; try apply Nat2Z.inj_lt; pose proof (mod_bound_pos n limb_width); omega). + rewrite Z.shiftr_spec by apply Nat2Z.is_nonneg. + f_equal. + rewrite <- Nat2Z.inj_add. + apply Z2Nat.inj; try apply Nat2Z.is_nonneg. + rewrite !Nat2Z.id. + symmetry. + rewrite Nat.add_comm, Nat.mul_comm. + apply Nat.div_mod; omega. +Qed. + +Lemma testbit_spec : forall n us base limb_width, (0 < limb_width)%nat -> + (0 < length base)%nat -> (length us <= length base)%nat -> + no_overflow us (Z.of_nat limb_width) -> + uniform_base base (2 ^ (Z.of_nat limb_width)) -> + testbit limb_width n us = Z.testbit (BaseSystem.decode base us) (Z.of_nat n). +Proof. + intros. + erewrite unfold_bits_testbit, unfold_bits_decode; eauto; omega. +Qed.
\ No newline at end of file diff --git a/src/Util/CaseUtil.v b/src/Util/CaseUtil.v index af04a1e49..cf3ebf29c 100644 --- a/src/Util/CaseUtil.v +++ b/src/Util/CaseUtil.v @@ -10,3 +10,9 @@ Ltac case_max := (exact (le_Sn_le _ _ (not_le _ _ H))) end end. + +Lemma pull_app_if_sumbool {A B X Y} (b : sumbool X Y) (f g : A -> B) (x : A) + : (if b then f x else g x) = (if b then f else g) x. +Proof. + destruct b; reflexivity. +Qed. diff --git a/src/Util/IterAssocOp.v b/src/Util/IterAssocOp.v index 016a4f7bd..6116312e1 100644 --- a/src/Util/IterAssocOp.v +++ b/src/Util/IterAssocOp.v @@ -5,11 +5,15 @@ Local Open Scope equiv_scope. Generalizable All Variables. Section IterAssocOp. Context `{Equivalence T} + {scalar : Type} (op:T->T->T) {op_proper:Proper (equiv==>equiv==>equiv) op} (assoc: forall a b c, op a (op b c) === op (op a b) c) (neutral:T) (neutral_l : forall a, op neutral a === a) - (neutral_r : forall a, op a neutral === a). + (neutral_r : forall a, op a neutral === a) + (testbit : scalar -> nat -> bool) + (scToN : scalar -> N) + (testbit_spec : forall x i, testbit x i = N.testbit_nat (scToN x) i). Existing Instance op_proper. Fixpoint nat_iter_op n a := @@ -51,19 +55,19 @@ Section IterAssocOp. | S exp' => f (funexp f a exp') end. - Definition test_and_op n a (state : nat * T) := + Definition test_and_op sc a (state : nat * T) := let '(i, acc) := state in let acc2 := op acc acc in match i with | O => (0, acc) - | S i' => (i', if N.testbit_nat n i' then op a acc2 else acc2) + | S i' => (i', if testbit sc i' then op a acc2 else acc2) end. - Definition iter_op n a : T := - snd (funexp (test_and_op n a) (N.size_nat n, neutral) (N.size_nat n)). + Definition iter_op sc a bound : T := + snd (funexp (test_and_op sc a) (bound, neutral) bound). - Definition test_and_op_inv n a (s : nat * T) := - snd s === nat_iter_op (N.to_nat (N.shiftr_nat n (fst s))) a. + Definition test_and_op_inv sc a (s : nat * T) := + snd s === nat_iter_op (N.to_nat (N.shiftr_nat (scToN sc) (fst s))) a. Hint Rewrite N.succ_double_spec @@ -91,7 +95,7 @@ Section IterAssocOp. reflexivity. Qed. - Lemma shiftr_succ : forall n i, + Lemma Nshiftr_succ : forall n i, N.to_nat (N.shiftr_nat n i) = if N.testbit_nat n i then S (2 * N.to_nat (N.shiftr_nat n (S i))) @@ -115,21 +119,23 @@ Section IterAssocOp. apply N2Nat.id. Qed. - Lemma test_and_op_inv_step : forall n a s, - test_and_op_inv n a s -> - test_and_op_inv n a (test_and_op n a s). + Lemma test_and_op_inv_step : forall sc a s, + test_and_op_inv sc a s -> + test_and_op_inv sc a (test_and_op sc a s). Proof. destruct s as [i acc]. unfold test_and_op_inv, test_and_op; simpl; intro Hpre. destruct i; [ apply Hpre | ]. simpl. - rewrite shiftr_succ. - case_eq (N.testbit_nat n i); intro; simpl; rewrite Hpre, <- plus_n_O, nat_iter_op_plus; reflexivity. + rewrite Nshiftr_succ. + case_eq (testbit sc i); intro testbit_eq; simpl; + rewrite testbit_spec in testbit_eq; rewrite testbit_eq; + rewrite Hpre, <- plus_n_O, nat_iter_op_plus; reflexivity. Qed. - Lemma test_and_op_inv_holds : forall n a i s, - test_and_op_inv n a s -> - test_and_op_inv n a (funexp (test_and_op n a) s i). + Lemma test_and_op_inv_holds : forall sc a i s, + test_and_op_inv sc a s -> + test_and_op_inv sc a (funexp (test_and_op sc a) s i). Proof. induction i; intros; auto; simpl; apply test_and_op_inv_step; auto. Qed. @@ -144,14 +150,14 @@ Section IterAssocOp. destruct i; rewrite NPeano.Nat.sub_succ_r; subst; rewrite <- IHy; simpl; reflexivity. Qed. - Lemma iter_op_termination : forall n a, - test_and_op_inv n a - (funexp (test_and_op n a) (N.size_nat n, neutral) (N.size_nat n)) -> - iter_op n a === nat_iter_op (N.to_nat n) a. + Lemma iter_op_termination : forall sc a bound, + N.size_nat (scToN sc) <= bound -> + test_and_op_inv sc a + (funexp (test_and_op sc a) (bound, neutral) bound) -> + iter_op sc a bound === nat_iter_op (N.to_nat (scToN sc)) a. Proof. - unfold test_and_op_inv, iter_op; simpl; intros ? ? Hinv. + unfold test_and_op_inv, iter_op; simpl; intros ? ? ? ? Hinv. rewrite Hinv, funexp_test_and_op_index, Minus.minus_diag. - replace (N.shiftr_nat n 0) with n by auto. reflexivity. Qed. @@ -160,25 +166,33 @@ Section IterAssocOp. destruct n; auto; simpl; induction p; simpl; auto; rewrite IHp, Pnat.Pos2Nat.inj_succ; reflexivity. Qed. - Lemma Nshiftr_size : forall n, N.shiftr_nat n (N.size_nat n) = 0%N. + Lemma Nshiftr_size : forall n bound, N.size_nat n <= bound -> + N.shiftr_nat n bound = 0%N. Proof. intros. - rewrite Nsize_nat_equiv. + rewrite <- (Nat2N.id bound). rewrite Nshiftr_nat_equiv. - destruct (N.eq_dec n 0); subst; auto. + destruct (N.eq_dec n 0); subst; [apply N.shiftr_0_l|]. apply N.shiftr_eq_0. - rewrite N.size_log2 by auto. - apply N.lt_succ_diag_r. + rewrite Nsize_nat_equiv in *. + rewrite N.size_log2 in * by auto. + apply N.le_succ_l. + rewrite <- N.compare_le_iff. + rewrite N2Nat.inj_compare. + rewrite <- Compare_dec.nat_compare_le. + rewrite Nat2N.id. + auto. Qed. - - Lemma iter_op_spec : forall n a, iter_op n a === nat_iter_op (N.to_nat n) a. + + Lemma iter_op_spec : forall sc a bound, N.size_nat (scToN sc) <= bound -> + iter_op sc a bound === nat_iter_op (N.to_nat (scToN sc)) a. Proof. intros. - apply iter_op_termination. + apply iter_op_termination; auto. apply test_and_op_inv_holds. unfold test_and_op_inv. simpl. - rewrite Nshiftr_size. + rewrite Nshiftr_size by auto. reflexivity. Qed. diff --git a/src/Util/ListUtil.v b/src/Util/ListUtil.v index 1f9a62457..36d8a3ad3 100644 --- a/src/Util/ListUtil.v +++ b/src/Util/ListUtil.v @@ -433,6 +433,22 @@ Proof. auto. Qed. +Lemma nth_default_cons_S : forall {A} us (u0 : A) n d, + nth_default d (u0 :: us) (S n) = nth_default d us n. +Proof. + boring. +Qed. + +Lemma nth_error_Some_nth_default : forall {T} i x (l : list T), (i < length l)%nat -> + nth_error l i = Some (nth_default x l i). +Proof. + intros ? ? ? ? i_lt_length. + destruct (nth_error_length_exists_value _ _ i_lt_length) as [k nth_err_k]. + unfold nth_default. + rewrite nth_err_k. + reflexivity. +Qed. + Lemma set_nth_cons : forall {T} (x u0 : T) us, set_nth 0 x (u0 :: us) = x :: us. Proof. auto. @@ -533,3 +549,36 @@ Lemma cons_eq_tail : forall {T} (x y:T) xs ys, x::xs = y::ys -> xs=ys. Proof. intros; solve_by_inversion. Qed. + +Lemma map_nth_default_always {A B} (f : A -> B) (n : nat) (x : A) (l : list A) + : nth_default (f x) (map f l) n = f (nth_default x l n). +Proof. + revert n; induction l; simpl; intro n; destruct n; [ try reflexivity.. ]. + nth_tac. +Qed. + +Lemma fold_right_and_True_forall_In_iff : forall {T} (l : list T) (P : T -> Prop), + (forall x, In x l -> P x) <-> fold_right and True (map P l). +Proof. + induction l; intros; simpl; try tauto. + rewrite <- IHl. + intuition (subst; auto). +Qed. + +Lemma fold_right_invariant : forall {A} P (f: A -> A -> A) l x, + P x -> (forall y, In y l -> forall z, P z -> P (f y z)) -> + P (fold_right f x l). +Proof. + induction l; intros ? ? step; auto. + simpl. + apply step; try apply in_eq. + apply IHl; auto. + intros y in_y_l. + apply (in_cons a) in in_y_l. + auto. +Qed. + +Lemma In_firstn : forall {T} n l (x : T), In x (firstn n l) -> In x l. +Proof. + induction n; destruct l; boring. +Qed. diff --git a/src/Util/NatUtil.v b/src/Util/NatUtil.v index 6a62d6c22..1f69b04d2 100644 --- a/src/Util/NatUtil.v +++ b/src/Util/NatUtil.v @@ -56,3 +56,14 @@ Proof. rewrite <- nat_compare_lt; auto. } Qed. + + + +Lemma beq_nat_eq_nat_dec {R} (x y : nat) (a b : R) + : (if EqNat.beq_nat x y then a else b) = (if eq_nat_dec x y then a else b). +Proof. + destruct (eq_nat_dec x y) as [H|H]; + [ rewrite (proj2 (@beq_nat_true_iff _ _) H); reflexivity + | rewrite (proj2 (@beq_nat_false_iff _ _) H); reflexivity ]. +Qed. + diff --git a/src/Util/Tactics.v b/src/Util/Tactics.v new file mode 100644 index 000000000..08ebfe13e --- /dev/null +++ b/src/Util/Tactics.v @@ -0,0 +1,25 @@ +(** * Generic Tactics *) + +(* [pose proof defn], but only if no hypothesis of the same type exists. + most useful for proofs of a proposition *) +Tactic Notation "unique" "pose" "proof" constr(defn) := + let T := type of defn in + match goal with + | [ H : T |- _ ] => fail 1 + | _ => pose proof defn + end. +(* [assert T], but only if no hypothesis of the same type exists. + most useful for proofs of a proposition *) +Tactic Notation "unique" "assert" constr(T) := + match goal with + | [ H : T |- _ ] => fail 1 + | _ => assert T + end. + +(* [assert T], but only if no hypothesis of the same type exists. + most useful for proofs of a proposition *) +Tactic Notation "unique" "assert" constr(T) "by" tactic3(tac) := + match goal with + | [ H : T |- _ ] => fail 1 + | _ => assert T by tac + end. diff --git a/src/Util/WordUtil.v b/src/Util/WordUtil.v index 17d04c60a..6a8831b14 100644 --- a/src/Util/WordUtil.v +++ b/src/Util/WordUtil.v @@ -1,5 +1,6 @@ Require Import Coq.Numbers.Natural.Peano.NPeano. Require Import Coq.ZArith.ZArith. +Require Import Crypto.Util.NatUtil. Require Import Bedrock.Word. Local Open Scope nat_scope. @@ -21,6 +22,39 @@ Proof. auto. Qed. +Lemma wordToN_NToWord_idempotent : forall sz n, (n < Npow2 sz)%N -> + wordToN (NToWord sz n) = n. +Proof. + intros. + rewrite wordToN_nat, NToWord_nat. + rewrite wordToNat_natToWord_idempotent; rewrite Nnat.N2Nat.id; auto. +Qed. + +Lemma NToWord_wordToN : forall sz w, NToWord sz (wordToN w) = w. +Proof. + intros. + rewrite wordToN_nat, NToWord_nat, Nnat.Nat2N.id. + apply natToWord_wordToNat. +Qed. + +Lemma bound_check_nat_N : forall x n, (Z.to_nat x < 2 ^ n)%nat -> (Z.to_N x < Npow2 n)%N. +Proof. + intros x n bound_nat. + rewrite <- Nnat.N2Nat.id, Npow2_nat. + replace (Z.to_N x) with (N.of_nat (Z.to_nat x)) by apply Z_nat_N. + apply (Nat2N_inj_lt _ (pow2 n)). + rewrite pow2_id; assumption. +Qed. + +Lemma weqb_false_iff : forall sz (x y : word sz), weqb x y = false <-> x <> y. +Proof. + split; intros. + + intro eq_xy; apply weqb_true_iff in eq_xy; congruence. + + case_eq (weqb x y); intros weqb_xy; auto. + apply weqb_true_iff in weqb_xy. + congruence. +Qed. + Definition wfirstn n {m} (w : Word.word m) {H : n <= m} : Word.word n. refine (Word.split1 n (m - n) (match _ in _ = N return Word.word N with | eq_refl => w diff --git a/src/Util/ZUtil.v b/src/Util/ZUtil.v index db3d84b2d..1b7cfafdc 100644 --- a/src/Util/ZUtil.v +++ b/src/Util/ZUtil.v @@ -1,6 +1,7 @@ Require Import Coq.ZArith.Zpower Coq.ZArith.Znumtheory Coq.ZArith.ZArith Coq.ZArith.Zdiv. Require Import Coq.omega.Omega Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith. Require Import Crypto.Util.NatUtil. +Require Import Coq.Lists.List. Local Open Scope Z. Lemma gt_lt_symmetry: forall n m, n > m <-> m < n. @@ -168,6 +169,222 @@ Proof. intros; omega. Qed. + +Lemma Z_testbit_low : forall n x i, (0 <= i < n) -> + Z.testbit x i = Z.testbit (Z.land x (Z.ones n)) i. +Proof. + intros. + rewrite Z.land_ones by omega. + symmetry. + apply Z.mod_pow2_bits_low. + omega. +Qed. + + +Lemma Z_testbit_shiftl : forall i, (0 <= i) -> forall a b n, (i < n) -> + Z.testbit (a + Z.shiftl b n) i = Z.testbit a i. +Proof. + intros. + erewrite Z_testbit_low; eauto. + rewrite Z.land_ones, Z.shiftl_mul_pow2 by omega. + rewrite Z.mod_add by (pose proof (Z.pow_pos_nonneg 2 n); omega). + auto using Z.mod_pow2_bits_low. +Qed. + +Lemma Z_mod_div_eq0 : forall a b, 0 < b -> (a mod b) / b = 0. +Proof. + intros. + apply Z.div_small. + auto using Z.mod_pos_bound. +Qed. + +Lemma Z_shiftr_add_land : forall n m a b, (n <= m)%nat -> + Z.shiftr ((Z.land a (Z.ones (Z.of_nat n))) + (Z.shiftl b (Z.of_nat n))) (Z.of_nat m) = Z.shiftr b (Z.of_nat (m - n)). +Proof. + intros. + rewrite Z.land_ones by apply Nat2Z.is_nonneg. + rewrite !Z.shiftr_div_pow2 by apply Nat2Z.is_nonneg. + rewrite Z.shiftl_mul_pow2 by apply Nat2Z.is_nonneg. + rewrite (le_plus_minus n m) at 1 by assumption. + rewrite Nat2Z.inj_add. + rewrite Z.pow_add_r by apply Nat2Z.is_nonneg. + rewrite <- Z.div_div by first + [ pose proof (Z.pow_pos_nonneg 2 (Z.of_nat n)); omega + | apply Z.pow_pos_nonneg; omega ]. + rewrite Z.div_add by (pose proof (Z.pow_pos_nonneg 2 (Z.of_nat n)); omega). + rewrite Z_mod_div_eq0 by (pose proof (Z.pow_pos_nonneg 2 (Z.of_nat n)); omega); auto. +Qed. + +Lemma Z_land_add_land : forall n m a b, (m <= n)%nat -> + Z.land ((Z.land a (Z.ones (Z.of_nat n))) + (Z.shiftl b (Z.of_nat n))) (Z.ones (Z.of_nat m)) = Z.land a (Z.ones (Z.of_nat m)). +Proof. + intros. + rewrite !Z.land_ones by apply Nat2Z.is_nonneg. + rewrite Z.shiftl_mul_pow2 by apply Nat2Z.is_nonneg. + replace (b * 2 ^ Z.of_nat n) with + ((b * 2 ^ Z.of_nat (n - m)) * 2 ^ Z.of_nat m) by + (rewrite (le_plus_minus m n) at 2; try assumption; + rewrite Nat2Z.inj_add, Z.pow_add_r by apply Nat2Z.is_nonneg; ring). + rewrite Z.mod_add by (pose proof (Z.pow_pos_nonneg 2 (Z.of_nat m)); omega). + symmetry. apply Znumtheory.Zmod_div_mod; try (apply Z.pow_pos_nonneg; omega). + rewrite (le_plus_minus m n) by assumption. + rewrite Nat2Z.inj_add, Z.pow_add_r by apply Nat2Z.is_nonneg. + apply Z.divide_factor_l. +Qed. + +Lemma Z_pow_gt0 : forall a, 0 < a -> forall b, 0 <= b -> 0 < a ^ b. +Proof. + intros until 1. + apply natlike_ind; try (simpl; omega). + intros. + rewrite Z.pow_succ_r by assumption. + apply Z.mul_pos_pos; assumption. +Qed. + +Lemma div_pow2succ : forall n x, (0 <= x) -> + n / 2 ^ Z.succ x = Z.div2 (n / 2 ^ x). +Proof. + intros. + rewrite Z.pow_succ_r, Z.mul_comm by auto. + rewrite <- Z.div_div by (try apply Z.pow_nonzero; omega). + rewrite Zdiv2_div. + reflexivity. +Qed. + +Lemma shiftr_succ : forall n x, + Z.shiftr n (Z.succ x) = Z.shiftr (Z.shiftr n x) 1. +Proof. + intros. + rewrite Z.shiftr_shiftr by omega. + reflexivity. +Qed. + + +Definition Z_shiftl_by n a := Z.shiftl a n. + +Lemma Z_shiftl_by_mul_pow2 : forall n a, 0 <= n -> Z.mul (2 ^ n) a = Z_shiftl_by n a. +Proof. + intros. + unfold Z_shiftl_by. + rewrite Z.shiftl_mul_pow2 by assumption. + apply Z.mul_comm. +Qed. + +Lemma map_shiftl : forall n l, 0 <= n -> map (Z.mul (2 ^ n)) l = map (Z_shiftl_by n) l. +Proof. + intros; induction l; auto using Z_shiftl_by_mul_pow2. + simpl. + rewrite IHl. + f_equal. + apply Z_shiftl_by_mul_pow2. + assumption. +Qed. + +Lemma Z_odd_mod : forall a b, (b <> 0)%Z -> + Z.odd (a mod b) = if Z.odd b then xorb (Z.odd a) (Z.odd (a / b)) else Z.odd a. +Proof. + intros. + rewrite Zmod_eq_full by assumption. + rewrite <-Z.add_opp_r, Z.odd_add, Z.odd_opp, Z.odd_mul. + case_eq (Z.odd b); intros; rewrite ?Bool.andb_true_r, ?Bool.andb_false_r; auto using Bool.xorb_false_r. +Qed. + +Lemma mod_same_pow : forall a b c, 0 <= c <= b -> a ^ b mod a ^ c = 0. +Proof. + intros. + replace b with (b - c + c) by ring. + rewrite Z.pow_add_r by omega. + apply Z_mod_mult. +Qed. + + Lemma Z_ones_succ : forall x, (0 <= x) -> + Z.ones (Z.succ x) = 2 ^ x + Z.ones x. + Proof. + unfold Z.ones; intros. + rewrite !Z.shiftl_1_l. + rewrite Z.add_pred_r. + apply Z.succ_inj. + rewrite !Z.succ_pred. + rewrite Z.pow_succ_r; omega. + Qed. + + Lemma Z_div_floor : forall a b c, 0 < b -> a < b * (Z.succ c) -> a / b <= c. + Proof. + intros. + apply Z.lt_succ_r. + apply Z.div_lt_upper_bound; try omega. + Qed. + + Lemma Z_shiftr_1_r_le : forall a b, a <= b -> + Z.shiftr a 1 <= Z.shiftr b 1. + Proof. + intros. + rewrite !Z.shiftr_div_pow2, Z.pow_1_r by omega. + apply Z.div_le_mono; omega. + Qed. + + Lemma Z_ones_pred : forall i, 0 < i -> Z.ones (Z.pred i) = Z.shiftr (Z.ones i) 1. + Proof. + induction i; [ | | pose proof (Pos2Z.neg_is_neg p) ]; try omega. + intros. + unfold Z.ones. + rewrite !Z.shiftl_1_l, Z.shiftr_div_pow2, <-!Z.sub_1_r, Z.pow_1_r, <-!Z.add_opp_r by omega. + replace (2 ^ (Z.pos p)) with (2 ^ (Z.pos p - 1)* 2). + rewrite Z.div_add_l by omega. + reflexivity. + replace 2 with (2 ^ 1) at 2 by auto. + rewrite <-Z.pow_add_r by (pose proof (Pos2Z.is_pos p); omega). + f_equal. omega. + Qed. + + Lemma Z_shiftr_ones' : forall a n, 0 <= a < 2 ^ n -> forall i, (0 <= i) -> + Z.shiftr a i <= Z.ones (n - i) \/ n <= i. + Proof. + intros until 1. + apply natlike_ind. + + unfold Z.ones. + rewrite Z.shiftr_0_r, Z.shiftl_1_l, Z.sub_0_r. + omega. + + intros. + destruct (Z_lt_le_dec x n); try omega. + intuition. + left. + rewrite shiftr_succ. + replace (n - Z.succ x) with (Z.pred (n - x)) by omega. + rewrite Z_ones_pred by omega. + apply Z_shiftr_1_r_le. + assumption. + Qed. + + Lemma Z_shiftr_ones : forall a n i, 0 <= a < 2 ^ n -> (0 <= i) -> (i <= n) -> + Z.shiftr a i <= Z.ones (n - i) . + Proof. + intros a n i G G0 G1. + destruct (Z_le_lt_eq_dec i n G1). + + destruct (Z_shiftr_ones' a n G i G0); omega. + + subst; rewrite Z.sub_diag. + destruct (Z_eq_dec a 0). + - subst; rewrite Z.shiftr_0_l; reflexivity. + - rewrite Z.shiftr_eq_0; try omega; try reflexivity. + apply Z.log2_lt_pow2; omega. + Qed. + + Lemma Z_shiftr_upper_bound : forall a n, 0 <= n -> 0 <= a <= 2 ^ n -> Z.shiftr a n <= 1. + Proof. + intros a ? ? [a_nonneg a_upper_bound]. + apply Z_le_lt_eq_dec in a_upper_bound. + destruct a_upper_bound. + + destruct (Z_eq_dec 0 a). + - subst; rewrite Z.shiftr_0_l; omega. + - rewrite Z.shiftr_eq_0; auto; try omega. + apply Z.log2_lt_pow2; auto; omega. + + subst. + rewrite Z.shiftr_div_pow2 by assumption. + rewrite Z.div_same; try omega. + assert (0 < 2 ^ n) by (apply Z.pow_pos_nonneg; omega). + omega. + Qed. + (* prove that known nonnegative numbers are nonnegative *) Ltac zero_bounds' := repeat match goal with @@ -188,3 +405,14 @@ Ltac zero_bounds' := end; try omega; try prime_bound; auto. Ltac zero_bounds := try omega; try prime_bound; zero_bounds'. + + Lemma Z_ones_nonneg : forall i, (0 <= i) -> 0 <= Z.ones i. + Proof. + apply natlike_ind. + + unfold Z.ones. simpl; omega. + + intros. + rewrite Z_ones_succ by assumption. + zero_bounds. + apply Z.pow_nonneg; omega. + Qed. + diff --git a/to_gallina.md b/to_gallina.md new file mode 100644 index 000000000..1ac5075ef --- /dev/null +++ b/to_gallina.md @@ -0,0 +1,7 @@ +Remaining work needed for Gallina verification code +--------------------------------------------------- ++ efficient GF exponentiation ++ efficient GF inverse ++ make EdDSA point addition use ModularBaseSystem ++ represent scalars (Fl) in ModularBaseSystem (large c) ++ canonical representations of field elements |