diff options
author | Robert Sloan <varomodt@gmail.com> | 2016-06-22 14:06:28 -0400 |
---|---|---|
committer | Robert Sloan <varomodt@gmail.com> | 2016-06-22 14:06:28 -0400 |
commit | 2f44fe53e1a598b524e11cda3dc9ce7a04534247 (patch) | |
tree | 47a77eb4ed8fea3ac5ec99c5bf5ad9131ba44fd9 /src | |
parent | e101fc5dd8783d029d7a4933c7ccca4a67ed3874 (diff) | |
parent | 3d8afe1c9bd905e3a62523e87a2aa7e5d9f5093d (diff) |
Merge with plv/master
Diffstat (limited to 'src')
45 files changed, 4003 insertions, 2737 deletions
diff --git a/src/Algebra.v b/src/Algebra.v new file mode 100644 index 000000000..6dc188e2c --- /dev/null +++ b/src/Algebra.v @@ -0,0 +1,598 @@ +Require Import Coq.Classes.Morphisms. Require Coq.Setoids.Setoid. +Require Import Crypto.Util.Tactics Crypto.Tactics.Nsatz. +Local Close Scope nat_scope. Local Close Scope type_scope. Local Close Scope core_scope. + +Section Algebra. + Context {T:Type} {eq:T->T->Prop}. + Local Infix "=" := eq : type_scope. Local Notation "a <> b" := (not (a = b)) : type_scope. + + Class is_eq_dec := { eq_dec : forall x y : T, {x=y} + {x<>y} }. + + Section SingleOperation. + Context {op:T->T->T}. + + Class is_associative := { associative : forall x y z, op x (op y z) = op (op x y) z }. + + Context {id:T}. + + Class is_left_identity := { left_identity : forall x, op id x = x }. + Class is_right_identity := { right_identity : forall x, op x id = x }. + + Class monoid := + { + monoid_is_associative : is_associative; + monoid_is_left_identity : is_left_identity; + monoid_is_right_identity : is_right_identity; + + monoid_op_Proper: Proper (respectful eq (respectful eq eq)) op; + monoid_Equivalence : Equivalence eq; + monoid_is_eq_dec : is_eq_dec + }. + Global Existing Instance monoid_is_associative. + Global Existing Instance monoid_is_left_identity. + Global Existing Instance monoid_is_right_identity. + Global Existing Instance monoid_Equivalence. + Global Existing Instance monoid_is_eq_dec. + Global Existing Instance monoid_op_Proper. + + Context {inv:T->T}. + Class is_left_inverse := { left_inverse : forall x, op (inv x) x = id }. + Class is_right_inverse := { right_inverse : forall x, op x (inv x) = id }. + + Class group := + { + group_monoid : monoid; + group_is_left_inverse : is_left_inverse; + group_is_right_inverse : is_right_inverse; + + group_inv_Proper: Proper (respectful eq eq) inv + }. + Global Existing Instance group_monoid. + Global Existing Instance group_is_left_inverse. + Global Existing Instance group_is_right_inverse. + Global Existing Instance group_inv_Proper. + + Class is_commutative := { commutative : forall x y, op x y = op y x }. + + Record abelian_group := + { + abelian_group_group : group; + abelian_group_is_commutative : is_commutative + }. + Existing Class abelian_group. + Global Existing Instance abelian_group_group. + Global Existing Instance abelian_group_is_commutative. + End SingleOperation. + + Section AddMul. + Context {zero one:T}. Local Notation "0" := zero. Local Notation "1" := one. + Context {opp:T->T}. Local Notation "- x" := (opp x). + Context {add:T->T->T} {sub:T->T->T} {mul:T->T->T}. + Local Infix "+" := add. Local Infix "-" := sub. Local Infix "*" := mul. + + Class is_left_distributive := { left_distributive : forall a b c, a * (b + c) = a * b + a * c }. + Class is_right_distributive := { right_distributive : forall a b c, (b + c) * a = b * a + c * a }. + + + Class ring := + { + ring_abelian_group_add : abelian_group (op:=add) (id:=zero) (inv:=opp); + ring_monoid_mul : monoid (op:=mul) (id:=one); + ring_is_left_distributive : is_left_distributive; + ring_is_right_distributive : is_right_distributive; + + ring_sub_definition : forall x y, x - y = x + opp y; + + ring_mul_Proper : Proper (respectful eq (respectful eq eq)) mul; + ring_sub_Proper : Proper(respectful eq (respectful eq eq)) sub + }. + Global Existing Instance ring_abelian_group_add. + Global Existing Instance ring_monoid_mul. + Global Existing Instance ring_is_left_distributive. + Global Existing Instance ring_is_right_distributive. + Global Existing Instance ring_mul_Proper. + Global Existing Instance ring_sub_Proper. + + Class commutative_ring := + { + commutative_ring_ring : ring; + commutative_ring_is_commutative : is_commutative (op:=mul) + }. + Global Existing Instance commutative_ring_ring. + Global Existing Instance commutative_ring_is_commutative. + + Class is_mul_nonzero_nonzero := { mul_nonzero_nonzero : forall x y, x<>0 -> y<>0 -> x*y<>0 }. + + Class is_zero_neq_one := { zero_neq_one : zero <> one }. + + Class integral_domain := + { + integral_domain_commutative_ring : commutative_ring; + integral_domain_is_mul_nonzero_nonzero : is_mul_nonzero_nonzero; + integral_domain_is_zero_neq_one : is_zero_neq_one + }. + Global Existing Instance integral_domain_commutative_ring. + Global Existing Instance integral_domain_is_mul_nonzero_nonzero. + Global Existing Instance integral_domain_is_zero_neq_one. + + Context {inv:T->T} {div:T->T->T}. + Class is_left_multiplicative_inverse := { left_multiplicative_inverse : forall x, x<>0 -> (inv x) * x = 1 }. + + Class field := + { + field_commutative_ring : commutative_ring; + field_is_left_multiplicative_inverse : is_left_multiplicative_inverse; + field_domain_is_zero_neq_one : is_zero_neq_one; + + field_div_definition : forall x y , div x y = x * inv y; + + field_inv_Proper : Proper (respectful eq eq) inv; + field_div_Proper : Proper (respectful eq (respectful eq eq)) div + }. + Global Existing Instance field_commutative_ring. + Global Existing Instance field_is_left_multiplicative_inverse. + Global Existing Instance field_domain_is_zero_neq_one. + Global Existing Instance field_inv_Proper. + Global Existing Instance field_div_Proper. + End AddMul. +End Algebra. + + +Module Monoid. + Section Monoid. + Context {T eq op id} {monoid:@monoid T eq op id}. + Local Infix "=" := eq : type_scope. Local Notation "a <> b" := (not (a = b)) : type_scope. + Local Infix "*" := op. + Local Infix "=" := eq : eq_scope. + Local Open Scope eq_scope. + + Lemma cancel_right z iz (Hinv:op z iz = id) : + forall x y, x * z = y * z <-> x = y. + Proof. + split; intros. + { assert (op (op x z) iz = op (op y z) iz) as Hcut by (f_equiv; assumption). + rewrite <-associative in Hcut. + rewrite <-!associative, !Hinv, !right_identity in Hcut; exact Hcut. } + { f_equiv; assumption. } + Qed. + + Lemma cancel_left z iz (Hinv:op iz z = id) : + forall x y, z * x = z * y <-> x = y. + Proof. + split; intros. + { assert (op iz (op z x) = op iz (op z y)) as Hcut by (f_equiv; assumption). + rewrite !associative, !Hinv, !left_identity in Hcut; exact Hcut. } + { f_equiv; assumption. } + Qed. + + Lemma inv_inv x ix iix : ix*x = id -> iix*ix = id -> iix = x. + Proof. + intros Hi Hii. + assert (H:op iix id = op iix (op ix x)) by (rewrite Hi; reflexivity). + rewrite associative, Hii, left_identity, right_identity in H; exact H. + Qed. + + Lemma inv_op x y ix iy : ix*x = id -> iy*y = id -> (iy*ix)*(x*y) =id. + Proof. + intros Hx Hy. + cut (iy * (ix*x) * y = id); try intro H. + { rewrite <-!associative; rewrite <-!associative in H; exact H. } + rewrite Hx, right_identity, Hy. reflexivity. + Qed. + + End Monoid. +End Monoid. + +Section ZeroNeqOne. + Context {T eq zero one} `{@is_zero_neq_one T eq zero one} `{Equivalence T eq}. + + Lemma one_neq_zero : not (eq one zero). + Proof. + intro HH; symmetry in HH. auto using zero_neq_one. + Qed. +End ZeroNeqOne. + +Module Group. + Section BasicProperties. + Context {T eq op id inv} `{@group T eq op id inv}. + Local Infix "=" := eq : type_scope. Local Notation "a <> b" := (not (a = b)) : type_scope. + Local Infix "*" := op. + Local Infix "=" := eq : eq_scope. + Local Open Scope eq_scope. + + Lemma cancel_left : forall z x y, z*x = z*y <-> x = y. + Proof. eauto using Monoid.cancel_left, left_inverse. Qed. + Lemma cancel_right : forall z x y, x*z = y*z <-> x = y. + Proof. eauto using Monoid.cancel_right, right_inverse. Qed. + Lemma inv_inv x : inv(inv(x)) = x. + Proof. eauto using Monoid.inv_inv, left_inverse. Qed. + Lemma inv_op x y : (inv y*inv x)*(x*y) =id. + Proof. eauto using Monoid.inv_op, left_inverse. Qed. + + Lemma inv_unique x ix : ix * x = id -> ix = inv x. + Proof. + intro Hix. + cut (ix*x*inv x = inv x). + - rewrite <-associative, right_inverse, right_identity; trivial. + - rewrite Hix, left_identity; reflexivity. + Qed. + + Lemma inv_id : inv id = id. + Proof. symmetry. eapply inv_unique, left_identity. Qed. + + Lemma inv_nonzero_nonzero : forall x, x <> id -> inv x <> id. + Proof. + intros ? Hx Ho. + assert (Hxo: x * inv x = id) by (rewrite right_inverse; reflexivity). + rewrite Ho, right_identity in Hxo. intuition. + Qed. + + Section ZeroNeqOne. + Context {one} `{is_zero_neq_one T eq id one}. + Lemma opp_one_neq_zero : inv one <> id. + Proof. apply inv_nonzero_nonzero, one_neq_zero. Qed. + Lemma zero_neq_opp_one : id <> inv one. + Proof. intro Hx. symmetry in Hx. eauto using opp_one_neq_zero. Qed. + End ZeroNeqOne. + End BasicProperties. + + Section Homomorphism. + Context {G EQ OP ID INV} {groupG:@group G EQ OP ID INV}. + Context {H eq op id inv} {groupH:@group H eq op id inv}. + Context {phi:G->H}. + Local Infix "=" := eq. Local Infix "=" := eq : type_scope. + + Class is_homomorphism := + { + homomorphism : forall a b, phi (OP a b) = op (phi a) (phi b); + + is_homomorphism_phi_proper : Proper (respectful EQ eq) phi + }. + Global Existing Instance is_homomorphism_phi_proper. + Context `{is_homomorphism}. + + Lemma homomorphism_id : phi ID = id. + Proof. + assert (Hii: op (phi ID) (phi ID) = op (phi ID) id) by + (rewrite <- homomorphism, left_identity, right_identity; reflexivity). + rewrite cancel_left in Hii; exact Hii. + Qed. + + Lemma homomorphism_inv : forall x, phi (INV x) = inv (phi x). + Proof. + Admitted. + End Homomorphism. +End Group. + +Require Coq.nsatz.Nsatz. + +Ltac dropAlgebraSyntax := + cbv beta delta [ + Algebra_syntax.zero + Algebra_syntax.one + Algebra_syntax.addition + Algebra_syntax.multiplication + Algebra_syntax.subtraction + Algebra_syntax.opposite + Algebra_syntax.equality + Algebra_syntax.bracket + Algebra_syntax.power + ] in *. + +Ltac dropRingSyntax := + dropAlgebraSyntax; + cbv beta delta [ + Ncring.zero_notation + Ncring.one_notation + Ncring.add_notation + Ncring.mul_notation + Ncring.sub_notation + Ncring.opp_notation + Ncring.eq_notation + ] in *. + +Module Ring. + Section Ring. + Context {T eq zero one opp add sub mul} `{@ring T eq zero one opp add sub mul}. + Local Infix "=" := eq : type_scope. Local Notation "a <> b" := (not (a = b)) : type_scope. + Local Notation "0" := zero. Local Notation "1" := one. + Local Infix "+" := add. Local Infix "-" := sub. Local Infix "*" := mul. + + Lemma mul_0_r : forall x, 0 * x = 0. + Proof. + intros. + assert (0*x = 0*x) as Hx by reflexivity. + rewrite <-(right_identity 0), right_distributive in Hx at 1. + assert (0*x + 0*x - 0*x = 0*x - 0*x) as Hxx by (f_equiv; exact Hx). + rewrite !ring_sub_definition, <-associative, right_inverse, right_identity in Hxx; exact Hxx. + Qed. + + Lemma mul_0_l : forall x, x * 0 = 0. + Proof. + intros. + assert (x*0 = x*0) as Hx by reflexivity. + rewrite <-(left_identity 0), left_distributive in Hx at 1. + assert (opp (x*0) + (x*0 + x*0) = opp (x*0) + x*0) as Hxx by (f_equiv; exact Hx). + rewrite associative, left_inverse, left_identity in Hxx; exact Hxx. + Qed. + + Lemma sub_0_l x : 0 - x = opp x. + Proof. rewrite ring_sub_definition. rewrite left_identity. reflexivity. Qed. + + Lemma mul_opp_r x y : x * opp y = opp (x * y). + Proof. + assert (Ho:x*(opp y) + x*y = 0) + by (rewrite <-left_distributive, left_inverse, mul_0_l; reflexivity). + rewrite <-(left_identity (opp (x*y))), <-Ho; clear Ho. + rewrite <-!associative, right_inverse, right_identity; reflexivity. + Qed. + + Lemma mul_opp_l x y : opp x * y = opp (x * y). + Proof. + assert (Ho:opp x*y + x*y = 0) + by (rewrite <-right_distributive, left_inverse, mul_0_r; reflexivity). + rewrite <-(left_identity (opp (x*y))), <-Ho; clear Ho. + rewrite <-!associative, right_inverse, right_identity; reflexivity. + Qed. + + Definition opp_nonzero_nonzero : forall x, x <> 0 -> opp x <> 0 := Group.inv_nonzero_nonzero. + + Global Instance is_left_distributive_sub : is_left_distributive (eq:=eq)(add:=sub)(mul:=mul). + Proof. + split; intros. rewrite !ring_sub_definition, left_distributive. + eapply Group.cancel_left, mul_opp_r. + Qed. + + Global Instance is_right_distributive_sub : is_right_distributive (eq:=eq)(add:=sub)(mul:=mul). + Proof. + split; intros. rewrite !ring_sub_definition, right_distributive. + eapply Group.cancel_left, mul_opp_l. + Qed. + + Global Instance Ncring_Ring_ops : @Ncring.Ring_ops T zero one add mul sub opp eq. + Global Instance Ncring_Ring : @Ncring.Ring T zero one add mul sub opp eq Ncring_Ring_ops. + Proof. + split; dropRingSyntax; eauto using left_identity, right_identity, commutative, associative, right_inverse, left_distributive, right_distributive, ring_sub_definition with core typeclass_instances. + - (* TODO: why does [eauto using @left_identity with typeclass_instances] not work? *) + eapply @left_identity; eauto with typeclass_instances. + - eapply @right_identity; eauto with typeclass_instances. + - eapply associative. + - intros; eapply right_distributive. + - intros; eapply left_distributive. + Qed. + End Ring. + + Section Homomorphism. + Context {R EQ ZERO ONE OPP ADD SUB MUL} `{@ring R EQ ZERO ONE OPP ADD SUB MUL}. + Context {S eq zero one opp add sub mul} `{@ring S eq zero one opp add sub mul}. + Context {phi:R->S}. + Local Infix "=" := eq. Local Infix "=" := eq : type_scope. + + Class is_homomorphism := + { + homomorphism_is_homomorphism : Group.is_homomorphism (phi:=phi) (OP:=ADD) (op:=add) (EQ:=EQ) (eq:=eq); + homomorphism_mul : forall x y, phi (MUL x y) = mul (phi x) (phi y); + homomorphism_one : phi ONE = one + }. + Global Existing Instance homomorphism_is_homomorphism. + + Context `{is_homomorphism}. + + Lemma homomorphism_add : forall x y, phi (ADD x y) = add (phi x) (phi y). + Proof. apply Group.homomorphism. Qed. + + Definition homomorphism_opp : forall x, phi (OPP x) = opp (phi x) := + (Group.homomorphism_inv (INV:=OPP) (inv:=opp)). + + Lemma homomorphism_sub : forall x y, phi (SUB x y) = sub (phi x) (phi y). + Proof. + intros. + rewrite !ring_sub_definition, Group.homomorphism, homomorphism_opp. reflexivity. + Qed. + + End Homomorphism. + + Section TacticSupportCommutative. + Context {T eq zero one opp add sub mul} `{@commutative_ring T eq zero one opp add sub mul}. + + Global Instance Cring_Cring_commutative_ring : + @Cring.Cring T zero one add mul sub opp eq Ring.Ncring_Ring_ops Ring.Ncring_Ring. + Proof. unfold Cring.Cring; intros; dropRingSyntax. eapply commutative. Qed. + + Lemma ring_theory_for_stdlib_tactic : Ring_theory.ring_theory zero one add mul sub opp eq. + Proof. + constructor; intros. (* TODO(automation): make [auto] do this? *) + - apply left_identity. + - apply commutative. + - apply associative. + - apply left_identity. + - apply commutative. + - apply associative. + - apply right_distributive. + - apply ring_sub_definition. + - apply right_inverse. + Qed. + End TacticSupportCommutative. +End Ring. + +Module IntegralDomain. + Section IntegralDomain. + Context {T eq zero one opp add sub mul} `{@integral_domain T eq zero one opp add sub mul}. + + Lemma mul_nonzero_nonzero_cases (x y : T) + : eq (mul x y) zero -> eq x zero \/ eq y zero. + Proof. + pose proof mul_nonzero_nonzero x y. + destruct (eq_dec x zero); destruct (eq_dec y zero); intuition. + Qed. + + Global Instance Integral_domain : + @Integral_domain.Integral_domain T zero one add mul sub opp eq Ring.Ncring_Ring_ops + Ring.Ncring_Ring Ring.Cring_Cring_commutative_ring. + Proof. + split; dropRingSyntax. + - auto using mul_nonzero_nonzero_cases. + - intro bad; symmetry in bad; auto using zero_neq_one. + Qed. + End IntegralDomain. +End IntegralDomain. + +Module Field. + Section Field. + Context {T eq zero one opp add mul sub inv div} `{@field T eq zero one opp add sub mul inv div}. + Local Infix "=" := eq : type_scope. Local Notation "a <> b" := (not (a = b)) : type_scope. + Local Notation "0" := zero. Local Notation "1" := one. + Local Infix "+" := add. Local Infix "*" := mul. + + Global Instance is_mul_nonzero_nonzero : @is_mul_nonzero_nonzero T eq 0 mul. + Proof. + constructor. intros x y Hx Hy Hxy. + assert (0 = (inv y * (inv x * x)) * y) as H00. (rewrite <-!associative, Hxy, !Ring.mul_0_l; reflexivity). + rewrite left_multiplicative_inverse in H00 by assumption. + rewrite right_identity in H00. + rewrite left_multiplicative_inverse in H00 by assumption. + auto using zero_neq_one. + Qed. + + Global Instance integral_domain : @integral_domain T eq zero one opp add sub mul. + Proof. + split; auto using field_commutative_ring, field_domain_is_zero_neq_one, is_mul_nonzero_nonzero. + Qed. + + Require Coq.setoid_ring.Field_theory. + Lemma field_theory_for_stdlib_tactic : Field_theory.field_theory 0 1 add mul sub opp div inv eq. + Proof. + constructor. + { apply Ring.ring_theory_for_stdlib_tactic. } + { intro H01. symmetry in H01. auto using zero_neq_one. } + { apply field_div_definition. } + { apply left_multiplicative_inverse. } + Qed. + + End Field. + + Section Homomorphism. + Context {F EQ ZERO ONE OPP ADD MUL SUB INV DIV} `{@field F EQ ZERO ONE OPP ADD SUB MUL INV DIV}. + Context {K eq zero one opp add mul sub inv div} `{@field K eq zero one opp add sub mul inv div}. + Context {phi:F->K}. + Local Infix "=" := eq. Local Infix "=" := eq : type_scope. + Context `{@Ring.is_homomorphism F EQ ONE ADD MUL K eq one add mul phi}. + + Lemma homomorphism_multiplicative_inverse : forall x, phi (INV x) = inv (phi x). Admitted. + + Lemma homomorphism_div : forall x y, phi (DIV x y) = div (phi x) (phi y). + Proof. + intros. rewrite !field_div_definition. + rewrite Ring.homomorphism_mul, homomorphism_multiplicative_inverse. reflexivity. + Qed. + End Homomorphism. +End Field. + +(*** Tactics for manipulating field equations *) +Require Import Coq.setoid_ring.Field_tac. + +Ltac guess_field := + match goal with + | |- ?eq _ _ => constr:(_:field (eq:=eq)) + | |- not (?eq _ _) => constr:(_:field (eq:=eq)) + | [H: ?eq _ _ |- _ ] => constr:(_:field (eq:=eq)) + | [H: not (?eq _ _) |- _] => constr:(_:field (eq:=eq)) + end. + +Ltac common_denominator := + let fld := guess_field in + lazymatch type of fld with + field (div:=?div) => + lazymatch goal with + | |- appcontext[div] => field_simplify_eq + | |- _ => idtac + end + end. + +Ltac common_denominator_in H := + let fld := guess_field in + lazymatch type of fld with + field (div:=?div) => + lazymatch type of H with + | appcontext[div] => field_simplify_eq in H + | _ => idtac + end + end. + +Ltac common_denominator_all := + common_denominator; + repeat match goal with [H: _ |- _ _ _ ] => progress common_denominator_in H end. + +Inductive field_simplify_done {T} : T -> Type := + Field_simplify_done : forall H, field_simplify_done H. + +Ltac field_simplify_eq_hyps := + repeat match goal with + [ H: _ |- _ ] => + match goal with + | [ Ha : field_simplify_done H |- _ ] => fail + | _ => idtac + end; + field_simplify_eq in H; + unique pose proof (Field_simplify_done H) + end; + repeat match goal with [ H: field_simplify_done _ |- _] => clear H end. + +Ltac field_simplify_eq_all := field_simplify_eq_hyps; try field_simplify_eq. + + +(*** Polynomial equations over fields *) + +Ltac neq01 := + try solve + [apply zero_neq_one + |apply Group.zero_neq_opp_one + |apply one_neq_zero + |apply Group.opp_one_neq_zero]. + +Ltac field_algebra := + intros; + common_denominator_all; + try (nsatz; dropRingSyntax); + repeat (apply conj); + try solve + [neq01 + |trivial + |apply Ring.opp_nonzero_nonzero;trivial]. + +Section Example. + Context {F zero one opp add sub mul inv div} `{F_field:field F eq zero one opp add sub mul inv div}. + Local Infix "+" := add. Local Infix "*" := mul. Local Infix "-" := sub. Local Infix "/" := div. + Local Notation "0" := zero. Local Notation "1" := one. + + Add Field _ExampleField : (Field.field_theory_for_stdlib_tactic (T:=F)). + + Example _example_nsatz x y : 1+1 <> 0 -> x + y = 0 -> x - y = 0 -> x = 0. + Proof. field_algebra. Qed. + + Example _example_field_nsatz x y z : y <> 0 -> x/y = z -> z*y + y = x + y. + Proof. intros; subst; field_algebra. Qed. + + Example _example_nonzero_nsatz_contradict x y : x * y = 1 -> not (x = 0). + Proof. intros. intro. nsatz_contradict. Qed. +End Example. + +Section Z. + Require Import ZArith. + Global Instance ring_Z : @ring Z Logic.eq 0%Z 1%Z Z.opp Z.add Z.sub Z.mul. + Proof. repeat split; auto using Z.eq_dec with zarith typeclass_instances. Qed. + + Global Instance commutative_ring_Z : @commutative_ring Z Logic.eq 0%Z 1%Z Z.opp Z.add Z.sub Z.mul. + Proof. eauto using @commutative_ring, @is_commutative, ring_Z with zarith. Qed. + + Global Instance integral_domain_Z : @integral_domain Z Logic.eq 0%Z 1%Z Z.opp Z.add Z.sub Z.mul. + Proof. + split. + { apply commutative_ring_Z. } + { constructor. intros. apply Z.neq_mul_0; auto. } + { constructor. discriminate. } + Qed. + + Example _example_nonzero_nsatz_contradict_Z x y : Z.mul x y = (Zpos xH) -> not (x = Z0). + Proof. intros. intro. nsatz_contradict. Qed. +End Z. diff --git a/src/BaseSystem.v b/src/BaseSystem.v index e6ad55f18..c07aad759 100644 --- a/src/BaseSystem.v +++ b/src/BaseSystem.v @@ -1,7 +1,8 @@ Require Import Coq.Lists.List. -Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. Require Import Coq.ZArith.ZArith Coq.ZArith.Zdiv. Require Import Coq.omega.Omega Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith. +Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. +Import Nat. Local Open Scope Z. @@ -39,7 +40,7 @@ Section BaseSystem. Proof. unfold decode'; intros; f_equal; apply combine_truncate_l. Qed. - + Fixpoint add (us vs:digits) : digits := match us,vs with | u::us', v::vs' => u+v :: add us' vs' @@ -58,26 +59,26 @@ Section BaseSystem. | nil, v::vs' => (0-v)::sub nil vs' end. - Definition crosscoef i j : Z := + Definition crosscoef i j : Z := let b := nth_default 0 base in (b(i) * b(j)) / b(i+j)%nat. Hint Unfold crosscoef. Fixpoint zeros n := match n with O => nil | S n' => 0::zeros n' end. - + (* mul' is multiplication with the SECOND ARGUMENT REVERSED and OUTPUT REVERSED *) - Fixpoint mul_bi' (i:nat) (vsr:digits) := + Fixpoint mul_bi' (i:nat) (vsr:digits) := match vsr with | v::vsr' => v * crosscoef i (length vsr') :: mul_bi' i vsr' | nil => nil end. Definition mul_bi (i:nat) (vs:digits) : digits := zeros i ++ rev (mul_bi' i (rev vs)). - + (* mul' is multiplication with the FIRST ARGUMENT REVERSED *) Fixpoint mul' (usr vs:digits) : digits := match usr with - | u::usr' => + | u::usr' => mul_each u (mul_bi (length usr') vs) .+ mul' usr' vs | _ => nil end. @@ -87,7 +88,7 @@ End BaseSystem. (* Example : polynomial base system *) Section PolynomialBaseCoefs. - Context (b1 : positive) (baseLength : nat) (baseLengthNonzero : NPeano.ltb 0 baseLength = true). + Context (b1 : positive) (baseLength : nat) (baseLengthNonzero : ltb 0 baseLength = true). (** PolynomialBaseCoefs generates base vectors for [BaseSystem]. *) Definition bi i := (Zpos b1)^(Z.of_nat i). Definition poly_base := map bi (seq 0 baseLength). @@ -96,7 +97,7 @@ Section PolynomialBaseCoefs. unfold poly_base, bi, nth_default. case_eq baseLength; intros. { assert ((0 < baseLength)%nat) by - (rewrite <-NPeano.ltb_lt; apply baseLengthNonzero). + (rewrite <-ltb_lt; apply baseLengthNonzero). subst; omega. } auto. @@ -119,7 +120,7 @@ Section PolynomialBaseCoefs. Qed. Lemma poly_base_succ : - forall i, ((S i) < length poly_base)%nat -> + forall i, ((S i) < length poly_base)%nat -> let b := nth_default 0 poly_base in let r := (b (S i) / b i) in b (S i) = r * b i. @@ -127,7 +128,7 @@ Section PolynomialBaseCoefs. intros; subst b; subst r. repeat rewrite poly_base_defn in * by omega. unfold bi. - replace (Z.pos b1 ^ Z.of_nat (S i)) + replace (Z.pos b1 ^ Z.of_nat (S i)) with (Z.pos b1 * (Z.pos b1 ^ Z.of_nat i)) by (rewrite Nat2Z.inj_succ; rewrite <- Z.pow_succ_r; intuition). replace (Z.pos b1 * Z.pos b1 ^ Z.of_nat i / Z.pos b1 ^ Z.of_nat i) @@ -166,7 +167,7 @@ Import ListNotations. Section BaseSystemExample. Definition baseLength := 32%nat. - Lemma baseLengthNonzero : NPeano.ltb 0 baseLength = true. + Lemma baseLengthNonzero : ltb 0 baseLength = true. compute; reflexivity. Qed. Definition base2 := poly_base 2 baseLength. diff --git a/src/BaseSystemProofs.v b/src/BaseSystemProofs.v index ab56cb711..4414877b4 100644 --- a/src/BaseSystemProofs.v +++ b/src/BaseSystemProofs.v @@ -18,7 +18,7 @@ Section BaseSystemProofs. Qed. Lemma decode'_splice : forall xs ys bs, - decode' bs (xs ++ ys) = + decode' bs (xs ++ ys) = decode' (firstn (length xs) bs) xs + decode' (skipn (length xs) bs) ys. Proof. unfold decode'. @@ -83,7 +83,7 @@ Section BaseSystemProofs. unfold decode, encode; destruct z; boring. Qed. - Lemma mul_each_base : forall us bs c, + Lemma mul_each_base : forall us bs c, decode' bs (mul_each c us) = decode' (mul_each c bs) us. Proof. induction us; destruct bs; boring; ring. @@ -99,8 +99,8 @@ Section BaseSystemProofs. induction us; destruct low; boring. Qed. - Lemma base_mul_app : forall low c us, - decode' (low ++ mul_each c low) us = decode' low (firstn (length low) us) + + Lemma base_mul_app : forall low c us, + decode' (low ++ mul_each c low) us = decode' low (firstn (length low) us) + c * decode' low (skipn (length low) us). Proof. intros. @@ -118,7 +118,7 @@ Section BaseSystemProofs. Qed. Hint Rewrite length_zeros. - Lemma app_zeros_zeros : forall n m, zeros n ++ zeros m = zeros (n + m). + Lemma app_zeros_zeros : forall n m, zeros n ++ zeros m = zeros (n + m)%nat. Proof. induction n; boring. Qed. @@ -130,6 +130,18 @@ Section BaseSystemProofs. Qed. Hint Rewrite zeros_app0. + Lemma nth_default_zeros : forall n i, nth_default 0 (BaseSystem.zeros n) i = 0. + Proof. + induction n; intros; [ cbv [BaseSystem.zeros]; apply nth_default_nil | ]. + rewrite <-zeros_app0, nth_default_app. + rewrite length_zeros. + destruct (lt_dec i n); auto. + destruct (eq_nat_dec i n); subst. + + rewrite Nat.sub_diag; apply nth_default_cons. + + apply nth_default_out_of_bounds. + cbv [length]; omega. + Qed. + Lemma rev_zeros : forall n, rev (zeros n) = zeros n. Proof. induction n; boring. @@ -225,7 +237,7 @@ Section BaseSystemProofs. Lemma zeros_plus_zeros : forall n, zeros n = zeros n .+ zeros n. induction n; auto. - simpl; f_equal; auto. + simpl; f_equal; auto. Qed. Lemma mul_bi'_n_nil : forall n, mul_bi' base n nil = nil. @@ -243,13 +255,13 @@ Section BaseSystemProofs. induction us; auto. Qed. Hint Rewrite add_nil_r. - + Lemma add_first_terms : forall us vs a b, (a :: us) .+ (b :: vs) = (a + b) :: (us .+ vs). auto. Qed. Hint Rewrite add_first_terms. - + Lemma mul_bi'_cons : forall n x us, mul_bi' base n (x :: us) = x * crosscoef base n (length us) :: mul_bi' base n us. Proof. @@ -266,7 +278,7 @@ Section BaseSystemProofs. Hint Rewrite app_nil_l. Hint Rewrite app_nil_r. - Lemma add_snoc_same_length : forall l us vs a b, + Lemma add_snoc_same_length : forall l us vs a b, (length us = l) -> (length vs = l) -> (us ++ a :: nil) .+ (vs ++ b :: nil) = (us .+ vs) ++ (a + b) :: nil. Proof. @@ -276,7 +288,7 @@ Section BaseSystemProofs. Lemma mul_bi'_add : forall us n vs l (Hlus: length us = l) (Hlvs: length vs = l), - mul_bi' base n (rev (us .+ vs)) = + mul_bi' base n (rev (us .+ vs)) = mul_bi' base n (rev us) .+ mul_bi' base n (rev vs). Proof. (* TODO(adamc): please help prettify this *) @@ -310,7 +322,7 @@ Section BaseSystemProofs. Proof. induction n; boring. Qed. - + Lemma rev_add_rev : forall us vs l, (length us = l) -> (length vs = l) -> (rev us) .+ (rev vs) = rev (us .+ vs). Proof. @@ -352,7 +364,7 @@ Section BaseSystemProofs. Hint Rewrite minus_diag. Lemma add_trailing_zeros : forall us vs, (length us >= length vs)%nat -> - us .+ vs = us .+ (vs ++ (zeros (length us - length vs))). + us .+ vs = us .+ (vs ++ (zeros (length us - length vs)%nat)). Proof. induction us, vs; boring; f_equal; boring. Qed. @@ -377,8 +389,8 @@ Section BaseSystemProofs. induction us; boring. Qed. - Lemma sub_length_le_max : forall us vs, - (length (sub us vs) <= max (length us) (length vs))%nat. + Lemma sub_length : forall us vs, + (length (sub us vs) = max (length us) (length vs))%nat. Proof. induction us, vs; boring. rewrite sub_nil_length; auto. @@ -450,7 +462,7 @@ Section BaseSystemProofs. (* mul' is multiplication with the FIRST ARGUMENT REVERSED *) Fixpoint mul' (usr vs:digits) : digits := match usr with - | u::usr' => + | u::usr' => mul_each u (mul_bi base (length usr') vs) .+ mul' usr' vs | _ => nil end. @@ -499,5 +511,37 @@ Section BaseSystemProofs. rewrite rev_length; omega. Qed. + Lemma add_length_exact : forall us vs, + length (us .+ vs) = max (length us) (length vs). + Proof. + induction us; destruct vs; boring. + Qed. + + Lemma mul'_length_exact: forall us vs, + (length us <= length vs)%nat -> us <> nil -> + (length (mul' us vs) = pred (length us + length vs))%nat. + Proof. + induction us; intros; try solve [boring]. + unfold mul'; fold mul'. + unfold mul_each. + rewrite add_length_exact, map_length, mul_bi_length, length_cons. + destruct us. + + rewrite Max.max_0_r. simpl; omega. + + rewrite Max.max_l; [ omega | ]. + rewrite IHus by ( congruence || simpl in *; omega). + omega. + Qed. + + Lemma mul_length_exact: forall us vs, + (length us <= length vs)%nat -> us <> nil -> + (length (mul us vs) = pred (length us + length vs))%nat. + Proof. + intros; unfold mul. + rewrite mul'_length_exact; rewrite ?rev_length; try omega. + intro rev_nil. + match goal with H : us <> nil |- _ => apply H end. + apply length0_nil; rewrite <-rev_length, rev_nil. + reflexivity. + Qed. End BaseSystemProofs. diff --git a/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v b/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v index f70479c3a..683addd5d 100644 --- a/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v +++ b/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v @@ -1,47 +1,33 @@ Require Export Crypto.Spec.CompleteEdwardsCurve. -Require Import Crypto.ModularArithmetic.FField. -Require Import Crypto.ModularArithmetic.FNsatz. +Require Import Crypto.Algebra Crypto.Tactics.Nsatz. Require Import Crypto.CompleteEdwardsCurve.Pre. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. Require Import Coq.Logic.Eqdep_dec. Require Import Crypto.Tactics.VerdiTactics. +Require Import Coq.Classes.Morphisms. +Require Import Relation_Definitions. +Require Import Crypto.Util.Tuple. Module E. + Import Group Ring Field CompleteEdwardsCurve.E. Section CompleteEdwardsCurveTheorems. - Context {prm:TwistedEdwardsParams}. - Local Opaque q a d prime_q two_lt_q nonzero_a square_a nonsquare_d. (* [F_field] calls [compute] *) - Existing Instance prime_q. - - Add Field Ffield_p' : (@Ffield_theory q _) - (morphism (@Fring_morph q), - preprocess [Fpreprocess], - postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption], - constants [Fconstant], - div (@Fmorph_div_theory q), - power_tac (@Fpower_theory q) [Fexp_tac]). - - Add Field Ffield_notConstant : (OpaqueFieldTheory q) - (constants [notConstant]). - - Ltac clear_prm := - generalize dependent a; intro a; intros; - generalize dependent d; intro d; intros; - generalize dependent prime_q; intro prime_q; intros; - generalize dependent q; intro q; intros; - clear prm. - - Lemma point_eq : forall xy1 xy2 pf1 pf2, - xy1 = xy2 -> exist E.onCurve xy1 pf1 = exist E.onCurve xy2 pf2. - Proof. - destruct xy1, xy2; intros; find_injection; intros; subst. apply f_equal. - apply UIP_dec, F_eq_dec. (* this is a hack. We actually don't care about the equality of the proofs. However, we *can* prove it, and knowing it lets us use the universal equality instead of a type-specific equivalence, which makes many things nicer. *) - Qed. Hint Resolve point_eq. - - Definition point_eqb (p1 p2:E.point) : bool := andb - (F_eqb (fst (proj1_sig p1)) (fst (proj1_sig p2))) - (F_eqb (snd (proj1_sig p1)) (snd (proj1_sig p2))). - + Context {F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv a d} + {field:@field F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv} + {prm:@twisted_edwards_params F Feq Fzero Fone Fadd Fmul a d}. + Local Infix "=" := Feq : type_scope. Local Notation "a <> b" := (not (a = b)) : type_scope. + Local Notation "0" := Fzero. Local Notation "1" := Fone. + Local Infix "+" := Fadd. Local Infix "*" := Fmul. + Local Infix "-" := Fsub. Local Infix "/" := Fdiv. + Local Notation "x ^2" := (x*x) (at level 30). + Local Notation point := (@point F Feq Fone Fadd Fmul a d). + Local Notation onCurve := (@onCurve F Feq Fone Fadd Fmul a d). + + Add Field _edwards_curve_theorems_field : (field_theory_for_stdlib_tactic (H:=field)). + + Definition eq (P Q:point) := fieldwise (n:=2) Feq (coordinates P) (coordinates Q). + Infix "=" := eq : E_scope. + + (* TODO: decide whether we still want something like this, then port Local Ltac t := unfold point_eqb; repeat match goal with @@ -55,246 +41,190 @@ Module E. | [H: _ |- _ ] => apply F_eqb_eq in H | _ => rewrite F_eqb_refl end; eauto. - + Lemma point_eqb_sound : forall p1 p2, point_eqb p1 p2 = true -> p1 = p2. Proof. t. Qed. - + Lemma point_eqb_complete : forall p1 p2, p1 = p2 -> point_eqb p1 p2 = true. Proof. t. Qed. - + Lemma point_eqb_neq : forall p1 p2, point_eqb p1 p2 = false -> p1 <> p2. Proof. intros. destruct (point_eqb p1 p2) eqn:Hneq; intuition. apply point_eqb_complete in H0; congruence. Qed. - + Lemma point_eqb_neq_complete : forall p1 p2, p1 <> p2 -> point_eqb p1 p2 = false. Proof. intros. destruct (point_eqb p1 p2) eqn:Hneq; intuition. apply point_eqb_sound in Hneq. congruence. Qed. - + Lemma point_eqb_refl : forall p, point_eqb p p = true. Proof. t. Qed. - + Definition point_eq_dec (p1 p2:E.point) : {p1 = p2} + {p1 <> p2}. destruct (point_eqb p1 p2) eqn:H; match goal with | [ H: _ |- _ ] => apply point_eqb_sound in H | [ H: _ |- _ ] => apply point_eqb_neq in H end; eauto. Qed. - + Lemma point_eqb_correct : forall p1 p2, point_eqb p1 p2 = if point_eq_dec p1 p2 then true else false. Proof. intros. destruct (point_eq_dec p1 p2); eauto using point_eqb_complete, point_eqb_neq_complete. Qed. - - Ltac Edefn := unfold E.add, E.add', E.zero; intros; - repeat match goal with - | [ p : E.point |- _ ] => - let x := fresh "x" p in - let y := fresh "y" p in - let pf := fresh "pf" p in - destruct p as [[x y] pf]; unfold E.onCurve in pf - | _ => eapply point_eq, (f_equal2 pair) - | _ => eapply point_eq - end. - Lemma add_comm : forall A B, (A+B = B+A)%E. - Proof. - Edefn; apply (f_equal2 div); ring. - Qed. - - Ltac unifiedAdd_nonzero := match goal with - | [ |- (?op 1 (d * _ * _ * _ * _ * - inv (1 - d * ?xA * ?xB * ?yA * ?yB) * inv (1 + d * ?xA * ?xB * ?yA * ?yB)))%F <> 0%F] - => let Hadd := fresh "Hadd" in - pose proof (@unifiedAdd'_onCurve _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d (xA, yA) (xB, yB)) as Hadd; - simpl in Hadd; - match goal with - | [H : (1 - d * ?xC * xB * ?yC * yB)%F <> 0%F |- (?op 1 ?other)%F <> 0%F] => - replace other with - (d * xC * ((xA * yB + yA * xB) / (1 + d * xA * xB * yA * yB)) - * yC * ((yA * yB - a * xA * xB) / (1 - d * xA * xB * yA * yB)))%F by (subst; unfold div; ring); - auto - end - end. - - Lemma add_assoc : forall A B C, (A+(B+C) = (A+B)+C)%E. - Proof. - Edefn; F_field_simplify_eq; try abstract (rewrite ?@F_pow_2_r in *; clear_prm; F_nsatz); - pose proof (@edwardsAddCompletePlus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d); - pose proof (@edwardsAddCompleteMinus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d); - cbv beta iota in *; - repeat split; field_nonzero idtac; unifiedAdd_nonzero. - Qed. - - Lemma add_0_r : forall P, (P + E.zero = P)%E. - Proof. - Edefn; repeat rewrite ?F_add_0_r, ?F_add_0_l, ?F_sub_0_l, ?F_sub_0_r, - ?F_mul_0_r, ?F_mul_0_l, ?F_mul_1_l, ?F_mul_1_r, ?F_div_1_r; exact eq_refl. - Qed. + *) - Lemma add_0_l : forall P, (E.zero + P)%E = P. - Proof. - intros; rewrite add_comm. apply add_0_r. - Qed. + (* TODO: move to util *) + Lemma decide_and : forall P Q, {P}+{not P} -> {Q}+{not Q} -> {P/\Q}+{not(P/\Q)}. + Proof. intros; repeat match goal with [H:{_}+{_}|-_] => destruct H end; intuition. Qed. - Lemma mul_0_l : forall P, (0 * P = E.zero)%E. - Proof. - auto. - Qed. + Ltac destruct_points := + repeat match goal with + | [ p : point |- _ ] => + let x := fresh "x" p in + let y := fresh "y" p in + let pf := fresh "pf" p in + destruct p as [[x y] pf] + end. - Lemma mul_S_l : forall n P, (S n * P)%E = (P + n * P)%E. - Proof. - auto. - Qed. + Ltac expand_opp := + rewrite ?mul_opp_r, ?mul_opp_l, ?ring_sub_definition, ?inv_inv, <-?ring_sub_definition. - Lemma mul_add_l : forall a b P, ((a + b)%nat * P)%E = E.add (a * P)%E (b * P)%E. - Proof. - induction a; intros; rewrite ?plus_Sn_m, ?plus_O_n, ?mul_S_l, ?mul_0_l, ?add_0_l, ?mul_S_, ?IHa, ?add_assoc; auto. - Qed. + Local Hint Resolve char_gt_2. + Local Hint Resolve nonzero_a. + Local Hint Resolve square_a. + Local Hint Resolve nonsquare_d. + Local Hint Resolve @edwardsAddCompletePlus. + Local Hint Resolve @edwardsAddCompleteMinus. - Lemma mul_assoc : forall (n m : nat) P, (n * (m * P) = (n * m)%nat * P)%E. - Proof. - induction n; intros; auto. - rewrite ?mul_S_l, ?Mult.mult_succ_l, ?mul_add_l, ?IHn, add_comm. reflexivity. - Qed. + Local Obligation Tactic := intros; destruct_points; simpl; field_algebra. + Program Definition opp (P:point) : point := + exist _ (let '(x, y) := coordinates P in (Fopp x, y) ) _. - Lemma mul_zero_r : forall m, (m * E.zero = E.zero)%E. - Proof. - induction m; rewrite ?mul_S_l, ?add_0_l; auto. - Qed. - - (* solve for x ^ 2 *) - Definition solve_for_x2 (y : F q) := ((y ^ 2 - 1) / (d * (y ^ 2) - a))%F. - - Lemma d_y2_a_nonzero : (forall y, 0 <> d * y ^ 2 - a)%F. - intros ? eq_zero. - pose proof prime_q. - destruct square_a as [sqrt_a sqrt_a_id]. - rewrite <- sqrt_a_id in eq_zero. - destruct (Fq_square_mul_sub _ _ _ eq_zero) as [ [sqrt_d sqrt_d_id] | a_zero]. - + pose proof (nonsquare_d sqrt_d); auto. - + subst. - rewrite Fq_pow_zero in sqrt_a_id by congruence. - auto using nonzero_a. - Qed. - - Lemma a_d_y2_nonzero : (forall y, a - d * y ^ 2 <> 0)%F. - Proof. - intros y eq_zero. - pose proof prime_q. - eapply F_minus_swap in eq_zero. - eauto using (d_y2_a_nonzero y). - Qed. - - Lemma solve_correct : forall x y, E.onCurve (x, y) <-> - (x ^ 2 = solve_for_x2 y)%F. - Proof. - split. - + intro onCurve_x_y. - pose proof prime_q. - unfold E.onCurve in onCurve_x_y. - eapply F_div_mul; auto using (d_y2_a_nonzero y). - replace (x ^ 2 * (d * y ^ 2 - a))%F with ((d * x ^ 2 * y ^ 2) - (a * x ^ 2))%F by ring. - rewrite F_sub_add_swap. - replace (y ^ 2 + a * x ^ 2)%F with (a * x ^ 2 + y ^ 2)%F by ring. - rewrite onCurve_x_y. - ring. - + intro x2_eq. - unfold E.onCurve, solve_for_x2 in *. - rewrite x2_eq. - field. - auto using d_y2_a_nonzero. - Qed. - - - Program Definition opp (P:E.point) : E.point := let '(x, y) := proj1_sig P in (opp x, y). - Next Obligation. Proof. - pose (proj2_sig P) as H; rewrite <-Heq_anonymous in H; simpl in H. - rewrite F_square_opp; trivial. - Qed. - - Definition sub P Q := (P + opp Q)%E. - - Lemma opp_zero : opp E.zero = E.zero. - Proof. - pose proof @F_opp_0. - unfold opp, E.zero; eapply point_eq; congruence. - Qed. - - Lemma add_opp_r : forall P, (P + opp P = E.zero)%E. - Proof. - unfold opp; Edefn; rewrite ?@F_pow_2_r in *; (F_field_simplify_eq; [clear_prm; F_nsatz|..]); - rewrite <-?@F_pow_2_r in *; - pose proof (@edwardsAddCompletePlus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d _ _ _ _ pfP pfP); - pose proof (@edwardsAddCompleteMinus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d _ _ _ _ pfP pfP); - field_nonzero idtac. - Qed. - - Lemma add_opp_l : forall P, (opp P + P = E.zero)%E. - Proof. - intros. rewrite add_comm. eapply add_opp_r. - Qed. - - Lemma add_cancel_r : forall A B C, (B+A = C+A -> B = C)%E. - Proof. - intros. - assert ((B + A) + opp A = (C + A) + opp A)%E as Hc by congruence. - rewrite <-!add_assoc, !add_opp_r, !add_0_r in Hc; exact Hc. - Qed. - - Lemma add_cancel_l : forall A B C, (A+B = A+C -> B = C)%E. - Proof. - intros. - rewrite (add_comm A C) in H. - rewrite (add_comm A B) in H. - eauto using add_cancel_r. - Qed. - - Lemma shuffle_eq_add_opp : forall P Q R, (P + Q = R <-> Q = opp P + R)%E. - Proof. - split; intros. - { assert (opp P + (P + Q) = opp P + R)%E as Hc by congruence. - rewrite add_assoc, add_opp_l, add_comm, add_0_r in Hc; exact Hc. } - { subst. rewrite add_assoc, add_opp_r, add_comm, add_0_r; reflexivity. } - Qed. - - Lemma opp_opp : forall P, opp (opp P) = P. - Proof. - intros. - pose proof (add_opp_r P%E) as H. - rewrite add_comm in H. - rewrite shuffle_eq_add_opp in H. - rewrite add_0_r in H. - congruence. - Qed. - - Lemma opp_add : forall P Q, opp (P + Q)%E = (opp P + opp Q)%E. + Ltac bash_step := + match goal with + | |- _ => progress intros + | [H: _ /\ _ |- _ ] => destruct H + | |- _ => progress destruct_points + | |- _ => progress cbv [fst snd coordinates proj1_sig eq fieldwise fieldwise' add zero opp] in * + | |- _ => split + | |- Feq _ _ => field_algebra + | |- _ <> 0 => expand_opp; solve [nsatz_nonzero|eauto 6] + | |- {_}+{_} => eauto 15 using decide_and, @eq_dec with typeclass_instances + end. + + Ltac bash := repeat bash_step. + + Global Instance Proper_add : Proper (eq==>eq==>eq) add. Proof. bash. Qed. + Global Instance Proper_opp : Proper (eq==>eq) opp. Proof. bash. Qed. + Global Instance Proper_coordinates : Proper (eq==>fieldwise (n:=2) Feq) coordinates. Proof. bash. Qed. + + Global Instance edwards_acurve_abelian_group : abelian_group (eq:=eq)(op:=add)(id:=zero)(inv:=opp). + Proof. + bash. + (* TODO: port denominator-nonzero proofs for associativity *) + match goal with | |- _ <> 0 => admit end. + match goal with | |- _ <> 0 => admit end. + match goal with | |- _ <> 0 => admit end. + match goal with | |- _ <> 0 => admit end. + Admitted. + + (* TODO: move to [Group] and [AbelianGroup] as appropriate *) + Lemma mul_0_l : forall P, (0 * P = zero)%E. + Proof. intros; reflexivity. Qed. + Lemma mul_S_l : forall n P, (S n * P = P + n * P)%E. + Proof. intros; reflexivity. Qed. + Lemma mul_add_l : forall (n m:nat) (P:point), ((n + m)%nat * P = n * P + m * P)%E. Proof. - intros. - pose proof (add_opp_r (P+Q)%E) as H. - rewrite <-!add_assoc in H. - rewrite add_comm in H. - rewrite <-!add_assoc in H. - rewrite shuffle_eq_add_opp in H. - rewrite add_comm in H. - rewrite shuffle_eq_add_opp in H. - rewrite add_0_r in H. - assumption. + induction n; intros; + rewrite ?plus_Sn_m, ?plus_O_n, ?mul_S_l, ?left_identity, <-?associative, <-?IHn; reflexivity. Qed. - - Lemma opp_mul : forall n P, opp (E.mul n P) = E.mul n (opp P). + Lemma mul_assoc : forall (n m : nat) P, (n * (m * P) = (n * m)%nat * P)%E. Proof. - pose proof opp_add; pose proof opp_zero. - induction n; simpl; intros; congruence. + induction n; intros; [reflexivity|]. + rewrite ?mul_S_l, ?Mult.mult_succ_l, ?mul_add_l, ?IHn, commutative; reflexivity. Qed. + Lemma mul_zero_r : forall m, (m * E.zero = E.zero)%E. + Proof. induction m; rewrite ?mul_S_l, ?left_identity, ?IHm; try reflexivity. Qed. + Lemma opp_mul : forall n P, (opp (n * P) = n * (opp P))%E. + Admitted. + + Section PointCompression. + Local Notation "x ^2" := (x*x). + Definition solve_for_x2 (y : F) := ((y^2 - 1) / (d * (y^2) - a)). + + Lemma a_d_y2_nonzero : forall y, d * y^2 - a <> 0. + Proof. + intros ? eq_zero. + destruct square_a as [sqrt_a sqrt_a_id]; rewrite <- sqrt_a_id in eq_zero. + destruct (eq_dec y 0); [apply nonzero_a|apply nonsquare_d with (sqrt_a/y)]; field_algebra. + Qed. + + Lemma solve_correct : forall x y, onCurve (x, y) <-> (x^2 = solve_for_x2 y). + Proof. + unfold solve_for_x2; simpl; split; intros; field_algebra; auto using a_d_y2_nonzero. + Qed. + End PointCompression. End CompleteEdwardsCurveTheorems. + + Section Homomorphism. + Context {F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv Fa Fd} + {fieldF:@field F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv} + {Fprm:@twisted_edwards_params F Feq Fzero Fone Fadd Fmul Fa Fd}. + Context {K Keq Kzero Kone Kopp Kadd Ksub Kmul Kinv Kdiv Ka Kd} + {fieldK:@field K Keq Kzero Kone Kopp Kadd Ksub Kmul Kinv Kdiv} + {Kprm:@twisted_edwards_params K Keq Kzero Kone Kadd Kmul Ka Kd}. + Context {phi:F->K} {Hphi:@Ring.is_homomorphism F Feq Fone Fadd Fmul + K Keq Kone Kadd Kmul phi}. + Context {Ha:Keq (phi Fa) Ka} {Hd:Keq (phi Fd) Kd}. + Local Notation Fpoint := (@point F Feq Fone Fadd Fmul Fa Fd). + Local Notation Kpoint := (@point K Keq Kone Kadd Kmul Ka Kd). + + Create HintDb field_homomorphism discriminated. + Hint Rewrite <- + homomorphism_one + homomorphism_add + homomorphism_sub + homomorphism_mul + homomorphism_div + Ha + Hd + : field_homomorphism. + + Program Definition ref_phi (P:Fpoint) : Kpoint := exist _ ( + let (x, y) := coordinates P in (phi x, phi y)) _. + Next Obligation. + destruct P as [[? ?] ?]; simpl. + rewrite_strat bottomup hints field_homomorphism. + eauto using is_homomorphism_phi_proper; assumption. + Qed. + + Context {point_phi:Fpoint->Kpoint} + {point_phi_Proper:Proper (eq==>eq) point_phi} + {point_phi_correct: forall (P:Fpoint), eq (point_phi P) (ref_phi P)}. + + Lemma lift_homomorphism : @Group.is_homomorphism Fpoint eq add Kpoint eq add point_phi. + Proof. + repeat match goal with + | |- Group.is_homomorphism => split + | |- _ => intro + | |- _ /\ _ => split + | [H: _ /\ _ |- _ ] => destruct H + | [p: point |- _ ] => destruct p as [[??]?] + | |- context[point_phi] => setoid_rewrite point_phi_correct + | |- _ => progress cbv [fst snd coordinates proj1_sig eq fieldwise fieldwise' add zero opp ref_phi] in * + | |- Keq ?x ?x => reflexivity + | |- Keq ?x ?y => rewrite_strat bottomup hints field_homomorphism + | [ H : Feq _ _ |- Keq (phi _) (phi _)] => solve [f_equiv; intuition] + end. + Qed. + End Homomorphism. End E. -Infix "-" := E.sub : E_scope.
\ No newline at end of file diff --git a/src/CompleteEdwardsCurve/DoubleAndAdd.v b/src/CompleteEdwardsCurve/DoubleAndAdd.v deleted file mode 100644 index 50027349d..000000000 --- a/src/CompleteEdwardsCurve/DoubleAndAdd.v +++ /dev/null @@ -1,30 +0,0 @@ -Require Import Crypto.Tactics.VerdiTactics. -Require Import Crypto.Spec.CompleteEdwardsCurve Crypto.Util.IterAssocOp. -Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. -Require Import Coq.Numbers.BinNums Coq.NArith.NArith Coq.NArith.Nnat Coq.ZArith.ZArith. - -Section EdwardsDoubleAndAdd. - Context {prm:TwistedEdwardsParams}. - Definition doubleAndAdd (bound n : nat) (P : E.point) : E.point := - iter_op E.add E.zero N.testbit_nat (N.of_nat n) P bound. - - Lemma scalarMult_double : forall n P, E.mul (n + n) P = E.mul n (P + P)%E. - Proof. - intros. - replace (n + n)%nat with (n * 2)%nat by omega. - induction n; simpl; auto. - rewrite E.add_assoc. - f_equal; auto. - Qed. - - Lemma doubleAndAdd_spec : forall bound n P, N.size_nat (N.of_nat n) <= bound -> - E.mul n P = doubleAndAdd bound n P. - Proof. - induction n; auto; intros; unfold doubleAndAdd; - rewrite iter_op_spec with (scToN := fun x => x); ( - unfold Morphisms.Proper, Morphisms.respectful, Equivalence.equiv; - intros; subst; try rewrite Nat2N.id; - reflexivity || assumption || apply E.add_assoc - || rewrite E.add_comm; apply E.add_0_r). - Qed. -End EdwardsDoubleAndAdd.
\ No newline at end of file diff --git a/src/CompleteEdwardsCurve/ExtendedCoordinates.v b/src/CompleteEdwardsCurve/ExtendedCoordinates.v index e91bc084b..25af83a0a 100644 --- a/src/CompleteEdwardsCurve/ExtendedCoordinates.v +++ b/src/CompleteEdwardsCurve/ExtendedCoordinates.v @@ -1,194 +1,154 @@ -Require Import Crypto.CompleteEdwardsCurve.Pre. -Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.ModularArithmetic.FField. -Require Import Crypto.Tactics.VerdiTactics. -Require Import Util.IterAssocOp BinNat NArith. -Require Import Coq.Setoids.Setoid Coq.Classes.Morphisms Coq.Classes.Equivalence. -Local Open Scope equiv_scope. -Local Open Scope F_scope. - -Section ExtendedCoordinates. - Context {prm:TwistedEdwardsParams}. - Local Opaque q a d prime_q two_lt_q nonzero_a square_a nonsquare_d. (* [F_field] calls [compute] *) - Existing Instance prime_q. - - Add Field Ffield_p' : (@Ffield_theory q _) - (morphism (@Fring_morph q), - preprocess [Fpreprocess], - postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption], - constants [Fconstant], - div (@Fmorph_div_theory q), - power_tac (@Fpower_theory q) [Fexp_tac]). - - Add Field Ffield_notConstant : (OpaqueFieldTheory q) - (constants [notConstant]). - - (** [extended] represents a point on an elliptic curve using extended projective - * Edwards coordinates with twist a=-1 (see <https://eprint.iacr.org/2008/522.pdf>). *) - Record extended := mkExtended {extendedX : F q; - extendedY : F q; - extendedZ : F q; - extendedT : F q}. - Local Notation "'(' X ',' Y ',' Z ',' T ')'" := (mkExtended X Y Z T). - - Definition twistedToExtended (P : (F q*F q)) : extended := - let '(x, y) := P in (x, y, 1, x*y). - Definition extendedToTwisted (P : extended) : F q * F q := - let '(X, Y, Z, T) := P in ((X/Z), (Y/Z)). - Definition rep (P:extended) (rP:(F q*F q)) : Prop := - let '(X, Y, Z, T) := P in - extendedToTwisted P = rP /\ - Z <> 0 /\ - T = X*Y/Z. - Local Hint Unfold twistedToExtended extendedToTwisted rep. - Local Notation "P '~=' rP" := (rep P rP) (at level 70). - - Ltac unfoldExtended := - repeat progress (autounfold; unfold E.onCurve, E.add, E.add', rep in *; intros); - repeat match goal with - | [ p : (F q*F q)%type |- _ ] => - let x := fresh "x" p in - let y := fresh "y" p in - destruct p as [x y] - | [ p : extended |- _ ] => - let X := fresh "X" p in - let Y := fresh "Y" p in - let Z := fresh "Z" p in - let T := fresh "T" p in - destruct p as [X Y Z T] - | [ H: _ /\ _ |- _ ] => destruct H - | [ H: @eq (F q * F q)%type _ _ |- _ ] => invcs H - | [ H: @eq F q ?x _ |- _ ] => isVar x; rewrite H; clear H - end. - - Ltac solveExtended := unfoldExtended; - repeat match goal with - | [ |- _ /\ _ ] => split - | [ |- @eq (F q * F q)%type _ _] => apply f_equal2 - | _ => progress rewrite ?@F_add_0_r, ?@F_add_0_l, ?@F_sub_0_l, ?@F_sub_0_r, - ?@F_mul_0_r, ?@F_mul_0_l, ?@F_mul_1_l, ?@F_mul_1_r, ?@F_div_1_r - | _ => solve [eapply @Fq_1_neq_0; eauto with typeclass_instances] - | _ => solve [eauto with typeclass_instances] - | [ H: a = _ |- _ ] => rewrite H - end. - - Lemma twistedToExtended_rep : forall P, twistedToExtended P ~= P. - Proof. - solveExtended. - Qed. - - Lemma extendedToTwisted_rep : forall P rP, P ~= rP -> extendedToTwisted P = rP. - Proof. - solveExtended. - Qed. - - Definition extendedPoint := { P:extended | rep P (extendedToTwisted P) /\ E.onCurve (extendedToTwisted P) }. - - Program Definition mkExtendedPoint : E.point -> extendedPoint := twistedToExtended. - Next Obligation. - destruct x; erewrite extendedToTwisted_rep; eauto using twistedToExtended_rep. - Qed. - - Program Definition unExtendedPoint : extendedPoint -> E.point := extendedToTwisted. - Next Obligation. - destruct x; simpl; intuition. - Qed. - - Definition extendedPoint_eq P Q := unExtendedPoint P = unExtendedPoint Q. - Global Instance Equivalence_extendedPoint_eq : Equivalence extendedPoint_eq. - Proof. - repeat (econstructor || intro); unfold extendedPoint_eq in *; congruence. - Qed. - - Lemma unExtendedPoint_mkExtendedPoint : forall P, unExtendedPoint (mkExtendedPoint P) = P. - Proof. - destruct P; eapply E.point_eq; simpl; erewrite extendedToTwisted_rep; eauto using twistedToExtended_rep. - Qed. - - Global Instance Proper_mkExtendedPoint : Proper (eq==>equiv) mkExtendedPoint. - Proof. - repeat (econstructor || intro); unfold extendedPoint_eq in *; congruence. - Qed. - - Global Instance Proper_unExtendedPoint : Proper (equiv==>eq) unExtendedPoint. - Proof. - repeat (econstructor || intro); unfold extendedPoint_eq in *; congruence. - Qed. - - Definition twice_d := d + d. - - Section TwistMinus1. - Context (a_eq_minus1 : a = opp 1). - (** Second equation from <http://eprint.iacr.org/2008/522.pdf> section 3.1, also <https://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html#addition-add-2008-hwcd-3> and <https://tools.ietf.org/html/draft-josefsson-eddsa-ed25519-03> *) - Definition unifiedAddM1' (P1 P2 : extended) : extended := - let '(X1, Y1, Z1, T1) := P1 in - let '(X2, Y2, Z2, T2) := P2 in - let A := (Y1-X1)*(Y2-X2) in - let B := (Y1+X1)*(Y2+X2) in - let C := T1*twice_d*T2 in - let D := Z1*(Z2+Z2) in - let E := B-A in - let F := D-C in - let G := D+C in - let H := B+A in - let X3 := E*F in - let Y3 := G*H in - let T3 := E*H in - let Z3 := F*G in - (X3, Y3, Z3, T3). - Local Hint Unfold E.add. - - Local Ltac tnz := repeat apply Fq_mul_nonzero_nonzero; auto using (@char_gt_2 q two_lt_q). - - Lemma F_mul_2_l : forall x : F q, ZToField 2 * x = x + x. - intros. ring. - Qed. - - Lemma unifiedAddM1'_rep: forall P Q rP rQ, E.onCurve rP -> E.onCurve rQ -> - P ~= rP -> Q ~= rQ -> (unifiedAddM1' P Q) ~= (E.add' rP rQ). - Proof. - intros P Q rP rQ HoP HoQ HrP HrQ. - pose proof (@edwardsAddCompletePlus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d). - pose proof (@edwardsAddCompleteMinus _ _ _ _ two_lt_q nonzero_a square_a nonsquare_d). - unfoldExtended; unfold twice_d; rewrite a_eq_minus1 in *; simpl in *. repeat rewrite <-F_mul_2_l. - repeat split; repeat apply (f_equal2 pair); try F_field; repeat split; auto; - repeat rewrite ?F_add_0_r, ?F_add_0_l, ?F_sub_0_l, ?F_sub_0_r, - ?F_mul_0_r, ?F_mul_0_l, ?F_mul_1_l, ?F_mul_1_r, ?F_div_1_r; - field_nonzero tnz. - Qed. - - Lemma unifiedAdd'_onCurve : forall P Q, E.onCurve P -> E.onCurve Q -> E.onCurve (E.add' P Q). - Proof. - intros; pose proof (proj2_sig (E.add (exist _ _ H) (exist _ _ H0))); eauto. - Qed. - - Program Definition unifiedAddM1 : extendedPoint -> extendedPoint -> extendedPoint := unifiedAddM1'. - Next Obligation. - destruct x, x0; simpl; intuition. - - erewrite extendedToTwisted_rep; eauto using unifiedAddM1'_rep. - - erewrite extendedToTwisted_rep. - (* It would be nice if I could use eauto here, but it gets evars wrong :( *) - 2: eapply unifiedAddM1'_rep. 5:apply H1. 4:apply H. 3:auto. 2:auto. - eauto using unifiedAdd'_onCurve. - Qed. - - Lemma unifiedAddM1_rep : forall P Q, E.add (unExtendedPoint P) (unExtendedPoint Q) = unExtendedPoint (unifiedAddM1 P Q). - Proof. - destruct P, Q; unfold unExtendedPoint, E.add, unifiedAddM1; eapply E.point_eq; simpl in *; intuition. - pose proof (unifiedAddM1'_rep x x0 (extendedToTwisted x) (extendedToTwisted x0)); - destruct (unifiedAddM1' x x0); - unfold rep in *; intuition. - Qed. - - Global Instance Proper_unifiedAddM1 : Proper (equiv==>equiv==>equiv) unifiedAddM1. - Proof. - repeat (econstructor || intro). - repeat match goal with [H: _ === _ |- _ ] => inversion H; clear H end; unfold equiv, extendedPoint_eq. - rewrite <-!unifiedAddM1_rep. - destruct x, y, x0, y0; simpl in *; eapply E.point_eq; congruence. - Qed. +Require Export Crypto.Spec.CompleteEdwardsCurve. +Require Import Crypto.Algebra Crypto.Tactics.Nsatz. +Require Import Crypto.CompleteEdwardsCurve.Pre Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. +Require Import Coq.Logic.Eqdep_dec. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Coq.Classes.Morphisms. +Require Import Relation_Definitions. +Require Import Crypto.Util.Tuple. + +Module Extended. + Section ExtendedCoordinates. + Import Group Ring Field. + Context {F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv a d} + {field:@field F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv} + {prm:@E.twisted_edwards_params F Feq Fzero Fone Fadd Fmul a d}. + Local Infix "=" := Feq : type_scope. Local Notation "a <> b" := (not (a = b)) : type_scope. + Local Notation "0" := Fzero. Local Notation "1" := Fone. + Local Infix "+" := Fadd. Local Infix "*" := Fmul. + Local Infix "-" := Fsub. Local Infix "/" := Fdiv. + Local Notation "x ^2" := (x*x) (at level 30). + Local Notation Epoint := (@E.point F Feq Fone Fadd Fmul a d). + Local Notation onCurve := (@Pre.onCurve F Feq Fone Fadd Fmul a d). + + Add Field _edwards_curve_extended_field : (field_theory_for_stdlib_tactic (H:=field)). + + (** [Extended.point] represents a point on an elliptic curve using extended projective + * Edwards coordinates with twist a=-1 (see <https://eprint.iacr.org/2008/522.pdf>). *) + Definition point := { P | let '(X,Y,Z,T) := P in onCurve((X/Z), (Y/Z)) /\ Z<>0 /\ Z*T=X*Y }. + Definition coordinates (P:point) : F*F*F*F := proj1_sig P. + + Create HintDb bash discriminated. + Local Hint Unfold E.eq fst snd fieldwise fieldwise' coordinates E.coordinates proj1_sig Pre.onCurve : bash. + Ltac bash := + repeat match goal with + | |- Proper _ _ => intro + | _ => progress intros + | [ H: _ /\ _ |- _ ] => destruct H + | [ p:E.point |- _ ] => destruct p as [[??]?] + | [ p:point |- _ ] => destruct p as [[[[??]?]?]?] + | _ => progress autounfold with bash in * + | |- _ /\ _ => split + | _ => solve [neq01] + | _ => solve [eauto] + | _ => solve [intuition] + | _ => solve [etransitivity; eauto] + | |- Feq _ _ => field_algebra + | |- _ <> 0 => apply mul_nonzero_nonzero + | [ H : _ <> 0 |- _ <> 0 ] => + intro; apply H; + field_algebra; + solve [ apply Ring.opp_nonzero_nonzero, E.char_gt_2 + | apply E.char_gt_2] + end. + + Obligation Tactic := bash. + + Program Definition from_twisted (P:Epoint) : point := exist _ + (let (x,y) := E.coordinates P in (x, y, 1, x*y)) _. + + Program Definition to_twisted (P:point) : Epoint := exist _ + (let '(X,Y,Z,T) := coordinates P in ((X/Z), (Y/Z))) _. + + Definition eq (P Q:point) := E.eq (to_twisted P) (to_twisted Q). + + Local Hint Unfold from_twisted to_twisted eq : bash. + + Global Instance Equivalence_eq : Equivalence eq. Proof. split; split; bash. Qed. + Global Instance Proper_from_twisted : Proper (E.eq==>eq) from_twisted. Proof. bash. Qed. + Global Instance Proper_to_twisted : Proper (eq==>E.eq) to_twisted. Proof. bash. Qed. + Lemma to_twisted_from_twisted P : E.eq (to_twisted (from_twisted P)) P. Proof. bash. Qed. + + Section TwistMinus1. + Context {a_eq_minus1 : a = Fopp 1}. + Context {twice_d:F} {Htwice_d:twice_d = d + d}. + (** Second equation from <http://eprint.iacr.org/2008/522.pdf> section 3.1, also <https://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html#addition-add-2008-hwcd-3> and <https://tools.ietf.org/html/draft-josefsson-eddsa-ed25519-03> *) + Definition add_coordinates P1 P2 : F*F*F*F := + let '(X1, Y1, Z1, T1) := P1 in + let '(X2, Y2, Z2, T2) := P2 in + let A := (Y1-X1)*(Y2-X2) in + let B := (Y1+X1)*(Y2+X2) in + let C := T1*twice_d*T2 in + let D := Z1*(Z2+Z2) in + let E := B-A in + let F := D-C in + let G := D+C in + let H := B+A in + let X3 := E*F in + let Y3 := G*H in + let T3 := E*H in + let Z3 := F*G in + (X3, Y3, Z3, T3). + + Local Hint Unfold E.add E.coordinates add_coordinates : bash. + + Lemma add_coordinates_correct (P Q:point) : + let '(X,Y,Z,T) := add_coordinates (coordinates P) (coordinates Q) in + let (x, y) := E.coordinates (E.add (to_twisted P) (to_twisted Q)) in + (fieldwise (n:=2) Feq) (x, y) (X/Z, Y/Z). + Proof. + destruct P as [[[[]?]?][HP []]]; destruct Q as [[[[]?]?][HQ []]]. + pose proof edwardsAddCompletePlus (a_nonzero:=E.nonzero_a)(a_square:=E.square_a)(d_nonsquare:=E.nonsquare_d)(char_gt_2:=E.char_gt_2) _ _ _ _ HP HQ. + pose proof edwardsAddCompleteMinus (a_nonzero:=E.nonzero_a)(a_square:=E.square_a)(d_nonsquare:=E.nonsquare_d)(char_gt_2:=E.char_gt_2) _ _ _ _ HP HQ. + bash. + Qed. + + Obligation Tactic := idtac. + Program Definition add (P Q:point) : point := add_coordinates (coordinates P) (coordinates Q). + Next Obligation. + intros. + pose proof (add_coordinates_correct P Q) as Hrep. + pose proof Pre.unifiedAdd'_onCurve(a_nonzero:=E.nonzero_a)(a_square:=E.square_a)(d_nonsquare:=E.nonsquare_d)(char_gt_2:=E.char_gt_2) (E.coordinates (to_twisted P)) (E.coordinates (to_twisted Q)) as Hon. + destruct P as [[[[]?]?][HP []]]; destruct Q as [[[[]?]?][HQ []]]. + pose proof edwardsAddCompletePlus (a_nonzero:=E.nonzero_a)(a_square:=E.square_a)(d_nonsquare:=E.nonsquare_d)(char_gt_2:=E.char_gt_2) _ _ _ _ HP HQ as Hnz1. + pose proof edwardsAddCompleteMinus (a_nonzero:=E.nonzero_a)(a_square:=E.square_a)(d_nonsquare:=E.nonsquare_d)(char_gt_2:=E.char_gt_2) _ _ _ _ HP HQ as Hnz2. + autounfold with bash in *; simpl in *. + destruct Hrep as [HA HB]. rewrite <-!HA, <-!HB; clear HA HB. + bash. + Qed. + Local Hint Unfold add : bash. + + Lemma to_twisted_add P Q : E.eq (to_twisted (add P Q)) (E.add (to_twisted P) (to_twisted Q)). + Proof. + pose proof (add_coordinates_correct P Q) as Hrep. + destruct P as [[[[]?]?][HP []]]; destruct Q as [[[[]?]?][HQ []]]. + autounfold with bash in *; simpl in *. + destruct Hrep as [HA HB]. rewrite <-!HA, <-!HB; clear HA HB. + split; reflexivity. + Qed. + + Global Instance Proper_add : Proper (eq==>eq==>eq) add. + Proof. + unfold eq. intros x y H x0 y0 H0. + transitivity (to_twisted x + to_twisted x0)%E; rewrite to_twisted_add, ?H, ?H0; reflexivity. + Qed. + + Lemma homomorphism_to_twisted : @Group.is_homomorphism point eq add Epoint E.eq E.add to_twisted. + Proof. split; trivial using Proper_to_twisted, to_twisted_add. Qed. + + Lemma add_from_twisted P Q : eq (from_twisted (P + Q)%E) (add (from_twisted P) (from_twisted Q)). + Proof. + pose proof (to_twisted_add (from_twisted P) (from_twisted Q)). + unfold eq; rewrite !to_twisted_from_twisted in *. + symmetry; assumption. + Qed. + + Lemma homomorphism_from_twisted : @Group.is_homomorphism Epoint E.eq E.add point eq add from_twisted. + Proof. split; trivial using Proper_from_twisted, add_from_twisted. Qed. + + (* TODO: decide whether we still need those, then port *) + (* Lemma unifiedAddM1_0_r : forall P, unifiedAddM1 P (mkExtendedPoint E.zero) === P. unfold equiv, extendedPoint_eq; intros. rewrite <-!unifiedAddM1_rep, unExtendedPoint_mkExtendedPoint, E.add_0_r; auto. @@ -210,30 +170,75 @@ Section ExtendedCoordinates. trivial. Qed. - Definition scalarMultM1 := iter_op unifiedAddM1 (mkExtendedPoint E.zero) N.testbit_nat. - Definition scalarMultM1_spec := - iter_op_spec unifiedAddM1 unifiedAddM1_assoc (mkExtendedPoint E.zero) unifiedAddM1_0_l - N.testbit_nat (fun x => x) testbit_conversion_identity. - Lemma scalarMultM1_rep : forall n P, unExtendedPoint (scalarMultM1 (N.of_nat n) P (N.size_nat (N.of_nat n))) = E.mul n (unExtendedPoint P). - intros; rewrite scalarMultM1_spec, Nat2N.id; auto. - induction n; [simpl; rewrite !unExtendedPoint_mkExtendedPoint; reflexivity|]. + Lemma scalarMultM1_rep : forall n P, unExtendedPoint (nat_iter_op unifiedAddM1 (mkExtendedPoint E.zero) n P) = E.mul n (unExtendedPoint P). + induction n; [simpl; rewrite !unExtendedPoint_mkExtendedPoint; reflexivity|]; intros. unfold E.mul; fold E.mul. rewrite <-IHn, unifiedAddM1_rep; auto. Qed. + *) + End TwistMinus1. + End ExtendedCoordinates. + + Section Homomorphism. + Import Group Ring Field. + Context {F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv Fa Fd} + {fieldF:@field F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv} + {Fprm:@E.twisted_edwards_params F Feq Fzero Fone Fadd Fmul Fa Fd}. + Context {K Keq Kzero Kone Kopp Kadd Ksub Kmul Kinv Kdiv Ka Kd} + {fieldK:@field K Keq Kzero Kone Kopp Kadd Ksub Kmul Kinv Kdiv} + {Kprm:@E.twisted_edwards_params K Keq Kzero Kone Kadd Kmul Ka Kd}. + Context {phi:F->K} {Hphi:@Ring.is_homomorphism F Feq Fone Fadd Fmul + K Keq Kone Kadd Kmul phi}. + Context {phi_nonzero : forall x, ~ Feq x Fzero -> ~ Keq (phi x) Kzero}. + Context {HFa: Feq Fa (Fopp Fone)} {HKa:Keq Ka (Kopp Kone)}. + Context {Hd:Keq (phi Fd) Kd} {Kdd Fdd} {HKdd:Keq Kdd (Kadd Kd Kd)} {HFdd:Feq Fdd (Fadd Fd Fd)}. + Local Notation Fpoint := (@point F Feq Fzero Fone Fadd Fmul Fdiv Fa Fd). + Local Notation Kpoint := (@point K Keq Kzero Kone Kadd Kmul Kdiv Ka Kd). + + Lemma Ha : Keq (phi Fa) Ka. + Proof. rewrite HFa, HKa, <-homomorphism_one. eapply homomorphism_opp. Qed. + + Lemma Hdd : Keq (phi Fdd) Kdd. + Proof. rewrite HFdd, HKdd. rewrite homomorphism_add. repeat f_equiv; auto. Qed. + + Create HintDb field_homomorphism discriminated. + Hint Rewrite <- + homomorphism_one + homomorphism_add + homomorphism_sub + homomorphism_mul + homomorphism_div + Ha + Hd + Hdd + : field_homomorphism. + + Program Definition ref_phi (P:Fpoint) : Kpoint := exist _ ( + let '(X, Y, Z, T) := coordinates P in (phi X, phi Y, phi Z, phi T)) _. + Next Obligation. + destruct P as [[[[] ?] ?] [? [? ?]]]; unfold onCurve in *; simpl. + rewrite_strat bottomup hints field_homomorphism. + eauto 10 using is_homomorphism_phi_proper, phi_nonzero. + Qed. + + Context {point_phi:Fpoint->Kpoint} + {point_phi_Proper:Proper (eq==>eq) point_phi} + {point_phi_correct: forall (P:Fpoint), eq (point_phi P) (ref_phi P)}. - End TwistMinus1. - - Definition negateExtended' P := let '(X, Y, Z, T) := P in (opp X, Y, Z, opp T). - Program Definition negateExtended (P:extendedPoint) : extendedPoint := negateExtended' (proj1_sig P). - Next Obligation. - Proof. - unfold negateExtended', rep; destruct P as [[X Y Z T] H]; simpl. destruct H as [[[] []] ?]; subst. - repeat rewrite ?F_div_opp_1, ?F_mul_opp_l, ?F_square_opp; repeat split; trivial. - Qed. - - Lemma negateExtended_correct : forall P, E.opp (unExtendedPoint P) = unExtendedPoint (negateExtended P). - Proof. - unfold E.opp, unExtendedPoint, negateExtended; destruct P as [[]]; simpl; intros. - eapply E.point_eq; repeat rewrite ?F_div_opp_1, ?F_mul_opp_l, ?F_square_opp; trivial. - Qed. -End ExtendedCoordinates. + Lemma lift_homomorphism : @Group.is_homomorphism Fpoint eq (add(a_eq_minus1:=HFa)(Htwice_d:=HFdd)) Kpoint eq (add(a_eq_minus1:=HKa)(Htwice_d:=HKdd)) point_phi. + Proof. + repeat match goal with + | |- Group.is_homomorphism => split + | |- _ => intro + | |- _ /\ _ => split + | [H: _ /\ _ |- _ ] => destruct H + | [p: point |- _ ] => destruct p as [[[[] ?] ?] [? [? ?]]] + | |- context[point_phi] => setoid_rewrite point_phi_correct + | |- _ => progress cbv [fst snd coordinates proj1_sig eq to_twisted E.eq E.coordinates fieldwise fieldwise' add add_coordinates ref_phi] in * + | |- Keq ?x ?x => reflexivity + | |- Keq ?x ?y => rewrite_strat bottomup hints field_homomorphism + | [ H : Feq _ _ |- Keq (phi _) (phi _)] => solve [f_equiv; intuition] + end. + Qed. + End Homomorphism. +End Extended. diff --git a/src/CompleteEdwardsCurve/Pre.v b/src/CompleteEdwardsCurve/Pre.v index fea4a99b3..5314ee37c 100644 --- a/src/CompleteEdwardsCurve/Pre.v +++ b/src/CompleteEdwardsCurve/Pre.v @@ -1,186 +1,100 @@ -Require Import Coq.ZArith.BinInt Coq.ZArith.Znumtheory Crypto.Tactics.VerdiTactics. +Require Import Coq.Classes.Morphisms. Require Coq.Setoids.Setoid. +Require Import Crypto.Algebra Crypto.Tactics.Nsatz. -Require Import Crypto.Spec.ModularArithmetic. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Local Open Scope F_scope. - +Generalizable All Variables. Section Pre. - Context {q : BinInt.Z}. - Context {a : F q}. - Context {d : F q}. - Context {prime_q : Znumtheory.prime q}. - Context {two_lt_q : 2 < q}. - Context {a_nonzero : a <> 0}. - Context {a_square : exists sqrt_a, sqrt_a^2 = a}. - Context {d_nonsquare : forall x, x^2 <> d}. - - Add Field Ffield_Z : (@Ffield_theory q _) - (morphism (@Fring_morph q), - preprocess [Fpreprocess], - postprocess [Fpostprocess], - constants [Fconstant], - div (@Fmorph_div_theory q), - power_tac (@Fpower_theory q) [Fexp_tac]). - + Context {F eq zero one opp add sub mul inv div} `{field F eq zero one opp add sub mul inv div}. + Local Infix "=" := eq. Local Notation "a <> b" := (not (a = b)). + Local Infix "=" := eq : type_scope. Local Notation "a <> b" := (not (a = b)) : type_scope. + Local Notation "0" := zero. Local Notation "1" := one. + Local Infix "+" := add. Local Infix "*" := mul. + Local Infix "-" := sub. Local Infix "/" := div. + Local Notation "x '^' 2" := (x*x) (at level 30). + + Add Field EdwardsCurveField : (Field.field_theory_for_stdlib_tactic (T:=F)). + + Context {a:F} {a_nonzero : a<>0} {a_square : exists sqrt_a, sqrt_a^2 = a}. + Context {d:F} {d_nonsquare : forall sqrt_d, sqrt_d^2 <> d}. + Context {char_gt_2 : 1+1 <> 0}. + (* the canonical definitions are in Spec *) - Local Notation onCurve P := (let '(x, y) := P in a*x^2 + y^2 = 1 + d*x^2*y^2). - Local Notation unifiedAdd' P1' P2' := ( - let '(x1, y1) := P1' in - let '(x2, y2) := P2' in - (((x1*y2 + y1*x2)/(1 + d*x1*x2*y1*y2)) , ((y1*y2 - a*x1*x2)/(1 - d*x1*x2*y1*y2))) - ). - - Lemma char_gt_2 : ZToField 2 <> (0: F q). - intro; find_injection. - pose proof two_lt_q. - rewrite (Z.mod_small 2 q), Z.mod_0_l in *; omega. - Qed. + Definition onCurve (P:F*F) := let (x, y) := P in a*x^2 + y^2 = 1 + d*x^2*y^2. + Definition unifiedAdd' (P1' P2':F*F) : F*F := + let (x1, y1) := P1' in + let (x2, y2) := P2' in + pair (((x1*y2 + y1*x2)/(1 + d*x1*x2*y1*y2))) (((y1*y2 - a*x1*x2)/(1 - d*x1*x2*y1*y2))). - Ltac rewriteAny := match goal with [H: _ = _ |- _ ] => rewrite H end. - Ltac rewriteLeftAny := match goal with [H: _ = _ |- _ ] => rewrite <- H end. - - Ltac whatsNotZero := - repeat match goal with - | [H: ?lhs = ?rhs |- _ ] => - match goal with [Ha: lhs <> 0 |- _ ] => fail 1 | _ => idtac end; - assert (lhs <> 0) by (rewrite H; auto using Fq_1_neq_0) - | [H: ?lhs = ?rhs |- _ ] => - match goal with [Ha: rhs <> 0 |- _ ] => fail 1 | _ => idtac end; - assert (rhs <> 0) by (rewrite H; auto using Fq_1_neq_0) - | [H: (?a^?p)%F <> 0 |- _ ] => - match goal with [Ha: a <> 0 |- _ ] => fail 1 | _ => idtac end; - let Y:=fresh in let Z:=fresh in try ( - assert (p <> 0%N) as Z by (intro Y; inversion Y); - assert (a <> 0) by (eapply Fq_root_nonzero; eauto using Fq_1_neq_0); - clear Z) - | [H: (?a*?b)%F <> 0 |- _ ] => - match goal with [Ha: a <> 0 |- _ ] => fail 1 | _ => idtac end; - assert (a <> 0) by (eapply F_mul_nonzero_l; eauto using Fq_1_neq_0) - | [H: (?a*?b)%F <> 0 |- _ ] => - match goal with [Ha: b <> 0 |- _ ] => fail 1 | _ => idtac end; - assert (b <> 0) by (eapply F_mul_nonzero_r; eauto using Fq_1_neq_0) - end. + Ltac use_sqrt_a := destruct a_square as [sqrt_a a_square']; rewrite <-a_square' in *. Lemma edwardsAddComplete' x1 y1 x2 y2 : - onCurve (x1, y1) -> - onCurve (x2, y2) -> + onCurve (pair x1 y1) -> + onCurve (pair x2 y2) -> (d*x1*x2*y1*y2)^2 <> 1. Proof. - intros Hc1 Hc2 Hcontra; simpl in Hc1, Hc2; whatsNotZero. - - pose proof char_gt_2. pose proof a_nonzero as Ha_nonzero. - destruct a_square as [sqrt_a a_square']. - rewrite <-a_square' in *. - - (* Furthermore... *) - pose proof (eq_refl (d*x1^2*y1^2*(sqrt_a^2*x2^2 + y2^2))) as Heqt. - rewrite Hc2 in Heqt at 2. - replace (d * x1 ^ 2 * y1 ^ 2 * (1 + d * x2 ^ 2 * y2 ^ 2)) - with (d*x1^2*y1^2 + (d*x1*x2*y1*y2)^2) in Heqt by field. - rewrite Hcontra in Heqt. - replace (d * x1 ^ 2 * y1 ^ 2 + 1) with (1 + d * x1 ^ 2 * y1 ^ 2) in Heqt by field. - rewrite <-Hc1 in Heqt. - - (* main equation for both potentially nonzero denominators *) - destruct (F_eq_dec (sqrt_a*x2 + y2) 0); destruct (F_eq_dec (sqrt_a*x2 - y2) 0); - try lazymatch goal with [H: ?f (sqrt_a * x2) y2 <> 0 |- _ ] => - assert ((f (sqrt_a*x1) (d * x1 * x2 * y1 * y2*y1))^2 = - f ((sqrt_a^2)*x1^2 + (d * x1 * x2 * y1 * y2)^2*y1^2) - (d * x1 * x2 * y1 * y2*sqrt_a*(ZToField 2)*x1*y1)) as Heqw1 by field; - rewrite Hcontra in Heqw1; - replace (1 * y1^2) with (y1^2) in * by field; - rewrite <- Heqt in *; - assert (d = (f (sqrt_a * x1) (d * x1 * x2 * y1 * y2 * y1))^2 / - (x1 * y1 * (f (sqrt_a * x2) y2))^2) - by (rewriteAny; field; auto); - match goal with [H: d = (?n^2)/(?l^2) |- _ ] => - destruct (d_nonsquare (n/l)); (remember n; rewriteAny; field; auto) - end - end. - - assert (Hc: (sqrt_a * x2 + y2) + (sqrt_a * x2 - y2) = 0) by (repeat rewriteAny; field). - - replace (sqrt_a * x2 + y2 + (sqrt_a * x2 - y2)) with (ZToField 2 * sqrt_a * x2) in Hc by field. - - (* contradiction: product of nonzero things is zero *) - destruct (Fq_mul_zero_why _ _ Hc) as [Hcc|Hcc]; subst; intuition. - destruct (Fq_mul_zero_why _ _ Hcc) as [Hccc|Hccc]; subst; intuition. - apply Ha_nonzero; field. + unfold onCurve, not; use_sqrt_a; intros. + destruct (eq_dec (sqrt_a*x2 + y2) 0); destruct (eq_dec (sqrt_a*x2 - y2) 0); + lazymatch goal with + | [H: not (eq (?f (sqrt_a * x2) y2) 0) |- _ ] + => apply d_nonsquare with (sqrt_d:= (f (sqrt_a * x1) (d * x1 * x2 * y1 * y2 * y1)) + /(f (sqrt_a * x2) y2 * x1 * y1 )) + | _ => apply a_nonzero + end; field_algebra; auto using Ring.opp_nonzero_nonzero; intro; nsatz_contradict. Qed. Lemma edwardsAddCompletePlus x1 y1 x2 y2 : - onCurve (x1, y1) -> - onCurve (x2, y2) -> - (1 + d*x1*x2*y1*y2) <> 0. - Proof. - intros Hc1 Hc2; simpl in Hc1, Hc2. - intros; destruct (F_eq_dec (d*x1*x2*y1*y2) (0-1)) as [H|H]. - - assert ((d*x1*x2*y1*y2)^2 = 1) by (rewriteAny; field). destruct (edwardsAddComplete' x1 y1 x2 y2); auto. - - replace (d * x1 * x2 * y1 * y2) with (1+d * x1 * x2 * y1 * y2-1) in H by field. - intro Hz; rewrite Hz in H; intuition. - Qed. - + onCurve (x1, y1) -> onCurve (x2, y2) -> (1 + d*x1*x2*y1*y2) <> 0. + Proof. intros H1 H2 ?. apply (edwardsAddComplete' _ _ _ _ H1 H2); field_algebra. Qed. + Lemma edwardsAddCompleteMinus x1 y1 x2 y2 : - onCurve (x1, y1) -> - onCurve (x2, y2) -> - (1 - d*x1*x2*y1*y2) <> 0. - Proof. - intros Hc1 Hc2. destruct (F_eq_dec (d*x1*x2*y1*y2) 1) as [H|H]. - - assert ((d*x1*x2*y1*y2)^2 = 1) by (rewriteAny; field). destruct (edwardsAddComplete' x1 y1 x2 y2); auto. - - replace (d * x1 * x2 * y1 * y2) with ((1-(1-d * x1 * x2 * y1 * y2))) in H by field. - intro Hz; rewrite Hz in H; apply H; field. - Qed. - - Definition zeroOnCurve : onCurve (0, 1). - simpl. field. - Qed. - - Lemma unifiedAdd'_onCurve' x1 y1 x2 y2 x3 y3 - (H: (x3, y3) = unifiedAdd' (x1, y1) (x2, y2)) : - onCurve (x1, y1) -> onCurve (x2, y2) -> onCurve (x3, y3). + onCurve (x1, y1) -> onCurve (x2, y2) -> (1 - d*x1*x2*y1*y2) <> 0. + Proof. intros H1 H2 ?. apply (edwardsAddComplete' _ _ _ _ H1 H2); field_algebra. Qed. + + Lemma zeroOnCurve : onCurve (0, 1). Proof. simpl. field_algebra. Qed. + + Lemma unifiedAdd'_onCurve : forall P1 P2, + onCurve P1 -> onCurve P2 -> onCurve (unifiedAdd' P1 P2). Proof. - (* https://eprint.iacr.org/2007/286.pdf Theorem 3.1; - * c=1 and an extra a in front of x^2 *) - - injection H; clear H; intros. - - Ltac t x1 y1 x2 y2 := - assert ((a*x2^2 + y2^2)*d*x1^2*y1^2 - = (1 + d*x2^2*y2^2) * d*x1^2*y1^2) by (rewriteAny; auto); - assert (a*x1^2 + y1^2 - (a*x2^2 + y2^2)*d*x1^2*y1^2 - = 1 - d^2*x1^2*x2^2*y1^2*y2^2) by (repeat rewriteAny; field). - t x1 y1 x2 y2; t x2 y2 x1 y1. - - remember ((a*x1^2 + y1^2 - (a*x2^2+y2^2)*d*x1^2*y1^2)*(a*x2^2 + y2^2 - - (a*x1^2 + y1^2)*d*x2^2*y2^2)) as T. - assert (HT1: T = (1 - d^2*x1^2*x2^2*y1^2*y2^2)^2) by (repeat rewriteAny; field). - assert (HT2: T = (a * ((x1 * y2 + y1 * x2) * (1 - d * x1 * x2 * y1 * y2)) ^ 2 +( - (y1 * y2 - a * x1 * x2) * (1 + d * x1 * x2 * y1 * y2)) ^ 2 -d * ((x1 * - y2 + y1 * x2)* (y1 * y2 - a * x1 * x2))^2)) by (subst; field). - replace (1:F q) with (a*x3^2 + y3^2 -d*x3^2*y3^2); [field|]; subst x3 y3. - - match goal with [ |- ?x = 1 ] => replace x with - ((a * ((x1 * y2 + y1 * x2) * (1 - d * x1 * x2 * y1 * y2)) ^ 2 + - ((y1 * y2 - a * x1 * x2) * (1 + d * x1 * x2 * y1 * y2)) ^ 2 - - d*((x1 * y2 + y1 * x2) * (y1 * y2 - a * x1 * x2)) ^ 2)/ - ((1-d^2*x1^2*x2^2*y1^2*y2^2)^2)) end. - - rewrite <-HT1, <-HT2; field; rewrite HT1. - replace ((1 - d ^ 2 * x1 ^ 2 * x2 ^ 2 * y1 ^ 2 * y2 ^ 2)) - with ((1 - d*x1*x2*y1*y2)*(1 + d*x1*x2*y1*y2)) by field. - auto using Fq_pow_nonzero, Fq_mul_nonzero_nonzero, - edwardsAddCompleteMinus, edwardsAddCompletePlus. - - field; replace (1 - (d * x1 * x2 * y1 * y2) ^ 2) - with ((1 - d*x1*x2*y1*y2)*(1 + d*x1*x2*y1*y2)) - by field; - auto using Fq_pow_nonzero, Fq_mul_nonzero_nonzero, - edwardsAddCompleteMinus, edwardsAddCompletePlus. + unfold onCurve, unifiedAdd'; intros [x1 y1] [x2 y2] H1 H2. + field_algebra; auto using edwardsAddCompleteMinus, edwardsAddCompletePlus. Qed. - - Lemma unifiedAdd'_onCurve : forall P1 P2, onCurve P1 -> onCurve P2 -> - onCurve (unifiedAdd' P1 P2). +End Pre. + +Import Group Ring Field. + +(* TODO: move -- this does not need to be defined before [point] *) +Section RespectsFieldHomomorphism. + Context {F EQ ZERO ONE OPP ADD MUL SUB INV DIV} `{@field F EQ ZERO ONE OPP ADD SUB MUL INV DIV}. + Context {K eq zero one opp add mul sub inv div} `{@field K eq zero one opp add sub mul inv div}. + Local Infix "=" := eq. Local Infix "=" := eq : type_scope. + Context {phi:F->K} `{@is_homomorphism F EQ ONE ADD MUL K eq one add mul phi}. + Context {A D:F} {a d:K} {a_ok:phi A=a} {d_ok:phi D=d}. + + Let phip := fun (P':F*F) => let (x, y) := P' in (phi x, phi y). + + Let eqp := fun (P1' P2':K*K) => + let (x1, y1) := P1' in + let (x2, y2) := P2' in + and (eq x1 x2) (eq y1 y2). + + Create HintDb field_homomorphism discriminated. + Hint Rewrite + homomorphism_one + homomorphism_add + homomorphism_sub + homomorphism_mul + homomorphism_div + a_ok + d_ok + : field_homomorphism. + + Lemma morphism_unidiedAdd' : forall P Q:F*F, + eqp + (phip (unifiedAdd'(F:=F)(one:=ONE)(add:=ADD)(sub:=SUB)(mul:=MUL)(div:=DIV)(a:=A)(d:=D) P Q)) + (unifiedAdd'(F:=K)(one:=one)(add:=add)(sub:=sub)(mul:=mul)(div:=div)(a:=a)(d:=d) (phip P) (phip Q)). Proof. - intros; destruct P1, P2. - remember (unifiedAdd' (f, f0) (f1, f2)) as r; destruct r. - eapply unifiedAdd'_onCurve'; eauto. + intros [x1 y1] [x2 y2]. + cbv [unifiedAdd' phip eqp]; + apply conj; + (rewrite_strat topdown hints field_homomorphism); reflexivity. Qed. -End Pre.
\ No newline at end of file +End RespectsFieldHomomorphism.
\ No newline at end of file diff --git a/src/EdDSAProofs.v b/src/EdDSAProofs.v deleted file mode 100644 index dba71b49c..000000000 --- a/src/EdDSAProofs.v +++ /dev/null @@ -1,78 +0,0 @@ -Require Import Crypto.Spec.EdDSA Crypto.Spec.Encoding. -Require Import Coq.Numbers.Natural.Peano.NPeano. -Require Import Bedrock.Word. -Require Import Coq.ZArith.Znumtheory Coq.ZArith.BinInt Coq.ZArith.ZArith. -Require Import Crypto.Spec.CompleteEdwardsCurve Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems Crypto.ModularArithmetic.ModularArithmeticTheorems. -Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. -Require Import Crypto.Tactics.VerdiTactics. -Local Open Scope nat_scope. - -Section EdDSAProofs. - Context {prm:EdDSAParams}. - Existing Instance E. - Existing Instance PointEncoding. - Existing Instance FqEncoding. - Existing Instance FlEncoding. - Existing Instance n_le_b. - Hint Rewrite sign_spec split1_combine split2_combine. - Hint Rewrite Nat.mod_mod using omega. - - Ltac arith' := intros; autorewrite with core; try (omega || congruence). - - Ltac arith := arith'; - repeat match goal with - | [ H : _ |- _ ] => rewrite H; arith' - end. - - (* for signature (R_, S_), R_ = encode_point (r * B) *) - Lemma decode_sign_split1 : forall A_ sk {n} (M : word n), - split1 b b (sign A_ sk M) = enc (wordToNat (H (prngKey sk ++ M)) * B)%E. - Proof. - unfold sign; arith. - Qed. - Hint Rewrite decode_sign_split1. - - (* for signature (R_, S_), S_ = encode_scalar (r + H(R_, A_, M)s) *) - Lemma decode_sign_split2 : forall sk {n} (M : word n), - split2 b b (sign (public sk) sk M) = - let r : nat := H (prngKey sk ++ M) in (* secret nonce *) - let R : E.point := (r * B)%E in (* commitment to nonce *) - let s : nat := curveKey sk in (* secret scalar *) - let S : F (Z.of_nat l) := ZToField (Z.of_nat (r + H (enc R ++ public sk ++ M) * s)) in - enc S. - Proof. - unfold sign; arith. - Qed. - Hint Rewrite decode_sign_split2. - - Hint Rewrite E.add_0_r E.add_0_l E.add_assoc. - Hint Rewrite E.mul_assoc E.mul_add_l E.mul_0_l E.mul_zero_r. - Hint Rewrite plus_O_n plus_Sn_m mult_0_l mult_succ_l. - Hint Rewrite l_order_B. - Lemma l_order_B' : forall x, (l * x * B = E.zero)%E. - Proof. - intros; rewrite Mult.mult_comm. rewrite <- E.mul_assoc. arith. - Qed. Hint Rewrite l_order_B'. - - Lemma scalarMult_mod_l : forall n0, (n0 mod l * B = n0 * B)%E. - Proof. - intros. - rewrite (div_mod n0 l) at 2 by (generalize l_odd; omega). - arith. - Qed. Hint Rewrite scalarMult_mod_l. - - Hint Rewrite @encoding_valid. - Hint Rewrite @FieldToZ_ZToField. - Hint Rewrite <-mod_Zmod. - Hint Rewrite Nat2Z.id. - - Lemma l_nonzero : l <> O. pose l_odd; omega. Qed. - Hint Resolve l_nonzero. - - Lemma verify_valid_passes : forall sk {n} (M : word n), - verify (public sk) M (sign (public sk) sk M) = true. - Proof. - unfold verify, sign, public; arith; try break_if; intuition. - Qed. -End EdDSAProofs. diff --git a/src/Encoding/EncodingTheorems.v b/src/Encoding/EncodingTheorems.v index 52ac91ada..c6f48a0ab 100644 --- a/src/Encoding/EncodingTheorems.v +++ b/src/Encoding/EncodingTheorems.v @@ -2,7 +2,7 @@ Require Import Crypto.Spec.Encoding. Section EncodingTheorems. Context {A B : Type} {E : canonical encoding of A as B}. - + Lemma encoding_inj : forall x y, enc x = enc y -> x = y. Proof. intros. diff --git a/src/Encoding/ModularWordEncodingTheorems.v b/src/Encoding/ModularWordEncodingTheorems.v index 7251ac1e6..41b75e216 100644 --- a/src/Encoding/ModularWordEncodingTheorems.v +++ b/src/Encoding/ModularWordEncodingTheorems.v @@ -24,7 +24,7 @@ Section SignBit. assert (m < 1)%Z by (apply Z2Nat.inj_lt; try omega; assumption). omega. + assert (0 < m)%Z as m_pos by (pose proof prime_ge_2 m prime_m; omega). - pose proof (FieldToZ_range x m_pos). + pose proof (FieldToZ_range x m_pos). destruct (FieldToZ x); auto. - destruct p; auto. - pose proof (Pos2Z.neg_is_neg p); omega. diff --git a/src/Encoding/PointEncodingPre.v b/src/Encoding/PointEncodingPre.v deleted file mode 100644 index 73ced869b..000000000 --- a/src/Encoding/PointEncodingPre.v +++ /dev/null @@ -1,275 +0,0 @@ -Require Import Coq.ZArith.ZArith Coq.ZArith.Znumtheory. -Require Import Coq.Numbers.Natural.Peano.NPeano. -Require Import Coq.Program.Equality. -Require Import Crypto.Encoding.EncodingTheorems. -Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Bedrock.Word. -Require Import Crypto.Encoding.ModularWordEncodingTheorems. -Require Import Crypto.Tactics.VerdiTactics. -Require Import Crypto.Util.ZUtil. - -Require Import Crypto.Spec.Encoding Crypto.Spec.ModularWordEncoding Crypto.Spec.ModularArithmetic. - -Local Open Scope F_scope. - -Section PointEncoding. - Context {prm: TwistedEdwardsParams} {sz : nat} {sz_nonzero : (0 < sz)%nat} - {bound_check : (Z.to_nat q < 2 ^ sz)%nat} {q_5mod8 : (q mod 8 = 5)%Z} - {sqrt_minus1_valid : (@ZToField q 2 ^ Z.to_N (q / 4)) ^ 2 = opp 1} - {FqEncoding : canonical encoding of (F q) as (word sz)} - {sign_bit : F q -> bool} {sign_bit_zero : sign_bit 0 = false} - {sign_bit_opp : forall x, x <> 0 -> negb (sign_bit x) = sign_bit (opp x)}. - Existing Instance prime_q. - - Add Field Ffield : (@Ffield_theory q _) - (morphism (@Fring_morph q), - preprocess [Fpreprocess], - postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption], - constants [Fconstant], - div (@Fmorph_div_theory q), - power_tac (@Fpower_theory q) [Fexp_tac]). - - Definition sqrt_valid (a : F q) := ((sqrt_mod_q a) ^ 2 = a)%F. - - Lemma solve_sqrt_valid : forall p, E.onCurve p -> - sqrt_valid (E.solve_for_x2 (snd p)). - Proof. - intros ? onCurve_xy. - destruct p as [x y]; simpl. - rewrite (E.solve_correct x y) in onCurve_xy. - rewrite <- onCurve_xy. - unfold sqrt_valid. - eapply sqrt_mod_q_valid; eauto. - unfold isSquare; eauto. - Grab Existential Variables. eauto. - Qed. - - Lemma solve_onCurve: forall (y : F q), sqrt_valid (E.solve_for_x2 y) -> - E.onCurve (sqrt_mod_q (E.solve_for_x2 y), y). - Proof. - intros. - unfold sqrt_valid in *. - apply E.solve_correct; auto. - Qed. - - Lemma solve_opp_onCurve: forall (y : F q), sqrt_valid (E.solve_for_x2 y) -> - E.onCurve (opp (sqrt_mod_q (E.solve_for_x2 y)), y). - Proof. - intros y sqrt_valid_x2. - unfold sqrt_valid in *. - apply E.solve_correct. - rewrite <- sqrt_valid_x2 at 2. - ring. - Qed. - - Definition point_enc_coordinates (p : (F q * F q)) : Word.word (S sz) := let '(x,y) := p in - Word.WS (sign_bit x) (enc y). - - Let point_enc (p : E.point) : Word.word (S sz) := let '(x,y) := proj1_sig p in - Word.WS (sign_bit x) (enc y). - - Definition point_dec_coordinates (sign_bit : F q -> bool) (w : Word.word (S sz)) : option (F q * F q) := - match dec (Word.wtl w) with - | None => None - | Some y => let x2 := E.solve_for_x2 y in - let x := sqrt_mod_q x2 in - if F_eq_dec (x ^ 2) x2 - then - let p := (if Bool.eqb (whd w) (sign_bit x) then x else opp x, y) in - if (andb (F_eqb x 0) (whd w)) - then None (* special case for 0, since its opposite has the same sign; if the sign bit of 0 is 1, produce None.*) - else Some p - else None - end. - - Ltac inversion_Some_eq := match goal with [H: Some ?x = Some ?y |- _] => inversion H; subst end. - - Lemma point_dec_coordinates_onCurve : forall w p, point_dec_coordinates sign_bit w = Some p -> E.onCurve p. - Proof. - unfold point_dec_coordinates; intros. - edestruct dec; [ | congruence]. - break_if; [ | congruence]. - break_if; [ congruence | ]. - break_if; inversion_Some_eq; auto using solve_onCurve, solve_opp_onCurve. - Qed. - - Lemma prod_eq_dec : forall {A} (A_eq_dec : forall a a' : A, {a = a'} + {a <> a'}) - (x y : (A * A)), {x = y} + {x <> y}. - Proof. - decide equality. - Qed. - - Lemma option_eq_dec : forall {A} (A_eq_dec : forall a a' : A, {a = a'} + {a <> a'}) - (x y : option A), {x = y} + {x <> y}. - Proof. - decide equality. - Qed. - - Definition point_dec' w p : option E.point := - match (option_eq_dec (prod_eq_dec F_eq_dec) (point_dec_coordinates sign_bit w) (Some p)) with - | left EQ => Some (exist _ p (point_dec_coordinates_onCurve w p EQ)) - | right _ => None (* this case is never reached *) - end. - - Definition point_dec (w : word (S sz)) : option E.point := - match (point_dec_coordinates sign_bit w) with - | Some p => point_dec' w p - | None => None - end. - - Lemma point_coordinates_encoding_canonical : forall w p, - point_dec_coordinates sign_bit w = Some p -> point_enc_coordinates p = w. - Proof. - unfold point_dec_coordinates, point_enc_coordinates; intros ? ? coord_dec_Some. - case_eq (dec (wtl w)); [ intros ? dec_Some | intros dec_None; rewrite dec_None in *; congruence ]. - destruct p. - rewrite (shatter_word w). - f_equal; rewrite dec_Some in *; - do 2 (break_if; try congruence); inversion coord_dec_Some; subst. - + destruct (F_eq_dec (sqrt_mod_q (E.solve_for_x2 f1)) 0%F) as [sqrt_0 | ?]. - - rewrite sqrt_0 in *. - apply sqrt_mod_q_root_0 in sqrt_0; try assumption. - rewrite sqrt_0 in *. - break_if; [symmetry; auto using Bool.eqb_prop | ]. - rewrite sign_bit_zero in *. - simpl in Heqb; rewrite Heqb in *. - discriminate. - - break_if. - symmetry; auto using Bool.eqb_prop. - rewrite <- sign_bit_opp by assumption. - destruct (whd w); inversion Heqb0; break_if; auto. - + inversion coord_dec_Some; subst. - auto using encoding_canonical. -Qed. - - Lemma point_encoding_canonical : forall w x, point_dec w = Some x -> point_enc x = w. - Proof. - (* - unfold point_enc; intros. - unfold point_dec in *. - assert (point_dec_coordinates w = Some (proj1_sig x)). { - set (y := point_dec_coordinates w) in *. - revert H. - dependent destruction y. intros. - rewrite H0 in H. - *) - Admitted. - -Lemma point_dec_coordinates_correct w - : option_map (@proj1_sig _ _) (point_dec w) = point_dec_coordinates sign_bit w. -Proof. - unfold point_dec, option_map. - do 2 break_match; try congruence; unfold point_dec' in *; - break_match; try congruence. - inversion_Some_eq. - reflexivity. -Qed. - -Lemma y_decode : forall p, dec (wtl (point_enc_coordinates p)) = Some (snd p). -Proof. - intros. - destruct p as [x y]; simpl. - exact (encoding_valid y). -Qed. - -Lemma sign_bit_opp_eq_iff : forall x y, y <> 0 -> - (sign_bit x <> sign_bit y <-> sign_bit x = sign_bit (opp y)). -Proof. - split; intro sign_mismatch; case_eq (sign_bit x); case_eq (sign_bit y); - try congruence; intros y_sign x_sign; rewrite <- sign_bit_opp in * by auto; - rewrite y_sign, x_sign in *; reflexivity || discriminate. -Qed. - -Lemma sign_bit_squares : forall x y, y <> 0 -> x ^ 2 = y ^ 2 -> - sign_bit x = sign_bit y -> x = y. -Proof. - intros ? ? y_nonzero squares_eq sign_match. - destruct (sqrt_solutions _ _ squares_eq) as [? | eq_opp]; auto. - assert (sign_bit x = sign_bit (opp y)) as sign_mismatch by (f_equal; auto). - apply sign_bit_opp_eq_iff in sign_mismatch; auto. - congruence. -Qed. - -Lemma sign_bit_match : forall x x' y : F q, E.onCurve (x, y) -> E.onCurve (x', y) -> - sign_bit x = sign_bit x' -> x = x'. -Proof. - intros ? ? ? onCurve_x onCurve_x' sign_match. - apply E.solve_correct in onCurve_x. - apply E.solve_correct in onCurve_x'. - destruct (F_eq_dec x' 0). - + subst. - rewrite Fq_pow_zero in onCurve_x' by congruence. - rewrite <- onCurve_x' in *. - eapply Fq_root_zero; eauto. - + apply sign_bit_squares; auto. - rewrite onCurve_x, onCurve_x'. - reflexivity. -Qed. - -Lemma point_encoding_coordinates_valid : forall p, E.onCurve p -> - point_dec_coordinates sign_bit (point_enc_coordinates p) = Some p. -Proof. - intros p onCurve_p. - unfold point_dec_coordinates. - rewrite y_decode. - pose proof (solve_sqrt_valid p onCurve_p) as solve_sqrt_valid_p. - destruct p as [x y]. - unfold sqrt_valid in *. - simpl. - replace (E.solve_for_x2 y) with (x ^ 2 : F q) in * by (apply E.solve_correct; assumption). - case_eq (F_eqb x 0); intro eqb_x_0. - + apply F_eqb_eq in eqb_x_0; rewrite eqb_x_0 in *. - rewrite !Fq_pow_zero, sqrt_mod_q_of_0, Fq_pow_zero by congruence. - rewrite if_F_eq_dec_if_F_eqb, sign_bit_zero. - reflexivity. - + assert (sqrt_mod_q (x ^ 2) <> 0) by (intro false_eq; apply sqrt_mod_q_root_0 in false_eq; try assumption; - apply Fq_root_zero in false_eq; rewrite false_eq, F_eqb_refl in eqb_x_0; congruence). - replace (F_eqb (sqrt_mod_q (x ^ 2)) 0) with false by (symmetry; - apply F_eqb_neq_complete; assumption). - break_if. - - simpl. - f_equal. - break_if. - * rewrite Bool.eqb_true_iff in Heqb. - pose proof (solve_onCurve y solve_sqrt_valid_p). - f_equal. - apply (sign_bit_match _ _ y); auto. - apply E.solve_correct in onCurve_p; rewrite onCurve_p in *. - assumption. - * rewrite Bool.eqb_false_iff in Heqb. - pose proof (solve_opp_onCurve y solve_sqrt_valid_p). - f_equal. - apply sign_bit_opp_eq_iff in Heqb; try assumption. - apply (sign_bit_match _ _ y); auto. - apply E.solve_correct in onCurve_p. - rewrite onCurve_p; auto. - - simpl in solve_sqrt_valid_p. - replace (E.solve_for_x2 y) with (x ^ 2 : F q) in * by (apply E.solve_correct; assumption). - congruence. -Qed. - -Lemma point_dec'_valid : forall p, - point_dec' (point_enc_coordinates (proj1_sig p)) (proj1_sig p) = Some p. -Proof. - unfold point_dec'; intros. - break_match. - + f_equal. - destruct p. - apply E.point_eq. - reflexivity. - + rewrite point_encoding_coordinates_valid in n by apply (proj2_sig p). - congruence. -Qed. - -Lemma point_encoding_valid : forall p, point_dec (point_enc p) = Some p. -Proof. - intros. - unfold point_dec. - replace (point_enc p) with (point_enc_coordinates (proj1_sig p)) by reflexivity. - break_match; rewrite point_encoding_coordinates_valid in * by apply (proj2_sig p); try congruence. - inversion_Some_eq. - eapply point_dec'_valid. -Qed. - -End PointEncoding. diff --git a/src/Encoding/PointEncodingTheorems.v b/src/Encoding/PointEncodingTheorems.v deleted file mode 100644 index ccea1d81b..000000000 --- a/src/Encoding/PointEncodingTheorems.v +++ /dev/null @@ -1,207 +0,0 @@ -Require Import Coq.ZArith.ZArith Coq.ZArith.Znumtheory. -Require Import Coq.Numbers.Natural.Peano.NPeano. -Require Import Coq.Program.Equality. -Require Import Crypto.Encoding.EncodingTheorems. -Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Bedrock.Word. -Require Import Crypto.Tactics.VerdiTactics. - -Require Import Crypto.Spec.Encoding Crypto.Spec.ModularArithmetic Crypto.Spec.CompleteEdwardsCurve. - -Local Open Scope F_scope. - -Section PointEncoding. - Context {prm: CompleteEdwardsCurve.TwistedEdwardsParams} {sz : nat} - {FqEncoding : canonical encoding of ModularArithmetic.F (CompleteEdwardsCurve.q) as Word.word sz} - {q_5mod8 : (CompleteEdwardsCurve.q mod 8 = 5)%Z} - {sqrt_minus1_valid : (@ZToField CompleteEdwardsCurve.q 2 ^ BinInt.Z.to_N (CompleteEdwardsCurve.q / 4)) ^ 2 = opp 1}. - Existing Instance CompleteEdwardsCurve.prime_q. - - Add Field Ffield : (@PrimeFieldTheorems.Ffield_theory CompleteEdwardsCurve.q _) - (morphism (@ModularArithmeticTheorems.Fring_morph CompleteEdwardsCurve.q), - preprocess [ModularArithmeticTheorems.Fpreprocess], - postprocess [ModularArithmeticTheorems.Fpostprocess; try exact PrimeFieldTheorems.Fq_1_neq_0; try assumption], - constants [ModularArithmeticTheorems.Fconstant], - div (@ModularArithmeticTheorems.Fmorph_div_theory CompleteEdwardsCurve.q), - power_tac (@ModularArithmeticTheorems.Fpower_theory CompleteEdwardsCurve.q) [ModularArithmeticTheorems.Fexp_tac]). - - Definition sqrt_valid (a : F q) := ((sqrt_mod_q a) ^ 2 = a)%F. - - Lemma solve_sqrt_valid : forall (p : E.point), - sqrt_valid (E.solve_for_x2 (snd (proj1_sig p))). - Proof. - intros. - destruct p as [[x y] onCurve_xy]; simpl. - rewrite (E.solve_correct x y) in onCurve_xy. - rewrite <- onCurve_xy. - unfold sqrt_valid. - eapply sqrt_mod_q_valid; eauto. - unfold isSquare; eauto. - Grab Existential Variables. eauto. - Qed. - - Lemma solve_onCurve: forall (y : F q), sqrt_valid (E.solve_for_x2 y) -> - E.onCurve (sqrt_mod_q (E.solve_for_x2 y), y). - Proof. - intros. - unfold sqrt_valid in *. - apply E.solve_correct; auto. - Qed. - - Lemma solve_opp_onCurve: forall (y : F q), sqrt_valid (E.solve_for_x2 y) -> - E.onCurve (opp (sqrt_mod_q (E.solve_for_x2 y)), y). - Proof. - intros y sqrt_valid_x2. - unfold sqrt_valid in *. - apply E.solve_correct. - rewrite <- sqrt_valid_x2 at 2. - ring. - Qed. - -Definition sign_bit (x : F q) := (wordToN (enc (opp x)) <? wordToN (enc x))%N. -Definition point_enc (p : E.point) : word (S sz) := let '(x,y) := proj1_sig p in - WS (sign_bit x) (enc y). -Definition point_dec_coordinates (w : word (S sz)) : option (F q * F q) := - match dec (wtl w) with - | None => None - | Some y => let x2 := E.solve_for_x2 y in - let x := sqrt_mod_q x2 in - if F_eq_dec (x ^ 2) x2 - then - let p := (if Bool.eqb (whd w) (sign_bit x) then x else opp x, y) in - Some p - else None - end. - -Definition point_dec (w : word (S sz)) : option E.point := - match dec (wtl w) with - | None => None - | Some y => let x2 := E.solve_for_x2 y in - let x := sqrt_mod_q x2 in - match (F_eq_dec (x ^ 2) x2) with - | right _ => None - | left EQ => if Bool.eqb (whd w) (sign_bit x) - then Some (exist _ (x, y) (solve_onCurve y EQ)) - else Some (exist _ (opp x, y) (solve_opp_onCurve y EQ)) - end - end. - -Lemma point_dec_coordinates_correct w - : option_map (@proj1_sig _ _) (point_dec w) = point_dec_coordinates w. -Proof. - unfold point_dec, point_dec_coordinates. - edestruct dec; [ | reflexivity ]. - edestruct @F_eq_dec; [ | reflexivity ]. - edestruct @Bool.eqb; reflexivity. -Qed. - -Lemma y_decode : forall p, dec (wtl (point_enc p)) = Some (snd (proj1_sig p)). -Proof. - intros. - destruct p as [[x y] onCurve_p]; simpl. - exact (encoding_valid y). -Qed. - - -Lemma wordToN_enc_neq_opp : forall x, x <> 0 -> (wordToN (enc (opp x)) <> wordToN (enc x))%N. -Proof. - intros x x_nonzero. - intro false_eq. - apply x_nonzero. - apply F_eq_opp_zero; try apply two_lt_q. - apply wordToN_inj in false_eq. - apply encoding_inj in false_eq. - auto. -Qed. - -Lemma sign_bit_opp_negb : forall x, x <> 0 -> negb (sign_bit x) = sign_bit (opp x). -Proof. - intros x x_nonzero. - unfold sign_bit. - rewrite <- N.leb_antisym. - rewrite N.ltb_compare, N.leb_compare. - rewrite F_opp_involutive. - case_eq (wordToN (enc x) ?= wordToN (enc (opp x)))%N; auto. - intro wordToN_enc_eq. - pose proof (wordToN_enc_neq_opp x x_nonzero). - apply N.compare_eq_iff in wordToN_enc_eq. - congruence. -Qed. - -Lemma sign_bit_opp : forall x y, y <> 0 -> - (sign_bit x <> sign_bit y <-> sign_bit x = sign_bit (opp y)). -Proof. - split; intro sign_mismatch; case_eq (sign_bit x); case_eq (sign_bit y); - try congruence; intros y_sign x_sign; rewrite <- sign_bit_opp_negb in * by auto; - rewrite y_sign, x_sign in *; reflexivity || discriminate. -Qed. - -Lemma sign_bit_squares : forall x y, y <> 0 -> x ^ 2 = y ^ 2 -> - sign_bit x = sign_bit y -> x = y. -Proof. - intros ? ? y_nonzero squares_eq sign_match. - destruct (sqrt_solutions _ _ squares_eq) as [? | eq_opp]; auto. - assert (sign_bit x = sign_bit (opp y)) as sign_mismatch by (f_equal; auto). - apply sign_bit_opp in sign_mismatch; auto. - congruence. -Qed. - -Lemma sign_bit_match : forall x x' y : F q, E.onCurve (x, y) -> E.onCurve (x', y) -> - sign_bit x = sign_bit x' -> x = x'. -Proof. - intros ? ? ? onCurve_x onCurve_x' sign_match. - apply E.solve_correct in onCurve_x. - apply E.solve_correct in onCurve_x'. - destruct (F_eq_dec x' 0). - + subst. - rewrite Fq_pow_zero in onCurve_x' by congruence. - rewrite <- onCurve_x' in *. - eapply Fq_root_zero; eauto. - + apply sign_bit_squares; auto. - rewrite onCurve_x, onCurve_x'. - reflexivity. -Qed. - -Lemma point_encoding_valid : forall p, point_dec (point_enc p) = Some p. -Proof. - intros. - unfold point_dec. - rewrite y_decode. - pose proof solve_sqrt_valid p as solve_sqrt_valid_p. - unfold sqrt_valid in *. - destruct p as [[x y] onCurve_p]. - simpl in *. - destruct (F_eq_dec ((sqrt_mod_q (E.solve_for_x2 y)) ^ 2) (E.solve_for_x2 y)); intuition. - break_if; f_equal; apply E.point_eq. - + rewrite Bool.eqb_true_iff in Heqb. - pose proof (solve_onCurve y solve_sqrt_valid_p). - f_equal. - apply (sign_bit_match _ _ y); auto. - + rewrite Bool.eqb_false_iff in Heqb. - pose proof (solve_opp_onCurve y solve_sqrt_valid_p). - f_equal. - apply sign_bit_opp in Heqb. - apply (sign_bit_match _ _ y); auto. - intro eq_zero. - apply E.solve_correct in onCurve_p. - rewrite eq_zero in *. - rewrite Fq_pow_zero in solve_sqrt_valid_p by congruence. - rewrite <- solve_sqrt_valid_p in onCurve_p. - apply Fq_root_zero in onCurve_p. - rewrite onCurve_p in Heqb; auto. -Qed. - -(* Waiting on canonicalization *) -Lemma point_encoding_canonical : forall (x_enc : word (S sz)) (x : E.point), -point_dec x_enc = Some x -> point_enc x = x_enc. -Admitted. - -Instance point_encoding : canonical encoding of E.point as (word (S sz)) := { - enc := point_enc; - dec := point_dec; - encoding_valid := point_encoding_valid; - encoding_canonical := point_encoding_canonical -}. - -End PointEncoding. diff --git a/src/Experiments/DerivationsOptionRectLetInEncoding.v b/src/Experiments/DerivationsOptionRectLetInEncoding.v new file mode 100644 index 000000000..e5b74085e --- /dev/null +++ b/src/Experiments/DerivationsOptionRectLetInEncoding.v @@ -0,0 +1,351 @@ +Require Import Coq.omega.Omega. +Require Import Bedrock.Word. +Require Import Crypto.Spec.EdDSA. +Require Import Crypto.Tactics.VerdiTactics. +Require Import BinNat BinInt NArith Crypto.Spec.ModularArithmetic. +Require Import ModularArithmetic.ModularArithmeticTheorems. +Require Import ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.Spec.CompleteEdwardsCurve. +Require Import Crypto.Spec.Encoding Crypto.Spec.ModularWordEncoding. +Require Import Crypto.CompleteEdwardsCurve.ExtendedCoordinates. +Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. +Require Import Crypto.Util.IterAssocOp Crypto.Util.WordUtil. +Require Import Coq.Setoids.Setoid Coq.Classes.Morphisms Coq.Classes.Equivalence. +Require Import Zdiv. +Require Import Crypto.Util.Tuple. +Local Open Scope equiv_scope. + +Generalizable All Variables. + + +Local Ltac set_evars := + repeat match goal with + | [ |- appcontext[?E] ] => is_evar E; let e := fresh "e" in set (e := E) + end. + +Local Ltac subst_evars := + repeat match goal with + | [ e := ?E |- _ ] => is_evar E; subst e + end. + +Definition path_sig {A P} {RA:relation A} {Rsig:relation (@sig A P)} + {HP:Proper (RA==>Basics.impl) P} + (H:forall (x y:A) (px:P x) (py:P y), RA x y -> Rsig (exist _ x px) (exist _ y py)) + (x : @sig A P) (y0:A) (pf : RA (proj1_sig x) y0) +: Rsig x (exist _ y0 (HP _ _ pf (proj2_sig x))). +Proof. destruct x. eapply H. assumption. Defined. + +Definition Let_In {A P} (x : A) (f : forall a : A, P a) : P x := let y := x in f y. +Global Instance Let_In_Proper_changebody {A P R} {Reflexive_R:@Reflexive P R} + : Proper (eq ==> pointwise_relation _ R ==> R) (@Let_In A (fun _ => P)). +Proof. + lazy; intros; try congruence. + subst; auto. +Qed. + +Lemma Let_In_Proper_changevalue {A B} RA {RB} (f:A->B) {Proper_f:Proper (RA==>RB) f} + : Proper (RA ==> RB) (fun x => Let_In x f). +Proof. intuition. Qed. + +Ltac fold_identity_lambdas := + repeat match goal with + | [ H: appcontext [fun x => ?f x] |- _ ] => change (fun x => f x) with f in * + | |- appcontext [fun x => ?f x] => change (fun x => f x) with f in * + end. + +Local Ltac replace_let_in_with_Let_In := + match goal with + | [ |- context G[let x := ?y in @?z x] ] + => let G' := context G[Let_In y z] in change G' + end. + +Local Ltac Let_In_app fn := + match goal with + | [ |- appcontext G[Let_In (fn ?x) ?f] ] + => change (Let_In (fn x) f) with (Let_In x (fun y => f (fn y))); cbv beta + end. + +Lemma if_map : forall {T U} (f:T->U) (b:bool) (x y:T), (if b then f x else f y) = f (if b then x else y). +Proof. + destruct b; trivial. +Qed. + +Lemma pull_Let_In {B C} (f : B -> C) A (v : A) (b : A -> B) + : Let_In v (fun v' => f (b v')) = f (Let_In v b). +Proof. + reflexivity. +Qed. + +Lemma Let_app_In {A B T} (g:A->B) (f:B->T) (x:A) : + @Let_In _ (fun _ => T) (g x) f = + @Let_In _ (fun _ => T) x (fun p => f (g x)). +Proof. reflexivity. Qed. + +Lemma Let_app_In' : forall {A B T} {R} {R_equiv:@Equivalence T R} + (g : A -> B) (f : B -> T) (x : A) + f' (f'_ok: forall z, f' z === f (g z)), + Let_In (g x) f === Let_In x f'. +Proof. intros; cbv [Let_In]; rewrite f'_ok; reflexivity. Qed. +Definition unfold_Let_In {A B} x (f:A->B) : Let_In x f = let y := x in f y := eq_refl. + +Lemma Let_app2_In {A B C D T} (g1:A->C) (g2:B->D) (f:C*D->T) (x:A) (y:B) : + @Let_In _ (fun _ => T) (g1 x, g2 y) f = + @Let_In _ (fun _ => T) (x, y) (fun p => f ((g1 (fst p), g2 (snd p)))). +Proof. reflexivity. Qed. + +Lemma funexp_proj {T T'} `{@Equivalence T' RT'} + (proj : T -> T') + (f : T -> T) + (f' : T' -> T') {Proper_f':Proper (RT'==>RT') f'} + (f_proj : forall a, proj (f a) === f' (proj a)) + x n + : proj (funexp f x n) === funexp f' (proj x) n. +Proof. + revert x; induction n as [|n IHn]; simpl; intros. + - reflexivity. + - rewrite f_proj. rewrite IHn. reflexivity. +Qed. + +Global Instance pair_Equivalence {A B} `{@Equivalence A RA} `{@Equivalence B RB} : @Equivalence (A*B) (fun x y => fst x = fst y /\ snd x === snd y). +Proof. + constructor; repeat intro; intuition; try congruence. + match goal with [H : _ |- _ ] => solve [rewrite H; auto] end. +Qed. + +Global Instance Proper_test_and_op {T scalar} `{Requiv:@Equivalence T RT} + {op:T->T->T} {Proper_op:Proper (RT==>RT==>RT) op} + {testbit:scalar->nat->bool} {s:scalar} {zero:T} : + let R := fun x y => fst x = fst y /\ snd x === snd y in + Proper (R==>R) (test_and_op op testbit s zero). +Proof. + unfold test_and_op; simpl; repeat intro; intuition; + repeat match goal with + | [ |- context[match ?n with _ => _ end] ] => destruct n eqn:?; simpl in *; subst; try discriminate; auto + | [ H: _ |- _ ] => setoid_rewrite H; reflexivity + end. +Qed. + +Lemma iter_op_proj {T T' S} `{T'Equiv:@Equivalence T' RT'} + (proj : T -> T') (op : T -> T -> T) (op' : T' -> T' -> T') {Proper_op':Proper (RT' ==> RT' ==> RT') op'} x y z + (testbit : S -> nat -> bool) (bound : nat) + (op_proj : forall a b, proj (op a b) === op' (proj a) (proj b)) + : proj (iter_op op x testbit y z bound) === iter_op op' (proj x) testbit y (proj z) bound. +Proof. + unfold iter_op. + lazymatch goal with + | [ |- ?proj (snd (funexp ?f ?x ?n)) === snd (funexp ?f' _ ?n) ] + => pose proof (fun pf x0 x1 => @funexp_proj _ _ _ _ (fun x' => (fst x', proj (snd x'))) f f' (Proper_test_and_op (Requiv:=T'Equiv)) pf (x0, x1)) as H'; + lazymatch type of H' with + | ?H'' -> _ => assert (H'') as pf; [clear H'|edestruct (H' pf); simpl in *; solve [eauto]] + end + end. + + intros [??]; simpl. + repeat match goal with + | [ |- context[match ?n with _ => _ end] ] => destruct n eqn:? + | _ => progress (unfold equiv; simpl) + | _ => progress (subst; intuition) + | _ => reflexivity + | _ => rewrite op_proj + end. +Qed. + +Global Instance option_rect_Proper_nd {A T} + : Proper ((pointwise_relation _ eq) ==> eq ==> eq ==> eq) (@option_rect A (fun _ => T)). +Proof. + intros ?? H ??? [|]??; subst; simpl; congruence. +Qed. + +Global Instance option_rect_Proper_nd' {A T} + : Proper ((pointwise_relation _ eq) ==> eq ==> forall_relation (fun _ => eq)) (@option_rect A (fun _ => T)). +Proof. + intros ?? H ??? [|]; subst; simpl; congruence. +Qed. + +Hint Extern 1 (Proper _ (@option_rect ?A (fun _ => ?T))) => exact (@option_rect_Proper_nd' A T) : typeclass_instances. + +Lemma option_rect_option_map : forall {A B C} (f:A->B) some none v, + option_rect (fun _ => C) (fun x => some (f x)) none v = option_rect (fun _ => C) some none (option_map f v). +Proof. + destruct v; reflexivity. +Qed. + +Lemma option_rect_function {A B C S' N' v} f + : f (option_rect (fun _ : option A => option B) S' N' v) + = option_rect (fun _ : option A => C) (fun x => f (S' x)) (f N') v. +Proof. destruct v; reflexivity. Qed. +Local Ltac commute_option_rect_Let_In := (* pull let binders out side of option_rect pattern matching *) + idtac; + lazymatch goal with + | [ |- ?LHS = option_rect ?P ?S ?N (Let_In ?x ?f) ] + => (* we want to just do a [change] here, but unification is stupid, so we have to tell it what to unfold in what order *) + cut (LHS = Let_In x (fun y => option_rect P S N (f y))); cbv beta; + [ set_evars; + let H := fresh in + intro H; + rewrite H; + clear; + abstract (cbv [Let_In]; reflexivity) + | ] + end. + +(** TODO: possibly move me, remove local *) +Local Ltac replace_option_match_with_option_rect := + idtac; + lazymatch goal with + | [ |- _ = ?RHS :> ?T ] + => lazymatch RHS with + | match ?a with None => ?N | Some x => @?S x end + => replace RHS with (option_rect (fun _ => T) S N a) by (destruct a; reflexivity) + end + end. +Local Ltac simpl_option_rect := (* deal with [option_rect _ _ _ None] and [option_rect _ _ _ (Some _)] *) + repeat match goal with + | [ |- context[option_rect ?P ?S ?N None] ] + => change (option_rect P S N None) with N + | [ |- context[option_rect ?P ?S ?N (Some ?x) ] ] + => change (option_rect P S N (Some x)) with (S x); cbv beta + end. + +Definition COMPILETIME {T} (x:T) : T := x. + +Lemma N_to_nat_le_mono : forall a b, (a <= b)%N -> (N.to_nat a <= N.to_nat b)%nat. +Proof. + intros. + pose proof (Nomega.Nlt_out a (N.succ b)). + rewrite N2Nat.inj_succ, N.lt_succ_r, <-NPeano.Nat.lt_succ_r in *; auto. +Qed. +Lemma N_size_nat_le_mono : forall a b, (a <= b)%N -> (N.size_nat a <= N.size_nat b)%nat. +Proof. + intros. + destruct (N.eq_dec a 0), (N.eq_dec b 0); try abstract (subst;rewrite ?N.le_0_r in *;subst;simpl;omega). + rewrite !Nsize_nat_equiv, !N.size_log2 by assumption. + edestruct N.succ_le_mono; eauto using N_to_nat_le_mono, N.log2_le_mono. +Qed. + +Lemma Z_to_N_Z_of_nat : forall n, Z.to_N (Z.of_nat n) = N.of_nat n. +Proof. induction n; auto. Qed. + +Lemma Z_of_nat_nonzero : forall m, m <> 0 -> (0 < Z.of_nat m)%Z. +Proof. intros. destruct m; [congruence|reflexivity]. Qed. + +Section with_unqualified_modulo. +Import NPeano Nat. +Local Infix "mod" := modulo : nat_scope. +Lemma N_of_nat_modulo : forall n m, m <> 0 -> N.of_nat (n mod m)%nat = (N.of_nat n mod N.of_nat m)%N. +Proof. + intros. + apply Znat.N2Z.inj_iff. + rewrite !Znat.nat_N_Z. + rewrite Zdiv.mod_Zmod by auto. + apply Znat.Z2N.inj_iff. + { apply Z.mod_pos_bound. apply Z_of_nat_nonzero. assumption. } + { apply Znat.N2Z.is_nonneg. } + rewrite Znat.Z2N.inj_mod by (auto using Znat.Nat2Z.is_nonneg, Z_of_nat_nonzero). + rewrite !Z_to_N_Z_of_nat, !Znat.N2Z.id; reflexivity. +Qed. +End with_unqualified_modulo. + +Lemma encoding_canonical' {T} {B} {encoding:canonical encoding of T as B} : + forall a b, enc a = enc b -> a = b. +Proof. + intros. + pose proof (f_equal dec H). + pose proof encoding_valid. + pose proof encoding_canonical. + congruence. +Qed. + +Lemma compare_encodings {T} {B} {encoding:canonical encoding of T as B} + (B_eqb:B->B->bool) (B_eqb_iff : forall a b:B, (B_eqb a b = true) <-> a = b) + : forall a b : T, (a = b) <-> (B_eqb (enc a) (enc b) = true). +Proof. + intros. + split; intro H. + { rewrite B_eqb_iff; congruence. } + { apply B_eqb_iff in H; eauto using encoding_canonical'. } +Qed. + +Lemma eqb_eq_dec' {T} (eqb:T->T->bool) (eqb_iff:forall a b, eqb a b = true <-> a = b) : + forall a b, if eqb a b then a = b else a <> b. +Proof. + intros. + case_eq (eqb a b); intros. + { eapply eqb_iff; trivial. } + { specialize (eqb_iff a b). rewrite H in eqb_iff. intuition. } +Qed. + +Definition eqb_eq_dec {T} (eqb:T->T->bool) (eqb_iff:forall a b, eqb a b = true <-> a = b) : + forall a b : T, {a=b}+{a<>b}. +Proof. + intros. + pose proof (eqb_eq_dec' eqb eqb_iff a b). + destruct (eqb a b); eauto. +Qed. + +Definition eqb_eq_dec_and_output {T} (eqb:T->T->bool) (eqb_iff:forall a b, eqb a b = true <-> a = b) : + forall a b : T, {a = b /\ eqb a b = true}+{a<>b /\ eqb a b = false}. +Proof. + intros. + pose proof (eqb_eq_dec' eqb eqb_iff a b). + destruct (eqb a b); eauto. +Qed. + +Lemma eqb_compare_encodings {T} {B} {encoding:canonical encoding of T as B} + (T_eqb:T->T->bool) (T_eqb_iff : forall a b:T, (T_eqb a b = true) <-> a = b) + (B_eqb:B->B->bool) (B_eqb_iff : forall a b:B, (B_eqb a b = true) <-> a = b) + : forall a b : T, T_eqb a b = B_eqb (enc a) (enc b). +Proof. + intros; + destruct (eqb_eq_dec_and_output T_eqb T_eqb_iff a b); + destruct (eqb_eq_dec_and_output B_eqb B_eqb_iff (enc a) (enc b)); + intuition; + try find_copy_apply_lem_hyp B_eqb_iff; + try find_copy_apply_lem_hyp T_eqb_iff; + try congruence. + apply (compare_encodings B_eqb B_eqb_iff) in H2; congruence. +Qed. + +Lemma decode_failed_neq_encoding {T B} (encoding_T_B:canonical encoding of T as B) (X:B) + (dec_failed:dec X = None) (a:T) : X <> enc a. +Proof. pose proof encoding_valid. congruence. Qed. +Lemma compare_without_decoding {T B} (encoding_T_B:canonical encoding of T as B) + (T_eqb:T->T->bool) (T_eqb_iff:forall a b, T_eqb a b = true <-> a = b) + (B_eqb:B->B->bool) (B_eqb_iff:forall a b, B_eqb a b = true <-> a = b) + (P_:B) (Q:T) : + option_rect (fun _ : option T => bool) + (fun P : T => T_eqb P Q) + false + (dec P_) + = B_eqb P_ (enc Q). +Proof. + destruct (dec P_) eqn:Hdec; simpl option_rect. + { apply encoding_canonical in Hdec; subst; auto using eqb_compare_encodings. } + { pose proof encoding_canonical. + pose proof encoding_valid. + pose proof eqb_compare_encodings. + eapply decode_failed_neq_encoding in Hdec. + destruct (B_eqb P_ (enc Q)) eqn:Heq; [rewrite B_eqb_iff in Heq; eauto | trivial]. } +Qed. + +Lemma unfoldDiv : forall {m} (x y:F m), (x/y = x * inv y)%F. Proof. unfold div. congruence. Qed. + +Definition FieldToN {m} (x:F m) := Z.to_N (FieldToZ x). +Lemma FieldToN_correct {m} (x:F m) : FieldToN (m:=m) x = Z.to_N (FieldToZ x). reflexivity. Qed. + +Definition natToField {m} x : F m := ZToField (Z.of_nat x). +Definition FieldToNat {m} (x:F m) : nat := Z.to_nat (FieldToZ x). + +Section with_unqualified_modulo2. +Import NPeano Nat. +Local Infix "mod" := modulo : nat_scope. +Lemma FieldToNat_natToField {m} : m <> 0 -> forall x, x mod m = FieldToNat (natToField (m:=Z.of_nat m) x). + unfold natToField, FieldToNat; intros. + rewrite (FieldToZ_ZToField), <-mod_Zmod, Nat2Z.id; trivial. +Qed. +End with_unqualified_modulo2. + +Lemma F_eqb_iff {q} : forall x y : F q, F_eqb x y = true <-> x = y. +Proof. + split; eauto using F_eqb_eq, F_eqb_complete. +Qed. diff --git a/src/Experiments/GenericFieldPow.v b/src/Experiments/GenericFieldPow.v new file mode 100644 index 000000000..33d524567 --- /dev/null +++ b/src/Experiments/GenericFieldPow.v @@ -0,0 +1,337 @@ +Require Import Coq.setoid_ring.Cring. +Require Import Coq.omega.Omega. +Generalizable All Variables. + + +(*TODO: move *) +Lemma Z_pos_pred_0 p : Z.pos p - 1 = 0 -> p=1%positive. +Proof. destruct p; simpl in *; try discriminate; auto. Qed. + +Lemma Z_neg_succ_neg : forall a b, (Z.neg a + 1)%Z = Z.neg b -> a = Pos.succ b. +Admitted. + +Lemma Z_pos_pred_pos : forall a b, (Z.pos a - 1)%Z = Z.pos b -> a = Pos.succ b. +Admitted. + +Lemma Z_pred_neg p : (Z.neg p - 1)%Z = Z.neg (Pos.succ p). +Admitted. + +(* teach nsatz to deal with the definition of power we are use *) +Instance reify_pow_pos (R:Type) `{Ring R} +e1 lvar t1 n +`{Ring (T:=R)} +{_:reify e1 lvar t1} +: reify (PEpow e1 (N.pos n)) lvar (pow_pos t1 n)|1. + +Class Field_ops (F:Type) + `{Ring_ops F} + {inv:F->F} := {}. + +Class Division (A B C:Type) := division : A -> B -> C. + +Local Notation "_/_" := division. +Local Notation "n / d" := (division n d). + +Module F. + + Definition div `{Field_ops F} n d := n * (inv d). + Global Instance div_notation `{Field_ops F} : @Division F F F := div. + + Class Field {F inv} `{FieldCring:Cring (R:=F)} {Fo:Field_ops F (inv:=inv)} := + { + field_inv_comp : Proper (_==_ ==> _==_) inv; + field_inv_def : forall x, (x == 0 -> False) -> inv x * x == 1; + field_one_neq_zero : not (1 == 0) + }. + Global Existing Instance field_inv_comp. + + Definition powZ `{Field_ops F} (x:F) (n:Z) := + match n with + | Z0 => 1 + | Zpos p => pow_pos x p + | Zneg p => inv (pow_pos x p) + end. + Global Instance power_field `{Field_ops F} : Power | 5 := { power := powZ }. + + Section FieldProofs. + Context `{Field F}. + + Definition unfold_div (x y:F) : x/y = x * inv y := eq_refl. + + Global Instance Proper_div : + Proper (_==_ ==> _==_ ==> _==_) div. + Proof. + unfold div; repeat intro. + repeat match goal with + | [H: _ == _ |- _ ] => rewrite H; clear H + end; reflexivity. + Qed. + + Global Instance Proper_pow_pos : Proper (_==_==>eq==>_==_) pow_pos. + Proof. + cut (forall n (y x : F), x == y -> pow_pos x n == pow_pos y n); + [repeat intro; subst; eauto|]. + induction n; simpl; intros; trivial; + repeat eapply ring_mult_comp; eauto. + Qed. + + Global Instance Propper_powZ : Proper (_==_==>eq==>_==_) powZ. + Proof. + repeat intro; subst; unfold powZ. + match goal with |- context[match ?x with _ => _ end] => destruct x end; + repeat (eapply Proper_pow_pos || f_equiv; trivial). + Qed. + + Require Import Coq.setoid_ring.Field_theory Coq.setoid_ring.Field_tac. + Lemma field_theory_for_tactic : field_theory 0 1 _+_ _*_ _-_ -_ _/_ inv _==_. + Proof. + split; repeat constructor; repeat intro; gen_rewrite; try cring; + eauto using field_one_neq_zero, field_inv_def. Qed. + + Require Import Coq.setoid_ring.Ring_theory Coq.setoid_ring.NArithRing. + Lemma power_theory_for_tactic : power_theory 1 _*_ _==_ NtoZ power. + Proof. constructor; destruct n; reflexivity. Qed. + + Create HintDb field_nonzero discriminated. + Hint Resolve field_one_neq_zero : field_nonzero. + Ltac field_nonzero := repeat split; auto 3 with field_nonzero. + Ltac field_power_isconst t := Ncst t. + Add Field FieldProofsAddField : field_theory_for_tactic + (postprocess [field_nonzero], + power_tac power_theory_for_tactic [field_power_isconst]). + + Lemma div_mul_idemp_l : forall a b, (a==0 -> False) -> a*b/a == b. + Proof. intros. field. Qed. + + Context {eq_dec:forall x y : F, {x==y}+{x==y->False}}. + Lemma mul_zero_why : forall a b, a*b == 0 -> a == 0 \/ b == 0. + intros; destruct (eq_dec a 0); intuition. + assert (a * b / a == 0) by + (match goal with [H: _ == _ |- _ ] => rewrite H; field end). + rewrite div_mul_idemp_l in *; auto. + Qed. + + Require Import Coq.nsatz.Nsatz. + Global Instance Integral_domain_Field : Integral_domain (R:=F). + Proof. + constructor; intros; eauto using mul_zero_why, field_one_neq_zero. + Qed. + + Tactic Notation (at level 0) "field_simplify_eq" "in" hyp(H) := + let t := type of H in + generalize H; + field_lookup (PackField FIELD_SIMPL_EQ) [] t; + try (exact I); + try (idtac; []; clear H;intro H). + + Require Import Util.Tactics. + Inductive field_simplify_done {x y:F} : (x==y) -> Type := + Field_simplify_done : forall (H:x==y), field_simplify_done H. + Ltac field_nsatz := + repeat match goal with + [ H: (_:F) == _ |- _ ] => + match goal with + | [ Ha : field_simplify_done H |- _ ] => fail + | _ => idtac + end; + field_simplify_eq in H; + unique pose proof (Field_simplify_done H) + end; + repeat match goal with [ H: field_simplify_done _ |- _] => clear H end; + try field_simplify_eq; + try nsatz. + + Create HintDb field discriminated. + Hint Extern 5 (_ == _) => field_nsatz : field. + Hint Extern 5 (_ <-> _) => split. + + Lemma mul_inv_l : forall x, not (x == 0) -> inv x * x == 1. Proof. auto with field. Qed. + + Lemma mul_inv_r : forall x, not (x == 0) -> x * inv x == 1. Proof. auto with field. Qed. + + Lemma mul_cancel_r' (x y z:F) : not (z == 0) -> x * z == y * z -> x == y. + Proof. + intros. + assert (x * z * inv z == y * z * inv z) by + (match goal with [H: _ == _ |- _ ] => rewrite H; auto with field end). + assert (x * z * inv z == x * (z * inv z)) by auto with field. + assert (y * z * inv z == y * (z * inv z)) by auto with field. + rewrite mul_inv_r, @ring_mul_1_r in *; auto with field. + Qed. + + Lemma mul_cancel_r (x y z:F) : not (z == 0) -> (x * z == y * z <-> x == y). + Proof. intros;split;intros Heq; try eapply mul_cancel_r' in Heq; eauto with field. Qed. + + Lemma mul_cancel_l (x y z:F) : not (z == 0) -> (z * x == z * y <-> x == y). + Proof. intros;split;intros; try eapply mul_cancel_r; eauto with field. Qed. + + Lemma mul_cancel_r_eq : forall x z:F, not(z==0) -> (x*z == z <-> x == 1). + Proof. + intros;split;intros Heq; [|nsatz]. + pose proof ring_mul_1_l z as Hz; rewrite <- Hz in Heq at 2; rewrite mul_cancel_r in Heq; eauto. + Qed. + + Lemma mul_cancel_l_eq : forall x z:F, not(z==0) -> (z*x == z <-> x == 1). + Proof. intros;split;intros Heq; try eapply mul_cancel_r_eq; eauto with field. Qed. + + Lemma inv_unique (a:F) : forall x y, x * a == 1 -> y * a == 1 -> x == y. Proof. auto with field. Qed. + + Lemma mul_nonzero_nonzero (a b:F) : not (a == 0) -> not (b == 0) -> not (a*b == 0). + Proof. intros; intro Hab. destruct (mul_zero_why _ _ Hab); auto. Qed. + Hint Resolve mul_nonzero_nonzero : field_nonzero. + + Lemma inv_nonzero (x:F) : not(x == 0) -> not(inv x==0). + Proof. + intros Hx Hi. + assert (Hc:not (inv x*x==0)) by (rewrite field_inv_def; eauto with field_nonzero); contradict Hc. + ring [Hi]. + Qed. + Hint Resolve inv_nonzero : field_nonzero. + + Lemma div_nonzero (x y:F) : not(x==0) -> not(y==0) -> not(x/y==0). + Proof. + unfold division, div_notation, div; auto with field_nonzero. + Qed. + Hint Resolve div_nonzero : field_nonzero. + + Lemma pow_pos_nonzero (x:F) p : not(x==0) -> not(Ncring.pow_pos x p == 0). + Proof. + intros; induction p using Pos.peano_ind; try assumption; []. + rewrite Ncring.pow_pos_succ; eauto using mul_nonzero_nonzero. + Qed. + Hint Resolve pow_pos_nonzero : field_nonzero. + + Lemma sub_diag_iff (x y:F) : x - y == 0 <-> x == y. Proof. auto with field. Qed. + + Lemma mul_same (x:F) : x*x == x^2%Z. Proof. auto with field. Qed. + + Lemma inv_mul (x y:F) : not(x==0) -> not (y==0) -> inv (x*y) == inv x * inv y. + Proof. intros;field;intuition. Qed. + + Lemma pow_0_r (x:F) : x^0 == 1. Proof. auto with field. Qed. + Lemma pow_1_r : forall x:F, x^1%Z == x. Proof. auto with field. Qed. + Lemma pow_2_r : forall x:F, x^2%Z == x*x. Proof. auto with field. Qed. + Lemma pow_3_r : forall x:F, x^3%Z == x*x*x. Proof. auto with field. Qed. + + Lemma pow_succ_r (x:F) (n:Z) : not (x==0)\/(n>=0)%Z -> x^(n+1) == x * x^n. + Proof. + intros Hnz; unfold power, powZ, power_field, powZ; destruct n eqn:HSn. + - simpl; ring. + - setoid_rewrite <-Pos2Z.inj_succ; rewrite Ncring.pow_pos_succ; ring. + - destruct (Z.succ (Z.neg p)) eqn:Hn. + + assert (p=1%positive) by (destruct p; simpl in *; try discriminate; auto). + subst; simpl in *; field. destruct Hnz; auto with field_nonzero. + + destruct p, p0; discriminate. + + setoid_rewrite Hn. + apply Z_neg_succ_neg in Hn; subst. + rewrite Ncring.pow_pos_succ; field; + destruct Hnz; auto with field_nonzero. + Qed. + + Lemma pow_pred_r (x:F) (n:Z) : not (x==0) -> x^(n-1) == x^n/x. + Proof. + intros; unfold power, powZ, power_field, powZ; destruct n eqn:HSn. + - simpl. rewrite unfold_div; field. + - destruct (Z.pos p - 1) eqn:Hn. + + apply Z_pos_pred_0 in Hn; subst; simpl; field. + + apply Z_pos_pred_pos in Hn; subst. + rewrite Ncring.pow_pos_succ; field; auto with field_nonzero. + + destruct p; discriminate. + - rewrite Z_pred_neg, Ncring.pow_pos_succ; field; auto with field_nonzero. + Qed. + + Local Ltac pow_peano := + repeat (setoid_rewrite pow_0_r + || setoid_rewrite pow_succ_r + || setoid_rewrite pow_pred_r). + + Lemma pow_mul (x y:F) : forall (n:Z), not(x==0)/\not(y==0)\/(n>=0)%Z -> (x * y)^n == x^n * y^n. + Proof. + match goal with |- forall n, @?P n => eapply (Z.order_induction'_0 P) end. + { repeat intro. subst. reflexivity. } + - intros; cbv [power power_field powZ]; ring. + - intros n Hn IH Hxy. + repeat setoid_rewrite pow_succ_r; try rewrite IH; try ring; (right; omega). + - intros n Hn IH Hxy. destruct Hxy as [[]|?]; try omega; []. + repeat setoid_rewrite pow_pred_r; try rewrite IH; try field; auto with field_nonzero. + Qed. + + Lemma pow_nonzero (x:F) : forall (n:Z), not(x==0) -> not(x^n==0). + match goal with |- forall n, @?P n => eapply (Z.order_induction'_0 P) end; intros; pow_peano; + eauto with field_nonzero. + { repeat intro. subst. reflexivity. } + Qed. + Hint Resolve pow_nonzero : field_nonzero. + + Lemma pow_inv (x:F) : forall (n:Z), not(x==0) -> inv x^n == inv (x^n). + match goal with |- forall n, @?P n => eapply (Z.order_induction'_0 P) end. + { repeat intro. subst. reflexivity. } + - intros; cbv [power power_field powZ]. field; eauto with field_nonzero. + - intros n Hn IH Hx. + repeat setoid_rewrite pow_succ_r; try rewrite IH; try field; eauto with field_nonzero. + - intros n Hn IH Hx. + repeat setoid_rewrite pow_pred_r; try rewrite IH; try field; eauto 3 with field_nonzero. + Qed. + + Lemma pow_0_l : forall n, (n>0)%Z -> (0:F)^n==0. + match goal with |- forall n, @?P n => eapply (Z.order_induction'_0 P) end; intros; try omega. + { repeat intro. subst. reflexivity. } + setoid_rewrite pow_succ_r; [auto with field|right;omega]. + Qed. + + Lemma pow_div (x y:F) (n:Z) : not (y==0) -> not(x==0)\/(n >= 0)%Z -> (x/y)^n == x^n/y^n. + Proof. + intros Hy Hxn. unfold division, div_notation, div. + rewrite pow_mul, pow_inv; try field; destruct Hxn; auto with field_nonzero. + Qed. + + Hint Extern 3 (_ >= _)%Z => omega : field_nonzero. + Lemma issquare_mul (x y z:F) : not (y == 0) -> x^2%Z == z * y^2%Z -> (x/y)^2%Z == z. + Proof. intros. rewrite pow_div by (auto with field_nonzero); auto with field. Qed. + + Lemma issquare_mul_sub (x y z:F) : 0 == z*y^2%Z - x^2%Z -> (x/y)^2%Z == z \/ x == 0. + Proof. destruct (eq_dec y 0); [right|left]; auto using issquare_mul with field. Qed. + + Lemma div_mul : forall x y z : F, not(y==0) -> (z == (x / y) <-> z * y == x). + Proof. auto with field. Qed. + + Lemma div_1_r : forall x : F, x/1 == x. + Proof. eauto with field field_nonzero. Qed. + + Lemma div_1_l : forall x : F, not(x==0) -> 1/x == inv x. + Proof. auto with field. Qed. + + Lemma div_opp_l : forall x y, not (y==0) -> (-_ x) / y == -_ (x / y). + Proof. auto with field. Qed. + + Lemma div_opp_r : forall x y, not (y==0) -> x / (-_ y) == -_ (x / y). + Proof. auto with field. Qed. + + Lemma eq_opp_zero : forall x : F, (~ 1 + 1 == (0:F)) -> (x == -_ x <-> x == 0). + Proof. auto with field. Qed. + + Lemma add_cancel_l : forall x y z:F, z+x == z+y <-> x == y. + Proof. auto with field. Qed. + + Lemma add_cancel_r : forall x y z:F, x+z == y+z <-> x == y. + Proof. auto with field. Qed. + + Lemma add_cancel_r_eq : forall x z:F, x+z == z <-> x == 0. + Proof. auto with field. Qed. + + Lemma add_cancel_l_eq : forall x z:F, z+x == z <-> x == 0. + Proof. auto with field. Qed. + + Lemma sqrt_solutions : forall x y:F, y ^ 2%Z == x ^ 2%Z -> y == x \/ y == -_ x. + Proof. + intros ? ? squares_eq. + remember (y - x) as z eqn:Heqz. + assert (y == x + z) as Heqy by (subst; ring); rewrite Heqy in *; clear Heqy Heqz. + assert (Hw:(x + z)^2%Z == z * (x + (x + z)) + x^2%Z) + by (auto with field); rewrite Hw in squares_eq; clear Hw. + rewrite add_cancel_r_eq in squares_eq. + apply mul_zero_why in squares_eq; destruct squares_eq; auto with field. + Qed. + + End FieldProofs. +End F. diff --git a/src/Spec/Ed25519.v b/src/Experiments/SpecEd25519.v index 4876bb8d1..4e30313d9 100644 --- a/src/Spec/Ed25519.v +++ b/src/Experiments/SpecEd25519.v @@ -1,6 +1,6 @@ Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. Require Import Coq.Numbers.Natural.Peano.NPeano Coq.NArith.NArith. -Require Import Crypto.Spec.PointEncoding Crypto.Spec.ModularWordEncoding. +Require Import Crypto.Spec.ModularWordEncoding. Require Import Crypto.Encoding.ModularWordEncodingTheorems. Require Import Crypto.Spec.EdDSA. Require Import Crypto.Spec.CompleteEdwardsCurve Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. @@ -13,7 +13,7 @@ Require Import Coq.omega.Omega. Local Open Scope nat_scope. Definition q : Z := (2 ^ 255 - 19)%Z. -Lemma prime_q : prime q. Admitted. +Global Instance prime_q : prime q. Admitted. Lemma two_lt_q : (2 < q)%Z. reflexivity. Qed. Definition a : F q := opp 1%F. @@ -65,24 +65,23 @@ Lemma nonsquare_d : forall x, (x^2 <> d)%F. exact eq_refl. Qed. (* 10s *) -Instance curve25519params : TwistedEdwardsParams := { - q := q; - prime_q := prime_q; - two_lt_q := two_lt_q; - a := a; - nonzero_a := nonzero_a; - square_a := square_a; - d := d; - nonsquare_d := nonsquare_d -}. +Instance curve25519params : @E.twisted_edwards_params (F q) eq (ZToField 0) (ZToField 1) add mul a d := + { + nonzero_a := nonzero_a + (* TODO:port + char_gt_2 : ~ Feq (Fadd Fone Fone) Fzero; + nonzero_a : ~ Feq a Fzero; + nonsquare_d : forall x : F, ~ Feq (Fmul x x) d } + *) + }. +Admitted. Lemma two_power_nat_Z2Nat : forall n, Z.to_nat (two_power_nat n) = 2 ^ n. Admitted. Definition b := 256. -Lemma b_valid : (2 ^ (b - 1) > Z.to_nat CompleteEdwardsCurve.q)%nat. +Lemma b_valid : (2 ^ (b - 1) > Z.to_nat q)%nat. Proof. - replace (CompleteEdwardsCurve.q) with q by reflexivity. unfold q, gt. replace (2 ^ (b - 1)) with (Z.to_nat (2 ^ (Z.of_nat (b - 1)))) by (rewrite <- two_power_nat_equiv; apply two_power_nat_Z2Nat). @@ -143,37 +142,24 @@ Proof. reflexivity. Qed. -Definition PointEncoding : canonical encoding of E.point as (word b) := - (@point_encoding curve25519params (b - 1) q_5mod8 sqrt_minus1_valid FqEncoding sign_bit - (@sign_bit_zero _ prime_q two_lt_q _ b_valid) (@sign_bit_opp _ prime_q two_lt_q _ b_valid)). - -Definition H : forall n : nat, word n -> word (b + b). Admitted. -Definition B : E.point. Admitted. (* TODO: B = decodePoint (y=4/5, x="positive") *) -Definition B_nonzero : B <> E.zero. Admitted. -Definition l_order_B : (l * B)%E = E.zero. Admitted. - -Local Instance ed25519params : EdDSAParams := { - E := curve25519params; - b := b; - H := H; - c := c; - n := n; - B := B; - l := l; - FqEncoding := FqEncoding; - FlEncoding := FlEncoding; - PointEncoding := PointEncoding; - - b_valid := b_valid; - c_valid := c_valid; - n_ge_c := n_ge_c; - n_le_b := n_le_b; - B_not_identity := B_nonzero; - l_prime := prime_l; - l_odd := l_odd; - l_order_B := l_order_B -}. - -Definition ed25519_verify - : forall (pubkey:word b) (len:nat) (msg:word len) (sig:word (b+b)), bool - := @verify ed25519params.
\ No newline at end of file +Local Notation point := (@E.point (F q) eq (ZToField 1) add mul a d). +Local Notation zero := (E.zero(H:=field_modulo)). +Local Notation add := (E.add(H0:=curve25519params)). +Local Infix "*" := (E.mul(H0:=curve25519params)). +Axiom H : forall n : nat, word n -> word (b + b). +Axiom B : point. (* TODO: B = decodePoint (y=4/5, x="positive") *) +Axiom B_nonzero : B <> zero. +Axiom l_order_B : l * B = zero. +Axiom point_encoding : canonical encoding of point as word b. +Axiom scalar_encoding : canonical encoding of {n : nat | n < l} as word b. + +Global Instance Ed25519 : @EdDSA point E.eq add zero E.opp E.mul b H c n l B point_encoding scalar_encoding := + { + EdDSA_c_valid := c_valid; + EdDSA_n_ge_c := n_ge_c; + EdDSA_n_le_b := n_le_b; + EdDSA_B_not_identity := B_nonzero; + EdDSA_l_prime := prime_l; + EdDSA_l_odd := l_odd; + EdDSA_l_order_B := l_order_B + }.
\ No newline at end of file diff --git a/src/ModularArithmetic/ExtendedBaseVector.v b/src/ModularArithmetic/ExtendedBaseVector.v index 2e65df9bd..9ed7d065e 100644 --- a/src/ModularArithmetic/ExtendedBaseVector.v +++ b/src/ModularArithmetic/ExtendedBaseVector.v @@ -22,19 +22,19 @@ Section ExtendedBaseVector. * * (x \dot base) * (y \dot base) = (z \dot ext_base) * - * Then we can separate z into its first and second halves: + * Then we can separate z into its first and second halves: * * (z \dot ext_base) = (z1 \dot base) + (2 ^ k) * (z2 \dot base) * * Now, if we want to reduce the product modulo 2 ^ k - c: - * + * * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + (2 ^ k) * (z2 \dot base) mod (2^k-c) * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + c * (z2 \dot base) mod (2^k-c) * * This sum may be short enough to express using base; if not, we can reduce again. *) Definition ext_base := base ++ (map (Z.mul (2^k)) base). - + Lemma ext_base_positive : forall b, In b ext_base -> b > 0. Proof. unfold ext_base. intros b In_b_base. @@ -76,14 +76,14 @@ Section ExtendedBaseVector. intuition. Qed. - Lemma map_nth_default_base_high : forall n, (n < (length base))%nat -> + Lemma map_nth_default_base_high : forall n, (n < (length base))%nat -> nth_default 0 (map (Z.mul (2 ^ k)) base) n = (2 ^ k) * (nth_default 0 base n). Proof. intros. erewrite map_nth_default; auto. Qed. - + Lemma base_good_over_boundary : forall (i : nat) (l : (i < length base)%nat) diff --git a/src/ModularArithmetic/FField.v b/src/ModularArithmetic/FField.v deleted file mode 100644 index 4f2b623e0..000000000 --- a/src/ModularArithmetic/FField.v +++ /dev/null @@ -1,63 +0,0 @@ -Require Export Crypto.Spec.ModularArithmetic. -Require Export Coq.setoid_ring.Field. - -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. - -Local Open Scope F_scope. - -Definition OpaqueF := F. -Definition OpaqueZmodulo := BinInt.Z.modulo. -Definition Opaqueadd {p} : OpaqueF p -> OpaqueF p -> OpaqueF p := @add p. -Definition Opaquemul {p} : OpaqueF p -> OpaqueF p -> OpaqueF p := @mul p. -Definition Opaquesub {p} : OpaqueF p -> OpaqueF p -> OpaqueF p := @sub p. -Definition Opaquediv {p} : OpaqueF p -> OpaqueF p -> OpaqueF p := @div p. -Definition Opaqueopp {p} : OpaqueF p -> OpaqueF p := @opp p. -Definition Opaqueinv {p} : OpaqueF p -> OpaqueF p := @inv p. -Definition OpaqueZToField {p} : BinInt.Z -> OpaqueF p := @ZToField p. -Definition Opaqueadd_correct {p} : @Opaqueadd p = @add p := eq_refl. -Definition Opaquesub_correct {p} : @Opaquesub p = @sub p := eq_refl. -Definition Opaquemul_correct {p} : @Opaquemul p = @mul p := eq_refl. -Definition Opaquediv_correct {p} : @Opaquediv p = @div p := eq_refl. -Global Opaque F OpaqueZmodulo Opaqueadd Opaquemul Opaquesub Opaquediv Opaqueopp Opaqueinv OpaqueZToField. - -Definition OpaqueFieldTheory p {prime_p} : @field_theory (OpaqueF p) (OpaqueZToField 0%Z) (OpaqueZToField 1%Z) Opaqueadd Opaquemul Opaquesub Opaqueopp Opaquediv Opaqueinv eq := Eval hnf in @Ffield_theory p prime_p. - -Ltac FIELD_SIMPL_idtac FLD lH rl := - let Simpl := idtac (* (protect_fv "field") *) in - let lemma := get_SimplifyEqLemma FLD in - get_FldPre FLD (); - Field_Scheme Simpl Ring_tac.ring_subst_niter lemma FLD lH; - get_FldPost FLD (). -Ltac field_simplify_eq_idtac := let G := Get_goal in field_lookup (PackField FIELD_SIMPL_idtac) [] G. - -Ltac F_to_Opaque := - change F with OpaqueF in *; - change BinInt.Z.modulo with OpaqueZmodulo in *; - change @add with @Opaqueadd in *; - change @mul with @Opaquemul in *; - change @sub with @Opaquesub in *; - change @div with @Opaquediv in *; - change @opp with @Opaqueopp in *; - change @inv with @Opaqueinv in *; - change @ZToField with @OpaqueZToField in *. - -Ltac F_from_Opaque p := - change OpaqueF with F in *; - change (@sig BinNums.Z (fun z : BinNums.Z => @eq BinNums.Z z (BinInt.Z.modulo z p))) with (F p) in *; - change OpaqueZmodulo with BinInt.Z.modulo in *; - change @Opaqueopp with @opp in *; - change @Opaqueinv with @inv in *; - change @OpaqueZToField with @ZToField in *; - rewrite ?@Opaqueadd_correct, ?@Opaquesub_correct, ?@Opaquemul_correct, ?@Opaquediv_correct in *. - -Ltac F_field_simplify_eq := - lazymatch goal with |- @eq (F ?p) _ _ => - F_to_Opaque; - field_simplify_eq_idtac; - compute; - F_from_Opaque p - end. - -Ltac F_field := F_field_simplify_eq; [ring|..]. - -Ltac notConstant t := constr:NotConstant. diff --git a/src/ModularArithmetic/FNsatz.v b/src/ModularArithmetic/FNsatz.v deleted file mode 100644 index 221b8d799..000000000 --- a/src/ModularArithmetic/FNsatz.v +++ /dev/null @@ -1,40 +0,0 @@ -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Export Crypto.ModularArithmetic.FField. -Require Import Coq.nsatz.Nsatz. - -Ltac FqAsIntegralDomain := - lazymatch goal with [H:Znumtheory.prime ?q |- _ ] => - pose proof (_:@Integral_domain.Integral_domain (F q) _ _ _ _ _ _ _ _ _ _) as FqIntegralDomain; - lazymatch type of FqIntegralDomain with @Integral_domain.Integral_domain _ _ _ _ _ _ _ _ ?ringOps ?ringOk ?ringComm => - generalize dependent ringComm; intro Cring; - generalize dependent ringOk; intro Ring; - generalize dependent ringOps; intro RingOps; - lazymatch type of RingOps with @Ncring.Ring_ops ?t ?z ?o ?a ?m ?s ?p ?e => - generalize dependent e; intro equiv; - generalize dependent p; intro opp; - generalize dependent s; intro sub; - generalize dependent m; intro mul; - generalize dependent a; intro add; - generalize dependent o; intro one; - generalize dependent z; intro zero; - generalize dependent t; intro R - end - end; intros; - clear q H - end. - -Ltac fixed_equality_to_goal H x y := generalize (psos_r1 x y H); clear H. -Ltac fixed_equalities_to_goal := - match goal with - | H:?x == ?y |- _ => fixed_equality_to_goal H x y - | H:_ ?x ?y |- _ => fixed_equality_to_goal H x y - | H:_ _ ?x ?y |- _ => fixed_equality_to_goal H x y - | H:_ _ _ ?x ?y |- _ => fixed_equality_to_goal H x y - | H:_ _ _ _ ?x ?y |- _ => fixed_equality_to_goal H x y - end. -Ltac fixed_nsatz := - intros; try apply psos_r1b; - lazymatch goal with - | |- @equality ?T _ _ _ => repeat fixed_equalities_to_goal; nsatz_generic 6%N 1%Z (@nil T) (@nil T) - end. -Ltac F_nsatz := abstract (FqAsIntegralDomain; fixed_nsatz). diff --git a/src/ModularArithmetic/ModularArithmeticTheorems.v b/src/ModularArithmetic/ModularArithmeticTheorems.v index dabfcf883..8e526745c 100644 --- a/src/ModularArithmetic/ModularArithmeticTheorems.v +++ b/src/ModularArithmetic/ModularArithmeticTheorems.v @@ -1,3 +1,4 @@ +Require Import Coq.omega.Omega. Require Import Crypto.Spec.ModularArithmetic. Require Import Crypto.ModularArithmetic.Pre. @@ -10,7 +11,7 @@ Require Export Crypto.Util.IterAssocOp. Section ModularArithmeticPreliminaries. Context {m:Z}. - Local Coercion ZToFm := ZToField : BinNums.Z -> F m. Hint Unfold ZToFm. + Let ZToFm := ZToField : BinNums.Z -> F m. Hint Unfold ZToFm. Local Coercion ZToFm : Z >-> F. Theorem F_eq: forall (x y : F m), x = y <-> FieldToZ x = FieldToZ y. Proof. @@ -19,20 +20,20 @@ Section ModularArithmeticPreliminaries. f_equal. eapply UIP_dec, Z.eq_dec. Qed. - + Lemma F_opp_spec : forall (a:F m), add a (opp a) = 0. intros a. pose (@opp_with_spec m) as H. - change (@opp m) with (proj1_sig H). + change (@opp m) with (proj1_sig H). destruct H; eauto. Qed. - + Lemma F_pow_spec : forall (a:F m), pow a 0%N = 1%F /\ forall x, pow a (1 + x)%N = mul a (pow a x). Proof. intros a. pose (@pow_with_spec m) as H. - change (@pow m) with (proj1_sig H). + change (@pow m) with (proj1_sig H). destruct H; eauto. Qed. End ModularArithmeticPreliminaries. @@ -81,7 +82,7 @@ Ltac eq_remove_proofs := lazymatch goal with assert (Q := F_eq a b); simpl in *; apply Q; clear Q end. - + Ltac Fdefn := intros; rewrite ?F_opp_spec; @@ -149,6 +150,15 @@ Section FandZ. intuition; find_inversion; rewrite ?Z.mod_0_l, ?Z.mod_small in *; intuition. Qed. + Require Crypto.Algebra. + Global Instance commutative_ring_modulo : @Algebra.commutative_ring (F m) Logic.eq (ZToField 0) (ZToField 1) opp add sub mul. + Proof. + repeat split; Fdefn; try apply F_eq_dec. + { rewrite Z.add_0_r. auto. } + { rewrite <-Z.add_sub_swap, <-Z.add_sub_assoc, Z.sub_diag, Z.add_0_r. apply Z_mod_same_full. } + { rewrite Z.mul_1_r. auto. } + Qed. + Lemma ZToField_0 : @ZToField m 0 = 0. Proof. Fdefn. @@ -217,7 +227,7 @@ Section FandZ. rewrite ?N2Nat.inj_succ, ?pow_0, <-?N.add_1_l, ?pow_succ; simpl; congruence. Qed. - + Lemma mod_plus_zero_subproof a b : 0 mod m = (a + b) mod m -> b mod m = (- a) mod m. Proof. @@ -244,7 +254,7 @@ Section FandZ. intros. pose proof (FieldToZ_opp' x) as H; rewrite mod_FieldToZ in H; trivial. Qed. - + Lemma sub_intersperse_modulus : forall x y, ((x - y) mod m = (x + (m - y)) mod m)%Z. Proof. intros. @@ -282,7 +292,7 @@ Section FandZ. Proof. Fdefn. Qed. - + (* Compatibility between inject and pow *) Lemma ZToField_pow : forall x n, @ZToField m x ^ n = ZToField (x ^ (Z.of_N n) mod m). @@ -317,7 +327,7 @@ End FandZ. Section RingModuloPre. Context {m:Z}. - Local Coercion ZToFm := ZToField : Z -> F m. Hint Unfold ZToFm. + Let ZToFm := ZToField : Z -> F m. Hint Unfold ZToFm. Local Coercion ZToFm : Z >-> F. (* Substitution to prove all Compats *) Ltac compat := repeat intro; subst; trivial. @@ -362,12 +372,12 @@ Section RingModuloPre. Proof. Fdefn; rewrite Z.mul_1_r; auto. Qed. - + Lemma F_mul_assoc: forall x y z : F m, x * (y * z) = x * y * z. Proof. Fdefn. - Qed. + Qed. Lemma F_pow_pow_N (x : F m) : forall (n : N), (x ^ id n)%F = pow_N 1%F mul x n. Proof. @@ -390,7 +400,7 @@ Section RingModuloPre. Qed. (***** Division Theory *****) - Definition Fquotrem(a b: F m): F m * F m := + Definition Fquotrem(a b: F m): F m * F m := let '(q, r) := (Z.quotrem a b) in (q : F m, r : F m). Lemma Fdiv_theory : div_theory eq (@add m) (@mul m) (@id _) Fquotrem. Proof. @@ -434,7 +444,7 @@ Section RingModuloPre. Qed. (* Redefine our division theory under the ring morphism *) - Lemma Fmorph_div_theory: + Lemma Fmorph_div_theory: div_theory eq Zplus Zmult (@ZToField m) Z.quotrem. Proof. constructor; intros; intuition. @@ -451,7 +461,7 @@ Section RingModuloPre. Fdefn. Qed. End RingModuloPre. - + Ltac Fconstant t := match t with @ZToField _ ?x => x | _ => NotConstant end. Ltac Fexp_tac t := Ncst t. Ltac Fpreprocess := rewrite <-?ZToField_0, ?ZToField_1. @@ -470,49 +480,49 @@ Module RingModulo (Export M : Modulus). Definition ring_morph_modulo := @Fring_morph modulus. Definition morph_div_theory_modulo := @Fmorph_div_theory modulus. Definition power_theory_modulo := @Fpower_theory modulus. - + Add Ring GFring_Z : ring_theory_modulo (morphism ring_morph_modulo, constants [Fconstant], div morph_div_theory_modulo, - power_tac power_theory_modulo [Fexp_tac]). + power_tac power_theory_modulo [Fexp_tac]). End RingModulo. Section VariousModulo. Context {m:Z}. - + Add Ring GFring_Z : (@Fring_theory m) (morphism (@Fring_morph m), constants [Fconstant], div (@Fmorph_div_theory m), - power_tac (@Fpower_theory m) [Fexp_tac]). + power_tac (@Fpower_theory m) [Fexp_tac]). Lemma F_mul_0_l : forall x : F m, 0 * x = 0. Proof. intros; ring. Qed. - + Lemma F_mul_0_r : forall x : F m, x * 0 = 0. Proof. intros; ring. Qed. - + Lemma F_mul_nonzero_l : forall a b : F m, a*b <> 0 -> a <> 0. intros; intuition; subst. assert (0 * b = 0) by ring; intuition. Qed. - + Lemma F_mul_nonzero_r : forall a b : F m, a*b <> 0 -> b <> 0. intros; intuition; subst. assert (a * 0 = 0) by ring; intuition. Qed. - + Lemma F_pow_distr_mul : forall (x y:F m) z, (0 <= z)%N -> (x ^ z) * (y ^ z) = (x * y) ^ z. Proof. intros. replace z with (Z.to_N (Z.of_N z)) by apply N2Z.id. - apply natlike_ind with (x := Z.of_N z); simpl; [ ring | | + apply natlike_ind with (x := Z.of_N z); simpl; [ ring | | replace 0%Z with (Z.of_N 0%N) by auto; apply N2Z.inj_le; auto]. intros z' z'_nonneg IHz'. rewrite Z2N.inj_succ by auto. @@ -521,7 +531,7 @@ Section VariousModulo. rewrite <- IHz'. ring. Qed. - + Lemma F_opp_0 : opp (0 : F m) = 0%F. Proof. intros; ring. @@ -563,7 +573,7 @@ Section VariousModulo. Proof. intros; ring. Qed. - + Lemma F_add_reg_r : forall x y z : F m, y + x = z + x -> y = z. Proof. intros ? ? ? A. @@ -653,7 +663,7 @@ Section VariousModulo. Proof. split; intro A; [ replace w with (w - x + x) by ring - | replace w with (w + z - z) by ring ]; rewrite A; ring. + | replace w with (w + z - z) by ring ]; rewrite A; ring. Qed. Definition isSquare (x : F m) := exists sqrt_x, sqrt_x ^ 2 = x. diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index 558b9a5a2..ca8c19d18 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -11,14 +11,14 @@ Local Open Scope Z_scope. Section PseudoMersenneBase. Context `{prm :PseudoMersenneBaseParams}. - + Definition decode (us : digits) : F modulus := ZToField (BaseSystem.decode base us). - - Definition rep (us : digits) (x : F modulus) := (length us <= length base)%nat /\ decode us = x. + + Definition rep (us : digits) (x : F modulus) := (length us = length base)%nat /\ decode us = x. Local Notation "u '~=' x" := (rep u x) (at level 70). Local Hint Unfold rep. - Definition encode (x : F modulus) := encode x. + Definition encode (x : F modulus) := encode x ++ BaseSystem.zeros (length base - 1)%nat. (* Converts from length of extended base to length of base by reduction modulo M.*) Definition reduce (us : digits) : digits := @@ -35,13 +35,13 @@ Section PseudoMersenneBase. End PseudoMersenneBase. Section CarryBasePow2. - Context `{prm :PseudoMersenneBaseParams}. + Context `{prm :PseudoMersenneBaseParams}. Definition log_cap i := nth_default 0 limb_widths i. Definition add_to_nth n (x:Z) xs := set_nth n (x + nth_default 0 xs n) xs. - + Definition pow2_mod n i := Z.land n (Z.ones i). Definition carry_simple i := fun us => @@ -54,64 +54,68 @@ Section CarryBasePow2. let us' := set_nth i (pow2_mod di (log_cap i)) us in add_to_nth 0 (c * (Z.shiftr di (log_cap i))) us'. - Definition carry i : digits -> digits := + Definition carry i : digits -> digits := if eq_nat_dec i (pred (length base)) then carry_and_reduce i else carry_simple i. Definition carry_sequence is us := fold_right carry us is. -End CarryBasePow2. - -Section Canonicalization. - Context `{prm :PseudoMersenneBaseParams}. - Fixpoint make_chain i := match i with | O => nil | S i' => i' :: make_chain i' end. - (* compute at compile time *) Definition full_carry_chain := make_chain (length limb_widths). - (* compute at compile time *) - Definition max_ones := Z.ones - ((fix loop current_max lw := - match lw with - | nil => current_max - | w :: lw' => loop (Z.max w current_max) lw' - end - ) 0 limb_widths). - - (* compute at compile time? *) Definition carry_full := carry_sequence full_carry_chain. + Definition carry_mul us vs := carry_full (mul us vs). + +End CarryBasePow2. + +Section Canonicalization. + Context `{prm :PseudoMersenneBaseParams}. + + (* compute at compile time *) + Definition max_ones := Z.ones (fold_right Z.max 0 limb_widths). + Definition max_bound i := Z.ones (log_cap i). - Definition isFull us := - (fix loop full i := - match i with - | O => full (* don't test 0; the test for 0 is the initial value of [full]. *) - | S i' => loop (andb (Z.eqb (max_bound i) (nth_default 0 us i)) full) i' - end - ) (Z.ltb (max_bound 0 - (c + 1)) (nth_default 0 us 0)) (length us - 1)%nat. + Fixpoint isFull' us full i := + match i with + | O => andb (Z.ltb (max_bound 0 - c) (nth_default 0 us 0)) full + | S i' => isFull' us (andb (Z.eqb (max_bound i) (nth_default 0 us i)) full) i' + end. + + Definition isFull us := isFull' us true (length base - 1)%nat. - Fixpoint range' n m := - match m with - | O => nil - | S m' => (n - m)%nat :: range' n m' + Fixpoint modulus_digits' i := + match i with + | O => max_bound i - c + 1 :: nil + | S i' => modulus_digits' i' ++ max_bound i :: nil end. - Definition range n := range' n n. + (* compute at compile time *) + Definition modulus_digits := modulus_digits' (length base - 1). + + Fixpoint map2 {A B C} (f : A -> B -> C) (la : list A) (lb : list B) : list C := + match la with + | nil => nil + | a :: la' => match lb with + | nil => nil + | b :: lb' => f a b :: map2 f la' lb' + end + end. + + Definition and_term us := if isFull us then max_ones else 0. - Definition land_max_bound and_term i := Z.land and_term (max_bound i). - Definition freeze us := let us' := carry_full (carry_full (carry_full us)) in - let and_term := if isFull us' then max_ones else 0 in + let and_term := and_term us' in (* [and_term] is all ones if us' is full, so the subtractions subtract q overall. Otherwise, it's all zeroes, and the subtractions do nothing. *) - map (fun x => (snd x) - land_max_bound and_term (fst x)) (combine (range (length us')) us'). - + map2 (fun x y => x - y) us' (map (Z.land and_term) modulus_digits). + End Canonicalization. diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index 981680b4a..116fe10e5 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -27,7 +27,12 @@ Definition Z_shiftl_by_opt := Eval compute in Z_shiftl_by. Definition nth_default_opt {A} := Eval compute in @nth_default A. Definition set_nth_opt {A} := Eval compute in @set_nth A. Definition map_opt {A B} := Eval compute in @map A B. -Definition base_from_limb_widths_opt := Eval compute in base_from_limb_widths. +Definition full_carry_chain_opt := Eval compute in @full_carry_chain. +Definition length_opt := Eval compute in length. +Definition base_opt := Eval compute in @base. +Definition max_ones_opt := Eval compute in @max_ones. +Definition max_bound_opt := Eval compute in @max_bound. +Definition minus_opt := Eval compute in minus. Definition Let_In {A P} (x : A) (f : forall y : A, P y) := let y := x in f y. @@ -71,18 +76,22 @@ Ltac construct_params prime_modulus len k := | abstract apply prime_modulus | abstract brute_force_indices lw]. -Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits := +Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits := match limb_widths with | nil => nil | x :: tail => 2 ^ (x + 1) - (2 * c) :: map (fun w => 2 ^ (w + 1) - 2) tail end. +Ltac compute_preconditions := + cbv; intros; repeat match goal with H : _ \/ _ |- _ => + destruct H; subst; [ congruence | ] end; (congruence || omega). + Ltac subst_precondition := match goal with | [H : ?P, H' : ?P -> _ |- _] => specialize (H' H); clear H end. -Ltac kill_precondition H := +Ltac kill_precondition H := forward H; [abstract (try exact eq_refl; clear; cbv; intros; repeat break_or_hyp; intuition)|]; subst_precondition. @@ -95,8 +104,7 @@ Ltac compute_formula := let p := fresh "p" in set (p := P) in H at 1; change P with p at 1; let r := fresh "r" in set (r := result) in H |- *; cbv -[m p r PseudoMersenneBaseRep.rep] in H; - repeat rewrite ?Z.mul_1_r, ?Z.add_0_l, ?Z.add_assoc, ?Z.mul_assoc in H; - exact H + repeat rewrite ?Z.mul_1_r, ?Z.add_0_l, ?Z.add_assoc, ?Z.mul_assoc in H end. Section Carries. @@ -113,8 +121,9 @@ Section Carries. rewrite <- pull_app_if_sumbool. cbv beta delta [carry carry_and_reduce carry_simple add_to_nth log_cap - pow2_mod Z.ones Z.pred base + pow2_mod Z.ones Z.pred PseudoMersenneBaseParams.limb_widths]. + change @base with @base_opt. change @nth_default with @nth_default_opt in *. change @set_nth with @set_nth_opt in *. lazymatch goal with @@ -129,7 +138,6 @@ Section Carries. change @set_nth with @set_nth_opt. change @map with @map_opt. rewrite <- @beq_nat_eq_nat_dec. - change base_from_limb_widths with base_from_limb_widths_opt. reflexivity. Defined. @@ -179,7 +187,7 @@ Section Carries. change (LHS = Let_In (nth_default_opt 0%Z b i) RHSf). change Z.shiftl with Z_shiftl_opt. change (-1) with (Z_opp_opt 1). - change Z.add with Z_add_opt at 8 12 20 24. + change Z.add with Z_add_opt at 5 9 17 21. reflexivity. Defined. @@ -191,6 +199,39 @@ Section Carries. @carry_opt_cps T i f b = f (carry i b) := proj2_sig (carry_opt_cps_sig i f b). + Definition carry_sequence_opt_cps2_sig {T} (is : list nat) (us : digits) + (f : digits -> T) + : { b : T | (forall i, In i is -> i < length base)%nat -> b = f (carry_sequence is us) }. + Proof. + eexists. + cbv [carry_sequence]. + transitivity (fold_right carry_opt_cps f (List.rev is) us). + Focus 2. + { + assert (forall i, In i (rev is) -> i < length base)%nat as Hr. { + subst. intros. rewrite <- in_rev in *. auto. } + remember (rev is) as ris eqn:Heq. + rewrite <- (rev_involutive is), <- Heq. + clear H Heq is. + rewrite fold_left_rev_right. + revert us; induction ris; [ reflexivity | ]; intros. + { simpl. + rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption]. + rewrite carry_opt_cps_correct; [reflexivity|]. + apply Hr; left; reflexivity. + } } + Unfocus. + reflexivity. + Defined. + + Definition carry_sequence_opt_cps2 {T} is us (f : digits -> T) := + Eval cbv [proj1_sig carry_sequence_opt_cps2_sig] in + proj1_sig (carry_sequence_opt_cps2_sig is us f). + + Definition carry_sequence_opt_cps2_correct {T} is us (f : digits -> T) + : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps2 is us f = f (carry_sequence is us) + := proj2_sig (carry_sequence_opt_cps2_sig is us f). + Definition carry_sequence_opt_cps_sig (is : list nat) (us : digits) : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }. Proof. @@ -198,7 +239,7 @@ Section Carries. cbv [carry_sequence]. transitivity (fold_right carry_opt_cps id (List.rev is) us). Focus 2. - { + { assert (forall i, In i (rev is) -> i < length base)%nat as Hr. { subst. intros. rewrite <- in_rev in *. auto. } remember (rev is) as ris eqn:Heq. @@ -226,14 +267,55 @@ Section Carries. Lemma carry_sequence_opt_cps_rep : forall (is : list nat) (us : list Z) (x : F modulus), (forall i : nat, In i is -> i < length base)%nat -> - length us = length base -> rep us x -> rep (carry_sequence_opt_cps is us) x. Proof. intros. rewrite carry_sequence_opt_cps_correct by assumption. - apply carry_sequence_rep; assumption. + apply carry_sequence_rep; eauto using rep_length. Qed. + Lemma full_carry_chain_bounds : forall i, In i full_carry_chain -> (i < length base)%nat. + Proof. + unfold full_carry_chain; rewrite <-base_length; intros. + apply make_chain_lt; auto. + Qed. + + Definition carry_full_opt_sig (us : digits) : { b : digits | b = carry_full us }. + Proof. + eexists. + cbv [carry_full]. + change @full_carry_chain with full_carry_chain_opt. + rewrite <-carry_sequence_opt_cps_correct by (auto; apply full_carry_chain_bounds). + reflexivity. + Defined. + + Definition carry_full_opt (us : digits) : digits + := Eval cbv [proj1_sig carry_full_opt_sig] in proj1_sig (carry_full_opt_sig us). + + Definition carry_full_opt_correct us : carry_full_opt us = carry_full us := + proj2_sig (carry_full_opt_sig us). + + Definition carry_full_opt_cps_sig + {T} + (f : digits -> T) + (us : digits) + : { d : T | d = f (carry_full us) }. + Proof. + eexists. + rewrite <- carry_full_opt_correct. + cbv beta iota delta [carry_full_opt]. + rewrite carry_sequence_opt_cps_correct by apply full_carry_chain_bounds. + rewrite <-carry_sequence_opt_cps2_correct by apply full_carry_chain_bounds. + reflexivity. + Defined. + + Definition carry_full_opt_cps {T} (f : digits -> T) (us : digits) : T + := Eval cbv [proj1_sig carry_full_opt_cps_sig] in proj1_sig (carry_full_opt_cps_sig f us). + + Definition carry_full_opt_cps_correct {T} us (f : digits -> T) : + carry_full_opt_cps f us = f (carry_full us) := + proj2_sig (carry_full_opt_cps_sig f us). + End Carries. Section Addition. @@ -416,12 +498,11 @@ Section Multiplication. eexists. cbv [BaseSystem.mul mul mul_each mul_bi mul_bi' zeros ext_base reduce]. rewrite <- mul'_opt_correct. - cbv [base PseudoMersenneBaseParams.limb_widths]. + change @base with base_opt. rewrite map_shiftl by apply k_nonneg. rewrite c_subst. rewrite k_subst. change @map with @map_opt. - change base_from_limb_widths with base_from_limb_widths_opt. change @Z_shiftl_by with @Z_shiftl_by_opt. reflexivity. Defined. @@ -433,31 +514,158 @@ Section Multiplication. : mul_opt us vs = mul us vs := proj2_sig (mul_opt_sig us vs). - Lemma mul_opt_rep: + Definition carry_mul_opt_sig (us vs : T) : { b : digits | b = carry_mul us vs }. + Proof. + eexists. + cbv [carry_mul]. + erewrite <-carry_full_opt_correct by eauto. + erewrite <-mul_opt_correct. + reflexivity. + Defined. + + Definition carry_mul_opt (us vs : T) : digits + := Eval cbv [proj1_sig carry_mul_opt_sig] in proj1_sig (carry_mul_opt_sig us vs). + + Definition carry_mul_opt_correct us vs + : carry_mul_opt us vs = carry_mul us vs + := proj2_sig (carry_mul_opt_sig us vs). + + Lemma carry_mul_opt_rep: forall (u v : T) (x y : F modulus), PseudoMersenneBaseRep.rep u x -> PseudoMersenneBaseRep.rep v y -> - PseudoMersenneBaseRep.rep (mul_opt u v) (x * y)%F. + PseudoMersenneBaseRep.rep (carry_mul_opt u v) (x * y)%F. Proof. intros. - rewrite mul_opt_correct. - change mul with PseudoMersenneBaseRep.mul. + rewrite carry_mul_opt_correct. + change carry_mul with PseudoMersenneBaseRep.mul. auto using PseudoMersenneBaseRep.mul_rep. Qed. - Definition carry_mul_opt - (is : list nat) - (us vs : list Z) - : list Z - := carry_sequence_opt_cps c_ is (mul_opt us vs). - - Lemma carry_mul_opt_correct - : forall (is : list nat) (us vs : list Z) (x y: F modulus), - PseudoMersenneBaseRep.rep us x -> PseudoMersenneBaseRep.rep vs y -> - (forall i : nat, In i is -> i < length base)%nat -> - length (mul_opt us vs) = length base -> - PseudoMersenneBaseRep.rep (carry_mul_opt is us vs) (x*y)%F. +End Multiplication. + +Record freezePreconditions {modulus} (prm : PseudoMersenneBaseParams modulus) int_width := +mkFreezePreconditions { + lt_1_length_base : (1 < length base)%nat; + int_width_pos : 0 < int_width; + int_width_compat : forall w, In w limb_widths -> w <= int_width; + c_pos : 0 < c; + c_reduce1 : c * (Z.ones (int_width - log_cap (pred (length base)))) < max_bound 0 + 1; + c_reduce2 : c <= max_bound 0 - c; + two_pow_k_le_2modulus : 2 ^ k <= 2 * modulus +}. +Local Hint Resolve lt_1_length_base int_width_pos int_width_compat c_pos + c_reduce1 c_reduce2 two_pow_k_le_2modulus. + +Section Canonicalization. + Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm} + (* allows caller to precompute k and c *) + (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_) + {int_width} (preconditions : freezePreconditions prm int_width). + + Definition modulus_digits_opt_sig : + { b : digits | b = modulus_digits }. + Proof. + eexists. + cbv beta iota delta [modulus_digits modulus_digits' app]. + change @max_bound with max_bound_opt. + rewrite c_subst. + change length with length_opt. + change minus with minus_opt. + change Z.add with Z_add_opt. + change Z.sub with Z_sub_opt. + change @base with base_opt. + reflexivity. + Defined. + + Definition modulus_digits_opt : digits + := Eval cbv [proj1_sig modulus_digits_opt_sig] in proj1_sig (modulus_digits_opt_sig). + + Definition modulus_digits_opt_correct + : modulus_digits_opt = modulus_digits + := proj2_sig (modulus_digits_opt_sig). + + + Definition carry_full_3_opt_cps_sig + {T} (f : digits -> T) + (us : digits) + : { d : T | d = f (carry_full (carry_full (carry_full us))) }. + Proof. + eexists. + transitivity (carry_full_opt_cps c_ (carry_full_opt_cps c_ (carry_full_opt_cps c_ f)) us). + Focus 2. { + rewrite !carry_full_opt_cps_correct by assumption; reflexivity. + } + Unfocus. + reflexivity. + Defined. + + Definition carry_full_3_opt_cps {T} (f : digits -> T) (us : digits) : T + := Eval cbv [proj1_sig carry_full_3_opt_cps_sig] in proj1_sig (carry_full_3_opt_cps_sig f us). + + Definition carry_full_3_opt_cps_correct {T} (f : digits -> T) us : + carry_full_3_opt_cps f us = f (carry_full (carry_full (carry_full us))) := + proj2_sig (carry_full_3_opt_cps_sig f us). + + Definition freeze_opt_sig (us : T) : + { b : digits | b = freeze us }. Proof. - intros is us vs x y; intros. - change (carry_mul_opt _ _ _) with (carry_sequence_opt_cps c_ is (mul_opt us vs)). - apply carry_sequence_opt_cps_rep, mul_opt_rep; auto. + eexists. + cbv [freeze]. + cbv [and_term]. + let LHS := match goal with |- ?LHS = ?RHS => LHS end in + let RHS := match goal with |- ?LHS = ?RHS => RHS end in + let RHSf := match (eval pattern (isFull (carry_full (carry_full (carry_full us)))) in RHS) with ?RHSf _ => RHSf end in + change (LHS = Let_In (isFull(carry_full (carry_full (carry_full us)))) RHSf). + let LHS := match goal with |- ?LHS = ?RHS => LHS end in + let RHS := match goal with |- ?LHS = ?RHS => RHS end in + let RHSf := match (eval pattern (carry_full (carry_full (carry_full us))) in RHS) with ?RHSf _ => RHSf end in + rewrite <-carry_full_3_opt_cps_correct with (f := RHSf). + cbv beta iota delta [and_term isFull isFull']. + change length with length_opt. + change @max_bound with max_bound_opt. + rewrite c_subst. + change @max_ones with max_ones_opt. + change @base with base_opt. + change minus with minus_opt. + change @map with @map_opt. + change Z.sub with Z_sub_opt at 1. + rewrite <-modulus_digits_opt_correct. + reflexivity. + Defined. + + Definition freeze_opt (us : T) : digits + := Eval cbv beta iota delta [proj1_sig freeze_opt_sig] in proj1_sig (freeze_opt_sig us). + + Definition freeze_opt_correct us + : freeze_opt us = freeze us + := proj2_sig (freeze_opt_sig us). + + Lemma freeze_opt_canonical: forall us vs x, + @pre_carry_bounds _ _ int_width us -> PseudoMersenneBaseRep.rep us x -> + @pre_carry_bounds _ _ int_width vs -> PseudoMersenneBaseRep.rep vs x -> + freeze_opt us = freeze_opt vs. + Proof. + intros. + rewrite !freeze_opt_correct. + change PseudoMersenneBaseRep.rep with rep in *. + eapply freeze_canonical with (B := int_width); eauto. Qed. -End Multiplication.
\ No newline at end of file + + Lemma freeze_opt_preserves_rep : forall us x, PseudoMersenneBaseRep.rep us x -> + PseudoMersenneBaseRep.rep (freeze_opt us) x. + Proof. + intros. + rewrite freeze_opt_correct. + change PseudoMersenneBaseRep.rep with rep in *. + eapply freeze_preserves_rep; eauto. + Qed. + + Lemma freeze_opt_spec : forall us vs x, rep us x -> rep vs x -> + @pre_carry_bounds _ _ int_width us -> + @pre_carry_bounds _ _ int_width vs -> + (PseudoMersenneBaseRep.rep (freeze_opt us) x /\ freeze_opt us = freeze_opt vs). + Proof. + split; eauto using freeze_opt_canonical. + auto using freeze_opt_preserves_rep. + Qed. + +End Canonicalization.
\ No newline at end of file diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index 274acff5a..0462b0f37 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -1,7 +1,7 @@ Require Import Zpower ZArith. Require Import Coq.Numbers.Natural.Peano.NPeano. Require Import List. -Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. +Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil Crypto.Util.NatUtil. Require Import VerdiTactics. Require Crypto.BaseSystem. Require Import Crypto.ModularArithmetic.ModularBaseSystem Crypto.ModularArithmetic.PrimeFieldTheorems. @@ -22,14 +22,21 @@ Section PseudoMersenneProofs. autounfold; intuition. Qed. + Lemma rep_length : forall us x, us ~= x -> length us = length base. + Proof. + autounfold; intuition. + Qed. + Lemma encode_rep : forall x : F modulus, encode x ~= x. Proof. intros. unfold encode, rep. split. { unfold encode; simpl. - apply base_length_nonzero. + rewrite length_zeros. + pose proof base_length_nonzero; omega. } { unfold decode. + rewrite decode_highzeros. rewrite encode_rep. apply ZToField_FieldToZ. apply bv. @@ -40,8 +47,7 @@ Section PseudoMersenneProofs. Proof. autounfold; intuition. { unfold add. - rewrite add_length_le_max. - case_max; try rewrite Max.max_r; omega. + auto using add_same_length. } unfold decode in *; unfold decode in *. rewrite add_rep. @@ -49,15 +55,14 @@ Section PseudoMersenneProofs. subst; auto. Qed. - Lemma sub_rep : forall c c_0modq, (length c <= length base)%nat -> - forall u v x y, u ~= x -> v ~= y -> + Lemma sub_rep : forall c c_0modq, (length c = length base)%nat -> + forall u v x y, u ~= x -> v ~= y -> ModularBaseSystem.sub c c_0modq u v ~= (x-y)%F. Proof. autounfold; unfold ModularBaseSystem.sub; intuition. { - rewrite sub_length_le_max. + rewrite sub_length. case_max; try rewrite Max.max_r; try omega. - rewrite add_length_le_max. - case_max; try rewrite Max.max_r; omega. + auto using add_same_length. } unfold decode in *; unfold BaseSystem.decode in *. rewrite BaseSystemProofs.sub_rep, BaseSystemProofs.add_rep. @@ -66,7 +71,7 @@ Section PseudoMersenneProofs. subst; auto. Qed. - Lemma decode_short : forall (us : BaseSystem.digits), + Lemma decode_short : forall (us : BaseSystem.digits), (length us <= length base)%nat -> BaseSystem.decode base us = BaseSystem.decode ext_base us. Proof. @@ -80,11 +85,11 @@ Section PseudoMersenneProofs. Qed. Lemma mul_rep_extended : forall (us vs : BaseSystem.digits), - (length us <= length base)%nat -> + (length us <= length base)%nat -> (length vs <= length base)%nat -> (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode ext_base (BaseSystem.mul ext_base us vs). Proof. - intros. + intros. rewrite mul_rep by (apply ExtBaseVector || unfold ext_base; simpl_list; omega). f_equal; rewrite decode_short; auto. Qed. @@ -93,7 +98,7 @@ Section PseudoMersenneProofs. pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega. Qed. - (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *) + (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *) Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus. Proof. intros. @@ -137,34 +142,16 @@ Section PseudoMersenneProofs. rewrite mul_each_rep; auto. Qed. - Lemma reduce_length : forall us, - (length us <= length ext_base)%nat -> - (length (reduce us) <= length base)%nat. + Lemma reduce_length : forall us, + (length base <= length us <= length ext_base)%nat -> + (length (reduce us) = length base)%nat. Proof. - intros. - unfold reduce. - remember (map (Z.mul c) (skipn (length base) us)) as high. - remember (firstn (length base) us) as low. - assert (length low >= length high)%nat. { - subst. rewrite firstn_length. - rewrite map_length. - rewrite skipn_length. - destruct (le_dec (length base) (length us)). { - rewrite Min.min_l by omega. - rewrite extended_base_length in H. omega. - } { - rewrite Min.min_r; omega. - } - } - assert ((length low <= length base)%nat) - by (rewrite Heqlow; rewrite firstn_length; apply Min.le_min_l). - assert (length high <= length base)%nat - by (rewrite Heqhigh; rewrite map_length; rewrite skipn_length; - rewrite extended_base_length in H; omega). - rewrite add_trailing_zeros; auto. - rewrite (add_same_length _ _ (length low)); auto. - rewrite app_length. - rewrite length_zeros; intuition. + rewrite extended_base_length. + unfold reduce; intros. + rewrite add_length_exact. + rewrite map_length, firstn_length, skipn_length. + rewrite Min.min_l by omega. + apply Max.max_l; omega. Qed. Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> u .* v ~= (x*y)%F. @@ -172,20 +159,22 @@ Section PseudoMersenneProofs. autounfold; unfold ModularBaseSystem.mul; intuition. { apply reduce_length. - rewrite mul_length, extended_base_length. - omega. + rewrite mul_length_exact, extended_base_length; try omega. + destruct u; try congruence. + rewrite @nil_length0 in *. + pose proof base_length_nonzero; omega. } { rewrite ZToField_mod, reduce_rep, <-ZToField_mod. rewrite mul_rep by (apply ExtBaseVector || rewrite extended_base_length; omega). subst. - do 2 rewrite decode_short by auto. + do 2 rewrite decode_short by omega. apply ZToField_mul. } Qed. Lemma set_nth_sum : forall n x us, (n < length us)%nat -> - BaseSystem.decode base (set_nth n x us) = + BaseSystem.decode base (set_nth n x us) = (x - nth_default 0 us n) * nth_default 0 base n + BaseSystem.decode base us. Proof. intros. @@ -213,12 +202,27 @@ Section PseudoMersenneProofs. Qed. Lemma add_to_nth_sum : forall n x us, (n < length us)%nat -> - BaseSystem.decode base (add_to_nth n x us) = + BaseSystem.decode base (add_to_nth n x us) = x * nth_default 0 base n + BaseSystem.decode base us. Proof. unfold add_to_nth; intros; rewrite set_nth_sum; try ring_simplify; auto. Qed. + Lemma add_to_nth_nth_default : forall n x l i, (0 <= i < length l)%nat -> + nth_default 0 (add_to_nth n x l) i = + if (eq_nat_dec i n) then x + nth_default 0 l i else nth_default 0 l i. + Proof. + intros. + unfold add_to_nth. + rewrite set_nth_nth_default by assumption. + break_if; subst; reflexivity. + Qed. + + Lemma length_add_to_nth : forall n x l, length (add_to_nth n x l) = length l. + Proof. + unfold add_to_nth; intros; apply length_set_nth. + Qed. + Lemma nth_default_base_positive : forall i, (i < length base)%nat -> nth_default 0 base i > 0. Proof. @@ -240,13 +244,21 @@ Section PseudoMersenneProofs. apply base_succ; auto. Qed. + Lemma Fdecode_decode_mod : forall us x, (length us = length base) -> + decode us = x -> BaseSystem.decode base us mod modulus = x. + Proof. + unfold decode; intros ? ? ? decode_us. + rewrite <-decode_us. + apply FieldToZ_ZToField. + Qed. + End PseudoMersenneProofs. Section CarryProofs. Context `{prm : PseudoMersenneBaseParams}. Local Notation "u '~=' x" := (rep u x) (at level 70). Hint Unfold log_cap. - + Lemma base_length_lt_pred : (pred (length base) < length base)%nat. Proof. pose proof base_length_nonzero; omega. @@ -260,7 +272,7 @@ Section CarryProofs. apply limb_widths_nonneg. eapply nth_error_value_In; eauto. Qed. - + Lemma nth_default_base_succ : forall i, (S i < length base)%nat -> nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i. Proof. @@ -342,8 +354,8 @@ Section CarryProofs. Qed. Lemma carry_length : forall i us, - (length us <= length base)%nat -> - (length (carry i us) <= length base)%nat. + (length us = length base)%nat -> + (length (carry i us) = length base)%nat. Proof. unfold carry, carry_simple, carry_and_reduce, add_to_nth. intros; break_if; subst; repeat (rewrite length_set_nth); auto. @@ -356,36 +368,19 @@ Section CarryProofs. us ~= x -> carry i us ~= x. Proof. pose carry_length. pose carry_decode_eq_reduce. pose carry_simple_decode_eq. - unfold rep, decode, carry in *; intros. - intuition; break_if; subst; eauto; - apply F_eq; simpl; intuition. + intros; split; auto. + unfold rep, decode, carry in *. + intuition; break_if; subst; eauto; apply F_eq; simpl; intuition. Qed. Hint Resolve carry_rep. Lemma carry_sequence_length: forall is us, - (length us <= length base)%nat -> - (length (carry_sequence is us) <= length base)%nat. - Proof. - induction is; boring. - Qed. - Hint Resolve carry_sequence_length. - - Lemma carry_length_exact : forall i us, - (length us = length base)%nat -> - (length (carry i us) = length base)%nat. - Proof. - unfold carry, carry_simple, carry_and_reduce, add_to_nth. - intros; break_if; subst; repeat (rewrite length_set_nth); auto. - Qed. - - Lemma carry_sequence_length_exact: forall is us, (length us = length base)%nat -> (length (carry_sequence is us) = length base)%nat. Proof. induction is; boring. - apply carry_length_exact; auto. Qed. - Hint Resolve carry_sequence_length_exact. + Hint Resolve carry_sequence_length. Lemma carry_sequence_rep : forall is us x, (forall i, In i is -> (i < length base)%nat) -> @@ -395,46 +390,45 @@ Section CarryProofs. induction is; boring. Qed. -End CarryProofs. -Section CanonicalizationProofs. - Context `{prm : PseudoMersenneBaseParams} (lt_1_length_base : (1 < length base)%nat) (c_pos : 0 < c) {B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B). - - (* TODO : move *) - Lemma set_nth_nth_default : forall {A} (d:A) n x l i, (0 <= i < length l)%nat -> - nth_default d (set_nth n x l) i = - if (eq_nat_dec i n) then x else nth_default d l i. + (* TODO : move? *) + Lemma make_chain_lt : forall x i : nat, In i (make_chain x) -> (i < x)%nat. Proof. - induction n; (destruct l; [intros; simpl in *; omega | ]); simpl; - destruct i; break_if; try omega; intros; try apply nth_default_cons; - rewrite !nth_default_cons_S, ?IHn; try break_if; omega || reflexivity. + induction x; simpl; intuition. Qed. - (* TODO : move *) - Lemma add_to_nth_nth_default : forall n x l i, (0 <= i < length l)%nat -> - nth_default 0 (add_to_nth n x l) i = - if (eq_nat_dec i n) then x + nth_default 0 l i else nth_default 0 l i. + Lemma carry_full_preserves_rep : forall us x, + rep us x -> rep (carry_full us) x. Proof. - intros. - unfold add_to_nth. - rewrite set_nth_nth_default by assumption. - break_if; subst; reflexivity. + unfold carry_full; intros. + apply carry_sequence_rep; auto. + unfold full_carry_chain; rewrite base_length; apply make_chain_lt. + eauto using rep_length. Qed. - (* TODO : move *) - Lemma length_add_to_nth : forall n x l, length (add_to_nth n x l) = length l. - Proof. - unfold add_to_nth; intros; apply length_set_nth. - Qed. + Opaque carry_full. - (* TODO : move *) - Lemma singleton_list : forall {A} (l : list A), length l = 1%nat -> exists x, l = x :: nil. + Lemma carry_mul_rep : forall us vs x y, rep us x -> rep vs y -> + rep (carry_mul us vs) (x * y)%F. Proof. - intros; destruct l; simpl in *; try congruence. - eexists; f_equal. - apply length0_nil; omega. + unfold carry_mul; intros; apply carry_full_preserves_rep. + auto using mul_rep. Qed. +End CarryProofs. + +Section CanonicalizationProofs. + Context `{prm : PseudoMersenneBaseParams} (lt_1_length_base : (1 < length base)%nat) + {B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B) + (c_pos : 0 < c) + (* on the first reduce step, we add at most one bit of width to the first digit *) + (c_reduce1 : c * (Z.ones (B - log_cap (pred (length base)))) < max_bound 0 + 1) + (* on the second reduce step, we add at most one bit of width to the first digit, + and leave room to carry c one more time after the highest bit is carried *) + (c_reduce2 : c <= max_bound 0 - c) + (* this condition is probably implied by c_reduce2, but is more straighforward to compute than to prove *) + (two_pow_k_le_2modulus : 2 ^ k <= 2 * modulus). + (* BEGIN groundwork proofs *) Lemma pow_2_log_cap_pos : forall i, 0 < 2 ^ log_cap i. @@ -451,7 +445,7 @@ Section CanonicalizationProofs. omega. Qed. - Hint Resolve log_cap_nonneg. + Local Hint Resolve log_cap_nonneg. Lemma pow2_mod_log_cap_range : forall a i, 0 <= pow2_mod a (log_cap i) <= max_bound i. Proof. intros. @@ -488,6 +482,16 @@ Section CanonicalizationProofs. omega. Qed. + Lemma max_bound_pos : forall i, (i < length base)%nat -> 0 < max_bound i. + Proof. + unfold max_bound, log_cap; intros; apply Z_ones_pos_pos. + apply limb_widths_pos. + rewrite nth_default_eq. + apply nth_In. + rewrite <-base_length; assumption. + Qed. + Local Hint Resolve max_bound_pos. + Lemma max_bound_nonneg : forall i, 0 <= max_bound i. Proof. unfold max_bound; intros; auto using Z_ones_nonneg. @@ -501,15 +505,6 @@ Section CanonicalizationProofs. rewrite Z.land_ones; auto. Qed. - Lemma pow2_mod_upper_bound : forall a b, (0 <= a) -> (0 <= b) -> pow2_mod a b <= a. - Proof. - intros. - unfold pow2_mod. - rewrite Z.land_ones; auto. - apply Z.mod_le; auto. - apply Z.pow_pos_nonneg; omega. - Qed. - Lemma shiftr_eq_0_max_bound : forall i a, Z.shiftr a (log_cap i) = 0 -> a <= max_bound i. Proof. @@ -550,26 +545,26 @@ Section CanonicalizationProofs. omega. Qed. + Lemma log_cap_eq : forall i, log_cap i = nth_default 0 limb_widths i. + Proof. + reflexivity. + Qed. + (* END groundwork proofs *) Opaque pow2_mod log_cap max_bound. (* automation *) Ltac carry_length_conditions' := unfold carry_full, add_to_nth; - rewrite ?length_set_nth, ?carry_length_exact, ?carry_sequence_length_exact, ?carry_sequence_length_exact; + rewrite ?length_set_nth, ?carry_length, ?carry_sequence_length; try omega; try solve [pose proof base_length; pose proof base_length_nonzero; omega || auto ]. Ltac carry_length_conditions := try split; try omega; repeat carry_length_conditions'. - Ltac add_set_nth := rewrite ?add_to_nth_nth_default; try solve [carry_length_conditions]; - try break_if; try omega; rewrite ?set_nth_nth_default; try solve [carry_length_conditions]; - try break_if; try omega. + Ltac add_set_nth := + rewrite ?add_to_nth_nth_default by carry_length_conditions; break_if; try omega; + rewrite ?set_nth_nth_default by carry_length_conditions; break_if; try omega. (* BEGIN defs *) - Definition c_carry_constraint : Prop := - (c * (Z.ones (B - log_cap (pred (length base)))) < max_bound 0 + 1) - /\ (max_bound 0 + c < 2 ^ (log_cap 0 + 1)) - /\ (c <= max_bound 0 - c). - Definition pre_carry_bounds us := forall i, 0 <= nth_default 0 us i < if (eq_nat_dec i 0) then 2 ^ B else 2 ^ B - 2 ^ (B - log_cap (pred i)). @@ -581,26 +576,10 @@ Section CanonicalizationProofs. specialize (PCB i). omega. Qed. - Hint Resolve pre_carry_bounds_nonzero. + Local Hint Resolve pre_carry_bounds_nonzero. - Definition carry_done us := forall i, (i < length base)%nat -> Z.shiftr (nth_default 0 us i) (log_cap i) = 0. - - Lemma carry_carry_done_done : forall i us, - (length us = length base)%nat -> - (i < length base)%nat -> - (forall i, 0 <= nth_default 0 us i) -> - carry_done us -> carry_done (carry i us). - Proof. - unfold carry_done; intros until 3. intros Hcarry_done ? ?. - unfold carry, carry_simple, carry_and_reduce; break_if; subst. - + rewrite Hcarry_done by omega. - rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_bound). - destruct i0; add_set_nth; rewrite ?Z.mul_0_r, ?Z.add_0_l; auto. - match goal with H : S _ = pred (length base) |- _ => rewrite H; auto end. - + rewrite Hcarry_done by omega. - rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_bound). - destruct i0; add_set_nth; subst; rewrite ?Z.add_0_l; auto. - Qed. + Definition carry_done us := forall i, (i < length base)%nat -> + 0 <= nth_default 0 us i /\ Z.shiftr (nth_default 0 us i) (log_cap i) = 0. (* END defs *) @@ -620,6 +599,7 @@ Section CanonicalizationProofs. apply pow2_mod_log_cap_bounds_upper. - rewrite nth_default_out_of_bounds by carry_length_conditions; auto. Qed. + Local Hint Resolve nth_default_carry_bound_upper. Lemma nth_default_carry_bound_lower : forall i us, (length us = length base) -> 0 <= nth_default 0 (carry i us) i. @@ -635,6 +615,7 @@ Section CanonicalizationProofs. apply pow2_mod_log_cap_bounds_lower. - rewrite nth_default_out_of_bounds by carry_length_conditions; omega. Qed. + Local Hint Resolve nth_default_carry_bound_lower. Lemma nth_default_carry_bound_succ_lower : forall i us, (forall i, 0 <= nth_default 0 us i) -> (length us = length base) -> @@ -645,18 +626,15 @@ Section CanonicalizationProofs. + subst. replace (S (pred (length base))) with (length base) by omega. rewrite nth_default_out_of_bounds; carry_length_conditions. unfold carry_and_reduce. - add_set_nth. + carry_length_conditions. + unfold carry_simple. destruct (lt_dec (S i) (length us)). - - add_set_nth. - apply Z.add_nonneg_nonneg; [ apply Z.shiftr_nonneg | ]; unfold pre_carry_bounds in PCB. - * specialize (PCB i). omega. - * specialize (PCB (S i)). omega. + - add_set_nth; zero_bounds. - rewrite nth_default_out_of_bounds by carry_length_conditions; omega. Qed. Lemma carry_unaffected_low : forall i j us, ((0 < i < j)%nat \/ (i = 0 /\ j <> 0 /\ j <> pred (length base))%nat)-> - (length us = length base) -> + (length us = length base) -> nth_default 0 (carry j us) i = nth_default 0 us i. Proof. intros. @@ -671,7 +649,7 @@ Section CanonicalizationProofs. (omega || rewrite length_add_to_nth; rewrite length_set_nth; pose proof base_length_nonzero; omega). reflexivity. Qed. - + Lemma carry_unaffected_high : forall i j us, (S j < i)%nat -> (length us = length base) -> nth_default 0 (carry j us) i = nth_default 0 us i. Proof. @@ -679,7 +657,7 @@ Section CanonicalizationProofs. destruct (lt_dec i (length us)); [ | rewrite !nth_default_out_of_bounds by carry_length_conditions; reflexivity]. unfold carry, carry_simple. - break_if; add_set_nth. + break_if; [omega | add_set_nth]. Qed. Lemma carry_nothing : forall i j us, (i < length base)%nat -> @@ -688,23 +666,65 @@ Section CanonicalizationProofs. nth_default 0 (carry j us) i = nth_default 0 us i. Proof. unfold carry, carry_simple, carry_and_reduce; intros. - break_if; (add_set_nth; + break_if; (add_set_nth; [ rewrite max_bound_shiftr_eq_0 by omega; ring | subst; apply pow2_mod_log_cap_small; assumption ]). Qed. + Lemma carry_done_bounds : forall us, (length us = length base) -> + (carry_done us <-> forall i, 0 <= nth_default 0 us i < 2 ^ log_cap i). + Proof. + intros ? ?; unfold carry_done; split; [ intros Hcarry_done i | intros Hbounds i i_lt ]. + + destruct (lt_dec i (length base)) as [i_lt | i_nlt]. + - specialize (Hcarry_done i i_lt). + split; [ intuition | ]. + rewrite <- max_bound_log_cap. + apply Z.lt_succ_r. + apply shiftr_eq_0_max_bound; intuition. + - rewrite nth_default_out_of_bounds; try split; try omega; auto. + + specialize (Hbounds i). + split; intuition. + apply max_bound_shiftr_eq_0; auto. + rewrite <-max_bound_log_cap in *; omega. + Qed. + + Lemma carry_carry_done_done : forall i us, + (length us = length base)%nat -> + (i < length base)%nat -> + carry_done us -> carry_done (carry i us). + Proof. + unfold carry_done; intros i ? ? i_bound Hcarry_done x x_bound. + destruct (Hcarry_done x x_bound) as [lower_bound_x shiftr_0_x]. + destruct (Hcarry_done i i_bound) as [lower_bound_i shiftr_0_i]. + split. + + rewrite carry_nothing; auto. + split; [ apply Hcarry_done; auto | ]. + apply shiftr_eq_0_max_bound. + apply Hcarry_done; auto. + + unfold carry, carry_simple, carry_and_reduce; break_if; subst. + - add_set_nth; subst. + * rewrite shiftr_0_i, Z.mul_0_r, Z.add_0_l. + assumption. + * rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_bound). + assumption. + - rewrite shiftr_0_i by omega. + rewrite pow2_mod_log_cap_small by (intuition; auto using shiftr_eq_0_max_bound). + add_set_nth; subst; rewrite ?Z.add_0_l; auto. + Qed. + + Lemma carry_sequence_chain_step : forall i us, + carry_sequence (make_chain (S i)) us = carry i (carry_sequence (make_chain i) us). + Proof. + reflexivity. + Qed. + Lemma carry_bounds_0_upper : forall us j, (length us = length base) -> (0 < j < length base)%nat -> nth_default 0 (carry_sequence (make_chain j) us) 0 <= max_bound 0. Proof. - unfold carry_sequence; induction j; [simpl; intros; omega | ]. - intros. - simpl in *. - destruct (eq_nat_dec 0 j). - + subst. - apply nth_default_carry_bound_upper; fold (carry_sequence (make_chain 0) us); carry_length_conditions. - + rewrite carry_unaffected_low; try omega. - fold (carry_sequence (make_chain j) us); carry_length_conditions. + induction j as [ | [ | j ] IHj ]; [simpl; intros; omega | | ]; intros. + + subst; simpl; auto. + + rewrite carry_sequence_chain_step, carry_unaffected_low; carry_length_conditions. Qed. Lemma carry_bounds_upper : forall i us j, (0 < i < j)%nat -> (length us = length base) -> @@ -721,7 +741,7 @@ Section CanonicalizationProofs. fold (carry_sequence (make_chain j) us); carry_length_conditions. Qed. - Lemma carry_sequence_unaffected : forall i us j, (j < i)%nat -> (length us = length base)%nat -> + Lemma carry_sequence_unaffected : forall i us j, (j < i)%nat -> (length us = length base)%nat -> nth_default 0 (carry_sequence (make_chain j) us) i = nth_default 0 us i. Proof. induction j; [simpl; intros; omega | ]. @@ -731,33 +751,41 @@ Section CanonicalizationProofs. apply IHj; omega. Qed. + (* makes omega run faster *) + Ltac clear_obvious := + match goal with + | [H : ?a <= ?a |- _] => clear H + | [H : ?a <= S ?a |- _] => clear H + | [H : ?a < S ?a |- _] => clear H + | [H : ?a = ?a |- _] => clear H + end. + Lemma carry_sequence_bounds_lower : forall j i us, (length us = length base) -> (forall i, 0 <= nth_default 0 us i) -> (j <= length base)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain j) us) i. Proof. - induction j; intros. - + simpl. auto. - + simpl. - destruct (lt_dec (S j) i). - - rewrite carry_unaffected_high by carry_length_conditions. - apply IHj; auto; omega. - - assert ((i = S j) \/ (i = j) \/ (i < j))%nat as cases by omega. - destruct cases as [? | [? | ?]]. - * subst. apply nth_default_carry_bound_succ_lower; carry_length_conditions. - intros. - eapply IHj; auto; omega. - * subst. apply nth_default_carry_bound_lower; carry_length_conditions. - * destruct (eq_nat_dec j (pred (length base))); - [ | rewrite carry_unaffected_low by carry_length_conditions; apply IHj; auto; omega ]. - subst. - unfold carry, carry_and_reduce; break_if; try omega. - add_set_nth; [ | apply IHj; auto; omega ]. - apply Z.add_nonneg_nonneg; [ | apply IHj; auto; omega ]. - apply Z.mul_nonneg_nonneg; try omega. - apply Z.shiftr_nonneg. - apply IHj; auto; omega. + induction j; intros; simpl; auto. + destruct (lt_dec (S j) i). + + rewrite carry_unaffected_high by carry_length_conditions. + apply IHj; auto; omega. + + assert ((i = S j) \/ (i = j) \/ (i < j))%nat as cases by omega. + destruct cases as [? | [? | ?]]. + - subst. apply nth_default_carry_bound_succ_lower; carry_length_conditions. + intros; eapply IHj; auto; omega. + - subst. apply nth_default_carry_bound_lower; carry_length_conditions. + - destruct (eq_nat_dec j (pred (length base))); + [ | rewrite carry_unaffected_low by carry_length_conditions; apply IHj; auto; omega ]. + subst. + do 2 match goal with H : appcontext[S (pred (length base))] |- _ => + erewrite <-(S_pred (length base)) in H by eauto end. + unfold carry; break_if; [ unfold carry_and_reduce | omega ]. + clear_obvious. + add_set_nth; [ zero_bounds | ]; apply IHj; auto; omega. Qed. + Ltac carry_seq_lower_bound := + repeat (intros; eapply carry_sequence_bounds_lower; eauto; carry_length_conditions). + Lemma carry_bounds_lower : forall i us j, (0 < i <= j)%nat -> (length us = length base) -> (forall i, 0 <= nth_default 0 us i) -> (j <= length base)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain j) us) i. @@ -769,13 +797,12 @@ Section CanonicalizationProofs. destruct (eq_nat_dec i (S j)). + subst. apply nth_default_carry_bound_succ_lower; auto; fold (carry_sequence (make_chain j) us); carry_length_conditions. - intros. - apply carry_sequence_bounds_lower; auto; omega. + carry_seq_lower_bound. + assert (i = j \/ i < j)%nat as cases by omega. destruct cases as [eq_j_i | lt_i_j]; subst; [apply nth_default_carry_bound_lower| rewrite carry_unaffected_low]; try omega; fold (carry_sequence (make_chain j) us); carry_length_conditions. - apply carry_sequence_bounds_lower; auto; omega. + carry_seq_lower_bound. Qed. Lemma carry_full_bounds : forall us i, (i <> 0)%nat -> (forall i, 0 <= nth_default 0 us i) -> @@ -799,18 +826,15 @@ Section CanonicalizationProofs. unfold carry, carry_simple; break_if; try omega. add_set_nth. replace (2 ^ B) with (2 ^ (B - log_cap i) + (2 ^ B - 2 ^ (B - log_cap i))) by omega. - split. - + apply Z.add_nonneg_nonneg; try omega. - apply Z.shiftr_nonneg; try omega. - + apply Z.add_lt_mono; try omega. - rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. - apply Z.div_lt_upper_bound; try apply pow_2_log_cap_pos. - rewrite <-Z.pow_add_r by (apply log_cap_nonneg || apply B_compat_log_cap). - replace (log_cap i + (B - log_cap i)) with B by ring. - omega. + split; [ zero_bounds | ]. + apply Z.add_lt_mono; try omega. + rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. + apply Z.div_lt_upper_bound; try apply pow_2_log_cap_pos. + rewrite <-Z.pow_add_r by (apply log_cap_nonneg || apply B_compat_log_cap). + replace (log_cap i + (B - log_cap i)) with B by ring. + omega. Qed. - Lemma carry_sequence_no_overflow : forall i us, pre_carry_bounds us -> (length us = length base) -> nth_default 0 (carry_sequence (make_chain i) us) i < 2 ^ B. @@ -822,16 +846,13 @@ Section CanonicalizationProofs. intuition. + simpl. destruct (lt_eq_lt_dec i (pred (length base))) as [[? | ? ] | ? ]. - - apply carry_simple_no_overflow; carry_length_conditions. - apply carry_sequence_bounds_lower; carry_length_conditions. - apply carry_sequence_bounds_lower; carry_length_conditions. - rewrite carry_sequence_unaffected; try omega. + - apply carry_simple_no_overflow; carry_length_conditions; carry_seq_lower_bound. + rewrite carry_sequence_unaffected; try omega. specialize (PCB (S i)); rewrite Nat.pred_succ in PCB. break_if; intuition. - unfold carry; break_if; try omega. rewrite nth_default_out_of_bounds; [ apply Z.pow_pos_nonneg; omega | ]. - subst. - unfold carry_and_reduce. + subst; unfold carry_and_reduce. carry_length_conditions. - rewrite nth_default_out_of_bounds; [ apply Z.pow_pos_nonneg; omega | ]. carry_length_conditions. @@ -843,25 +864,20 @@ Section CanonicalizationProofs. Proof. unfold carry_full, full_carry_chain; intros. rewrite <- base_length. - replace (length base) with (S (pred (length base))) at 1 2 by omega. + replace (length base) with (S (pred (length base))) by omega. simpl. unfold carry, carry_and_reduce; break_if; try omega. - add_set_nth. - split. - + apply Z.add_nonneg_nonneg. - - apply Z.mul_nonneg_nonneg; try omega. - apply Z.shiftr_nonneg. - apply carry_sequence_bounds_lower; auto; omega. - - apply carry_sequence_bounds_lower; auto; omega. - + rewrite Z.add_comm. - apply Z.add_le_mono. - - apply carry_bounds_0_upper; auto; omega. - - apply Z.mul_le_mono_pos_l; auto. - apply Z_shiftr_ones; auto; - [ | pose proof (B_compat_log_cap (pred (length base))); omega ]. - split. - * apply carry_bounds_lower; auto; try omega. - * apply carry_sequence_no_overflow; auto. + clear_obvious; add_set_nth. + split; [zero_bounds; carry_seq_lower_bound | ]. + rewrite Z.add_comm. + apply Z.add_le_mono. + + apply carry_bounds_0_upper; auto; omega. + + apply Z.mul_le_mono_pos_l; auto. + apply Z_shiftr_ones; auto; + [ | pose proof (B_compat_log_cap (pred (length base))); omega ]. + split. + - apply carry_bounds_lower; auto; omega. + - apply carry_sequence_no_overflow; auto. Qed. Lemma carry_full_bounds_lower : forall i us, pre_carry_bounds us -> @@ -874,12 +890,12 @@ Section CanonicalizationProofs. - apply carry_bounds_lower; carry_length_conditions. - rewrite nth_default_out_of_bounds; carry_length_conditions. Qed. - + (* END proofs about first carry loop *) - + (* BEGIN proofs about second carry loop *) - Lemma carry_sequence_carry_full_bounds_same : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_sequence_carry_full_bounds_same : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full us)) i <= 2 ^ log_cap i. Proof. @@ -888,12 +904,9 @@ Section CanonicalizationProofs. unfold carry, carry_simple; break_if; try omega. add_set_nth. split. - + apply Z.add_nonneg_nonneg. - - apply Z.shiftr_nonneg. - destruct (eq_nat_dec i 0); subst. - * simpl. - apply carry_full_bounds_0; auto. - * apply IHi; auto; omega. + + zero_bounds; [destruct (eq_nat_dec i 0); subst | ]. + - simpl; apply carry_full_bounds_0; auto. + - apply IHi; auto; omega. - rewrite carry_sequence_unaffected by carry_length_conditions. apply carry_full_bounds; auto; omega. + rewrite <-max_bound_log_cap, <-Z.add_1_l. @@ -905,16 +918,14 @@ Section CanonicalizationProofs. eapply Z.le_lt_trans; [ apply carry_full_bounds_0; auto | ]. replace (2 ^ log_cap 0 * 2) with (2 ^ log_cap 0 + 2 ^ log_cap 0) by ring. rewrite <-max_bound_log_cap, <-Z.add_1_l. - apply Z.add_lt_le_mono; try omega. - unfold c_carry_constraint in *. - intuition. + apply Z.add_lt_le_mono; omega. * eapply Z.le_lt_trans; [ apply IHi; auto; omega | ]. - apply Z.lt_mul_diag_r; auto; omega. + apply Z.lt_mul_diag_r; auto; omega. - rewrite carry_sequence_unaffected by carry_length_conditions. apply carry_full_bounds; auto; omega. Qed. - Lemma carry_full_2_bounds_0 : forall us, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds_0 : forall us, pre_carry_bounds us -> (length us = length base)%nat -> (1 < length base)%nat -> 0 <= nth_default 0 (carry_full (carry_full us)) 0 <= max_bound 0 + c. Proof. @@ -924,19 +935,14 @@ Section CanonicalizationProofs. replace (length base) with (S (pred (length base))) by (pose proof base_length_nonzero; omega). simpl. unfold carry, carry_and_reduce; break_if; try omega. - add_set_nth. + clear_obvious; add_set_nth. split. - + apply Z.add_nonneg_nonneg. - apply Z.mul_nonneg_nonneg; try omega. - apply Z.shiftr_nonneg. + + zero_bounds; [ | carry_seq_lower_bound]. apply carry_sequence_carry_full_bounds_same; auto; omega. - eapply carry_sequence_bounds_lower; eauto; carry_length_conditions. - intros. - eapply carry_sequence_bounds_lower; eauto; carry_length_conditions. + rewrite Z.add_comm. apply Z.add_le_mono. - apply carry_bounds_0_upper; carry_length_conditions. - - replace c with (c * 1) at 2 by ring. + - etransitivity; [ | replace c with (c * 1) by ring; reflexivity ]. apply Z.mul_le_mono_pos_l; try omega. rewrite Z.shiftr_div_pow2 by auto. apply Z.div_le_upper_bound; auto. @@ -945,7 +951,7 @@ Section CanonicalizationProofs. omega. Qed. - Lemma carry_full_2_bounds_succ : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds_succ : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < pred (length base))%nat -> ((0 < i < length base)%nat -> 0 <= nth_default 0 @@ -954,20 +960,14 @@ Section CanonicalizationProofs. 0 <= nth_default 0 (carry_simple i (carry_sequence (make_chain i) (carry_full (carry_full us)))) (S i) <= 2 ^ log_cap (S i). Proof. - unfold carry_simple; intros ? ? PCB CCC length_eq ? IH. + unfold carry_simple; intros ? ? PCB length_eq ? IH. add_set_nth. split. - + apply Z.add_nonneg_nonneg. - apply Z.shiftr_nonneg. - destruct i; - [ simpl; pose proof (carry_full_2_bounds_0 us PCB CCC length_eq); omega | ]. - - assert (0 < S i < length base)%nat as IHpre by omega. - specialize (IH IHpre). - omega. - - rewrite carry_sequence_unaffected by carry_length_conditions. - apply carry_full_bounds; carry_length_conditions. - intros. - apply carry_sequence_bounds_lower; eauto; carry_length_conditions. + + zero_bounds. destruct i; + [ simpl; pose proof (carry_full_2_bounds_0 us PCB length_eq); omega | ]. + rewrite carry_sequence_unaffected by carry_length_conditions. + apply carry_full_bounds; carry_length_conditions. + carry_seq_lower_bound. + rewrite <-max_bound_log_cap, <-Z.add_1_l. rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. apply Z.add_le_mono. @@ -975,10 +975,10 @@ Section CanonicalizationProofs. ring_simplify. apply IH. omega. - rewrite carry_sequence_unaffected by carry_length_conditions. apply carry_full_bounds; carry_length_conditions. - intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions. + carry_seq_lower_bound. Qed. - Lemma carry_full_2_bounds_same : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds_same : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) i <= 2 ^ log_cap i. Proof. @@ -988,36 +988,33 @@ Section CanonicalizationProofs. split; (destruct (eq_nat_dec i 0); subst; [ cbv [make_chain carry_sequence fold_right carry_simple]; add_set_nth | eapply carry_full_2_bounds_succ; eauto; omega]). - + apply Z.add_nonneg_nonneg. - apply Z.shiftr_nonneg. - eapply carry_full_2_bounds_0; eauto. - eapply carry_full_bounds; eauto; carry_length_conditions. - intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions. + + zero_bounds. + - eapply carry_full_2_bounds_0; eauto. + - eapply carry_full_bounds; eauto; carry_length_conditions. + carry_seq_lower_bound. + rewrite <-max_bound_log_cap, <-Z.add_1_l. rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. apply Z.add_le_mono. - apply Z_div_floor; auto. eapply Z.le_lt_trans; [ eapply carry_full_2_bounds_0; eauto | ]. replace (Z.succ 1) with (2 ^ 1) by ring. - rewrite <-Z.pow_add_r by (omega || auto). - unfold c_carry_constraint in *. - intuition. - - apply carry_full_bounds; carry_length_conditions. - intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions. + rewrite <-max_bound_log_cap. + ring_simplify. omega. + - apply carry_full_bounds; carry_length_conditions; carry_seq_lower_bound. Qed. - Lemma carry_full_2_bounds' : forall us i j, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds' : forall us i j, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> (i + j < length base)%nat -> (j <> 0)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain (i + j)) (carry_full (carry_full us))) i <= max_bound i. Proof. induction j; intros; try omega. - split; (destruct j; [ rewrite Nat.add_1_r; simpl + split; (destruct j; [ rewrite Nat.add_1_r; simpl | rewrite <-plus_n_Sm; simpl; rewrite carry_unaffected_low by carry_length_conditions; eapply IHj; eauto; omega ]). + apply nth_default_carry_bound_lower; carry_length_conditions. + apply nth_default_carry_bound_upper; carry_length_conditions. Qed. - Lemma carry_full_2_bounds : forall us i j, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds : forall us i j, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> (i < j < length base)%nat -> 0 <= nth_default 0 (carry_sequence (make_chain j) (carry_full (carry_full us))) i <= max_bound i. Proof. @@ -1026,12 +1023,12 @@ Section CanonicalizationProofs. eapply carry_full_2_bounds'; eauto; omega. Qed. - Lemma carry_carry_full_2_bounds_0_lower : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_carry_full_2_bounds_0_lower : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> (0 <= nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) 0). Proof. induction i; try omega. - intros ? ? length_eq ?; simpl. + intros ? length_eq ?; simpl. destruct i. + unfold carry. break_if; @@ -1041,91 +1038,82 @@ Section CanonicalizationProofs. add_set_nth. apply pow2_mod_log_cap_bounds_lower. + rewrite carry_unaffected_low by carry_length_conditions. - assert (0 < S i < length base)%nat by omega. + assert (0 < S i < length base)%nat by omega. intuition. Qed. - Lemma carry_full_2_bounds_lower :forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_2_bounds_lower :forall us i, pre_carry_bounds us -> (length us = length base)%nat -> 0 <= nth_default 0 (carry_full (carry_full us)) i. Proof. - intros. - destruct i. + intros; destruct i. + apply carry_full_2_bounds_0; auto. + apply carry_full_bounds; try solve [carry_length_conditions]. - intro j. - destruct j. + intro j; destruct j. - apply carry_full_bounds_0; auto. - apply carry_full_bounds; carry_length_conditions. Qed. - Lemma carry_carry_full_2_bounds_0_upper : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_length : forall us, (length us = length base)%nat -> + length (carry_full us) = length us. + Proof. + intros; carry_length_conditions. + Qed. + Local Hint Resolve carry_full_length. + + Lemma carry_carry_full_2_bounds_0_upper : forall us i, pre_carry_bounds us -> (length us = length base)%nat -> (0 < i < length base)%nat -> (nth_default 0 (carry_sequence (make_chain i) (carry_full (carry_full us))) 0 <= max_bound 0 - c) \/ carry_done (carry_sequence (make_chain i) (carry_full (carry_full us))). Proof. induction i; try omega. - intros ? ? length_eq ?; simpl. + intros ? length_eq ?; simpl. destruct i. + destruct (Z_le_dec (nth_default 0 (carry_full (carry_full us)) 0) (max_bound 0)). - right. - unfold carry_done. + apply carry_carry_done_done; try solve [carry_length_conditions]. + apply carry_done_bounds; try solve [carry_length_conditions]. intros. - apply max_bound_shiftr_eq_0; simpl; rewrite carry_nothing; try solve [carry_length_conditions]. - * apply carry_full_2_bounds_lower; auto. - * split; try apply carry_full_2_bounds_lower; auto. - * destruct i; auto. - apply carry_full_bounds; try solve [carry_length_conditions]. - auto using carry_full_bounds_lower. - * split; auto. - apply carry_full_2_bounds_lower; auto. - - unfold carry. + simpl. + split; [ auto using carry_full_2_bounds_lower | ]. + * destruct i; rewrite <-max_bound_log_cap, Z.lt_succ_r; auto. + apply carry_full_bounds; auto using carry_full_bounds_lower. + rewrite carry_full_length; auto. + - left; unfold carry, carry_simple. break_if; [ pose proof base_length_nonzero; replace (length base) with 1%nat in *; omega | ]. - simpl. - unfold carry_simple. - add_set_nth. left. + add_set_nth. simpl. remember ((nth_default 0 (carry_full (carry_full us)) 0)) as x. - apply Z.le_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)). - * replace x with ((x - 2 ^ log_cap 0) + (1 * 2 ^ log_cap 0)) by ring. - rewrite pow2_mod_spec by auto. - rewrite Z.mod_add by (pose proof (pow_2_log_cap_pos 0); omega). - rewrite <-max_bound_log_cap, <-Z.add_1_l, Z.mod_small. - apply Z.sub_le_mono_r. - subst; apply carry_full_2_bounds_0; auto. - split; try omega. - pose proof carry_full_2_bounds_0. - apply Z.le_lt_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)); - [ apply Z.sub_le_mono_r; subst x; apply carry_full_2_bounds_0; auto; - ring_simplify; unfold c_carry_constraint in *; omega | ]. - ring_simplify; unfold c_carry_constraint in *; omega. - * ring_simplify; unfold c_carry_constraint in *; omega. + apply Z.le_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)); try omega. + replace x with ((x - 2 ^ log_cap 0) + (1 * 2 ^ log_cap 0)) by ring. + rewrite pow2_mod_spec by auto. + cbv [make_chain carry_sequence fold_right]. + rewrite Z.mod_add by (pose proof (pow_2_log_cap_pos 0); omega). + rewrite <-max_bound_log_cap, <-Z.add_1_l, Z.mod_small; + [ apply Z.sub_le_mono_r; subst; apply carry_full_2_bounds_0; auto | ]. + split; try omega. + pose proof carry_full_2_bounds_0. + apply Z.le_lt_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)); + [ apply Z.sub_le_mono_r; subst x; apply carry_full_2_bounds_0; auto; + ring_simplify | ]; omega. + rewrite carry_unaffected_low by carry_length_conditions. - assert (0 < S i < length base)%nat by omega. - intuition. - right. + assert (0 < S i < length base)%nat by omega. + intuition; right. apply carry_carry_done_done; try solve [carry_length_conditions]. - intro j. - destruct j. - - apply carry_carry_full_2_bounds_0_lower; auto. - - destruct (lt_eq_lt_dec j i) as [[? | ?] | ?]. - * apply carry_full_2_bounds; auto; omega. - * subst. apply carry_full_2_bounds_same; auto; omega. - * rewrite carry_sequence_unaffected; try solve [carry_length_conditions]. - apply carry_full_2_bounds_lower; auto; omega. - Qed. - + assumption. + Qed. + (* END proofs about second carry loop *) - + (* BEGIN proofs about third carry loop *) - Lemma carry_full_3_bounds : forall us i, pre_carry_bounds us -> c_carry_constraint -> + Lemma carry_full_3_bounds : forall us i, pre_carry_bounds us -> (length us = length base)%nat ->(i < length base)%nat -> 0 <= nth_default 0 (carry_full (carry_full (carry_full us))) i <= max_bound i. Proof. intros. destruct i; [ | apply carry_full_bounds; carry_length_conditions; - do 2 (intros; apply carry_sequence_bounds_lower; eauto; carry_length_conditions) ]. + carry_seq_lower_bound ]. unfold carry_full at 1 4, full_carry_chain. case_eq limb_widths; [intros; pose proof limb_widths_nonnil; congruence | ]. simpl. @@ -1135,45 +1123,967 @@ Section CanonicalizationProofs. unfold carry, carry_and_reduce; break_if; try omega; intros. add_set_nth. split. - + apply Z.add_nonneg_nonneg. - - apply Z.mul_nonneg_nonneg; auto; try omega. - apply Z.shiftr_nonneg. - eapply carry_full_2_bounds_same; eauto; omega. + + zero_bounds. + - eapply carry_full_2_bounds_same; eauto; omega. - eapply carry_carry_full_2_bounds_0_lower; eauto; omega. + pose proof (carry_carry_full_2_bounds_0_upper us (pred (length base))). assert (0 < pred (length base) < length base)%nat by omega. intuition. - replace (max_bound 0) with (c + (max_bound 0 - c)) by ring. apply Z.add_le_mono; try assumption. - replace c with (c * 1) at 2 by ring. + etransitivity; [ | replace c with (c * 1) by ring; reflexivity ]. apply Z.mul_le_mono_pos_l; try omega. rewrite Z.shiftr_div_pow2 by auto. apply Z.div_le_upper_bound; auto. ring_simplify. apply carry_full_2_bounds_same; auto. - - match goal with H : carry_done _ |- _ => unfold carry_done in H; rewrite H by omega end. + - match goal with H0 : (pred (length base) < length base)%nat, + H : carry_done _ |- _ => + destruct (H (pred (length base)) H0) as [Hcd1 Hcd2]; rewrite Hcd2 by omega end. ring_simplify. - apply shiftr_eq_0_max_bound; auto; omega. + apply shiftr_eq_0_max_bound; auto. + assert (0 < length base)%nat as zero_lt_length by omega. + match goal with H : carry_done _ |- _ => + destruct (H 0%nat zero_lt_length) end. + assumption. Qed. - Lemma nth_error_combine : forall {A B} i (x : A) (x' : B) l l', nth_error l i = Some x -> - nth_error l' i = Some x' -> nth_error (combine l l') i = Some (x, x'). - Admitted. - - Lemma nth_error_range : forall {A} i (l : list A), (i < length l)%nat -> - nth_error (range (length l)) i = Some i. - Admitted. + Lemma carry_full_3_done : forall us, pre_carry_bounds us -> + (length us = length base)%nat -> + carry_done (carry_full (carry_full (carry_full us))). + Proof. + intros. + apply carry_done_bounds; [ carry_length_conditions | intros ]. + destruct (lt_dec i (length base)). + + rewrite <-max_bound_log_cap, Z.lt_succ_r. + auto using carry_full_3_bounds. + + rewrite nth_default_out_of_bounds; carry_length_conditions. + Qed. (* END proofs about third carry loop *) - Opaque carry_full. - Lemma freeze_in_bounds : forall us i, (us <> nil)%nat -> - 0 <= nth_default 0 (freeze us) i < 2 ^ log_cap i. + Lemma isFull'_false : forall us n, isFull' us false n = false. Proof. - Admitted. + unfold isFull'; induction n; intros; rewrite Bool.andb_false_r; auto. + Qed. - Lemma freeze_canonical : forall us vs x, rep us x -> rep vs x -> + Lemma isFull'_last : forall us b j, (j <> 0)%nat -> isFull' us b j = true -> + max_bound j = nth_default 0 us j. + Proof. + induction j; simpl; intros; try omega. + match goal with + | [H : isFull' _ ((?comp ?a ?b) && _) _ = true |- _ ] => + case_eq (comp a b); rewrite ?Z.eqb_eq; intro comp_eq; try assumption; + rewrite comp_eq, Bool.andb_false_l, isFull'_false in H; congruence + end. + Qed. + + Lemma isFull'_lower_bound_0 : forall j us b, isFull' us b j = true -> + max_bound 0 - c < nth_default 0 us 0. + Proof. + induction j; intros. + + match goal with H : isFull' _ _ 0 = _ |- _ => cbv [isFull'] in H; + apply Bool.andb_true_iff in H; destruct H end. + apply Z.ltb_lt; assumption. + + eauto. + Qed. + + Lemma isFull'_true_full : forall us i j b, (i <> 0)%nat -> (i <= j)%nat -> isFull' us b j = true -> + max_bound i = nth_default 0 us i. + Proof. + induction j; intros; try omega. + assert (i = S j \/ i <= j)%nat as cases by omega. + destruct cases. + + subst. eapply isFull'_last; eauto. + + eapply IHj; eauto. + Qed. + + Lemma max_ones_nonneg : 0 <= max_ones. + Proof. + unfold max_ones. + apply Z_ones_nonneg. + pose proof limb_widths_nonneg. + induction limb_widths. + cbv; congruence. + simpl. + apply Z.max_le_iff. + right. + apply IHl; auto using in_cons. + Qed. + + Lemma land_max_ones_noop : forall x i, 0 <= x < 2 ^ log_cap i -> Z.land max_ones x = x. + Proof. + unfold max_ones. + intros ? ? x_range. + rewrite Z.land_comm. + rewrite Z.land_ones by apply Z_le_fold_right_max_initial. + apply Z.mod_small. + split; try omega. + eapply Z.lt_le_trans; try eapply x_range. + apply Z.pow_le_mono_r; try omega. + rewrite log_cap_eq. + destruct (lt_dec i (length limb_widths)). + + apply Z_le_fold_right_max. + - apply limb_widths_nonneg. + - rewrite nth_default_eq. + auto using nth_In. + + rewrite nth_default_out_of_bounds by omega. + apply Z_le_fold_right_max_initial. + Qed. + + Lemma full_isFull'_true : forall j us, (length us = length base) -> + ( max_bound 0 - c < nth_default 0 us 0 + /\ (forall i, (0 < i <= j)%nat -> nth_default 0 us i = max_bound i)) -> + isFull' us true j = true. + Proof. + induction j; intros. + + cbv [isFull']; apply Bool.andb_true_iff. + rewrite Z.ltb_lt; intuition. + + intuition. + simpl. + match goal with H : forall j, _ -> ?b j = ?a j |- appcontext[?a ?i =? ?b ?i] => + replace (a i =? b i) with true by (symmetry; apply Z.eqb_eq; symmetry; apply H; omega) end. + apply IHj; auto; intuition. + Qed. + + Lemma isFull'_true_iff : forall j us, (length us = length base) -> (isFull' us true j = true <-> + max_bound 0 - c < nth_default 0 us 0 + /\ (forall i, (0 < i <= j)%nat -> nth_default 0 us i = max_bound i)). + Proof. + intros; split; intros; auto using full_isFull'_true. + split; eauto using isFull'_lower_bound_0. + intros. + symmetry; eapply isFull'_true_full; [ omega | | eauto]. + omega. + Qed. + + Lemma isFull'_true_step : forall us j, isFull' us true (S j) = true -> + isFull' us true j = true. + Proof. + simpl; intros ? ? succ_true. + destruct (max_bound (S j) =? nth_default 0 us (S j)); auto. + rewrite isFull'_false in succ_true. + congruence. + Qed. + + Opaque isFull' max_ones. + + Lemma carry_full_3_length : forall us, (length us = length base) -> + length (carry_full (carry_full (carry_full us))) = length us. + Proof. + intros. + repeat rewrite carry_full_length by (repeat rewrite carry_full_length; auto); auto. + Qed. + Local Hint Resolve carry_full_3_length. + + Lemma nth_default_map2 : forall {A B C} (f : A -> B -> C) ls1 ls2 i d d1 d2, + nth_default d (map2 f ls1 ls2) i = + if lt_dec i (min (length ls1) (length ls2)) + then f (nth_default d1 ls1 i) (nth_default d2 ls2 i) + else d. + Proof. + induction ls1, ls2. + + cbv [map2 length min]. + intros. + break_if; try omega. + apply nth_default_nil. + + cbv [map2 length min]. + intros. + break_if; try omega. + apply nth_default_nil. + + cbv [map2 length min]. + intros. + break_if; try omega. + apply nth_default_nil. + + simpl. + destruct i. + - intros. rewrite !nth_default_cons. + break_if; auto; omega. + - intros. rewrite !nth_default_cons_S. + rewrite IHls1 with (d1 := d1) (d2 := d2). + repeat break_if; auto; omega. + Qed. + + Lemma map2_cons : forall A B C (f : A -> B -> C) ls1 ls2 a b, + map2 f (a :: ls1) (b :: ls2) = f a b :: map2 f ls1 ls2. + Proof. + reflexivity. + Qed. + + Lemma map2_nil_l : forall A B C (f : A -> B -> C) ls2, + map2 f nil ls2 = nil. + Proof. + reflexivity. + Qed. + + Lemma map2_nil_r : forall A B C (f : A -> B -> C) ls1, + map2 f ls1 nil = nil. + Proof. + destruct ls1; reflexivity. + Qed. + Local Hint Resolve map2_nil_r map2_nil_l. + + Opaque map2. + + Lemma map2_length : forall A B C (f : A -> B -> C) ls1 ls2, + length (map2 f ls1 ls2) = min (length ls1) (length ls2). + Proof. + induction ls1, ls2; intros; try solve [cbv; auto]. + rewrite map2_cons, !length_cons, IHls1. + auto. + Qed. + + Lemma modulus_digits'_length : forall i, length (modulus_digits' i) = S i. + Proof. + induction i; intros; [ cbv; congruence | ]. + unfold modulus_digits'; fold modulus_digits'. + rewrite app_length, IHi. + cbv [length]; omega. + Qed. + + Lemma modulus_digits_length : length modulus_digits = length base. + Proof. + unfold modulus_digits. + rewrite modulus_digits'_length; omega. + Qed. + + (* Helps with solving goals of the form [x = y -> min x y = x] or [x = y -> min x y = y] *) + Local Hint Resolve Nat.eq_le_incl eq_le_incl_rev. + + Hint Rewrite app_length cons_length map2_length modulus_digits_length length_zeros + map_length combine_length firstn_length map_app : lengths. + Ltac simpl_lengths := autorewrite with lengths; + repeat rewrite carry_full_length by (repeat rewrite carry_full_length; auto); + auto using Min.min_l; auto using Min.min_r. + + Lemma freeze_length : forall us, (length us = length base) -> + length (freeze us) = length us. + Proof. + unfold freeze; intros; simpl_lengths. + Qed. + + Lemma decode_firstn_succ : forall n us, (length us = length base) -> + (n < length base)%nat -> + BaseSystem.decode' (firstn (S n) base) (firstn (S n) us) = + BaseSystem.decode' (firstn n base) (firstn n us) + + nth_default 0 base n * nth_default 0 us n. + Proof. + intros. + rewrite !firstn_succ with (d := 0) by omega. + rewrite base_app, firstn_app. + autorewrite with lengths; rewrite !Min.min_l by omega. + rewrite Nat.sub_diag, firstn_firstn, firstn0, app_nil_r by omega. + rewrite skipn_app_sharp by (autorewrite with lengths; apply Min.min_l; omega). + rewrite decode'_cons, decode_nil, Z.add_0_r. + reflexivity. + Qed. + + Local Hint Resolve sum_firstn_limb_widths_nonneg. + Local Hint Resolve limb_widths_nonneg. + Local Hint Resolve nth_error_value_In. + + (* TODO : move *) + Lemma sum_firstn_all_succ : forall n l, (length l <= n)%nat -> + sum_firstn l (S n) = sum_firstn l n. + Proof. + unfold sum_firstn; intros. + rewrite !firstn_all_strong by omega. + congruence. + Qed. + + Lemma decode_carry_done_upper_bound' : forall n us, carry_done us -> + (length us = length base) -> + BaseSystem.decode (firstn n base) (firstn n us) < 2 ^ (sum_firstn limb_widths n). + Proof. + induction n; intros; [ cbv; congruence | ]. + destruct (lt_dec n (length base)) as [ n_lt_length | ? ]. + + rewrite decode_firstn_succ; auto. + rewrite base_length in n_lt_length. + destruct (nth_error_length_exists_value _ _ n_lt_length). + erewrite sum_firstn_succ; eauto. + rewrite Z.pow_add_r; eauto. + rewrite nth_default_base by (rewrite base_length; assumption). + rewrite Z.lt_add_lt_sub_r. + eapply Z.lt_le_trans; eauto. + rewrite Z.mul_comm at 1. + rewrite <-Z.mul_sub_distr_l. + rewrite <-Z.mul_1_r at 1. + apply Z.mul_le_mono_nonneg_l; [ apply Z.pow_nonneg; omega | ]. + replace 1 with (Z.succ 0) by reflexivity. + rewrite Z.le_succ_l, Z.lt_0_sub. + match goal with H : carry_done us |- _ => rewrite carry_done_bounds in H by auto; specialize (H n) end. + replace x with (log_cap n); try intuition. + rewrite log_cap_eq. + apply nth_error_value_eq_nth_default; auto. + + repeat erewrite firstn_all_strong by omega. + rewrite sum_firstn_all_succ by (rewrite <-base_length; omega). + eapply Z.le_lt_trans; [ | eauto]. + repeat erewrite firstn_all_strong by omega. + omega. + Qed. + + Lemma decode_carry_done_upper_bound : forall us, carry_done us -> + (length us = length base) -> BaseSystem.decode base us < 2 ^ k. + Proof. + unfold k; intros. + rewrite <-(firstn_all_strong base (length limb_widths)) by (rewrite <-base_length; auto). + rewrite <-(firstn_all_strong us (length limb_widths)) by (rewrite <-base_length; auto). + auto using decode_carry_done_upper_bound'. + Qed. + + Lemma decode_carry_done_lower_bound' : forall n us, carry_done us -> + (length us = length base) -> + 0 <= BaseSystem.decode (firstn n base) (firstn n us). + Proof. + induction n; intros; [ cbv; congruence | ]. + destruct (lt_dec n (length base)) as [ n_lt_length | ? ]. + + rewrite decode_firstn_succ by auto. + zero_bounds. + - rewrite nth_default_base by assumption. + apply Z.pow_nonneg; omega. + - match goal with H : carry_done us |- _ => rewrite carry_done_bounds in H by auto; specialize (H n) end. + intuition. + + eapply Z.le_trans; [ apply IHn; eauto | ]. + repeat rewrite firstn_all_strong by omega. + omega. + Qed. + + Lemma decode_carry_done_lower_bound : forall us, carry_done us -> + (length us = length base) -> 0 <= BaseSystem.decode base us. + Proof. + intros. + rewrite <-(firstn_all_strong base (length limb_widths)) by (rewrite <-base_length; auto). + rewrite <-(firstn_all_strong us (length limb_widths)) by (rewrite <-base_length; auto). + auto using decode_carry_done_lower_bound'. + Qed. + + + Lemma nth_default_modulus_digits' : forall d j i, + nth_default d (modulus_digits' j) i = + if lt_dec i (S j) + then (if (eq_nat_dec i 0) then max_bound i - c + 1 else max_bound i) + else d. + Proof. + induction j; intros; (break_if; [| apply nth_default_out_of_bounds; rewrite modulus_digits'_length; omega]). + + replace i with 0%nat by omega. + apply nth_default_cons. + + simpl. rewrite nth_default_app. + rewrite modulus_digits'_length. + break_if. + - rewrite IHj; break_if; try omega; reflexivity. + - replace i with (S j) by omega. + rewrite Nat.sub_diag, nth_default_cons. + reflexivity. + Qed. + + Lemma nth_default_modulus_digits : forall d i, + nth_default d modulus_digits i = + if lt_dec i (length base) + then (if (eq_nat_dec i 0) then max_bound i - c + 1 else max_bound i) + else d. + Proof. + unfold modulus_digits; intros. + rewrite nth_default_modulus_digits'. + replace (S (length base - 1)) with (length base) by omega. + reflexivity. + Qed. + + Lemma carry_done_modulus_digits : carry_done modulus_digits. + Proof. + apply carry_done_bounds; [apply modulus_digits_length | ]. + intros. + rewrite nth_default_modulus_digits. + break_if; [ | split; auto; omega]. + break_if; subst; split; auto; try rewrite <- max_bound_log_cap; omega. + Qed. + Local Hint Resolve carry_done_modulus_digits. + + (* TODO : move *) + Lemma decode_mod : forall us vs x, (length us = length base) -> (length vs = length base) -> + decode us = x -> + BaseSystem.decode base us mod modulus = BaseSystem.decode base vs mod modulus -> + decode vs = x. + Proof. + unfold decode; intros until 2; intros decode_us_x BSdecode_eq. + rewrite ZToField_mod in decode_us_x |- *. + rewrite <-BSdecode_eq. + assumption. + Qed. + + Ltac simpl_list_lengths := repeat match goal with + | H : appcontext[length (@nil ?A)] |- _ => rewrite (@nil_length0 A) in H + | H : appcontext[length (_ :: _)] |- _ => rewrite length_cons in H + | |- appcontext[length (@nil ?A)] => rewrite (@nil_length0 A) + | |- appcontext[length (_ :: _)] => rewrite length_cons + end. + + Lemma map2_app : forall A B C (f : A -> B -> C) ls1 ls2 ls1' ls2', + (length ls1 = length ls2) -> + map2 f (ls1 ++ ls1') (ls2 ++ ls2') = map2 f ls1 ls2 ++ map2 f ls1' ls2'. + Proof. + induction ls1, ls2; intros; rewrite ?map2_nil_r, ?app_nil_l; try congruence; + simpl_list_lengths; try omega. + rewrite <-!app_comm_cons, !map2_cons. + rewrite IHls1; auto. + Qed. + + Lemma decode_map2_sub : forall us vs, + (length us = length vs) -> + BaseSystem.decode' base (map2 (fun x y => x - y) us vs) + = BaseSystem.decode' base us - BaseSystem.decode' base vs. + Proof. + induction us using rev_ind; induction vs using rev_ind; + intros; autorewrite with lengths in *; simpl_list_lengths; + rewrite ?decode_nil; try omega. + rewrite map2_app by omega. + rewrite map2_cons, map2_nil_l. + rewrite !set_higher. + autorewrite with lengths. + rewrite Min.min_l by omega. + rewrite IHus by omega. + replace (length vs) with (length us) by omega. + ring. + Qed. + + Lemma decode_modulus_digits' : forall i, (i <= length base)%nat -> + BaseSystem.decode' base (modulus_digits' i) = 2 ^ (sum_firstn limb_widths (S i)) - c. + Proof. + induction i; intros; unfold modulus_digits'; fold modulus_digits'. + + case_eq base; + [ intro base_eq; rewrite base_eq, (@nil_length0 Z) in lt_1_length_base; omega | ]. + intros z ? base_eq. + rewrite decode'_cons, decode_nil, Z.add_0_r. + replace z with (nth_default 0 base 0) by (rewrite base_eq; auto). + rewrite nth_default_base by omega. + replace (max_bound 0 - c + 1) with (Z.succ (max_bound 0) - c) by ring. + rewrite max_bound_log_cap. + rewrite sum_firstn_succ with (x := log_cap 0) by (rewrite log_cap_eq; + apply nth_error_Some_nth_default; rewrite <-base_length; omega). + rewrite Z.pow_add_r by auto. + cbv [sum_firstn fold_right firstn]. + ring. + + assert (S i < length base \/ S i = length base)%nat as cases by omega. + destruct cases. + - rewrite sum_firstn_succ with (x := log_cap (S i)) by + (rewrite log_cap_eq; apply nth_error_Some_nth_default; + rewrite <-base_length; omega). + rewrite Z.pow_add_r, <-max_bound_log_cap, set_higher by auto. + rewrite IHi, modulus_digits'_length, nth_default_base by omega. + ring. + - rewrite sum_firstn_all_succ by (rewrite <-base_length; omega). + rewrite decode'_splice, modulus_digits'_length, firstn_all by auto. + rewrite skipn_all, decode_base_nil, Z.add_0_r by omega. + apply IHi. + omega. + Qed. + + Lemma decode_modulus_digits : BaseSystem.decode' base modulus_digits = modulus. + Proof. + unfold modulus_digits; rewrite decode_modulus_digits' by omega. + replace (S (length base - 1)) with (length base) by omega. + rewrite base_length. + fold k. unfold c. + ring. + Qed. + + Lemma map_land_max_ones_modulus_digits' : forall i, + map (Z.land max_ones) (modulus_digits' i) = (modulus_digits' i). + Proof. + induction i; intros. + + cbv [modulus_digits' map]. + f_equal. + apply land_max_ones_noop with (i := 0%nat). + rewrite <-max_bound_log_cap. + omega. + + unfold modulus_digits'; fold modulus_digits'. + rewrite map_app. + f_equal; [ apply IHi; omega | ]. + cbv [map]; f_equal. + apply land_max_ones_noop with (i := S i). + rewrite <-max_bound_log_cap. + split; auto; omega. + Qed. + + Lemma map_land_max_ones_modulus_digits : map (Z.land max_ones) modulus_digits = modulus_digits. + Proof. + apply map_land_max_ones_modulus_digits'. + Qed. + + Opaque modulus_digits. + + Lemma map_land_zero : forall ls, map (Z.land 0) ls = BaseSystem.zeros (length ls). + Proof. + induction ls; boring. + Qed. + + Lemma carry_full_preserves_Fdecode : forall us x, (length us = length base) -> + decode us = x -> decode (carry_full us) = x. + Proof. + intros. + apply carry_full_preserves_rep; auto. + unfold rep; auto. + Qed. + + Lemma freeze_preserves_rep : forall us x, rep us x -> rep (freeze us) x. + Proof. + unfold rep; intros. + intuition; rewrite ?freeze_length; auto. + unfold freeze, and_term. + break_if. + + apply decode_mod with (us := carry_full (carry_full (carry_full us))). + - rewrite carry_full_3_length; auto. + - autorewrite with lengths. + apply Min.min_r. + simpl_lengths; omega. + - repeat apply carry_full_preserves_rep; repeat rewrite carry_full_length; auto. + unfold rep; intuition. + - rewrite decode_map2_sub by (simpl_lengths; omega). + rewrite map_land_max_ones_modulus_digits. + rewrite decode_modulus_digits. + destruct (Z_eq_dec modulus 0); [ subst; rewrite !Zmod_0_r; reflexivity | ]. + rewrite <-Z.add_opp_r. + replace (-modulus) with (-1 * modulus) by ring. + symmetry; auto using Z.mod_add. + + eapply decode_mod; eauto. + simpl_lengths. + rewrite map_land_zero, decode_map2_sub, zeros_rep, Z.sub_0_r by simpl_lengths. + match goal with H : decode ?us = ?x |- _ => erewrite Fdecode_decode_mod; eauto; + do 3 apply carry_full_preserves_Fdecode in H; simpl_lengths + end. + erewrite Fdecode_decode_mod; eauto; simpl_lengths. + Qed. + Hint Resolve freeze_preserves_rep. + + Lemma isFull_true_iff : forall us, (length us = length base) -> (isFull us = true <-> + max_bound 0 - c < nth_default 0 us 0 + /\ (forall i, (0 < i <= length base - 1)%nat -> nth_default 0 us i = max_bound i)). + Proof. + unfold isFull; intros; auto using isFull'_true_iff. + Qed. + + Definition minimal_rep us := BaseSystem.decode base us = (BaseSystem.decode base us) mod modulus. + + Fixpoint compare' us vs i := + match i with + | O => Eq + | S i' => if Z_eq_dec (nth_default 0 us i') (nth_default 0 vs i') + then compare' us vs i' + else Z.compare (nth_default 0 us i') (nth_default 0 vs i') + end. + + (* Lexicographically compare two vectors of equal length, starting from the END of the list + (in our context, this is the most significant end). NOT constant time. *) + Definition compare us vs := compare' us vs (length us). + + Lemma compare'_Eq : forall us vs i, (length us = length vs) -> + compare' us vs i = Eq -> firstn i us = firstn i vs. + Proof. + induction i; intros; [ cbv; congruence | ]. + destruct (lt_dec i (length us)). + + repeat rewrite firstn_succ with (d := 0) by omega. + match goal with H : compare' _ _ (S _) = Eq |- _ => + inversion H end. + break_if; f_equal; auto. + - f_equal; auto. + - rewrite Z.compare_eq_iff in *. congruence. + - rewrite Z.compare_eq_iff in *. congruence. + + rewrite !firstn_all_strong in IHi by omega. + match goal with H : compare' _ _ (S _) = Eq |- _ => + inversion H end. + rewrite (nth_default_out_of_bounds i us) in * by omega. + rewrite (nth_default_out_of_bounds i vs) in * by omega. + break_if; try congruence. + f_equal; auto. + Qed. + + Lemma compare_Eq : forall us vs, (length us = length vs) -> + compare us vs = Eq -> us = vs. + Proof. + intros. + erewrite <-(firstn_all _ us); eauto. + erewrite <-(firstn_all _ vs); eauto. + apply compare'_Eq; auto. + Qed. + + Lemma decode_lt_next_digit : forall us n, (length us = length base) -> + (n < length base)%nat -> (n < length us)%nat -> + carry_done us -> + BaseSystem.decode' (firstn n base) (firstn n us) < + (nth_default 0 base n). + Proof. + induction n; intros ? ? ? bounded. + + cbv [firstn]. + rewrite decode_base_nil. + apply Z.gt_lt; auto using nth_default_base_positive. + + rewrite decode_firstn_succ by (auto || omega). + rewrite nth_default_base_succ by omega. + eapply Z.lt_le_trans. + - apply Z.add_lt_mono_r. + apply IHn; auto; omega. + - rewrite <-(Z.mul_1_r (nth_default 0 base n)) at 1. + rewrite <-Z.mul_add_distr_l, Z.mul_comm. + apply Z.mul_le_mono_pos_r. + * apply Z.gt_lt. apply nth_default_base_positive; omega. + * rewrite Z.add_1_l. + apply Z.le_succ_l. + rewrite carry_done_bounds in bounded by assumption. + apply bounded. + Qed. + + Lemma highest_digit_determines : forall us vs n x, (x < 0) -> + (length us = length base) -> + (length vs = length base) -> + (n < length us)%nat -> carry_done us -> + (n < length vs)%nat -> carry_done vs -> + BaseSystem.decode (firstn n base) (firstn n us) + + nth_default 0 base n * x - + BaseSystem.decode (firstn n base) (firstn n vs) < 0. + Proof. + intros. + eapply Z.le_lt_trans. + + apply Z.le_sub_nonneg. + apply decode_carry_done_lower_bound'; auto. + + eapply Z.le_lt_trans. + - eapply Z.add_le_mono with (q := nth_default 0 base n * -1); [ apply Z.le_refl | ]. + apply Z.mul_le_mono_nonneg_l; try omega. + rewrite nth_default_base by omega; apply Z.pow_nonneg; omega. + - ring_simplify. + apply Z.lt_sub_0. + apply decode_lt_next_digit; auto. + omega. + Qed. + + Lemma Z_compare_decode_step_eq : forall n us vs, + (length us = length base) -> + (length us = length vs) -> + (S n <= length base)%nat -> + (nth_default 0 us n = nth_default 0 vs n) -> + (BaseSystem.decode (firstn (S n) base) us ?= + BaseSystem.decode (firstn (S n) base) vs) = + (BaseSystem.decode (firstn n base) us ?= + BaseSystem.decode (firstn n base) vs). + Proof. + intros until 3; intro nth_default_eq. + destruct (lt_dec n (length us)); try omega. + rewrite firstn_succ with (d := 0), !base_app by omega. + autorewrite with lengths; rewrite Min.min_l by omega. + do 2 (rewrite skipn_nth_default with (d := 0) by omega; + rewrite decode'_cons, decode_base_nil, Z.add_0_r). + rewrite Z.compare_sub, nth_default_eq, Z.add_add_simpl_r_r. + rewrite BaseSystem.decode'_truncate with (us := us). + rewrite BaseSystem.decode'_truncate with (us := vs). + rewrite firstn_length, Min.min_l, <-Z.compare_sub by omega. + reflexivity. + Qed. + + Lemma Z_compare_decode_step_lt : forall n us vs, + (length us = length base) -> + (length us = length vs) -> + (S n <= length base)%nat -> + carry_done us -> carry_done vs -> + (nth_default 0 us n < nth_default 0 vs n) -> + (BaseSystem.decode (firstn (S n) base) us ?= + BaseSystem.decode (firstn (S n) base) vs) = Lt. + Proof. + intros until 5; intro nth_default_lt. + destruct (lt_dec n (length us)). + + rewrite firstn_succ with (d := 0) by omega. + rewrite !base_app. + autorewrite with lengths; rewrite Min.min_l by omega. + do 2 (rewrite skipn_nth_default with (d := 0) by omega; + rewrite decode'_cons, decode_base_nil, Z.add_0_r). + rewrite Z.compare_sub. + apply Z.compare_lt_iff. + ring_simplify. + rewrite <-Z.add_sub_assoc. + rewrite <-Z.mul_sub_distr_l. + apply highest_digit_determines; auto; omega. + + rewrite !nth_default_out_of_bounds in nth_default_lt; omega. + Qed. + + Lemma Z_compare_decode_step_neq : forall n us vs, + (length us = length base) -> (length us = length vs) -> + (S n <= length base)%nat -> + carry_done us -> carry_done vs -> + (nth_default 0 us n <> nth_default 0 vs n) -> + (BaseSystem.decode (firstn (S n) base) us ?= + BaseSystem.decode (firstn (S n) base) vs) = + (nth_default 0 us n ?= nth_default 0 vs n). + Proof. + intros. + destruct (Z_dec (nth_default 0 us n) (nth_default 0 vs n)) as [[?|Hgt]|?]; try congruence. + + etransitivity; try apply Z_compare_decode_step_lt; auto. + + match goal with |- (?a ?= ?b) = (?c ?= ?d) => + rewrite (Z.compare_antisym b a); rewrite (Z.compare_antisym d c) end. + apply CompOpp_inj; rewrite !CompOpp_involutive. + apply gt_lt_symmetry in Hgt. + etransitivity; try apply Z_compare_decode_step_lt; auto; omega. + Qed. + + Lemma decode_compare' : forall n us vs, + (length us = length base) -> + (length us = length vs) -> + (n <= length base)%nat -> + carry_done us -> carry_done vs -> + (BaseSystem.decode (firstn n base) us ?= BaseSystem.decode (firstn n base) vs) + = compare' us vs n. + Proof. + induction n; intros. + + cbv [firstn compare']; rewrite !decode_base_nil; auto. + + unfold compare'; fold compare'. + break_if. + - rewrite Z_compare_decode_step_eq by (auto || omega). + apply IHn; auto; omega. + - rewrite Z_compare_decode_step_neq; (auto || omega). + Qed. + + Lemma decode_compare : forall us vs, + (length us = length base) -> carry_done us -> + (length vs = length base) -> carry_done vs -> + Z.compare (BaseSystem.decode base us) (BaseSystem.decode base vs) = compare us vs. + Proof. + unfold compare; intros. + erewrite <-(firstn_all _ base). + + apply decode_compare'; auto; omega. + + assumption. + Qed. + + Lemma compare'_succ : forall us j vs, compare' us vs (S j) = + if Z.eq_dec (nth_default 0 us j) (nth_default 0 vs j) + then compare' us vs j + else nth_default 0 us j ?= nth_default 0 vs j. + Proof. + reflexivity. + Qed. + + Lemma compare'_firstn_r_small_index : forall us j vs, (j <= length vs)%nat -> + compare' us vs j = compare' us (firstn j vs) j. + Proof. + induction j; intros; auto. + rewrite !compare'_succ by omega. + rewrite firstn_succ with (d := 0) by omega. + rewrite nth_default_app. + simpl_lengths. + rewrite Min.min_l by omega. + destruct (lt_dec j j); try omega. + rewrite Nat.sub_diag. + rewrite nth_default_cons. + break_if; try reflexivity. + rewrite IHj with (vs := firstn j vs ++ nth_default 0 vs j :: nil) by + (autorewrite with lengths; rewrite Min.min_l; omega). + rewrite firstn_app_sharp by (autorewrite with lengths; apply Min.min_l; omega). + apply IHj; omega. + Qed. + + Lemma compare'_firstn_r : forall us j vs, + compare' us vs j = compare' us (firstn j vs) j. + Proof. + intros. + destruct (le_dec j (length vs)). + + auto using compare'_firstn_r_small_index. + + f_equal. symmetry. + apply firstn_all_strong. + omega. + Qed. + + Lemma compare'_not_Lt : forall us vs j, j <> 0%nat -> + (forall i, (0 < i < j)%nat -> 0 <= nth_default 0 us i <= nth_default 0 vs i) -> + compare' us vs j <> Lt -> + nth_default 0 vs 0 <= nth_default 0 us 0 /\ + (forall i : nat, (0 < i < j)%nat -> nth_default 0 us i = nth_default 0 vs i). + Proof. + induction j; try congruence. + rewrite compare'_succ. + intros; destruct (eq_nat_dec j 0). + + break_if; subst; split; intros; try omega. + rewrite Z.compare_ge_iff in *; omega. + + break_if. + - split; intros; [ | destruct (eq_nat_dec i j); subst; auto ]; + apply IHj; auto; intros; try omega; + match goal with H : forall i, _ -> 0 <= ?f i <= ?g i |- 0 <= ?f _ <= ?g _ => + apply H; omega end. + - exfalso. rewrite Z.compare_ge_iff in *. + match goal with H : forall i, ?P -> 0 <= ?f i <= ?g i |- _ => + specialize (H j) end; omega. + Qed. + + Lemma isFull'_compare' : forall us j, j <> 0%nat -> (length us = length base) -> + (j <= length base)%nat -> carry_done us -> + (isFull' us true (j - 1) = true <-> compare' us modulus_digits j <> Lt). + Proof. + unfold compare; induction j; intros; try congruence. + replace (S j - 1)%nat with j by omega. + split; intros. + + simpl. + break_if; [destruct (eq_nat_dec j 0) | ]. + - subst. cbv; congruence. + - apply IHj; auto; try omega. + apply isFull'_true_step. + replace (S (j - 1)) with j by omega; auto. + - rewrite nth_default_modulus_digits in *. + repeat (break_if; try omega). + * subst. + match goal with H : isFull' _ _ _ = true |- _ => + apply isFull'_lower_bound_0 in H end. + apply Z.compare_ge_iff. + omega. + * match goal with H : isFull' _ _ _ = true |- _ => + apply isFull'_true_iff in H; try assumption; destruct H as [? eq_max_bound] end. + specialize (eq_max_bound j). + omega. + + apply isFull'_true_iff; try assumption. + match goal with H : compare' _ _ _ <> Lt |- _ => apply compare'_not_Lt in H; [ destruct H as [Hdigit0 Hnonzero] | | ] end. + - split; [ | intros i i_range; assert (0 < i < S j)%nat as i_range' by omega; + specialize (Hnonzero i i_range')]; + rewrite nth_default_modulus_digits in *; + repeat (break_if; try omega). + - congruence. + - intros. + rewrite nth_default_modulus_digits. + repeat (break_if; try omega). + rewrite <-Z.lt_succ_r with (m := max_bound i). + rewrite max_bound_log_cap; apply carry_done_bounds; assumption. + Qed. + + Lemma isFull_compare : forall us, (length us = length base) -> carry_done us -> + (isFull us = true <-> compare us modulus_digits <> Lt). + Proof. + unfold compare, isFull; intros ? lengths_eq. intros. + rewrite lengths_eq. + apply isFull'_compare'; try omega. + assumption. + Qed. + + Lemma isFull_decode : forall us, (length us = length base) -> carry_done us -> + (isFull us = true <-> + (BaseSystem.decode base us ?= BaseSystem.decode base modulus_digits <> Lt)). + Proof. + intros. + rewrite decode_compare; autorewrite with lengths; auto. + apply isFull_compare; auto. + Qed. + + Lemma isFull_false_upper_bound : forall us, (length us = length base) -> + carry_done us -> isFull us = false -> + BaseSystem.decode base us < modulus. + Proof. + intros. + destruct (Z_lt_dec (BaseSystem.decode base us) modulus) as [? | nlt_modulus]; + [assumption | exfalso]. + apply Z.compare_nlt_iff in nlt_modulus. + rewrite <-decode_modulus_digits in nlt_modulus at 2. + apply isFull_decode in nlt_modulus; try assumption; congruence. + Qed. + + Lemma isFull_true_lower_bound : forall us, (length us = length base) -> + carry_done us -> isFull us = true -> + modulus <= BaseSystem.decode base us. + Proof. + intros. + rewrite <-decode_modulus_digits at 1. + apply Z.compare_ge_iff. + apply isFull_decode; auto. + Qed. + + Lemma freeze_in_bounds : forall us, + pre_carry_bounds us -> (length us = length base) -> + carry_done (freeze us). + Proof. + unfold freeze, and_term; intros ? PCB lengths_eq. + rewrite carry_done_bounds by simpl_lengths; intro i. + rewrite nth_default_map2 with (d1 := 0) (d2 := 0). + simpl_lengths. + break_if; [ | split; (omega || auto)]. + break_if. + + rewrite map_land_max_ones_modulus_digits. + apply isFull_true_iff in Heqb; [ | simpl_lengths]. + destruct Heqb as [first_digit high_digits]. + destruct (eq_nat_dec i 0). + - subst. + clear high_digits. + rewrite nth_default_modulus_digits. + repeat (break_if; try omega). + pose proof (carry_full_3_done us PCB lengths_eq) as cf3_done. + rewrite carry_done_bounds in cf3_done by simpl_lengths. + specialize (cf3_done 0%nat). + omega. + - assert ((0 < i <= length base - 1)%nat) as i_range by + (simpl_lengths; apply lt_min_l in l; omega). + specialize (high_digits i i_range). + clear first_digit i_range. + rewrite high_digits. + rewrite <-max_bound_log_cap. + rewrite nth_default_modulus_digits. + repeat (break_if; try omega). + * rewrite Z.sub_diag. + split; try omega. + apply Z.lt_succ_r; auto. + * rewrite Z.lt_succ_r, Z.sub_0_r. split; (omega || auto). + + rewrite map_land_zero, nth_default_zeros. + rewrite Z.sub_0_r. + apply carry_done_bounds; [ simpl_lengths | ]. + auto using carry_full_3_done. + Qed. + Local Hint Resolve freeze_in_bounds. + + Local Hint Resolve carry_full_3_done. + + Lemma freeze_minimal_rep : forall us, pre_carry_bounds us -> (length us = length base) -> + minimal_rep (freeze us). + Proof. + unfold minimal_rep, freeze, and_term. + intros. + symmetry. apply Z.mod_small. + split; break_if; rewrite decode_map2_sub; simpl_lengths. + + rewrite map_land_max_ones_modulus_digits, decode_modulus_digits. + apply Z.le_0_sub. + apply isFull_true_lower_bound; simpl_lengths. + + rewrite map_land_zero, zeros_rep, Z.sub_0_r. + apply decode_carry_done_lower_bound; simpl_lengths. + + rewrite map_land_max_ones_modulus_digits, decode_modulus_digits. + rewrite Z.lt_sub_lt_add_r. + apply Z.lt_le_trans with (m := 2 * modulus); try omega. + eapply Z.lt_le_trans; [ | apply two_pow_k_le_2modulus ]. + apply decode_carry_done_upper_bound; simpl_lengths. + + rewrite map_land_zero, zeros_rep, Z.sub_0_r. + apply isFull_false_upper_bound; simpl_lengths. + Qed. + Local Hint Resolve freeze_minimal_rep. + + Lemma rep_decode_mod : forall us vs x, rep us x -> rep vs x -> + (BaseSystem.decode base us) mod modulus = (BaseSystem.decode base vs) mod modulus. + Proof. + unfold rep, decode; intros. + intuition. + repeat rewrite <-FieldToZ_ZToField. + congruence. + Qed. + + Lemma minimal_rep_unique : forall us vs x, + rep us x -> minimal_rep us -> carry_done us -> + rep vs x -> minimal_rep vs -> carry_done vs -> + us = vs. + Proof. + intros. + match goal with Hrep1 : rep _ ?x, Hrep2 : rep _ ?x |- _ => + pose proof (rep_decode_mod _ _ _ Hrep1 Hrep2) as eqmod end. + repeat match goal with Hmin : minimal_rep ?us |- _ => unfold minimal_rep in Hmin; + rewrite <- Hmin in eqmod; clear Hmin end. + apply Z.compare_eq_iff in eqmod. + rewrite decode_compare in eqmod; unfold rep in *; auto; intuition; try congruence. + apply compare_Eq; auto. + congruence. + Qed. + + Lemma freeze_canonical : forall us vs x, + pre_carry_bounds us -> rep us x -> + pre_carry_bounds vs -> rep vs x -> freeze us = freeze vs. - Admitted. + Proof. + intros. + assert (length us = length base) by (unfold rep in *; intuition). + assert (length vs = length base) by (unfold rep in *; intuition). + eapply minimal_rep_unique; eauto; rewrite freeze_length; assumption. + Qed. End CanonicalizationProofs.
\ No newline at end of file diff --git a/src/ModularArithmetic/Pre.v b/src/ModularArithmetic/Pre.v index 2978fdd42..fca5576b7 100644 --- a/src/ModularArithmetic/Pre.v +++ b/src/ModularArithmetic/Pre.v @@ -2,6 +2,7 @@ Require Import Coq.ZArith.BinInt Coq.NArith.BinNat Coq.Numbers.BinNums Coq.ZArit Require Import Coq.Logic.Eqdep_dec. Require Import Coq.Logic.EqdepFacts. Require Import Crypto.Tactics.VerdiTactics. +Require Import Coq.omega.Omega. Lemma Z_mod_mod x m : x mod m = (x mod m) mod m. symmetry. @@ -46,7 +47,7 @@ Defined. Definition mulmod m := fun a b => a * b mod m. Definition powmod_pos m := Pos.iter_op (mulmod m). Definition powmod m a x := match x with N0 => 1 mod m | Npos p => powmod_pos m p (a mod m) end. - + Lemma mulmod_assoc: forall m x y z : Z, mulmod m x (mulmod m y z) = mulmod m (mulmod m x y) z. Proof. @@ -144,7 +145,7 @@ Definition mod_inv_eucl (a m:Z) : Z. (match d with Z.pos _ => u | _ => -u end) end) mod m). Defined. - + Lemma reduced_nonzero_pos: forall a m : Z, m > 0 -> a <> 0 -> a = a mod m -> 0 < a. Proof. @@ -209,7 +210,7 @@ Proof. unfold mod_inv_eucl; simpl. lazymatch goal with [ |- context [euclid ?a ?b] ] => destruct (euclid a b) end. auto. - - + - destruct a. cbv [proj1_sig mod_inv_eucl_sig]. rewrite Z.mul_comm. @@ -217,4 +218,4 @@ Proof. rewrite mod_inv_eucl_correct; eauto. intro; destruct H0. eapply exist_reduced_eq. congruence. -Qed.
\ No newline at end of file +Qed. diff --git a/src/ModularArithmetic/PrimeFieldTheorems.v b/src/ModularArithmetic/PrimeFieldTheorems.v index 70a2c4a87..2021e8514 100644 --- a/src/ModularArithmetic/PrimeFieldTheorems.v +++ b/src/ModularArithmetic/PrimeFieldTheorems.v @@ -10,6 +10,7 @@ Require Import Coq.ZArith.BinInt Coq.NArith.BinNat Coq.ZArith.ZArith Coq.ZArith. Require Import Coq.Logic.Eqdep_dec. Require Import Crypto.Util.NumTheoryUtil Crypto.Util.ZUtil. Require Import Crypto.Util.Tactics. +Require Crypto.Algebra. Existing Class prime. @@ -51,6 +52,14 @@ Section FieldModuloPre. Proof. constructor; auto using Fring_theory, Fq_1_neq_0, F_mul_inv_l. Qed. + + Global Instance field_modulo : @Algebra.field (F q) Logic.eq (ZToField 0) (ZToField 1) opp add sub mul inv div. + Proof. + constructor; try solve_proper. + - apply commutative_ring_modulo. + - split. auto using F_mul_inv_l. + - split. auto using Fq_1_neq_0. + Qed. End FieldModuloPre. Module Type PrimeModulus. diff --git a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v index 1a7b3316e..49b1875ce 100644 --- a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v +++ b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v @@ -32,7 +32,7 @@ Section PseudoMersenneBaseParamProofs. unfold value in *. congruence. Qed. - + Lemma base_from_limb_widths_step : forall i b w, (S i < length base)%nat -> nth_error base i = Some b -> nth_error limb_widths i = Some w -> @@ -45,7 +45,7 @@ Section PseudoMersenneBaseParamProofs. case_eq i; intros; subst. + subst; apply nth_error_first in nth_err_w. apply nth_error_first in nth_err_b; subst. - apply map_nth_error. + apply map_nth_error. case_eq l; intros; subst; [simpl in *; omega | ]. unfold base_from_limb_widths; fold base_from_limb_widths. reflexivity. @@ -65,7 +65,7 @@ Section PseudoMersenneBaseParamProofs. apply nth_error_first in H. subst; eauto. Qed. - + Lemma sum_firstn_succ : forall l i x, nth_error l i = Some x -> sum_firstn l (S i) = x + sum_firstn l i. @@ -89,6 +89,13 @@ Section PseudoMersenneBaseParamProofs. - rewrite IHl by auto; ring. Qed. + Lemma limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w. + Proof. + intros. + apply Z.lt_le_incl. + auto using limb_widths_pos. + Qed. + Lemma sum_firstn_limb_widths_nonneg : forall n, 0 <= sum_firstn limb_widths n. Proof. unfold sum_firstn; intros. @@ -110,7 +117,7 @@ Section PseudoMersenneBaseParamProofs. induction i; intros. + unfold base, sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity. intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega. - + + + assert (i < length base)%nat as lt_i_length by omega. specialize (IHi lt_i_length). rewrite base_length in lt_i_length. @@ -131,7 +138,7 @@ Section PseudoMersenneBaseParamProofs. apply limb_widths_nonneg. eapply nth_error_value_In; eauto. Qed. - + Lemma nth_default_base : forall d i, (i < length base)%nat -> nth_default d base i = 2 ^ (sum_firstn limb_widths i). Proof. @@ -171,7 +178,7 @@ Section PseudoMersenneBaseParamProofs. + rewrite base_length in *; apply limb_widths_match_modulus; assumption. Qed. - Lemma base_succ : forall i, ((S i) < length base)%nat -> + Lemma base_succ : forall i, ((S i) < length base)%nat -> nth_default 0 base (S i) mod nth_default 0 base i = 0. Proof. intros. @@ -219,7 +226,7 @@ Section PseudoMersenneBaseParamProofs. Proof. unfold base; case_eq limb_widths; intros; [pose proof limb_widths_nonnil; congruence | reflexivity]. Qed. - + Lemma base_good : forall i j : nat, (i + j < length base)%nat -> let b := nth_default 0 base in diff --git a/src/ModularArithmetic/PseudoMersenneBaseParams.v b/src/ModularArithmetic/PseudoMersenneBaseParams.v index 3914d6219..e20a7ed09 100644 --- a/src/ModularArithmetic/PseudoMersenneBaseParams.v +++ b/src/ModularArithmetic/PseudoMersenneBaseParams.v @@ -7,7 +7,7 @@ Definition sum_firstn l n := fold_right Z.add 0 (firstn n l). Class PseudoMersenneBaseParams (modulus : Z) := { limb_widths : list Z; - limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w; + limb_widths_pos : forall w, In w limb_widths -> 0 < w; limb_widths_nonnil : limb_widths <> nil; limb_widths_good : forall i j, (i + j < length limb_widths)%nat -> sum_firstn limb_widths (i + j) <= diff --git a/src/ModularArithmetic/PseudoMersenneBaseRep.v b/src/ModularArithmetic/PseudoMersenneBaseRep.v index c16cc8d38..6b4d29a35 100644 --- a/src/ModularArithmetic/PseudoMersenneBaseRep.v +++ b/src/ModularArithmetic/PseudoMersenneBaseRep.v @@ -25,7 +25,7 @@ Class RepZMod (modulus : Z) := { Class SubtractionCoefficient (m : Z) (prm : PseudoMersenneBaseParams m) := { coeff : BaseSystem.digits; - coeff_length : (length coeff <= length PseudoMersenneBaseParamProofs.base)%nat; + coeff_length : (length coeff = length PseudoMersenneBaseParamProofs.base)%nat; coeff_mod: (BaseSystem.decode PseudoMersenneBaseParamProofs.base coeff) mod m = 0 }. @@ -45,6 +45,6 @@ Instance PseudoMersenneBase m (prm : PseudoMersenneBaseParams m) (sc : Subtracti sub := ModularBaseSystem.sub coeff coeff_mod; sub_rep := ModularBaseSystemProofs.sub_rep coeff coeff_mod coeff_length; - mul := ModularBaseSystem.mul; - mul_rep := ModularBaseSystemProofs.mul_rep + mul := ModularBaseSystem.carry_mul; + mul_rep := ModularBaseSystemProofs.carry_mul_rep }. diff --git a/src/ModularArithmetic/Tutorial.v b/src/ModularArithmetic/Tutorial.v index d6c7fa4b8..7d354ab3e 100644 --- a/src/ModularArithmetic/Tutorial.v +++ b/src/ModularArithmetic/Tutorial.v @@ -9,9 +9,9 @@ Section Mod24. (* Specify modulus *) Let q := 24. - + (* Boilerplate for letting Z numbers be interpreted as field elements *) - Local Coercion ZToFq := ZToField : BinNums.Z -> F q. Hint Unfold ZToFq. + Let ZToFq := ZToField : BinNums.Z -> F q. Hint Unfold ZToFq. Local Coercion ZToFq : Z >-> F. (* Boilerplate for [ring]. Similar boilerplate works for [field] if the modulus is prime . *) @@ -21,7 +21,7 @@ Section Mod24. postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption], constants [Fconstant], div (@Fmorph_div_theory q), - power_tac (@Fpower_theory q) [Fexp_tac]). + power_tac (@Fpower_theory q) [Fexp_tac]). Lemma sumOfSquares : forall a b: F q, (a+b)^2 = a^2 + 2*a*b + b^2. Proof. @@ -37,9 +37,9 @@ Section Modq. (* Set notations + - * / refer to F operations *) Local Open Scope F_scope. - + (* Boilerplate for letting Z numbers be interpreted as field elements *) - Local Coercion ZToFq := ZToField : BinNums.Z -> F q. Hint Unfold ZToFq. + Let ZToFq := ZToField : BinNums.Z -> F q. Hint Unfold ZToFq. Local Coercion ZToFq : Z >-> F. (* Boilerplate for [field]. Similar boilerplate works for [ring] if the modulus is not prime . *) @@ -49,7 +49,7 @@ Section Modq. postprocess [Fpostprocess; try exact Fq_1_neq_0; try assumption], constants [Fconstant], div (@Fmorph_div_theory q), - power_tac (@Fpower_theory q) [Fexp_tac]). + power_tac (@Fpower_theory q) [Fexp_tac]). Lemma sumOfSquares' : forall a b c: F q, c <> 0 -> ((a+b)/c)^2 = a^2/c^2 + ZToField 2*(a/c)*(b/c) + b^2/c^2. Proof. @@ -170,7 +170,7 @@ Module TimesZeroParametricTestModule (M: PrimeModulus). field; try exact Fq_1_neq_0. Qed. - Lemma biggerFraction : forall XP YP ZP TP XQ YQ ZQ TQ d : F modulus, + Lemma biggerFraction : forall XP YP ZP TP XQ YQ ZQ TQ d : F modulus, ZQ <> 0 -> ZP <> 0 -> ZP * ZQ * ZP * ZQ + d * XP * XQ * YP * YQ <> 0 -> @@ -187,4 +187,3 @@ Module TimesZeroParametricTestModule (M: PrimeModulus). field; assumption. Qed. End TimesZeroParametricTestModule. - diff --git a/src/Rep.v b/src/Rep.v deleted file mode 100644 index b7e7f10c5..000000000 --- a/src/Rep.v +++ /dev/null @@ -1,13 +0,0 @@ -Class RepConversions (T:Type) (RT:Type) : Type := - { - toRep : T -> RT; - unRep : RT -> T - }. - -Definition RepConversionsOK {T RT} (RC:RepConversions T RT) := forall x, unRep (toRep x) = x. - -Definition RepFunOK {T RT} `(RC:RepConversions T RT) (f:T->T) (rf : RT -> RT) := - forall x, f (unRep x) = unRep (rf x). - -Definition RepBinOpOK {T RT} `(RC:RepConversions T RT) (op:T->T->T) (rop : RT -> RT -> RT) := - forall x y, op (unRep x) (unRep y) = unRep (rop x y). diff --git a/src/Spec/CompleteEdwardsCurve.v b/src/Spec/CompleteEdwardsCurve.v index 3348be1d9..06c3f8fdb 100644 --- a/src/Spec/CompleteEdwardsCurve.v +++ b/src/Spec/CompleteEdwardsCurve.v @@ -1,46 +1,45 @@ -Require Coq.ZArith.BinInt Coq.ZArith.Znumtheory. - Require Crypto.CompleteEdwardsCurve.Pre. -Require Import Crypto.Spec.ModularArithmetic. -Local Open Scope F_scope. - -Class TwistedEdwardsParams := { - q : BinInt.Z; - a : F q; - d : F q; - prime_q : Znumtheory.prime q; - two_lt_q : BinInt.Z.lt 2 q; - nonzero_a : a <> 0; - square_a : exists sqrt_a, sqrt_a^2 = a; - nonsquare_d : forall x, x^2 <> d -}. - Module E. Section TwistedEdwardsCurves. - Context {prm:TwistedEdwardsParams}. - (* Twisted Edwards curves with complete addition laws. References: * <https://eprint.iacr.org/2008/013.pdf> * <http://ed25519.cr.yp.to/ed25519-20110926.pdf> * <https://eprint.iacr.org/2015/677.pdf> *) - Definition onCurve P := let '(x,y) := P in a*x^2 + y^2 = 1 + d*x^2*y^2. - Definition point := { P | onCurve P}. - - Definition zero : point := exist _ (0, 1) (@Pre.zeroOnCurve _ _ _ prime_q). - - Definition add' P1' P2' := - let '(x1, y1) := P1' in - let '(x2, y2) := P2' in - (((x1*y2 + y1*x2)/(1 + d*x1*x2*y1*y2)) , ((y1*y2 - a*x1*x2)/(1 - d*x1*x2*y1*y2))). - - Definition add (P1 P2 : point) : point := - let 'exist P1' pf1 := P1 in - let 'exist P2' pf2 := P2 in - exist _ (add' P1' P2') - (@Pre.unifiedAdd'_onCurve _ _ _ prime_q two_lt_q nonzero_a square_a nonsquare_d _ _ pf1 pf2). - + + Context {F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv} `{Algebra.field F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv}. + Local Infix "=" := Feq : type_scope. Local Notation "a <> b" := (not (a = b)) : type_scope. + Local Notation "0" := Fzero. Local Notation "1" := Fone. + Local Infix "+" := Fadd. Local Infix "*" := Fmul. + Local Infix "-" := Fsub. Local Infix "/" := Fdiv. + Local Notation "x ^2" := (x*x) (at level 30). + + Context {a d: F}. + Class twisted_edwards_params := + { + char_gt_2 : 1 + 1 <> 0; + nonzero_a : a <> 0; + square_a : exists sqrt_a, sqrt_a^2 = a; + nonsquare_d : forall x, x^2 <> d + }. + Context `{twisted_edwards_params}. + + Definition point := { P | let '(x,y) := P in a*x^2 + y^2 = 1 + d*x^2*y^2 }. + Definition coordinates (P:point) : (F*F) := proj1_sig P. + + (** The following points are indeed on the curve -- see [CompleteEdwardsCurve.Pre] for proof *) + Local Obligation Tactic := intros; apply Pre.zeroOnCurve + || apply (Pre.unifiedAdd'_onCurve (char_gt_2:=char_gt_2) (d_nonsquare:=nonsquare_d) + (a_nonzero:=nonzero_a) (a_square:=square_a) _ _ (proj2_sig _) (proj2_sig _)). + + Program Definition zero : point := (0, 1). + + Program Definition add (P1 P2:point) : point := exist _ ( + let (x1, y1) := coordinates P1 in + let (x2, y2) := coordinates P2 in + (((x1*y2 + y1*x2)/(1 + d*x1*x2*y1*y2)) , ((y1*y2 - a*x1*x2)/(1 - d*x1*x2*y1*y2)))) _. + Fixpoint mul (n:nat) (P : point) : point := match n with | O => zero @@ -48,7 +47,7 @@ Module E. end. End TwistedEdwardsCurves. End E. - + Delimit Scope E_scope with E. Infix "+" := E.add : E_scope. Infix "*" := E.mul : E_scope.
\ No newline at end of file diff --git a/src/Spec/EdDSA.v b/src/Spec/EdDSA.v index 99f0766e0..d71f2ad44 100644 --- a/src/Spec/EdDSA.v +++ b/src/Spec/EdDSA.v @@ -1,87 +1,80 @@ Require Import Crypto.Spec.Encoding. -Require Import Crypto.Spec.ModularArithmetic. -Require Import Crypto.Spec.CompleteEdwardsCurve. - -Require Import Crypto.Util.WordUtil. -Require Bedrock.Word. +Require Bedrock.Word Crypto.Util.WordUtil. Require Coq.ZArith.Znumtheory Coq.ZArith.BinInt. Require Coq.Numbers.Natural.Peano.NPeano. Require Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. -Coercion Word.wordToNat : Word.word >-> nat. +Local Infix "^" := NPeano.pow. +Local Infix "mod" := NPeano.modulo (at level 40, no associativity). +Local Infix "++" := Word.combine. + +Generalizable All Variables. +Section EdDSA. + Class EdDSA (* <https://eprint.iacr.org/2015/677.pdf> *) + {E Eeq Eadd Ezero Eopp} {EscalarMult} (* the underllying elliptic curve operations *) -Infix "^" := NPeano.pow. -Infix "mod" := NPeano.modulo. -Infix "++" := Word.combine. + {b : nat} (* public keys are k bits, signatures are 2*k bits *) + {H : forall {n}, Word.word n -> Word.word (b + b)} (* main hash function *) + {c : nat} (* cofactor E = 2^c *) + {n : nat} (* secret keys are (n+1) bits *) + {l : nat} (* order of the subgroup of E generated by B *) -Section EdDSAParams. + {B : E} (* base point *) - Class EdDSAParams := { (* <https://eprint.iacr.org/2015/677.pdf> *) - E : TwistedEdwardsParams; (* underlying elliptic curve *) + {PointEncoding : canonical encoding of E as Word.word b} (* wire format *) + {FlEncoding : canonical encoding of { n | n < l } as Word.word b} + := + { + EdDSA_group:@Algebra.group E Eeq Eadd Ezero Eopp; - b : nat; (* public keys are k bits, signatures are 2*k bits *) - b_valid : 2^(b - 1) > BinInt.Z.to_nat q; - FqEncoding : canonical encoding of F q as Word.word (b-1); - PointEncoding : canonical encoding of E.point as Word.word b; + EdDSA_c_valid : c = 2 \/ c = 3; - H : forall {n}, Word.word n -> Word.word (b + b); (* main hash function *) + EdDSA_n_ge_c : n >= c; + EdDSA_n_le_b : n <= b; - c : nat; (* cofactor E = 2^c *) - c_valid : c = 2 \/ c = 3; + EdDSA_B_not_identity : B <> Ezero; - n : nat; (* secret keys are (n+1) bits *) - n_ge_c : n >= c; - n_le_b : n <= b; + EdDSA_l_prime : Znumtheory.prime (BinInt.Z.of_nat l); + EdDSA_l_odd : l > 2; + EdDSA_l_order_B : EscalarMult l B = Ezero + }. + Global Existing Instance EdDSA_group. - B : E.point; - B_not_identity : B <> E.zero; + Context `{prm:EdDSA}. - l : nat; (* order of the subgroup of E generated by B *) - l_prime : Znumtheory.prime (BinInt.Z.of_nat l); - l_odd : l > 2; - l_order_B : (l*B)%E = E.zero; - FlEncoding : canonical encoding of F (BinInt.Z.of_nat l) as Word.word b - }. -End EdDSAParams. + Local Infix "=" := Eeq. + Local Coercion Word.wordToNat : Word.word >-> nat. + Local Notation secretkey := (Word.word b) (only parsing). + Local Notation publickey := (Word.word b) (only parsing). + Local Notation signature := (Word.word (b + b)) (only parsing). -Section EdDSA. - Context {prm:EdDSAParams}. - Existing Instance E. - Existing Instance PointEncoding. - Existing Instance FlEncoding. - Existing Class le. - Existing Instance n_le_b. - - Notation secretkey := (Word.word b) (only parsing). - Notation publickey := (Word.word b) (only parsing). - Notation signature := (Word.word (b + b)) (only parsing). - Local Infix "==" := CompleteEdwardsCurveTheorems.E.point_eq_dec (at level 70) : E_scope . - - (* TODO: proofread curveKey and definition of n *) - Definition curveKey (sk:secretkey) : nat := - let x := wfirstn n sk in (* first half of the secret key is a scalar *) + Local Arguments H {n} _. + Local Notation wfirstn n w := (@WordUtil.wfirstn n _ w _) (only parsing). + + Require Import Omega. + Obligation Tactic := simpl; intros; try apply NPeano.Nat.mod_upper_bound; destruct prm; omega. + + Program Definition curveKey (sk:secretkey) : nat := + let x := wfirstn n (H sk) in (* hash the key, use first "half" for secret scalar *) let x := x - (x mod (2^c)) in (* it is implicitly 0 mod (2^c) *) x + 2^n. (* and the high bit is always set *) + + Local Infix "+" := Eadd. + Local Infix "*" := EscalarMult. + Definition prngKey (sk:secretkey) : Word.word b := Word.split2 b b (H sk). - Definition public (sk:secretkey) : publickey := enc (curveKey sk * B)%E. + Definition public (sk:secretkey) : publickey := enc (curveKey sk*B). - Definition sign (A_:publickey) sk {n} (M : Word.word n) := + Program Definition sign (A_:publickey) sk {n} (M : Word.word n) := let r : nat := H (prngKey sk ++ M) in (* secret nonce *) - let R : E.point := (r * B)%E in (* commitment to nonce *) + let R : E := r * B in (* commitment to nonce *) let s : nat := curveKey sk in (* secret scalar *) - let S : F (BinInt.Z.of_nat l) := ZToField (BinInt.Z.of_nat - (r + H (enc R ++ public sk ++ M) * s)) in + let S : {n|n<l} := exist _ ((r + H (enc R ++ public sk ++ M) * s) mod l) _ in enc R ++ enc S. - Definition verify (A_:publickey) {n:nat} (M : Word.word n) (sig:signature) : bool := - let R_ := Word.split1 b b sig in - let S_ := Word.split2 b b sig in - match dec S_ : option (F (BinInt.Z.of_nat l)) with None => false | Some S' => - match dec A_ : option E.point with None => false | Some A => - match dec R_ : option E.point with None => false | Some R => - if BinInt.Z.to_nat (FieldToZ S') * B == R + (H (R_ ++ A_ ++ M)) * A - then true else false - end - end - end%E. -End EdDSA.
\ No newline at end of file + (* For a [n]-bit [message] from public key [A_], validity of a signature [R_ ++ S_] *) + Inductive valid {n:nat} : Word.word n -> publickey -> signature -> Prop := + ValidityRule : forall (message:Word.word n) (A:E) (R:E) (S:nat) S_lt_l, + S * B = R + (H (enc R ++ enc A ++ message) mod l) * A + -> valid message (enc A) (enc R ++ enc (exist _ S S_lt_l)). +End EdDSA. diff --git a/src/Spec/ModularArithmetic.v b/src/Spec/ModularArithmetic.v index 76efe3d79..8ee07fe5d 100644 --- a/src/Spec/ModularArithmetic.v +++ b/src/Spec/ModularArithmetic.v @@ -26,8 +26,8 @@ Section FieldOperations. Context {m : BinInt.Z}. (* Coercion without Context {m} --> non-uniform inheritance --> Anomalies *) - Local Coercion ZToFm := ZToField : BinNums.Z -> F m. - + Let ZToFm := ZToField : BinNums.Z -> F m. Local Coercion ZToFm : BinNums.Z >-> F. + Definition add (a b:F m) : F m := ZToField (a + b). Definition mul (a b:F m) : F m := ZToField (a * b). @@ -69,4 +69,4 @@ Infix "-" := sub : F_scope. Infix "/" := div : F_scope. Infix "^" := pow : F_scope. Notation "0" := (ZToField 0) : F_scope. -Notation "1" := (ZToField 1) : F_scope.
\ No newline at end of file +Notation "1" := (ZToField 1) : F_scope. diff --git a/src/Spec/ModularWordEncoding.v b/src/Spec/ModularWordEncoding.v index d6f6bcb3c..acd2bedbd 100644 --- a/src/Spec/ModularWordEncoding.v +++ b/src/Spec/ModularWordEncoding.v @@ -28,7 +28,7 @@ Section ModularWordEncoding. | Word.WS b _ w' => b end. - Instance modular_word_encoding : canonical encoding of F m as word sz := { + Global Instance modular_word_encoding : canonical encoding of F m as word sz := { enc := Fm_enc; dec := Fm_dec; encoding_valid := diff --git a/src/Spec/PointEncoding.v b/src/Spec/PointEncoding.v deleted file mode 100644 index f4634f52f..000000000 --- a/src/Spec/PointEncoding.v +++ /dev/null @@ -1,47 +0,0 @@ -Require Coq.ZArith.ZArith Coq.ZArith.Znumtheory. -Require Coq.Numbers.Natural.Peano.NPeano. -Require Crypto.Encoding.EncodingTheorems. -Require Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. -Require Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Bedrock.Word. -Require Crypto.Tactics.VerdiTactics. -Require Crypto.Encoding.PointEncodingPre. -Obligation Tactic := eauto; exact PointEncodingPre.point_encoding_canonical. - -Require Import Crypto.Spec.Encoding Crypto.Spec.ModularWordEncoding. -Require Import Crypto.Spec.CompleteEdwardsCurve Crypto.Spec.ModularArithmetic. - -Local Open Scope F_scope. - -Section PointEncoding. - Context {prm: TwistedEdwardsParams} {sz : nat} {sz_nonzero : (0 < sz)%nat} - {bound_check : (BinInt.Z.to_nat q < NPeano.Nat.pow 2 sz)%nat} {q_5mod8 : (q mod 8 = 5)%Z} - {sqrt_minus1_valid : (@ZToField q 2 ^ BinInt.Z.to_N (q / 4)) ^ 2 = opp 1} - {FqEncoding : canonical encoding of (F q) as (Word.word sz)} - {sign_bit : F q -> bool} {sign_bit_zero : sign_bit 0 = false} - {sign_bit_opp : forall x, x <> 0 -> negb (sign_bit x) = sign_bit (opp x)}. - Existing Instance prime_q. - - Definition point_enc (p : E.point) : Word.word (S sz) := let '(x,y) := proj1_sig p in - Word.WS (sign_bit x) (enc y). - - Program Definition point_dec_with_spec : - {point_dec : Word.word (S sz) -> option E.point - | forall w x, point_dec w = Some x -> (point_enc x = w) - } := @PointEncodingPre.point_dec _ _ _ sign_bit. - - Definition point_dec := Eval hnf in (proj1_sig point_dec_with_spec). - - Definition point_encoding_valid : forall p : E.point, point_dec (point_enc p) = Some p := - @PointEncodingPre.point_encoding_valid _ _ q_5mod8 sqrt_minus1_valid _ _ sign_bit_zero sign_bit_opp. - - Definition point_encoding_canonical : forall x_enc x, point_dec x_enc = Some x -> point_enc x = x_enc := - PointEncodingPre.point_encoding_canonical. - - Instance point_encoding : canonical encoding of E.point as (Word.word (S sz)) := { - enc := point_enc; - dec := point_dec; - encoding_valid := point_encoding_valid; - encoding_canonical := point_encoding_canonical - }. -End PointEncoding.
\ No newline at end of file diff --git a/src/Specific/Ed25519.v b/src/Specific/Ed25519.v deleted file mode 100644 index 3b90b5cdf..000000000 --- a/src/Specific/Ed25519.v +++ /dev/null @@ -1,581 +0,0 @@ -Require Import Bedrock.Word. -Require Import Crypto.Spec.Ed25519. -Require Import Crypto.Tactics.VerdiTactics. -Require Import BinNat BinInt NArith Crypto.Spec.ModularArithmetic. -Require Import ModularArithmetic.ModularArithmeticTheorems. -Require Import ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.Spec.CompleteEdwardsCurve. -Require Import Crypto.Encoding.PointEncodingPre. -Require Import Crypto.Spec.Encoding Crypto.Spec.ModularWordEncoding Crypto.Spec.PointEncoding. -Require Import Crypto.CompleteEdwardsCurve.ExtendedCoordinates. -Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. -Require Import Crypto.Util.IterAssocOp Crypto.Util.WordUtil Crypto.Rep. - -Local Infix "++" := Word.combine. -Local Notation " a '[:' i ']' " := (Word.split1 i _ a) (at level 40). -Local Notation " a '[' i ':]' " := (Word.split2 i _ a) (at level 40). -Local Arguments H {_} _. -Local Arguments scalarMultM1 {_} {_} _ _ _. -Local Arguments unifiedAddM1 {_} {_} _ _. - -Local Ltac set_evars := - repeat match goal with - | [ |- appcontext[?E] ] => is_evar E; let e := fresh "e" in set (e := E) - end. -Local Ltac subst_evars := - repeat match goal with - | [ e := ?E |- _ ] => is_evar E; subst e - end. - -Lemma funexp_proj {T T'} (proj : T -> T') (f : T -> T) (f' : T' -> T') x n - (f_proj : forall a, proj (f a) = f' (proj a)) - : proj (funexp f x n) = funexp f' (proj x) n. -Proof. - revert x; induction n as [|n IHn]; simpl; congruence. -Qed. - -Lemma iter_op_proj {T T' S} (proj : T -> T') (op : T -> T -> T) (op' : T' -> T' -> T') x y z - (testbit : S -> nat -> bool) (bound : nat) - (op_proj : forall a b, proj (op a b) = op' (proj a) (proj b)) - : proj (iter_op op x testbit y z bound) = iter_op op' (proj x) testbit y (proj z) bound. -Proof. - unfold iter_op. - simpl. - lazymatch goal with - | [ |- ?proj (snd (funexp ?f ?x ?n)) = snd (funexp ?f' _ ?n) ] - => pose proof (fun x0 x1 => funexp_proj (fun x => (fst x, proj (snd x))) f f' (x0, x1)) as H' - end. - simpl in H'. - rewrite <- H'. - { reflexivity. } - { intros [??]; simpl. - repeat match goal with - | [ |- context[match ?n with _ => _ end] ] - => destruct n eqn:? - | _ => progress simpl - | _ => progress subst - | _ => reflexivity - | _ => rewrite op_proj - end. } -Qed. - -Lemma B_proj : proj1_sig B = (fst(proj1_sig B), snd(proj1_sig B)). destruct B as [[]]; reflexivity. Qed. - -Require Import Coq.Setoids.Setoid. -Require Import Coq.Classes.Morphisms. -Global Instance option_rect_Proper_nd {A T} - : Proper ((pointwise_relation _ eq) ==> eq ==> eq ==> eq) (@option_rect A (fun _ => T)). -Proof. - intros ?? H ??? [|]??; subst; simpl; congruence. -Qed. - -Global Instance option_rect_Proper_nd' {A T} - : Proper ((pointwise_relation _ eq) ==> eq ==> forall_relation (fun _ => eq)) (@option_rect A (fun _ => T)). -Proof. - intros ?? H ??? [|]; subst; simpl; congruence. -Qed. - -Hint Extern 1 (Proper _ (@option_rect ?A (fun _ => ?T))) => exact (@option_rect_Proper_nd' A T) : typeclass_instances. - -Lemma option_rect_option_map : forall {A B C} (f:A->B) some none v, - option_rect (fun _ => C) (fun x => some (f x)) none v = option_rect (fun _ => C) some none (option_map f v). -Proof. - destruct v; reflexivity. -Qed. - -Axiom decode_scalar : word b -> option N. -Local Existing Instance Ed25519.FlEncoding. -Axiom decode_scalar_correct : forall x, decode_scalar x = option_map (fun x : F (Z.of_nat Ed25519.l) => Z.to_N x) (dec x). - -Local Infix "==?" := E.point_eqb (at level 70) : E_scope. -Local Infix "==?" := ModularArithmeticTheorems.F_eq_dec (at level 70) : F_scope. - -Lemma solve_for_R_eq : forall A B C, (A = B + C <-> B = A - C)%E. -Proof. - intros; split; intros; subst; unfold E.sub; - rewrite <-E.add_assoc, ?E.add_opp_r, ?E.add_opp_l, E.add_0_r; reflexivity. -Qed. - -Lemma solve_for_R : forall A B C, (A ==? B + C)%E = (B ==? A - C)%E. -Proof. - intros. - repeat match goal with |- context [(?P ==? ?Q)%E] => - let H := fresh "H" in - destruct (E.point_eq_dec P Q) as [H|H]; - (rewrite (E.point_eqb_complete _ _ H) || rewrite (E.point_eqb_neq_complete _ _ H)) - end; rewrite solve_for_R_eq in H; congruence. -Qed. - -Local Notation "'(' X ',' Y ',' Z ',' T ')'" := (mkExtended X Y Z T). -Local Notation "2" := (ZToField 2) : F_scope. - -Local Existing Instance PointEncoding. -Lemma decode_point_eq : forall (P_ Q_ : word (S (b-1))) (P Q:E.point), - dec P_ = Some P -> - dec Q_ = Some Q -> - weqb P_ Q_ = (P ==? Q)%E. -Proof. - intros. - replace P_ with (enc P) in * by (auto using encoding_canonical). - replace Q_ with (enc Q) in * by (auto using encoding_canonical). - rewrite E.point_eqb_correct. - edestruct E.point_eq_dec; (apply weqb_true_iff || apply weqb_false_iff); congruence. -Qed. - -Lemma decode_test_encode_test : forall S_ X, option_rect (fun _ : option E.point => bool) - (fun S : E.point => (S ==? X)%E) false (dec S_) = weqb S_ (enc X). -Proof. - intros. - destruct (dec S_) eqn:H. - { symmetry; eauto using decode_point_eq, encoding_valid. } - { simpl @option_rect. - destruct (weqb S_ (enc X)) eqn:Heqb; trivial. - apply weqb_true_iff in Heqb. subst. rewrite encoding_valid in H; discriminate. } -Qed. - -Definition enc' : F q * F q -> word b. -Proof. - intro x. - let enc' := (eval hnf in (@enc (@E.point curve25519params) _ _)) in - match (eval cbv [proj1_sig] in (fun pf => enc' (exist _ x pf))) with - | (fun _ => ?enc') => exact enc' - end. -Defined. - -Definition enc'_correct : @enc (@E.point curve25519params) _ _ = (fun x => enc' (proj1_sig x)) - := eq_refl. - -Definition Let_In {A P} (x : A) (f : forall a : A, P a) : P x := let y := x in f y. -Global Instance Let_In_Proper_nd {A P} - : Proper (eq ==> pointwise_relation _ eq ==> eq) (@Let_In A (fun _ => P)). -Proof. - lazy; intros; congruence. -Qed. -Lemma option_rect_function {A B C S' N' v} f - : f (option_rect (fun _ : option A => option B) S' N' v) - = option_rect (fun _ : option A => C) (fun x => f (S' x)) (f N') v. -Proof. destruct v; reflexivity. Qed. -Local Ltac commute_option_rect_Let_In := (* pull let binders out side of option_rect pattern matching *) - idtac; - lazymatch goal with - | [ |- ?LHS = option_rect ?P ?S ?N (Let_In ?x ?f) ] - => (* we want to just do a [change] here, but unification is stupid, so we have to tell it what to unfold in what order *) - cut (LHS = Let_In x (fun y => option_rect P S N (f y))); cbv beta; - [ set_evars; - let H := fresh in - intro H; - rewrite H; - clear; - abstract (cbv [Let_In]; reflexivity) - | ] - end. -Local Ltac replace_let_in_with_Let_In := - repeat match goal with - | [ |- context G[let x := ?y in @?z x] ] - => let G' := context G[Let_In y z] in change G' - | [ |- _ = Let_In _ _ ] - => apply Let_In_Proper_nd; [ reflexivity | cbv beta delta [pointwise_relation]; intro ] - end. -Local Ltac simpl_option_rect := (* deal with [option_rect _ _ _ None] and [option_rect _ _ _ (Some _)] *) - repeat match goal with - | [ |- context[option_rect ?P ?S ?N None] ] - => change (option_rect P S N None) with N - | [ |- context[option_rect ?P ?S ?N (Some ?x) ] ] - => change (option_rect P S N (Some x)) with (S x); cbv beta - end. - -Section Ed25519Frep. - Generalizable All Variables. - Context `(rcS:RepConversions N SRep) (rcSOK:RepConversionsOK rcS). - Context `(rcF:RepConversions (F (Ed25519.q)) FRep) (rcFOK:RepConversionsOK rcF). - Context (FRepAdd FRepSub FRepMul:FRep->FRep->FRep) (FRepAdd_correct:RepBinOpOK rcF add FRepMul). - Context (FRepSub_correct:RepBinOpOK rcF sub FRepSub) (FRepMul_correct:RepBinOpOK rcF mul FRepMul). - Local Notation rep2F := (unRep : FRep -> F (Ed25519.q)). - Local Notation F2Rep := (toRep : F (Ed25519.q) -> FRep). - Local Notation rep2S := (unRep : SRep -> N). - Local Notation S2Rep := (toRep : N -> SRep). - - Axiom FRepOpp : FRep -> FRep. - Axiom FRepOpp_correct : forall x, opp (rep2F x) = rep2F (FRepOpp x). - - Axiom wltu : forall {b}, word b -> word b -> bool. - Axiom wltu_correct : forall {b} (x y:word b), wltu x y = (wordToN x <? wordToN y)%N. - - Axiom compare_enc : forall x y, F_eqb x y = weqb (@enc _ _ FqEncoding x) (@enc _ _ FqEncoding y). - - Axiom wire2FRep : word (b-1) -> option FRep. - Axiom wire2FRep_correct : forall x, Fm_dec x = option_map rep2F (wire2FRep x). - - Axiom FRep2wire : FRep -> word (b-1). - Axiom FRep2wire_correct : forall x, FRep2wire x = @enc _ _ FqEncoding (rep2F x). - - Axiom SRep_testbit : SRep -> nat -> bool. - Axiom SRep_testbit_correct : forall (x0 : SRep) (i : nat), SRep_testbit x0 i = N.testbit_nat (unRep x0) i. - - Definition FSRepPow x n := iter_op FRepMul (toRep 1%F) SRep_testbit n x 255. - Lemma FSRepPow_correct : forall x n, (N.size_nat (unRep n) <= 255)%nat -> (unRep x ^ unRep n)%F = unRep (FSRepPow x n). - Proof. (* this proof derives the required formula, which I copy-pasted above to be able to reference it without the length precondition *) - unfold FSRepPow; intros. - erewrite <-pow_nat_iter_op_correct by auto. - erewrite <-(fun x => iter_op_spec (scalar := SRep) (mul (m:=Ed25519.q)) F_mul_assoc _ F_mul_1_l _ unRep SRep_testbit_correct n x 255%nat) by auto. - rewrite <-(rcFOK 1%F) at 1. - erewrite <-iter_op_proj by auto. - reflexivity. - Qed. - - Definition FRepInv x : FRep := FSRepPow x (S2Rep (Z.to_N (Ed25519.q - 2))). - Lemma FRepInv_correct : forall x, inv (rep2F x)%F = rep2F (FRepInv x). - unfold FRepInv; intros. - rewrite <-FSRepPow_correct; rewrite rcSOK; try reflexivity. - pose proof @Fq_inv_fermat_correct as H; unfold inv_fermat in H; rewrite H by - auto using Ed25519.prime_q, Ed25519.two_lt_q. - reflexivity. - Qed. - - Lemma unfoldDiv : forall {m} (x y:F m), (x/y = x * inv y)%F. Proof. unfold div. congruence. Qed. - - Definition rep2E (r:FRep * FRep * FRep * FRep) : extended := - match r with (((x, y), z), t) => mkExtended (rep2F x) (rep2F y) (rep2F z) (rep2F t) end. - - Lemma if_map : forall {T U} (f:T->U) (b:bool) (x y:T), (if b then f x else f y) = f (if b then x else y). - Proof. - destruct b; trivial. - Qed. - - Local Ltac Let_In_unRep := - match goal with - | [ |- appcontext G[Let_In (unRep ?x) ?f] ] - => change (Let_In (unRep x) f) with (Let_In x (fun y => f (unRep y))); cbv beta - end. - - - (** TODO: Move me *) - Lemma pull_Let_In {B C} (f : B -> C) A (v : A) (b : A -> B) - : Let_In v (fun v' => f (b v')) = f (Let_In v b). - Proof. - reflexivity. - Qed. - - Lemma Let_app_In {A B T} (g:A->B) (f:B->T) (x:A) : - @Let_In _ (fun _ => T) (g x) f = - @Let_In _ (fun _ => T) x (fun p => f (g x)). - Proof. reflexivity. Qed. - - Lemma Let_app2_In {A B C D T} (g1:A->C) (g2:B->D) (f:C*D->T) (x:A) (y:B) : - @Let_In _ (fun _ => T) (g1 x, g2 y) f = - @Let_In _ (fun _ => T) (x, y) (fun p => f ((g1 (fst p), g2 (snd p)))). - Proof. reflexivity. Qed. - - Create HintDb FRepOperations discriminated. - Hint Rewrite FRepMul_correct FRepAdd_correct FRepSub_correct FRepInv_correct FSRepPow_correct FRepOpp_correct : FRepOperations. - - Create HintDb EdDSA_opts discriminated. - Hint Rewrite FRepMul_correct FRepAdd_correct FRepSub_correct FRepInv_correct FSRepPow_correct FRepOpp_correct : EdDSA_opts. - - Lemma unifiedAddM1Rep_sig : forall a b : FRep * FRep * FRep * FRep, { unifiedAddM1Rep | rep2E unifiedAddM1Rep = unifiedAddM1' (rep2E a) (rep2E b) }. - Proof. - destruct a as [[[]]]; destruct b as [[[]]]. - eexists. - lazymatch goal with |- ?LHS = ?RHS :> ?T => - evar (e:T); replace LHS with e; [subst e|] - end. - unfold rep2E. cbv beta delta [unifiedAddM1']. - pose proof (rcFOK twice_d) as H; rewrite <-H; clear H. (* XXX: this is a hack -- rewrite misresolves typeclasses? *) - - { etransitivity; [|replace_let_in_with_Let_In; reflexivity]. - repeat ( - autorewrite with FRepOperations; - Let_In_unRep; - eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [Proper respectful pointwise_relation]; intro]). - lazymatch goal with |- ?LHS = (unRep ?x, unRep ?y, unRep ?z, unRep ?t) => - change (LHS = (rep2E (((x, y), z), t))) - end. - reflexivity. } - - subst e. - Local Opaque Let_In. - repeat setoid_rewrite (pull_Let_In rep2E). - Local Transparent Let_In. - reflexivity. - Defined. - - Definition unifiedAddM1Rep (a b:FRep * FRep * FRep * FRep) : FRep * FRep * FRep * FRep := Eval hnf in proj1_sig (unifiedAddM1Rep_sig a b). - Definition unifiedAddM1Rep_correct a b : rep2E (unifiedAddM1Rep a b) = unifiedAddM1' (rep2E a) (rep2E b) := Eval hnf in proj2_sig (unifiedAddM1Rep_sig a b). - - Definition rep2T (P:FRep * FRep) := (rep2F (fst P), rep2F (snd P)). - Definition erep2trep (P:FRep * FRep * FRep * FRep) := Let_In P (fun P => Let_In (FRepInv (snd (fst P))) (fun iZ => (FRepMul (fst (fst (fst P))) iZ, FRepMul (snd (fst (fst P))) iZ))). - Lemma erep2trep_correct : forall P, rep2T (erep2trep P) = extendedToTwisted (rep2E P). - Proof. - unfold rep2T, rep2E, erep2trep, extendedToTwisted; destruct P as [[[]]]; simpl. - rewrite !unfoldDiv, <-!FRepMul_correct, <-FRepInv_correct. reflexivity. - Qed. - - (** TODO: possibly move me, remove local *) - Local Ltac replace_option_match_with_option_rect := - idtac; - lazymatch goal with - | [ |- _ = ?RHS :> ?T ] - => lazymatch RHS with - | match ?a with None => ?N | Some x => @?S x end - => replace RHS with (option_rect (fun _ => T) S N a) by (destruct a; reflexivity) - end - end. - - (** TODO: Move me, remove Local *) - Definition proj1_sig_unmatched {A P} := @proj1_sig A P. - Definition proj1_sig_nounfold {A P} := @proj1_sig A P. - Definition proj1_sig_unfold {A P} := Eval cbv [proj1_sig] in @proj1_sig A P. - Local Ltac unfold_proj1_sig_exist := - (** Change the first [proj1_sig] into [proj1_sig_unmatched]; if it's applied to [exist], mark it as unfoldable, otherwise mark it as not unfoldable. Then repeat. Finally, unfold. *) - repeat (change @proj1_sig with @proj1_sig_unmatched at 1; - match goal with - | [ |- context[proj1_sig_unmatched (exist _ _ _)] ] - => change @proj1_sig_unmatched with @proj1_sig_unfold - | _ => change @proj1_sig_unmatched with @proj1_sig_nounfold - end); - (* [proj1_sig_nounfold] is a thin wrapper around [proj1_sig]; unfolding it restores [proj1_sig]. Unfolding [proj1_sig_nounfold] exposes the pattern match, which is reduced by ι. *) - cbv [proj1_sig_nounfold proj1_sig_unfold]. - - (** TODO: possibly move me, remove Local *) - Local Ltac reflexivity_when_unification_is_stupid_about_evars - := repeat first [ reflexivity - | apply f_equal ]. - - - Local Existing Instance eq_Reflexive. (* To get some of the [setoid_rewrite]s below to work, we need to infer [Reflexive eq] before [Reflexive Equivalence.equiv] *) - - (* TODO: move me *) - Lemma fold_rep2E x y z t - : (rep2F x, rep2F y, rep2F z, rep2F t) = rep2E (((x, y), z), t). - Proof. reflexivity. Qed. - Lemma commute_negateExtended'_rep2E x y z t - : negateExtended' (rep2E (((x, y), z), t)) - = rep2E (((FRepOpp x, y), z), FRepOpp t). - Proof. simpl; autorewrite with FRepOperations; reflexivity. Qed. - Lemma fold_rep2E_ffff x y z t - : (x, y, z, t) = rep2E (((toRep x, toRep y), toRep z), toRep t). - Proof. simpl; rewrite !rcFOK; reflexivity. Qed. - Lemma fold_rep2E_rrfr x y z t - : (rep2F x, rep2F y, z, rep2F t) = rep2E (((x, y), toRep z), t). - Proof. simpl; rewrite !rcFOK; reflexivity. Qed. - Lemma fold_rep2E_0fff y z t - : (0%F, y, z, t) = rep2E (((toRep 0%F, toRep y), toRep z), toRep t). - Proof. apply fold_rep2E_ffff. Qed. - Lemma fold_rep2E_ff1f x y t - : (x, y, 1%F, t) = rep2E (((toRep x, toRep y), toRep 1%F), toRep t). - Proof. apply fold_rep2E_ffff. Qed. - Lemma commute_negateExtended'_rep2E_rrfr x y z t - : negateExtended' (unRep x, unRep y, z, unRep t) - = rep2E (((FRepOpp x, y), toRep z), FRepOpp t). - Proof. rewrite <- commute_negateExtended'_rep2E; simpl; rewrite !rcFOK; reflexivity. Qed. - - Hint Rewrite @F_mul_0_l commute_negateExtended'_rep2E_rrfr fold_rep2E_0fff (@fold_rep2E_ff1f (fst (proj1_sig B))) @if_F_eq_dec_if_F_eqb compare_enc (if_map unRep) (fun T => Let_app2_In (T := T) unRep unRep) @F_pow_2_r @unfoldDiv : EdDSA_opts. - Hint Rewrite <- unifiedAddM1Rep_correct erep2trep_correct (fun x y z bound => iter_op_proj rep2E unifiedAddM1Rep unifiedAddM1' x y z N.testbit_nat bound unifiedAddM1Rep_correct) FRep2wire_correct: EdDSA_opts. - - Lemma sharper_verify : forall pk l msg sig, { verify | verify = ed25519_verify pk l msg sig}. - Proof. - eexists; intros. - cbv [ed25519_verify EdDSA.verify - ed25519params curve25519params - EdDSA.E EdDSA.B EdDSA.b EdDSA.l EdDSA.H - EdDSA.PointEncoding EdDSA.FlEncoding EdDSA.FqEncoding]. - - etransitivity. - Focus 2. - { repeat match goal with - | [ |- ?x = ?x ] => reflexivity - | _ => replace_option_match_with_option_rect - | [ |- _ = option_rect _ _ _ _ ] - => eapply option_rect_Proper_nd; [ intro | reflexivity.. ] - end. - set_evars. - rewrite<- E.point_eqb_correct. - rewrite solve_for_R; unfold E.sub. - rewrite E.opp_mul. - let p1 := constr:(scalarMultM1_rep eq_refl) in - let p2 := constr:(unifiedAddM1_rep eq_refl) in - repeat match goal with - | |- context [(_ * E.opp ?P)%E] => - rewrite <-(unExtendedPoint_mkExtendedPoint P); - rewrite negateExtended_correct; - rewrite <-p1 - | |- context [(_ * ?P)%E] => - rewrite <-(unExtendedPoint_mkExtendedPoint P); - rewrite <-p1 - | _ => rewrite p2 - end; - rewrite ?Znat.Z_nat_N, <-?Word.wordToN_nat; - subst_evars; - reflexivity. - } Unfocus. - - etransitivity. - Focus 2. - { lazymatch goal with |- _ = option_rect _ _ ?false ?dec => - symmetry; etransitivity; [|eapply (option_rect_option_map (fun (x:F _) => Z.to_N x) _ false dec)] - end. - eapply option_rect_Proper_nd; [intro|reflexivity..]. - match goal with - | [ |- ?RHS = ?e ?v ] - => let RHS' := (match eval pattern v in RHS with ?RHS' _ => RHS' end) in - unify e RHS' - end. - reflexivity. - } Unfocus. - rewrite <-decode_scalar_correct. - - etransitivity. - Focus 2. - { do 2 (eapply option_rect_Proper_nd; [intro|reflexivity..]). - symmetry; apply decode_test_encode_test. - } Unfocus. - - rewrite enc'_correct. - cbv [unExtendedPoint unifiedAddM1 negateExtended scalarMultM1]. - unfold_proj1_sig_exist. - - etransitivity. - Focus 2. - { do 2 (eapply option_rect_Proper_nd; [intro|reflexivity..]). - set_evars. - repeat match goal with - | [ |- appcontext[@proj1_sig ?A ?P (@iter_op ?T ?f ?neutral ?T' ?testbit ?exp ?base ?bound)] ] - => erewrite (@iter_op_proj T _ _ (@proj1_sig _ _)) by reflexivity - end. - subst_evars. - reflexivity. } - Unfocus. - - cbv [mkExtendedPoint E.zero]. - unfold_proj1_sig_exist. - rewrite B_proj. - - etransitivity. - Focus 2. - { do 1 (eapply option_rect_Proper_nd; [intro|reflexivity..]). - set_evars. - lazymatch goal with |- _ = option_rect _ _ ?false ?dec => - symmetry; etransitivity; [|eapply (option_rect_option_map (@proj1_sig _ _) _ false dec)] - end. - eapply option_rect_Proper_nd; [intro|reflexivity..]. - match goal with - | [ |- ?RHS = ?e ?v ] - => let RHS' := (match eval pattern v in RHS with ?RHS' _ => RHS' end) in - unify e RHS' - end. - reflexivity. - } Unfocus. - - cbv [dec PointEncoding point_encoding]. - etransitivity. - Focus 2. - { do 1 (eapply option_rect_Proper_nd; [intro|reflexivity..]). - etransitivity. - Focus 2. - { apply f_equal. - symmetry. - apply point_dec_coordinates_correct. } - Unfocus. - reflexivity. } - Unfocus. - - cbv iota beta delta [point_dec_coordinates sign_bit dec FqEncoding modular_word_encoding E.solve_for_x2 sqrt_mod_q]. - - etransitivity. - Focus 2. { - do 1 (eapply option_rect_Proper_nd; [|reflexivity..]). cbv beta delta [pointwise_relation]. intro. - etransitivity. - Focus 2. - { apply f_equal. - lazymatch goal with - | [ |- _ = ?term :> ?T ] - => lazymatch term with (match ?a with None => ?N | Some x => @?S x end) - => let term' := constr:((option_rect (fun _ => T) S N) a) in - replace term with term' by reflexivity - end - end. - reflexivity. } Unfocus. reflexivity. } Unfocus. - - etransitivity. - Focus 2. { - do 1 (eapply option_rect_Proper_nd; [cbv beta delta [pointwise_relation]; intro|reflexivity..]). - do 1 (eapply option_rect_Proper_nd; [ intro; reflexivity | reflexivity | ]). - eapply option_rect_Proper_nd; [ cbv beta delta [pointwise_relation]; intro | reflexivity.. ]. - replace_let_in_with_Let_In. - reflexivity. - } Unfocus. - - etransitivity. - Focus 2. { - do 1 (eapply option_rect_Proper_nd; [cbv beta delta [pointwise_relation]; intro|reflexivity..]). - set_evars. - rewrite option_rect_function. (* turn the two option_rects into one *) - subst_evars. - simpl_option_rect. - do 1 (eapply option_rect_Proper_nd; [cbv beta delta [pointwise_relation]; intro|reflexivity..]). - (* push the [option_rect] inside until it hits a [Some] or a [None] *) - repeat match goal with - | _ => commute_option_rect_Let_In - | [ |- _ = Let_In _ _ ] - => apply Let_In_Proper_nd; [ reflexivity | cbv beta delta [pointwise_relation]; intro ] - | [ |- ?LHS = option_rect ?P ?S ?N (if ?b then ?t else ?f) ] - => transitivity (if b then option_rect P S N t else option_rect P S N f); - [ - | destruct b; reflexivity ] - | [ |- _ = if ?b then ?t else ?f ] - => apply (f_equal2 (fun x y => if b then x else y)) - | [ |- _ = false ] => reflexivity - | _ => progress simpl_option_rect - end. - reflexivity. - } Unfocus. - - cbv iota beta delta [q d a]. - - rewrite wire2FRep_correct. - - etransitivity. - Focus 2. { - eapply option_rect_Proper_nd; [|reflexivity..]. cbv beta delta [pointwise_relation]. intro. - rewrite <-!(option_rect_option_map rep2F). - eapply option_rect_Proper_nd; [|reflexivity..]. cbv beta delta [pointwise_relation]. intro. - autorewrite with EdDSA_opts. - rewrite <-(rcFOK 1%F). - pattern Ed25519.d at 1. rewrite <-(rcFOK Ed25519.d) at 1. - pattern Ed25519.a at 1. rewrite <-(rcFOK Ed25519.a) at 1. - rewrite <- (rcSOK (Z.to_N (Ed25519.q / 8 + 1))). - autorewrite with EdDSA_opts. - (Let_In_unRep). - eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [pointwise_relation]; intro]. - etransitivity. Focus 2. eapply Let_In_Proper_nd; [|cbv beta delta [pointwise_relation]; intro;reflexivity]. { - rewrite FSRepPow_correct by (rewrite rcSOK; cbv; omega). - (Let_In_unRep). - etransitivity. Focus 2. eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [pointwise_relation]; intro]. { - set_evars. - rewrite <-(rcFOK sqrt_minus1). - autorewrite with EdDSA_opts. - subst_evars. - reflexivity. } Unfocus. - rewrite pull_Let_In. - reflexivity. } Unfocus. - set_evars. - (Let_In_unRep). - - subst_evars. eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [pointwise_relation]; intro]. set_evars. - - autorewrite with EdDSA_opts. - - subst_evars. - lazymatch goal with |- _ = if ?b then ?t else ?f => apply (f_equal2 (fun x y => if b then x else y)) end; [|reflexivity]. - eapply Let_In_Proper_nd; [reflexivity|cbv beta delta [pointwise_relation]; intro]. - set_evars. - - unfold twistedToExtended. - autorewrite with EdDSA_opts. - progress cbv beta delta [erep2trep]. - - subst_evars. - reflexivity. } Unfocus. - reflexivity. - Defined. -End Ed25519Frep.
\ No newline at end of file diff --git a/src/Specific/GF1305.v b/src/Specific/GF1305.v index b004a60d1..02ef714d9 100644 --- a/src/Specific/GF1305.v +++ b/src/Specific/GF1305.v @@ -15,6 +15,7 @@ Local Open Scope Z. Definition modulus : Z := 2^130 - 5. Lemma prime_modulus : prime modulus. Admitted. +Definition int_width := 32%Z. Instance params1305 : PseudoMersenneBaseParams modulus. construct_params prime_modulus 5%nat 130. @@ -26,16 +27,22 @@ Instance subCoeff : SubtractionCoefficient modulus params1305. apply Build_SubtractionCoefficient with (coeff := mul2modulus); cbv; auto. Defined. +Definition freezePreconditions1305 : freezePreconditions params1305 int_width. +Proof. + constructor; compute_preconditions. +Defined. + (* END PseudoMersenneBaseParams instance construction. *) (* Precompute k and c *) Definition k_ := Eval compute in k. Definition c_ := Eval compute in c. +Definition c_subst : c = c_ := eq_refl c_. (* Makes Qed not take forever *) Opaque Z.shiftr Pos.iter Z.div2 Pos.div2 Pos.div2_up Pos.succ Z.land Z.of_N Pos.land N.ldiff Pos.pred_N Pos.pred_double Z.opp Z.mul Pos.mul - Let_In digits Z.add Pos.add Z.pos_sub. + Let_In digits Z.add Pos.add Z.pos_sub andb Z.eqb Z.ltb. Local Open Scope nat_scope. Lemma GF1305Base26_mul_reduce_formula : @@ -45,8 +52,9 @@ Lemma GF1305Base26_mul_reduce_formula : -> rep ls (f*g)%F}. Proof. eexists; intros ? ? Hf Hg. - pose proof (carry_mul_opt_correct k_ c_ (eq_refl k) (eq_refl c_) [0;4;3;2;1;0]_ _ _ _ Hf Hg) as Hfg. + pose proof (carry_mul_opt_rep k_ c_ (eq_refl k) c_subst _ _ _ _ Hf Hg) as Hfg. compute_formula. + exact Hfg. Defined. Lemma GF1305Base26_add_formula : @@ -58,17 +66,30 @@ Proof. eexists; intros ? ? Hf Hg. pose proof (add_opt_rep _ _ _ _ Hf Hg) as Hfg. compute_formula. + exact Hfg. Defined. -Lemma GF25519Base25Point5_sub_formula : - forall f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 - g0 g1 g2 g3 g4 g5 g6 g7 g8 g9, - {ls | forall f g, rep [f0;f1;f2;f3;f4;f5;f6;f7;f8;f9] f - -> rep [g0;g1;g2;g3;g4;g5;g6;g7;g8;g9] g +Lemma GF1305Base26_sub_formula : + forall f0 f1 f2 f3 f4 g0 g1 g2 g3 g4, + {ls | forall f g, rep [f0;f1;f2;f3;f4] f + -> rep [g0;g1;g2;g3;g4] g -> rep ls (f - g)%F}. Proof. eexists. intros f g Hf Hg. pose proof (sub_opt_rep _ _ _ _ Hf Hg) as Hfg. compute_formula. -Defined.
\ No newline at end of file + exact Hfg. +Defined. + +Lemma GF1305Base26_freeze_formula : + forall f0 f1 f2 f3 f4, + {ls | forall x, rep [f0;f1;f2;f3;f4] x + -> rep ls x}. +Proof. + eexists. + intros x Hf. + pose proof (freeze_opt_preserves_rep _ c_subst freezePreconditions1305 _ _ Hf) as Hfreeze_rep. + compute_formula. + exact Hfreeze_rep. +Defined. diff --git a/src/Specific/GF25519.v b/src/Specific/GF25519.v index 8aaf8caf6..471c1d548 100644 --- a/src/Specific/GF25519.v +++ b/src/Specific/GF25519.v @@ -7,7 +7,6 @@ Require Import Coq.Lists.List Crypto.Util.ListUtil. Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. Require Import Crypto.Tactics.VerdiTactics. Require Import Crypto.BaseSystem. -Require Import Crypto.Rep. Import ListNotations. Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. Local Open Scope Z. @@ -16,6 +15,7 @@ Local Open Scope Z. Definition modulus : Z := 2^255 - 19. Lemma prime_modulus : prime modulus. Admitted. +Definition int_width := 32%Z. Instance params25519 : PseudoMersenneBaseParams modulus. construct_params prime_modulus 10%nat 255. @@ -27,16 +27,22 @@ Instance subCoeff : SubtractionCoefficient modulus params25519. apply Build_SubtractionCoefficient with (coeff := mul2modulus); cbv; auto. Defined. +Definition freezePreconditions25519 : freezePreconditions params25519 int_width. +Proof. + constructor; compute_preconditions. +Defined. + (* END PseudoMersenneBaseParams instance construction. *) (* Precompute k and c *) Definition k_ := Eval compute in k. Definition c_ := Eval compute in c. +Definition c_subst : c = c_ := eq_refl c_. (* Makes Qed not take forever *) Opaque Z.shiftr Pos.iter Z.div2 Pos.div2 Pos.div2_up Pos.succ Z.land Z.of_N Pos.land N.ldiff Pos.pred_N Pos.pred_double Z.opp Z.mul Pos.mul - Let_In digits Z.add Pos.add Z.pos_sub. + Let_In digits Z.add Pos.add Z.pos_sub andb Z.eqb Z.ltb. Local Open Scope nat_scope. Lemma GF25519Base25Point5_mul_reduce_formula : @@ -47,9 +53,23 @@ Lemma GF25519Base25Point5_mul_reduce_formula : -> rep ls (f*g)%F}. Proof. eexists; intros ? ? Hf Hg. - pose proof (carry_mul_opt_correct k_ c_ (eq_refl k_) (eq_refl c_) [0;9;8;7;6;5;4;3;2;1;0]_ _ _ _ Hf Hg) as Hfg. + pose proof (carry_mul_opt_rep k_ c_ (eq_refl k_) c_subst _ _ _ _ Hf Hg) as Hfg. compute_formula. -Time Defined. + exact Hfg. +(*Time*) Defined. + +(* Uncomment this to see a pretty-printed mulmod +Local Transparent Let_In. +Infix "<<" := Z.shiftr (at level 50). +Infix "&" := Z.land (at level 50). +Eval cbv beta iota delta [proj1_sig GF25519Base25Point5_mul_reduce_formula Let_In] in + fun f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 + g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 => proj1_sig ( + GF25519Base25Point5_mul_reduce_formula f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 + g0 g1 g2 g3 g4 g5 g6 g7 g8 g9). +Local Opaque Let_In. +*) + Extraction "/tmp/test.ml" GF25519Base25Point5_mul_reduce_formula. (* It's easy enough to use extraction to get the proper nice-looking formula. @@ -69,6 +89,7 @@ Proof. intros f g Hf Hg. pose proof (add_opt_rep _ _ _ _ Hf Hg) as Hfg. compute_formula. + exact Hfg. Defined. Lemma GF25519Base25Point5_sub_formula : @@ -82,99 +103,28 @@ Proof. intros f g Hf Hg. pose proof (sub_opt_rep _ _ _ _ Hf Hg) as Hfg. compute_formula. + exact Hfg. Defined. -Definition F25519Rep := (Z * Z * Z * Z * Z * Z * Z * Z * Z * Z)%type. - -Definition F25519toRep (x:F (2^255 - 19)) : F25519Rep := (0, 0, 0, 0, 0, 0, 0, 0, 0, FieldToZ x)%Z. -Definition F25519unRep (rx:F25519Rep) := - let '(x9, x8, x7, x6, x5, x4, x3, x2, x1, x0) := rx in - ModularBaseSystem.decode [x0;x1;x2;x3;x4;x5;x6;x7;x8;x9]. - -Global Instance F25519RepConversions : RepConversions (F (2^255 - 19)) F25519Rep := - { - toRep := F25519toRep; - unRep := F25519unRep - }. - -Lemma F25519RepConversionsOK : RepConversionsOK F25519RepConversions. +Lemma GF25519Base25Point5_freeze_formula : + forall f0 f1 f2 f3 f4 f5 f6 f7 f8 f9, + {ls | forall x, rep [f0;f1;f2;f3;f4;f5;f6;f7;f8;f9] x + -> rep ls x}. Proof. - unfold F25519RepConversions, RepConversionsOK, unRep, toRep, F25519toRep, F25519unRep; intros. - change (ModularBaseSystem.decode (ModularBaseSystem.encode x) = x). - eauto using ModularBaseSystemProofs.rep_decode, ModularBaseSystemProofs.encode_rep. -Qed. - -Definition F25519Rep_mul (f g:F25519Rep) : F25519Rep. - refine ( - let '(f9, f8, f7, f6, f5, f4, f3, f2, f1, f0) := f in - let '(g9, g8, g7, g6, g5, g4, g3, g2, g1, g0) := g in _). - (* FIXME: the r should not be present in generated code *) - pose (r := proj1_sig (GF25519Base25Point5_mul_reduce_formula f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 - g0 g1 g2 g3 g4 g5 g6 g7 g8 g9)). - simpl in r. - unfold F25519Rep. - repeat let t' := (eval cbv beta delta [r] in r) in - lazymatch t' with Let_In ?arg ?f => - let x := fresh "x" in - refine (let x := arg in _); - let t'' := (eval cbv beta in (f x)) in - change (Let_In arg f) with t'' in r - end. - let t' := (eval cbv beta delta [r] in r) in - lazymatch t' with [?r0;?r1;?r2;?r3;?r4;?r5;?r6;?r7;?r8;?r9] => - clear r; - exact (r9, r8, r7, r6, r5, r4, r3, r2, r1, r0) - end. -Time Defined. - -Lemma F25519_mul_OK : RepBinOpOK F25519RepConversions ModularArithmetic.mul F25519Rep_mul. - cbv iota beta delta [RepBinOpOK F25519RepConversions F25519Rep_mul toRep unRep F25519toRep F25519unRep]. - destruct x as [[[[[[[[[x9 x8] x7] x6] x5] x4] x3] x2] x1] x0]. - destruct y as [[[[[[[[[y9 y8] y7] y6] y5] y4] y3] y2] y1] y0]. - let E := constr:(GF25519Base25Point5_mul_reduce_formula x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 y0 y1 y2 y3 y4 y5 y6 y7 y8 y9) in - transitivity (ModularBaseSystem.decode (proj1_sig E)); [|solve[simpl;f_equal]]; - destruct E as [? r]; cbv [proj1_sig]. - cbv [rep ModularBaseSystem.rep PseudoMersenneBase modulus] in r; edestruct r; eauto. -Qed. - -Definition F25519Rep_add (f g:F25519Rep) : F25519Rep. - refine ( - let '(f9, f8, f7, f6, f5, f4, f3, f2, f1, f0) := f in - let '(g9, g8, g7, g6, g5, g4, g3, g2, g1, g0) := g in _). - let t' := (eval simpl in (proj1_sig (GF25519Base25Point5_add_formula f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 - g0 g1 g2 g3 g4 g5 g6 g7 g8 g9))) in - lazymatch t' with [?r0;?r1;?r2;?r3;?r4;?r5;?r6;?r7;?r8;?r9] => - exact (r9, r8, r7, r6, r5, r4, r3, r2, r1, r0) - end. -Defined. - -Definition F25519Rep_sub (f g:F25519Rep) : F25519Rep. - refine ( - let '(f9, f8, f7, f6, f5, f4, f3, f2, f1, f0) := f in - let '(g9, g8, g7, g6, g5, g4, g3, g2, g1, g0) := g in _). - let t' := (eval simpl in (proj1_sig (GF25519Base25Point5_sub_formula f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 - g0 g1 g2 g3 g4 g5 g6 g7 g8 g9))) in - lazymatch t' with [?r0;?r1;?r2;?r3;?r4;?r5;?r6;?r7;?r8;?r9] => - exact (r9, r8, r7, r6, r5, r4, r3, r2, r1, r0) - end. + eexists. + intros x Hf. + pose proof (freeze_opt_preserves_rep _ c_subst freezePreconditions25519 _ _ Hf) as Hfreeze_rep. + compute_formula. + exact Hfreeze_rep. Defined. -Lemma F25519_add_OK : RepBinOpOK F25519RepConversions ModularArithmetic.add F25519Rep_add. - cbv iota beta delta [RepBinOpOK F25519RepConversions F25519Rep_add toRep unRep F25519toRep F25519unRep]. - destruct x as [[[[[[[[[x9 x8] x7] x6] x5] x4] x3] x2] x1] x0]. - destruct y as [[[[[[[[[y9 y8] y7] y6] y5] y4] y3] y2] y1] y0]. - let E := constr:(GF25519Base25Point5_add_formula x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 y0 y1 y2 y3 y4 y5 y6 y7 y8 y9) in - transitivity (ModularBaseSystem.decode (proj1_sig E)); [|solve[simpl;f_equal]]; - destruct E as [? r]; cbv [proj1_sig]. - cbv [rep ModularBaseSystem.rep PseudoMersenneBase modulus] in r; edestruct r; eauto. -Qed. - -Lemma F25519_sub_OK : RepBinOpOK F25519RepConversions ModularArithmetic.sub F25519Rep_sub. - cbv iota beta delta [RepBinOpOK F25519RepConversions F25519Rep_sub toRep unRep F25519toRep F25519unRep]. - destruct x as [[[[[[[[[x9 x8] x7] x6] x5] x4] x3] x2] x1] x0]. - destruct y as [[[[[[[[[y9 y8] y7] y6] y5] y4] y3] y2] y1] y0]. - let E := constr:(GF25519Base25Point5_sub_formula x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 y0 y1 y2 y3 y4 y5 y6 y7 y8 y9) in - transitivity (ModularBaseSystem.decode (proj1_sig E)); [|solve[simpl;f_equal]]; - destruct E as [? r]; cbv [proj1_sig]. - cbv [rep ModularBaseSystem.rep PseudoMersenneBase modulus] in r; edestruct r; eauto. -Qed.
\ No newline at end of file +(* Uncomment the below to see pretty-printed freeze function *) +(* +Set Printing Depth 1000. +Local Transparent Let_In. +Infix "<<" := Z.shiftr (at level 50). +Infix "&" := Z.land (at level 50). +Eval cbv beta iota delta [proj1_sig GF25519Base25Point5_freeze_formula Let_In] in + fun f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 => proj1_sig ( + GF25519Base25Point5_freeze_formula f0 f1 f2 f3 f4 f5 f6 f7 f8 f9). +*)
\ No newline at end of file diff --git a/src/Tactics/Nsatz.v b/src/Tactics/Nsatz.v new file mode 100644 index 000000000..8fa8c4a86 --- /dev/null +++ b/src/Tactics/Nsatz.v @@ -0,0 +1,127 @@ +(*** Tactics for manipulating polynomial equations *) +Require Coq.nsatz.Nsatz. +Require Import List. + +Generalizable All Variables. +Lemma cring_sub_diag_iff {R zero eq sub} `{cring:Cring.Cring (R:=R) (ring0:=zero) (ring_eq:=eq) (sub:=sub)} + : forall x y, eq (sub x y) zero <-> eq x y. +Proof. + split;intros Hx. + { eapply Nsatz.psos_r1b. eapply Hx. } + { eapply Nsatz.psos_r1. eapply Hx. } +Qed. + +Ltac get_goal := lazymatch goal with |- ?g => g end. + +Ltac nsatz_equation_implications_to_list eq zero g := + lazymatch g with + | eq ?p zero => constr:(p::nil) + | eq ?p zero -> ?g => let l := nsatz_equation_implications_to_list eq zero g in constr:(p::l) + end. + +Ltac nsatz_reify_equations eq zero := + let g := get_goal in + let lb := nsatz_equation_implications_to_list eq zero g in + lazymatch (eval red in (Ncring_tac.list_reifyl (lterm:=lb))) with + (?variables, ?le) => + lazymatch (eval compute in (List.rev le)) with + | ?reified_goal::?reified_givens => constr:((variables, reified_givens, reified_goal)) + end + end. + +Ltac nsatz_get_free_variables reified_package := + lazymatch reified_package with (?fv, _, _) => fv end. + +Ltac nsatz_get_reified_givens reified_package := + lazymatch reified_package with (_, ?givens, _) => givens end. + +Ltac nsatz_get_reified_goal reified_package := + lazymatch reified_package with (_, _, ?goal) => goal end. + +Require Import Coq.setoid_ring.Ring_polynom. +(* Kludge for 8.4/8.5 compatibility *) +Module Import mynsatz_compute. + Import Nsatz. + Global Ltac mynsatz_compute x := nsatz_compute x. +End mynsatz_compute. +Ltac nsatz_compute x := mynsatz_compute x. + +Ltac nsatz_compute_to_goal sugar nparams reified_goal power reified_givens := + nsatz_compute (PEc sugar :: PEc nparams :: PEpow reified_goal power :: reified_givens). + +Ltac nsatz_compute_get_leading_coefficient := + lazymatch goal with + |- Logic.eq ((?a :: _ :: ?b) :: ?c) _ -> _ => a + end. + +Ltac nsatz_compute_get_certificate := + lazymatch goal with + |- Logic.eq ((?a :: _ :: ?b) :: ?c) _ -> _ => constr:((c,b)) + end. + +Ltac nsatz_rewrite_and_revert domain := + lazymatch type of domain with + | @Integral_domain.Integral_domain ?F ?zero _ _ _ _ _ ?eq ?Fops ?FRing ?FCring => + lazymatch goal with + | |- eq _ zero => idtac + | |- eq _ _ => rewrite <-(cring_sub_diag_iff (cring:=FCring)) + end; + repeat match goal with + | [H : eq _ zero |- _ ] => revert H + | [H : eq _ _ |- _ ] => rewrite <-(cring_sub_diag_iff (cring:=FCring)) in H; revert H + end + end. + +Ltac nsatz_nonzero := + try solve [apply Integral_domain.integral_domain_one_zero + |apply Integral_domain.integral_domain_minus_one_zero + |trivial]. + +Ltac nsatz_domain_sugar_power domain sugar power := + let nparams := constr:(BinInt.Zneg BinPos.xH) in (* some symbols can be "parameters", treated as coefficients *) + lazymatch type of domain with + | @Integral_domain.Integral_domain ?F ?zero _ _ _ _ _ ?eq ?Fops ?FRing ?FCring => + nsatz_rewrite_and_revert domain; + let reified_package := nsatz_reify_equations eq zero in + let fv := nsatz_get_free_variables reified_package in + let interp := constr:(@Nsatz.PEevalR _ _ _ _ _ _ _ _ Fops fv) in + let reified_givens := nsatz_get_reified_givens reified_package in + let reified_goal := nsatz_get_reified_goal reified_package in + nsatz_compute_to_goal sugar nparams reified_goal power reified_givens; + let a := nsatz_compute_get_leading_coefficient in + let crt := nsatz_compute_get_certificate in + intros _ (* discard [nsatz_compute] output *); intros; + apply (fun Haa refl cond => @Integral_domain.Rintegral_domain_pow _ _ _ _ _ _ _ _ _ _ _ domain (interp a) _ (BinNat.N.to_nat power) Haa (@Nsatz.check_correct _ _ _ _ _ _ _ _ _ _ FCring fv reified_givens (PEmul a (PEpow reified_goal power)) crt refl cond)); + [ nsatz_nonzero; cbv iota beta delta [Nsatz.PEevalR PEeval InitialRing.gen_phiZ InitialRing.gen_phiPOS] + | solve [vm_compute; exact (eq_refl true)] (* exact_no_check (eq_refl true) *) + | solve [repeat (split; [assumption|]); exact I] ] + end. + +Ltac nsatz_guess_domain := + match goal with + | |- ?eq _ _ => constr:(_:Integral_domain.Integral_domain (ring_eq:=eq)) + | |- not (?eq _ _) => constr:(_:Integral_domain.Integral_domain (ring_eq:=eq)) + | [H: ?eq _ _ |- _ ] => constr:(_:Integral_domain.Integral_domain (ring_eq:=eq)) + | [H: not (?eq _ _) |- _] => constr:(_:Integral_domain.Integral_domain (ring_eq:=eq)) + end. + +Ltac nsatz_sugar_power sugar power := + let domain := nsatz_guess_domain in + nsatz_domain_sugar_power domain sugar power. + +Tactic Notation "nsatz" constr(n) := + let nn := (eval compute in (BinNat.N.of_nat n)) in + nsatz_sugar_power BinInt.Z0 nn. + +Tactic Notation "nsatz" := nsatz 1%nat || nsatz 2%nat || nsatz 3%nat || nsatz 4%nat || nsatz 5%nat. + +Ltac nsatz_contradict := + unfold not; + intros; + let domain := nsatz_guess_domain in + lazymatch type of domain with + | @Integral_domain.Integral_domain _ ?zero ?one _ _ _ _ ?eq ?Fops ?FRing ?FCring => + assert (eq one zero) as Hbad; + [nsatz; nsatz_nonzero + |destruct (Integral_domain.integral_domain_one_zero (Integral_domain:=domain) Hbad)] + end. diff --git a/src/Testbit.v b/src/Testbit.v index 264069587..2bfcc3df6 100644 --- a/src/Testbit.v +++ b/src/Testbit.v @@ -3,6 +3,7 @@ Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. Require Import Crypto.BaseSystem Crypto.BaseSystemProofs. Require Import Coq.ZArith.ZArith Coq.ZArith.Zdiv. Require Import Coq.omega.Omega Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith. +Import Nat. Local Open Scope Z. @@ -209,4 +210,4 @@ Lemma testbit_spec : forall n us base limb_width, (0 < limb_width)%nat -> Proof. intros. erewrite unfold_bits_testbit, unfold_bits_decode; eauto; omega. -Qed.
\ No newline at end of file +Qed. diff --git a/src/Util/CaseUtil.v b/src/Util/CaseUtil.v index cf3ebf29c..2d1ab6c58 100644 --- a/src/Util/CaseUtil.v +++ b/src/Util/CaseUtil.v @@ -1,12 +1,12 @@ -Require Import Coq.Arith.Arith. +Require Import Coq.Arith.Arith Coq.Arith.Max. Ltac case_max := match goal with [ |- context[max ?x ?y] ] => destruct (le_dec x y); match goal with - | [ H : (?x <= ?y)%nat |- context[max ?x ?y] ] => rewrite Max.max_r by + | [ H : (?x <= ?y)%nat |- context[max ?x ?y] ] => rewrite max_r by (exact H) - | [ H : ~ (?x <= ?y)%nat |- context[max ?x ?y] ] => rewrite Max.max_l by + | [ H : ~ (?x <= ?y)%nat |- context[max ?x ?y] ] => rewrite max_l by (exact (le_Sn_le _ _ (not_le _ _ H))) end end. diff --git a/src/Util/IterAssocOp.v b/src/Util/IterAssocOp.v index 6116312e1..82d22046d 100644 --- a/src/Util/IterAssocOp.v +++ b/src/Util/IterAssocOp.v @@ -1,5 +1,7 @@ Require Import Coq.Setoids.Setoid Coq.Classes.Morphisms Coq.Classes.Equivalence. Require Import Coq.NArith.NArith Coq.PArith.BinPosDef. +Require Import Coq.Numbers.Natural.Peano.NPeano. + Local Open Scope equiv_scope. Generalizable All Variables. @@ -147,7 +149,7 @@ Section IterAssocOp. destruct (funexp (test_and_op n a) (x, acc) y) as [i acc']. simpl in IHy. unfold test_and_op. - destruct i; rewrite NPeano.Nat.sub_succ_r; subst; rewrite <- IHy; simpl; reflexivity. + destruct i; rewrite Nat.sub_succ_r; subst; rewrite <- IHy; simpl; reflexivity. Qed. Lemma iter_op_termination : forall sc a bound, diff --git a/src/Util/ListUtil.v b/src/Util/ListUtil.v index 36d8a3ad3..0426c0834 100644 --- a/src/Util/ListUtil.v +++ b/src/Util/ListUtil.v @@ -18,7 +18,7 @@ Proof. intros. induction n; boring. Qed. -Ltac nth_tac' := +Ltac nth_tac' := intros; simpl in *; unfold error,value in *; repeat progress (match goal with | [ |- context[nth_error nil ?n] ] => rewrite nth_error_nil_error | [ H: ?x = Some _ |- context[match ?x with Some _ => ?a | None => ?a end ] ] => destruct x @@ -79,10 +79,10 @@ Proof. reflexivity. nth_tac'. pose proof (nth_error_error_length A n l H0). - omega. + omega. Qed. -Ltac nth_tac := +Ltac nth_tac := repeat progress (try nth_tac'; try (match goal with | [ H: nth_error (map _ _) _ = Some _ |- _ ] => destruct (nth_error_map _ _ _ _ _ _ H); clear H | [ H: nth_error (seq _ _) _ = Some _ |- _ ] => rewrite nth_error_seq in H @@ -191,7 +191,7 @@ Proof. Qed. Lemma set_nth_equiv_splice_nth: forall {T} n x (xs:list T), - set_nth n x xs = + set_nth n x xs = if lt_dec n (length xs) then splice_nth n x xs else xs. @@ -210,7 +210,7 @@ Lemma combine_set_nth : forall {A B} n (x:A) xs (ys:list B), end. Proof. (* TODO(andreser): this proof can totally be automated, but requires writing ltac that vets multiple hypotheses at once *) - induction n, xs, ys; nth_tac; try rewrite IHn; nth_tac; + induction n, xs, ys; nth_tac; try rewrite IHn; nth_tac; try (f_equal; specialize (IHn x xs ys ); rewrite H in IHn; rewrite <- IHn); try (specialize (nth_error_value_length _ _ _ _ H); omega). assert (Some b0=Some b1) as HA by (rewrite <-H, <-H0; auto). @@ -258,13 +258,13 @@ Proof. destruct (lt_dec n (length xs)); auto. Qed. -Lemma combine_truncate_r : forall {A} (xs ys : list A), +Lemma combine_truncate_r : forall {A B} (xs : list A) (ys : list B), combine xs ys = combine xs (firstn (length xs) ys). Proof. induction xs; destruct ys; boring. Qed. -Lemma combine_truncate_l : forall {A} (xs ys : list A), +Lemma combine_truncate_l : forall {A B} (xs : list A) (ys : list B), combine xs ys = combine (firstn (length ys) xs) ys. Proof. induction xs; destruct ys; boring. @@ -330,7 +330,7 @@ Proof. intros. rewrite firstn_app_inleft; auto using firstn_all; omega. Qed. - + Lemma skipn_app_sharp : forall {A} n (l l': list A), length l = n -> skipn n (l ++ l') = l'. @@ -422,7 +422,7 @@ Proof. right; repeat eexists; auto. } Qed. - + Lemma nil_length0 : forall {T}, length (@nil T) = 0%nat. Proof. auto. @@ -512,7 +512,7 @@ Ltac nth_error_inbounds := match goal with | [ |- context[match nth_error ?xs ?i with Some _ => _ | None => _ end ] ] => case_eq (nth_error xs i); - match goal with + match goal with | [ |- forall _, nth_error xs i = Some _ -> _ ] => let x := fresh "x" in let H := fresh "H" in @@ -582,3 +582,46 @@ Lemma In_firstn : forall {T} n l (x : T), In x (firstn n l) -> In x l. Proof. induction n; destruct l; boring. Qed. + +Lemma firstn_firstn : forall {A} m n (l : list A), (n <= m)%nat -> + firstn n (firstn m l) = firstn n l. +Proof. + induction m; destruct n; intros; try omega; auto. + destruct l; auto. + simpl. + f_equal. + apply IHm; omega. +Qed. + +Lemma firstn_succ : forall {A} (d : A) n l, (n < length l)%nat -> + firstn (S n) l = (firstn n l) ++ nth_default d l n :: nil. +Proof. + induction n; destruct l; rewrite ?(@nil_length0 A); intros; try omega. + + rewrite nth_default_cons; auto. + + simpl. + rewrite nth_default_cons_S. + rewrite <-IHn by (rewrite cons_length in *; omega). + reflexivity. +Qed. + +Lemma firstn_all_strong : forall {A} (xs : list A) n, (length xs <= n)%nat -> + firstn n xs = xs. +Proof. + induction xs; intros; try apply firstn_nil. + destruct n; + match goal with H : (length (_ :: _) <= _)%nat |- _ => + simpl in H; try omega end. + simpl. + f_equal. + apply IHxs. + omega. +Qed. + +Lemma set_nth_nth_default : forall {A} (d:A) n x l i, (0 <= i < length l)%nat -> + nth_default d (set_nth n x l) i = + if (eq_nat_dec i n) then x else nth_default d l i. +Proof. + induction n; (destruct l; [intros; simpl in *; omega | ]); simpl; + destruct i; break_if; try omega; intros; try apply nth_default_cons; + rewrite !nth_default_cons_S, ?IHn; try break_if; omega || reflexivity. +Qed.
\ No newline at end of file diff --git a/src/Util/NatUtil.v b/src/Util/NatUtil.v index 1f69b04d2..bc2c8425b 100644 --- a/src/Util/NatUtil.v +++ b/src/Util/NatUtil.v @@ -1,4 +1,5 @@ Require Import Coq.Numbers.Natural.Peano.NPeano Coq.omega.Omega. +Import Nat. Local Open Scope nat_scope. @@ -57,7 +58,18 @@ Proof. } Qed. +Lemma lt_min_l : forall x a b, (x < min a b)%nat -> (x < a)%nat. +Proof. + intros ? ? ? lt_min. + apply Nat.min_glb_lt_iff in lt_min. + destruct lt_min; assumption. +Qed. +(* useful for hints *) +Lemma eq_le_incl_rev : forall a b, a = b -> b <= a. +Proof. + intros; omega. +Qed. Lemma beq_nat_eq_nat_dec {R} (x y : nat) (a b : R) : (if EqNat.beq_nat x y then a else b) = (if eq_nat_dec x y then a else b). @@ -66,4 +78,3 @@ Proof. [ rewrite (proj2 (@beq_nat_true_iff _ _) H); reflexivity | rewrite (proj2 (@beq_nat_false_iff _ _) H); reflexivity ]. Qed. - diff --git a/src/Util/Tuple.v b/src/Util/Tuple.v new file mode 100644 index 000000000..6802a86c3 --- /dev/null +++ b/src/Util/Tuple.v @@ -0,0 +1,81 @@ +Require Import Coq.Classes.Morphisms. +Require Import Relation_Definitions. + +Fixpoint tuple' T n : Type := + match n with + | O => T + | S n' => (tuple' T n' * T)%type + end. + +Definition tuple T n : Type := + match n with + | O => unit + | S n' => tuple' T n' + end. + +Fixpoint to_list' {T} (n:nat) {struct n} : tuple' T n -> list T := + match n with + | 0 => fun x => (x::nil)%list + | S n' => fun xs : tuple' T (S n') => let (xs', x) := xs in (x :: to_list' n' xs')%list + end. + +Definition to_list {T} (n:nat) : tuple T n -> list T := + match n with + | 0 => fun _ => nil + | S n' => fun xs : tuple T (S n') => to_list' n' xs + end. + +Fixpoint from_list' {T} (x:T) (xs:list T) : tuple' T (length xs) := + match xs with + | nil => x + | (y :: xs')%list => (from_list' y xs', x) + end. + +Definition from_list {T} (xs:list T) : tuple T (length xs) := + match xs as l return (tuple T (length l)) with + | nil => tt + | (t :: xs')%list => from_list' t xs' + end. + +Lemma to_list_from_list : forall {T} (xs:list T), to_list (length xs) (from_list xs) = xs. +Proof. + destruct xs; auto; simpl. + generalize dependent t. + induction xs; auto; simpl; intros; f_equal; auto. +Qed. + +Lemma length_to_list : forall {T} {n} (xs:tuple T n), length (to_list n xs) = n. +Proof. + destruct n; auto; intros; simpl in *. + induction n; auto; intros; simpl in *. + destruct xs; simpl in *; eauto. +Qed. + +Fixpoint fieldwise' {A B} (n:nat) (R:A->B->Prop) (a:tuple' A n) (b:tuple' B n) {struct n} : Prop. + destruct n; simpl @tuple' in *. + { exact (R a b). } + { exact (R (snd a) (snd b) /\ fieldwise' _ _ n R (fst a) (fst b)). } +Defined. + +Definition fieldwise {A B} (n:nat) (R:A->B->Prop) (a:tuple A n) (b:tuple B n) : Prop. + destruct n; simpl @tuple in *. + { exact True. } + { exact (fieldwise' _ R a b). } +Defined. + +Global Instance Equivalence_fieldwise' {A} {R:relation A} {R_equiv:Equivalence R} {n:nat}: + Equivalence (fieldwise' n R). +Proof. + induction n as [|? IHn]; [solve [auto]|]. + (* could use [dintuition] in 8.5 only, and remove the [destruct] *) + destruct IHn, R_equiv; simpl; constructor; repeat intro; intuition eauto. +Qed. + +Global Instance Equivalence_fieldwise {A} {R:relation A} {R_equiv:Equivalence R} {n:nat}: + Equivalence (fieldwise n R). +Proof. + destruct n; (repeat constructor || apply Equivalence_fieldwise'). +Qed. + +Arguments fieldwise' {A B n} _ _ _. +Arguments fieldwise {A B n} _ _ _. diff --git a/src/Util/ZUtil.v b/src/Util/ZUtil.v index 1b7cfafdc..a5716fca4 100644 --- a/src/Util/ZUtil.v +++ b/src/Util/ZUtil.v @@ -2,6 +2,7 @@ Require Import Coq.ZArith.Zpower Coq.ZArith.Znumtheory Coq.ZArith.ZArith Coq.ZAr Require Import Coq.omega.Omega Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith. Require Import Crypto.Util.NatUtil. Require Import Coq.Lists.List. +Import Nat. Local Open Scope Z. Lemma gt_lt_symmetry: forall n m, n > m <-> m < n. @@ -208,7 +209,7 @@ Proof. rewrite (le_plus_minus n m) at 1 by assumption. rewrite Nat2Z.inj_add. rewrite Z.pow_add_r by apply Nat2Z.is_nonneg. - rewrite <- Z.div_div by first + rewrite <- Z.div_div by first [ pose proof (Z.pow_pos_nonneg 2 (Z.of_nat n)); omega | apply Z.pow_pos_nonneg; omega ]. rewrite Z.div_add by (pose proof (Z.pow_pos_nonneg 2 (Z.of_nat n)); omega). @@ -332,7 +333,7 @@ Qed. replace (2 ^ (Z.pos p)) with (2 ^ (Z.pos p - 1)* 2). rewrite Z.div_add_l by omega. reflexivity. - replace 2 with (2 ^ 1) at 2 by auto. + change 2 with (2 ^ 1) at 2. rewrite <-Z.pow_add_r by (pose proof (Pos2Z.is_pos p); omega). f_equal. omega. Qed. @@ -345,7 +346,7 @@ Qed. + unfold Z.ones. rewrite Z.shiftr_0_r, Z.shiftl_1_l, Z.sub_0_r. omega. - + intros. + + intros. destruct (Z_lt_le_dec x n); try omega. intuition. left. @@ -360,7 +361,7 @@ Qed. Z.shiftr a i <= Z.ones (n - i) . Proof. intros a n i G G0 G1. - destruct (Z_le_lt_eq_dec i n G1). + destruct (Z_le_lt_eq_dec i n G1). + destruct (Z_shiftr_ones' a n G i G0); omega. + subst; rewrite Z.sub_diag. destruct (Z_eq_dec a 0). @@ -385,34 +386,91 @@ Qed. omega. Qed. -(* prove that known nonnegative numbers are nonnegative *) +(* prove that combinations of known positive/nonnegative numbers are positive/nonnegative *) Ltac zero_bounds' := repeat match goal with | [ |- 0 <= _ + _] => apply Z.add_nonneg_nonneg - | [ |- 0 <= _ - _] => apply Zle_minus_le_0 + | [ |- 0 <= _ - _] => apply Z.le_0_sub | [ |- 0 <= _ * _] => apply Z.mul_nonneg_nonneg | [ |- 0 <= _ / _] => apply Z.div_pos - | [ |- 0 < _ + _] => apply Z.add_pos_nonneg - (* TODO : this apply is not good: it can make a true goal false. Actually, - * we would want this tactic to explore two branches: - * - apply Z.add_pos_nonneg and continue - * - apply Z.add_nonneg_pos and continue - * Keep whichever one solves all subgoals. If neither does, don't apply. *) - - | [ |- 0 < _ - _] => apply Zlt_minus_lt_0 + | [ |- 0 <= _ ^ _ ] => apply Z.pow_nonneg + | [ |- 0 <= Z.shiftr _ _] => apply Z.shiftr_nonneg + | [ |- 0 < _ + _] => try solve [apply Z.add_pos_nonneg; zero_bounds']; + try solve [apply Z.add_nonneg_pos; zero_bounds'] + | [ |- 0 < _ - _] => apply Z.lt_0_sub | [ |- 0 < _ * _] => apply Z.lt_0_mul; left; split | [ |- 0 < _ / _] => apply Z.div_str_pos + | [ |- 0 < _ ^ _ ] => apply Z.pow_pos_nonneg end; try omega; try prime_bound; auto. Ltac zero_bounds := try omega; try prime_bound; zero_bounds'. - Lemma Z_ones_nonneg : forall i, (0 <= i) -> 0 <= Z.ones i. - Proof. - apply natlike_ind. - + unfold Z.ones. simpl; omega. - + intros. - rewrite Z_ones_succ by assumption. - zero_bounds. - apply Z.pow_nonneg; omega. - Qed. +Lemma Z_ones_nonneg : forall i, (0 <= i) -> 0 <= Z.ones i. +Proof. + apply natlike_ind. + + unfold Z.ones. simpl; omega. + + intros. + rewrite Z_ones_succ by assumption. + zero_bounds. +Qed. + +Lemma Z_ones_pos_pos : forall i, (0 < i) -> 0 < Z.ones i. +Proof. + intros. + unfold Z.ones. + rewrite Z.shiftl_1_l. + apply Z.lt_succ_lt_pred. + apply Z.pow_gt_1; omega. +Qed. + +Lemma N_le_1_l : forall p, (1 <= N.pos p)%N. +Proof. + destruct p; cbv; congruence. +Qed. +Lemma Pos_land_upper_bound_l : forall a b, (Pos.land a b <= N.pos a)%N. +Proof. + induction a; destruct b; intros; try solve [cbv; congruence]; + simpl; specialize (IHa b); case_eq (Pos.land a b); intro; simpl; + try (apply N_le_1_l || apply N.le_0_l); intro land_eq; + rewrite land_eq in *; unfold N.le, N.compare in *; + rewrite ?Pos.compare_xI_xI, ?Pos.compare_xO_xI, ?Pos.compare_xO_xO; + try assumption. + destruct (p ?=a)%positive; cbv; congruence. +Qed. + +Lemma Z_land_upper_bound_l : forall a b, (0 <= a) -> (0 <= b) -> + Z.land a b <= a. +Proof. + intros. + destruct a, b; try solve [exfalso; auto]; try solve [cbv; congruence]. + cbv [Z.land]. + rewrite <-N2Z.inj_pos, <-N2Z.inj_le. + auto using Pos_land_upper_bound_l. +Qed. + +Lemma Z_land_upper_bound_r : forall a b, (0 <= a) -> (0 <= b) -> + Z.land a b <= b. +Proof. + intros. + rewrite Z.land_comm. + auto using Z_land_upper_bound_l. +Qed. + +Lemma Z_le_fold_right_max : forall low l x, (forall y, In y l -> low <= y) -> + In x l -> x <= fold_right Z.max low l. +Proof. + induction l; intros ? lower_bound In_list; [cbv [In] in *; intuition | ]. + simpl. + destruct (in_inv In_list); subst. + + apply Z.le_max_l. + + etransitivity. + - apply IHl; auto; intuition. + - apply Z.le_max_r. +Qed. + +Lemma Z_le_fold_right_max_initial : forall low l, low <= fold_right Z.max low l. +Proof. + induction l; intros; try reflexivity. + etransitivity; [ apply IHl | apply Z.le_max_r ]. +Qed. |