diff options
author | 2016-09-05 12:35:38 -0400 | |
---|---|---|
committer | 2016-09-06 11:20:59 -0400 | |
commit | ebb83ddb57aa8da5dbaae11de69c2fdc1a3e8c97 (patch) | |
tree | f595a933abd65fde2632e7929e3c341cceba9bd9 /src | |
parent | c00aa881d043c40f6dda4c304c28ef199064f143 (diff) |
Pushed [freeze] through to GF25519 in preparation for defining [sqrt], cleaning up freeze-related organization and definitions along the way
Diffstat (limited to 'src')
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystem.v | 12 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemList.v | 4 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemListProofs.v | 8 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemOpt.v | 143 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemProofs.v | 10 | ||||
-rw-r--r-- | src/Specific/GF25519.v | 69 |
6 files changed, 147 insertions, 99 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index b8256f5a8..71851ac5d 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -67,16 +67,11 @@ Section ModularBaseSystem. Local Notation "u ~= x" := (rep u x). Local Hint Unfold rep. - Definition carry_full (us : digits) : digits := from_list (carry_full [[us]]) - (length_carry_full length_to_list). - - Definition freeze (us : digits) : digits := - let us' := carry_full (carry_full (carry_full us)) in - from_list (conditional_subtract_modulus [[us']] (ge_modulus [[us']])) - (length_conditional_subtract_modulus length_to_list). - Definition eq (x y : digits) : Prop := decode x = decode y. + Definition freeze (x : digits) : digits := + from_list (freeze [[x]]) (length_freeze length_to_list). + Definition eqb (x y : digits) : bool := fieldwiseb Z.eqb (freeze x) (freeze y). (* Note : both of the following square root definitions will produce garbage output if the input is @@ -91,6 +86,7 @@ Section ModularBaseSystem. (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus / 4 + 1)) (x : digits) : digits := pow x chain. + Import Morphisms. Global Instance eq_Equivalence : Equivalence eq. Proof. diff --git a/src/ModularArithmetic/ModularBaseSystemList.v b/src/ModularArithmetic/ModularBaseSystemList.v index d117635cd..836c7644c 100644 --- a/src/ModularArithmetic/ModularBaseSystemList.v +++ b/src/ModularArithmetic/ModularBaseSystemList.v @@ -69,6 +69,10 @@ Section Defs. Otherwise, it's all zeroes, and the subtractions do nothing. *) map2 (fun x y => x - y) us (map (Z.land and_term) modulus_digits). + Definition freeze (us : digits) : digits := + let us' := carry_full (carry_full (carry_full us)) in + conditional_subtract_modulus us' (ge_modulus us'). + Context {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x) (bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn target_widths (length target_widths)). diff --git a/src/ModularArithmetic/ModularBaseSystemListProofs.v b/src/ModularArithmetic/ModularBaseSystemListProofs.v index a12d88f9c..16699b8a2 100644 --- a/src/ModularArithmetic/ModularBaseSystemListProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemListProofs.v @@ -116,6 +116,14 @@ Section LengthProofs. rewrite map2_length, map_length, length_modulus_digits. apply Min.min_case; omega. Qed. + Hint Rewrite @length_conditional_subtract_modulus : distr_length. + + Lemma length_freeze {u} : + length u = length limb_widths + -> length (freeze u) = length limb_widths. + Proof. + intros; unfold freeze; repeat autorewrite with distr_length; congruence. + Qed. Lemma length_pack : forall {target_widths} {target_widths_nonneg : forall x, In x target_widths -> 0 <= x} diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index cb79ff868..878b62abe 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -160,6 +160,23 @@ Ltac kill_precondition H := forward H; [abstract (try exact eq_refl; clear; cbv; intros; repeat break_or_hyp; intuition)|]; subst_precondition. +Lemma Let_In_push : forall {A B C} (g : A -> B) (f : B -> C) x, + f (Let_In x g) = Let_In x (fun y => f (g y)). +Proof. + intros. + cbv [Let_In]. + reflexivity. +Qed. + +Lemma Let_In_ext : forall {A B} (f g : A -> B) x, + (forall x, f x = g x) -> + Let_In x g = Let_In x f. +Proof. + intros. + cbv [Let_In]. + congruence. +Qed. + Section Carries. Context `{prm : PseudoMersenneBaseParams} (* allows caller to precompute k and c *) @@ -365,44 +382,49 @@ Section Carries. apply Pow2BaseProofs.make_chain_lt; auto. Qed. - Definition carry_full_opt_sig (us : digits) : { b : digits | b = carry_full us }. + Definition carry_full_opt_sig (us : list Z) : + { b : list Z | (length us = length limb_widths) + -> b = carry_full us }. Proof. - eexists. - cbv [carry_full ModularBaseSystemList.carry_full]. - rewrite <-from_list_default_eq with (d := 0). - rewrite <-carry_sequence_opt_cps_correct by (rewrite ?length_to_list; auto; apply full_carry_chain_bounds). + eexists; cbv [carry_full]; intros. + match goal with |- ?LHS = ?RHS => change (LHS = id RHS) end. + rewrite <-carry_sequence_opt_cps_correct with (f := id) by (auto; apply full_carry_chain_bounds). change @Pow2Base.full_carry_chain with full_carry_chain_opt. reflexivity. Defined. - Definition carry_full_opt (us : digits) : digits + Definition carry_full_opt (us : list Z) : list Z := Eval cbv [proj1_sig carry_full_opt_sig] in proj1_sig (carry_full_opt_sig us). - Definition carry_full_opt_correct us : carry_full_opt us = carry_full us := - proj2_sig (carry_full_opt_sig us). + Definition carry_full_opt_correct us + : length us = length limb_widths + -> carry_full_opt us = carry_full us + := proj2_sig (carry_full_opt_sig us). Definition carry_full_opt_cps_sig {T} - (f : digits -> T) - (us : digits) - : { d : T | d = f (carry_full us) }. + (f : list Z -> T) + (us : list Z) + : { d : T | length us = length limb_widths + -> d = f (carry_full us) }. Proof. - eexists. - rewrite <- carry_full_opt_correct. + eexists; intros. + rewrite <- carry_full_opt_correct by auto. cbv beta iota delta [carry_full_opt]. - rewrite carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds). + rewrite carry_sequence_opt_cps_correct by (auto || apply full_carry_chain_bounds). match goal with |- ?LHS = ?f (?g (carry_sequence ?is ?us)) => change (LHS = (fun x => f (g x)) (carry_sequence is us)) end. - rewrite <-carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds). + rewrite <-carry_sequence_opt_cps_correct by (auto || apply full_carry_chain_bounds). reflexivity. Defined. - Definition carry_full_opt_cps {T} (f : digits -> T) (us : digits) : T + Definition carry_full_opt_cps {T} (f : list Z -> T) (us : list Z) : T := Eval cbv [proj1_sig carry_full_opt_cps_sig] in proj1_sig (carry_full_opt_cps_sig f us). - Definition carry_full_opt_cps_correct {T} us (f : digits -> T) : - carry_full_opt_cps f us = f (carry_full us) := - proj2_sig (carry_full_opt_cps_sig f us). + Definition carry_full_opt_cps_correct {T} us (f : list Z -> T) + : length us = length limb_widths + -> carry_full_opt_cps f us = f (carry_full us) + := proj2_sig (carry_full_opt_cps_sig f us). End Carries. @@ -844,44 +866,28 @@ Section Canonicalization. {int_width} (preconditions : freezePreconditions prm int_width). Local Notation digits := (tuple Z (length limb_widths)). - Definition encodeZ_opt := Eval compute in Pow2Base.encodeZ. - - Definition modulus_digits_opt_sig : - { b : list Z | b = modulus_digits }. - Proof. - eexists. - cbv beta iota delta [modulus_digits]. - change Pow2Base.encodeZ with encodeZ_opt. - reflexivity. - Defined. - - Definition modulus_digits_opt : list Z - := Eval cbv [proj1_sig modulus_digits_opt_sig] in proj1_sig (modulus_digits_opt_sig). - - Definition modulus_digits_opt_correct - : modulus_digits_opt = modulus_digits - := proj2_sig (modulus_digits_opt_sig). - Definition carry_full_3_opt_cps_sig - {T} (f : digits -> T) - (us : digits) - : { d : T | d = f (carry_full (carry_full (carry_full us))) }. + {T} (f : list Z -> T) + (us : list Z) + : { d : T | length us = length limb_widths + -> d = f (carry_full (carry_full (carry_full us))) }. Proof. eexists. transitivity (carry_full_opt_cps c_ (carry_full_opt_cps c_ (carry_full_opt_cps c_ f)) us). Focus 2. { - rewrite !carry_full_opt_cps_correct by assumption; reflexivity. + rewrite !carry_full_opt_cps_correct; repeat (autorewrite with distr_length; rewrite ?length_carry_full; auto). } Unfocus. reflexivity. Defined. - Definition carry_full_3_opt_cps {T} (f : digits -> T) (us : digits) : T + Definition carry_full_3_opt_cps {T} (f : list Z -> T) (us : list Z) : T := Eval cbv [proj1_sig carry_full_3_opt_cps_sig] in proj1_sig (carry_full_3_opt_cps_sig f us). - Definition carry_full_3_opt_cps_correct {T} (f : digits -> T) us : - carry_full_3_opt_cps f us = f (carry_full (carry_full (carry_full us))) := - proj2_sig (carry_full_3_opt_cps_sig f us). + Definition carry_full_3_opt_cps_correct {T} (f : list Z -> T) us + : length us = length limb_widths + -> carry_full_3_opt_cps f us = f (carry_full (carry_full (carry_full us))) + := proj2_sig (carry_full_3_opt_cps_sig f us). Definition conditional_subtract_modulus_opt_sig (f : list Z) (cond : bool) : { g | g = conditional_subtract_modulus f cond}. @@ -894,7 +900,6 @@ Section Canonicalization. let RHSf := match (eval pattern (if cond then a else b) in RHS) with ?RHSf _ => RHSf end in change (LHS = Let_In (if cond then a else b) RHSf) end. cbv [map2 map]. - change modulus_digits with modulus_digits_opt. change @max_ones with max_ones_opt. reflexivity. Defined. @@ -906,35 +911,32 @@ Section Canonicalization. : conditional_subtract_modulus_opt f cond = conditional_subtract_modulus f cond := Eval cbv [proj2_sig conditional_subtract_modulus_opt_sig] in proj2_sig (conditional_subtract_modulus_opt_sig f cond). - Definition freeze_opt_sig (us : digits) : - { b : digits | b = freeze us }. + Definition freeze_opt_sig (us : list Z) : + { b : list Z | length us = length limb_widths + -> b = ModularBaseSystemList.freeze us }. Proof. eexists. - cbv [freeze]. - rewrite <-from_list_default_eq with (d := 0%Z). - change (@from_list_default Z) with (@from_list_default_opt Z). + cbv [ModularBaseSystemList.freeze]. rewrite <-conditional_subtract_modulus_opt_correct. - let LHS := match goal with |- ?LHS = ?RHS => LHS end in - let RHS := match goal with |- ?LHS = ?RHS => RHS end in - let RHSf := match (eval pattern (to_list (length limb_widths) (carry_full (carry_full (carry_full us)))) in RHS) with ?RHSf _ => RHSf end in - change (LHS = Let_In (to_list (length limb_widths) (carry_full (carry_full (carry_full us)))) RHSf). + intros. let LHS := match goal with |- ?LHS = ?RHS => LHS end in let RHS := match goal with |- ?LHS = ?RHS => RHS end in let RHSf := match (eval pattern (carry_full (carry_full (carry_full us))) in RHS) with ?RHSf _ => RHSf end in - rewrite <-carry_full_3_opt_cps_correct with (f := RHSf). + rewrite <-carry_full_3_opt_cps_correct with (f := RHSf) by auto. cbv beta iota delta [ge_modulus ge_modulus']. change length with length_opt. - change (nth_default 0 modulus_digits) with (nth_default_opt 0 modulus_digits_opt). + change nth_default with @nth_default_opt. change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt. change minus with minus_opt. reflexivity. Defined. - Definition freeze_opt (us : digits) : digits + Definition freeze_opt (us : list Z) : list Z := Eval cbv beta iota delta [proj1_sig freeze_opt_sig] in proj1_sig (freeze_opt_sig us). Definition freeze_opt_correct us - : freeze_opt us = freeze us + : length us = length limb_widths + -> freeze_opt us = ModularBaseSystemList.freeze us := proj2_sig (freeze_opt_sig us). End Canonicalization. @@ -947,21 +949,6 @@ Section SquareRoots. Context (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_) (one_ : digits) (one_subst : one = one_). - Definition eqb_opt_sig (us vs : digits) : - { b | b = ModularBaseSystem.eqb us vs}. - Proof. - eexists; cbv [ModularBaseSystem.eqb]. - cbv [fieldwiseb fieldwiseb']. - erewrite <-!freeze_opt_correct by eassumption. - reflexivity. - Defined. - - Definition eqb_opt us vs := Eval cbv [proj1_sig eqb_opt_sig] in proj1_sig (eqb_opt_sig us vs). - - Definition eqb_opt_correct us vs - : eqb_opt us vs = ModularBaseSystem.eqb us vs - := Eval cbv [proj2_sig eqb_opt_sig] in proj2_sig (eqb_opt_sig us vs). - (* TODO : where should this lemma go? Alternatively, is there a standard-library tactic/lemma for this? *) Lemma if_equiv : forall {A} (eqA : A -> A -> Prop) (x0 x1 : bool) y0 y1 z0 z1, @@ -1004,6 +991,8 @@ Section SquareRoots. etransitivity. Focus 2. { apply if_equiv. { + etransitivity. + Focus 2. { apply eqb_Proper; [ | reflexivity ]. transitivity (carry_mul_opt k_ c_ (pow_opt k_ c_ one_ us chain) (pow_opt k_ c_ one_ us chain)); [ reflexivity | ]. cbv [eq]. @@ -1011,6 +1000,11 @@ Section SquareRoots. rewrite carry_mul_rep by reflexivity. rewrite mul_rep by reflexivity. f_equal; apply pow_opt_correct; auto. + } Unfocus. + cbv [ModularBaseSystem.eqb freeze]. + rewrite <-!from_list_default_eq with (d := 0). + erewrite <-!freeze_opt_correct by eauto using length_to_list. + reflexivity. } { apply pow_opt_correct; auto. } { @@ -1024,7 +1018,6 @@ Section SquareRoots. reflexivity. } } Unfocus. - rewrite <-eqb_opt_correct. rewrite k_subst, c_subst, one_subst. let LHS := match goal with |- eq ?LHS ?RHS => LHS end in let RHS := match goal with |- eq ?LHS ?RHS => RHS end in diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index 3b04eda2f..561c1ae81 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -447,16 +447,6 @@ Section CarryProofs. apply IHis; auto using in_cons. Qed. - Lemma carry_full_preserves_rep : forall us x, - rep us x -> rep (carry_full us) x. - Proof. - unfold carry_full; intros. - apply carry_sequence_rep; auto. - unfold full_carry_chain; apply make_chain_lt. - Qed. - - Opaque carry_full. - Context `{cc : CarryChain limb_widths}. Lemma carry_mul_rep : forall us vs x y, rep us x -> rep vs y -> diff --git a/src/Specific/GF25519.v b/src/Specific/GF25519.v index fe4428ca7..bf7811b86 100644 --- a/src/Specific/GF25519.v +++ b/src/Specific/GF25519.v @@ -37,7 +37,7 @@ Instance subCoeff : SubtractionCoefficient. Defined. Instance carryChain : CarryChain limb_widths. - apply Build_CarryChain with (carry_chain := ([0;1;2;3;4;5;6;7;8;9;0;1])%nat). + apply Build_CarryChain with (carry_chain := (rev [0;1;2;3;4;5;6;7;8;9;0;1])%nat). intros. repeat (destruct H as [|H]; [subst; vm_compute; repeat constructor | ]). contradiction H. @@ -111,17 +111,23 @@ Arguments chain {_ _ _} _. (* END precomputation *) -(* Precompute k, c, zero, and one *) +(* Precompute constants *) Definition k_ := Eval compute in k. -Definition c_ := Eval compute in c. -Definition one_ := Eval compute in one. -Definition zero_ := Eval compute in zero. Definition k_subst : k = k_ := eq_refl k_. + +Definition c_ := Eval compute in c. Definition c_subst : c = c_ := eq_refl c_. + +Definition one_ := Eval compute in one. Definition one_subst : one = one_ := eq_refl one_. + +Definition zero_ := Eval compute in zero. Definition zero_subst : zero = zero_ := eq_refl zero_. -Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In. +Definition modulus_digits_ := Eval compute in ModularBaseSystemList.modulus_digits. +Definition modulus_digits_subst : ModularBaseSystemList.modulus_digits = modulus_digits_ := eq_refl modulus_digits_. + +Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb andb. Definition app_7 {T} (f : wire_digits) (P : wire_digits -> T) : T. Proof. @@ -281,6 +287,23 @@ Proof. intros; subst; apply mul_correct. Qed. +(* Now that we have [pow], we can compute sqrt of -1 for use + in sqrt function (this is not needed unless the prime is + 5 mod 8) *) +Local Transparent Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb andb. + +Definition sqrt_m1 := Eval vm_compute in (pow (encode (F.of_Z _ 2)) (pow2_chain (Z.to_pos ((modulus - 1) / 4)))). + +Lemma sqrt_m1_correct : rep (mul sqrt_m1 sqrt_m1) (F.opp 1%F). +Proof. + cbv [rep]. + apply F.eq_to_Z_iff. + vm_compute. + reflexivity. +Qed. + +Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb andb. + Definition inv_sig (f : fe25519) : { g : fe25519 | g = inv_opt k_ c_ one_ f }. Proof. @@ -336,6 +359,40 @@ Proof. + reflexivity. Qed. +Definition freeze_sig (f : fe25519) : + { f' : fe25519 | f' = from_list_default 0 10 (freeze_opt c_ (to_list 10 f)) }. +Proof. + cbv [fe25519] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [freeze_opt]. + cbv [to_list to_list']. + cbv [conditional_subtract_modulus_opt]. + rewrite !modulus_digits_subst. + cbv - [from_list_default]. + rewrite Let_In_push. + repeat (erewrite Let_In_ext; [ | + repeat match goal with + | |- _ => progress intros; try apply Let_In_ext + | |- _ = from_list_default _ _ (Let_In _ _) => etransitivity; try (rewrite Let_In_push; reflexivity) + | |- from_list_default _ _ (Let_In _ _) = _ => etransitivity; try (rewrite Let_In_push; reflexivity) + end; reflexivity ]). + cbv [from_list_default from_list_default']. + reflexivity. +Defined. + +Definition freeze (f : fe25519) : fe25519 := + Eval cbv beta iota delta [proj1_sig freeze_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in + proj1_sig (freeze_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9)). + +Definition freeze_correct (f : fe25519) + : freeze f = from_list_default 0 10 (freeze_opt c_ (to_list 10 f)). +Proof. + pose proof (proj2_sig (freeze_sig f)). + cbv [fe25519] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. Definition pack_simpl_sig (f : fe25519) : { f' | f' = pack_opt params25519 wire_widths_nonneg bits_eq f }. |