diff options
author | Jason Gross <jgross@mit.edu> | 2016-10-17 18:20:33 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2016-10-19 12:37:32 -0400 |
commit | 077a20a0018c9823c2568eb624122f48ab35c1d5 (patch) | |
tree | 1352f75bb2c65c135e7e6dbb082b5a1f3830f0df /src | |
parent | c1f4f952a034a4a9b8677a5b4823f97a7fab2252 (diff) |
Add opt versions of add, sub, opp
Diffstat (limited to 'src')
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystem.v | 21 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemOpt.v | 107 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemProofs.v | 2 | ||||
-rw-r--r-- | src/Specific/GF25519.v | 89 |
4 files changed, 205 insertions, 14 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index 1769f86c4..ff2c23ab8 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -60,9 +60,22 @@ Section ModularBaseSystem. (* Placeholder *) Definition div (x y : digits) : digits := encode (F.div (decode x) (decode y)). + Definition carry_ (carry_chain : list nat) (us : digits) : digits := + from_list (carry_sequence carry_chain [[us]]) (length_carry_sequence length_to_list). + + Definition carry_add (carry_chain : list nat) (us vs : digits) : digits := + carry_ carry_chain (add us vs). Definition carry_mul (carry_chain : list nat) (us vs : digits) : digits := - from_list (carry_sequence carry_chain [[mul us vs]]) (length_carry_sequence length_to_list). - + carry_ carry_chain (mul us vs). + Definition carry_sub (carry_chain : list nat) (modulus_multiple: digits) + (modulus_multiple_correct : decode modulus_multiple = 0%F) + (us vs : digits) : digits := + carry_ carry_chain (sub modulus_multiple modulus_multiple_correct us vs). + Definition carry_opp (carry_chain : list nat) (modulus_multiple : digits) + (modulus_multiple_correct : decode modulus_multiple = 0%F) + (x : digits) : digits := + carry_ carry_chain (opp modulus_multiple modulus_multiple_correct x). + Definition rep (us : digits) (x : F modulus) := decode us = x. Local Notation "u ~= x" := (rep u x). Local Hint Unfold rep. @@ -102,8 +115,8 @@ Section ModularBaseSystem. Definition pack (x : digits) : target_digits := from_list (pack target_widths_nonneg bits_eq [[x]]) length_pack. - + Definition unpack (x : target_digits) : digits := from_list (unpack target_widths_nonneg bits_eq [[x]]) length_unpack. -End ModularBaseSystem.
\ No newline at end of file +End ModularBaseSystem. diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index 3b50662a6..f1b8c601b 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -410,8 +410,32 @@ Section Carries. End Carries. +Section CarryChain. + Context `{prm : PseudoMersenneBaseParams} {cc : CarryChain limb_widths}. + Local Notation digits := (tuple Z (length limb_widths)). + + Definition carry__opt_sig {T} (f : digits -> T) (us : digits) + : { x | x = f (carry_ carry_chain us) }. + Proof. + eexists. + cbv [carry_]. + rewrite <- from_list_default_eq with (d := 0%Z). + change @from_list_default with @from_list_default_opt. + erewrite <-carry_sequence_opt_cps_correct by eauto using carry_chain_valid, length_to_list. + cbv [carry_sequence_opt_cps]. + reflexivity. + Defined. + + Definition carry__opt_cps {T} (f:digits -> T) (us : digits) : T + := Eval cbv [proj1_sig carry__opt_sig] in proj1_sig (carry__opt_sig f us). + + Definition carry__opt_cps_correct {T} (f:digits -> T) (us : digits) + : carry__opt_cps f us = f (carry_ carry_chain us) + := proj2_sig (carry__opt_sig f us). +End CarryChain. + Section Addition. - Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient}. + Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient} {cc : CarryChain limb_widths}. Local Notation digits := (tuple Z (length limb_widths)). Definition add_opt_sig (us vs : digits) : { b : digits | b = add us vs }. @@ -426,10 +450,34 @@ Section Addition. Definition add_opt_correct us vs : add_opt us vs = add us vs := proj2_sig (add_opt_sig us vs). + + Definition carry_add_opt_sig {T} (f:digits -> T) + (us vs : digits) : { x | x = f (carry_add carry_chain us vs) }. + Proof. + eexists. + cbv [carry_add]. + rewrite <-carry__opt_cps_correct, <-add_opt_correct. + cbv [carry_sequence_opt_cps carry__opt_cps add_opt add]. + rewrite to_list_from_list. + reflexivity. + Defined. + + Definition carry_add_opt_cps {T} (f:digits -> T) (us vs : digits) : T + := Eval cbv [proj1_sig carry_add_opt_sig] in proj1_sig (carry_add_opt_sig f us vs). + + Definition carry_add_opt_cps_correct {T} (f:digits -> T) (us vs : digits) + : carry_add_opt_cps f us vs = f (carry_add carry_chain us vs) + := proj2_sig (carry_add_opt_sig f us vs). + + Definition carry_add_opt := carry_add_opt_cps id. + + Definition carry_add_opt_correct (us vs : digits) + : carry_add_opt us vs = carry_add carry_chain us vs := + carry_add_opt_cps_correct id us vs. End Addition. Section Subtraction. - Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient}. + Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient} {cc : CarryChain limb_widths}. Local Notation digits := (tuple Z (length limb_widths)). Definition sub_opt_sig (us vs : digits) : { b : digits | b = sub coeff coeff_mod us vs }. @@ -446,6 +494,30 @@ Section Subtraction. : sub_opt us vs = sub coeff coeff_mod us vs := proj2_sig (sub_opt_sig us vs). + Definition carry_sub_opt_sig {T} (f:digits -> T) + (us vs : digits) : { x | x = f (carry_sub carry_chain coeff coeff_mod us vs) }. + Proof. + eexists. + cbv [carry_sub]. + rewrite <-carry__opt_cps_correct, <-sub_opt_correct. + cbv [carry_sequence_opt_cps carry__opt_cps sub_opt]. + rewrite to_list_from_list. + reflexivity. + Defined. + + Definition carry_sub_opt_cps {T} (f:digits -> T) (us vs : digits) : T + := Eval cbv [proj1_sig carry_sub_opt_sig] in proj1_sig (carry_sub_opt_sig f us vs). + + Definition carry_sub_opt_cps_correct {T} (f:digits -> T) (us vs : digits) + : carry_sub_opt_cps f us vs = f (carry_sub carry_chain coeff coeff_mod us vs) + := proj2_sig (carry_sub_opt_sig f us vs). + + Definition carry_sub_opt := carry_sub_opt_cps id. + + Definition carry_sub_opt_correct (us vs : digits) + : carry_sub_opt us vs = carry_sub carry_chain coeff coeff_mod us vs := + carry_sub_opt_cps_correct id us vs. + Definition opp_opt_sig (us : digits) : { b : digits | b = opp coeff coeff_mod us }. Proof. eexists. @@ -461,6 +533,30 @@ Section Subtraction. : opp_opt us = opp coeff coeff_mod us := proj2_sig (opp_opt_sig us). + Definition carry_opp_opt_sig {T} (f:digits -> T) + (us : digits) : { x | x = f (carry_opp carry_chain coeff coeff_mod us) }. + Proof. + eexists. + cbv [carry_opp]. + rewrite <-carry__opt_cps_correct, <-opp_opt_correct. + cbv [carry_sequence_opt_cps carry__opt_cps opp_opt opp sub_opt]. + rewrite to_list_from_list. + reflexivity. + Defined. + + Definition carry_opp_opt_cps {T} (f:digits -> T) (us : digits) : T + := Eval cbv [proj1_sig carry_opp_opt_sig] in proj1_sig (carry_opp_opt_sig f us). + + Definition carry_opp_opt_cps_correct {T} (f:digits -> T) (us : digits) + : carry_opp_opt_cps f us = f (carry_opp carry_chain coeff coeff_mod us) + := proj2_sig (carry_opp_opt_sig f us). + + Definition carry_opp_opt := carry_opp_opt_cps id. + + Definition carry_opp_opt_correct (us : digits) + : carry_opp_opt us = carry_opp carry_chain coeff coeff_mod us := + carry_opp_opt_cps_correct id us. + End Subtraction. Section Multiplication. @@ -616,11 +712,8 @@ Section Multiplication. Proof. eexists. cbv [carry_mul]. - rewrite <- from_list_default_eq with (d := 0%Z). - change @from_list_default with @from_list_default_opt. - erewrite <-carry_sequence_opt_cps_correct by eauto using carry_chain_valid, length_to_list. - erewrite <-mul_opt_correct. - cbv [carry_sequence_opt_cps mul_opt]. + rewrite <-carry__opt_cps_correct, <-mul_opt_correct. + cbv [carry_sequence_opt_cps carry__opt_cps mul_opt]. erewrite from_list_default_eq. rewrite to_list_from_list. reflexivity. diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index c160eca7f..9cef5710d 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -360,7 +360,7 @@ Section FieldOperationProofs. + eapply _zero_neq_one. + trivial. Qed. - End FieldProofs. +End FieldProofs. End FieldOperationProofs. Opaque encode add mul sub inv pow. diff --git a/src/Specific/GF25519.v b/src/Specific/GF25519.v index d5dc43a7a..faf8b0519 100644 --- a/src/Specific/GF25519.v +++ b/src/Specific/GF25519.v @@ -111,7 +111,7 @@ Defined. Arguments chain {_ _ _} _. -(* END precomputation *) +(* END precomputation *) (* Precompute constants *) Definition k_ := Eval compute in k. @@ -192,6 +192,26 @@ Definition add_correct (f g : fe25519) Eval cbv beta iota delta [proj1_sig add_sig] in proj2_sig (add_sig f g). +Definition carry_add_sig (f g : fe25519) : + { fg : fe25519 | fg = carry_add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe25519). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. (* FIXME: The speed of this rewrite depends on the fact that we have 10 limbs; there are some lemmas in [zsimplify_Z_to_pos] which are specific to 10. *) + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_add (f g : fe25519) : fe25519 := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj1_sig (carry_add_sig f g). + +Definition carry_add_correct (f g : fe25519) + : carry_add f g = carry_add_opt f g := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj2_sig (carry_add_sig f g). + Definition sub_sig (f g : fe25519) : { fg : fe25519 | fg = sub_opt f g}. Proof. @@ -210,6 +230,26 @@ Definition sub_correct (f g : fe25519) Eval cbv beta iota delta [proj1_sig sub_sig] in proj2_sig (sub_sig f g). +Definition carry_sub_sig (f g : fe25519) : + { fg : fe25519 | fg = carry_sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe25519). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. (* FIXME: The speed of this rewrite depends on the fact that we have 10 limbs; there are some lemmas in [zsimplify_Z_to_pos] which are specific to 10. *) + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_sub (f g : fe25519) : fe25519 := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj1_sig (carry_sub_sig f g). + +Definition carry_sub_correct (f g : fe25519) + : carry_sub f g = carry_sub_opt f g := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj2_sig (carry_sub_sig f g). + (* For multiplication, we add another layer of definition so that we can rewrite under the [let] binders. *) Definition mul_simpl_sig (f g : fe25519) : @@ -249,6 +289,8 @@ Proof. rewrite <-mul_simpl_correct. rewrite <-(@appify2_correct fe25519). cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. (* FIXME: The speed of this rewrite depends on the fact that we have 10 limbs; there are some lemmas in [zsimplify_Z_to_pos] which are specific to 10. *) + autorewrite with zsimplify_Z_to_pos; cbv. reflexivity. Defined. @@ -279,6 +321,24 @@ Definition opp_correct (f : fe25519) : opp f = opp_opt f := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (opp_sig f). +Definition carry_opp_sig (f : fe25519) : + { g : fe25519 | g = carry_opp_opt f }. +Proof. + eexists. + rewrite <-(@app_10_correct fe25519). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. (* FIXME: The speed of this rewrite depends on the fact that we have 10 limbs; there are some lemmas in [zsimplify_Z_to_pos] which are specific to 10. *) + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_opp (f : fe25519) : fe25519 + := Eval cbv beta iota delta [proj1_sig carry_opp_sig] in proj1_sig (carry_opp_sig f). + +Definition carry_opp_correct (f : fe25519) + : carry_opp f = carry_opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (carry_opp_sig f). + Definition pow (f : fe25519) chain := fold_chain_opt one_ mul chain [f]. Lemma pow_correct (f : fe25519) : forall chain, pow f chain = pow_opt k_ c_ one_ f chain. @@ -345,6 +405,14 @@ Proof. + intros; rewrite opp_correct, opp_opt_correct; reflexivity. Qed. + +(** TODO(jadep from jgross): Fill me in *) +Lemma carry_field25519 : @field fe25519 eq zero one carry_opp carry_add carry_sub mul inv div. +Proof. + pose proof (Equivalence_Reflexive : Reflexive eq). + (*eapply (Field.equivalent_operations_field (fieldR := mbs_field)).*) +Admitted. + Lemma homomorphism_F25519 : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul @@ -361,6 +429,23 @@ Proof. + reflexivity. Qed. +(** TODO(jadep from jgross): Remove admits in this proof *) +Lemma homomorphism_carry_F25519 : + @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe25519 eq one carry_add mul encode. +Proof. + econstructor. + + econstructor; [ | apply encode_Proper]. + intros; cbv [eq]. + rewrite carry_add_correct, carry_add_opt_correct; admit; rewrite add_rep; apply encode_rep. + + intros; cbv [eq]. + rewrite mul_correct, carry_mul_opt_correct, carry_mul_rep + by auto using k_subst, c_subst, encode_rep. + apply encode_rep. + + reflexivity. +Admitted. + Definition ge_modulus_sig (f : fe25519) : { b : bool | b = ge_modulus_opt (to_list 10 f) }. Proof. @@ -563,4 +648,4 @@ Definition unpack (f : wire_digits) : fe25519 := Definition unpack_correct (f : wire_digits) : unpack f = unpack_opt params25519 wire_widths_nonneg bits_eq f - := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (unpack_sig f).
\ No newline at end of file + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (unpack_sig f). |