diff options
author | 2016-09-05 12:35:38 -0400 | |
---|---|---|
committer | 2016-09-06 11:20:59 -0400 | |
commit | ebb83ddb57aa8da5dbaae11de69c2fdc1a3e8c97 (patch) | |
tree | f595a933abd65fde2632e7929e3c341cceba9bd9 /src/ModularArithmetic | |
parent | c00aa881d043c40f6dda4c304c28ef199064f143 (diff) |
Pushed [freeze] through to GF25519 in preparation for defining [sqrt], cleaning up freeze-related organization and definitions along the way
Diffstat (limited to 'src/ModularArithmetic')
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystem.v | 12 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemList.v | 4 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemListProofs.v | 8 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemOpt.v | 143 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemProofs.v | 10 |
5 files changed, 84 insertions, 93 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index b8256f5a8..71851ac5d 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -67,16 +67,11 @@ Section ModularBaseSystem. Local Notation "u ~= x" := (rep u x). Local Hint Unfold rep. - Definition carry_full (us : digits) : digits := from_list (carry_full [[us]]) - (length_carry_full length_to_list). - - Definition freeze (us : digits) : digits := - let us' := carry_full (carry_full (carry_full us)) in - from_list (conditional_subtract_modulus [[us']] (ge_modulus [[us']])) - (length_conditional_subtract_modulus length_to_list). - Definition eq (x y : digits) : Prop := decode x = decode y. + Definition freeze (x : digits) : digits := + from_list (freeze [[x]]) (length_freeze length_to_list). + Definition eqb (x y : digits) : bool := fieldwiseb Z.eqb (freeze x) (freeze y). (* Note : both of the following square root definitions will produce garbage output if the input is @@ -91,6 +86,7 @@ Section ModularBaseSystem. (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus / 4 + 1)) (x : digits) : digits := pow x chain. + Import Morphisms. Global Instance eq_Equivalence : Equivalence eq. Proof. diff --git a/src/ModularArithmetic/ModularBaseSystemList.v b/src/ModularArithmetic/ModularBaseSystemList.v index d117635cd..836c7644c 100644 --- a/src/ModularArithmetic/ModularBaseSystemList.v +++ b/src/ModularArithmetic/ModularBaseSystemList.v @@ -69,6 +69,10 @@ Section Defs. Otherwise, it's all zeroes, and the subtractions do nothing. *) map2 (fun x y => x - y) us (map (Z.land and_term) modulus_digits). + Definition freeze (us : digits) : digits := + let us' := carry_full (carry_full (carry_full us)) in + conditional_subtract_modulus us' (ge_modulus us'). + Context {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x) (bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn target_widths (length target_widths)). diff --git a/src/ModularArithmetic/ModularBaseSystemListProofs.v b/src/ModularArithmetic/ModularBaseSystemListProofs.v index a12d88f9c..16699b8a2 100644 --- a/src/ModularArithmetic/ModularBaseSystemListProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemListProofs.v @@ -116,6 +116,14 @@ Section LengthProofs. rewrite map2_length, map_length, length_modulus_digits. apply Min.min_case; omega. Qed. + Hint Rewrite @length_conditional_subtract_modulus : distr_length. + + Lemma length_freeze {u} : + length u = length limb_widths + -> length (freeze u) = length limb_widths. + Proof. + intros; unfold freeze; repeat autorewrite with distr_length; congruence. + Qed. Lemma length_pack : forall {target_widths} {target_widths_nonneg : forall x, In x target_widths -> 0 <= x} diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index cb79ff868..878b62abe 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -160,6 +160,23 @@ Ltac kill_precondition H := forward H; [abstract (try exact eq_refl; clear; cbv; intros; repeat break_or_hyp; intuition)|]; subst_precondition. +Lemma Let_In_push : forall {A B C} (g : A -> B) (f : B -> C) x, + f (Let_In x g) = Let_In x (fun y => f (g y)). +Proof. + intros. + cbv [Let_In]. + reflexivity. +Qed. + +Lemma Let_In_ext : forall {A B} (f g : A -> B) x, + (forall x, f x = g x) -> + Let_In x g = Let_In x f. +Proof. + intros. + cbv [Let_In]. + congruence. +Qed. + Section Carries. Context `{prm : PseudoMersenneBaseParams} (* allows caller to precompute k and c *) @@ -365,44 +382,49 @@ Section Carries. apply Pow2BaseProofs.make_chain_lt; auto. Qed. - Definition carry_full_opt_sig (us : digits) : { b : digits | b = carry_full us }. + Definition carry_full_opt_sig (us : list Z) : + { b : list Z | (length us = length limb_widths) + -> b = carry_full us }. Proof. - eexists. - cbv [carry_full ModularBaseSystemList.carry_full]. - rewrite <-from_list_default_eq with (d := 0). - rewrite <-carry_sequence_opt_cps_correct by (rewrite ?length_to_list; auto; apply full_carry_chain_bounds). + eexists; cbv [carry_full]; intros. + match goal with |- ?LHS = ?RHS => change (LHS = id RHS) end. + rewrite <-carry_sequence_opt_cps_correct with (f := id) by (auto; apply full_carry_chain_bounds). change @Pow2Base.full_carry_chain with full_carry_chain_opt. reflexivity. Defined. - Definition carry_full_opt (us : digits) : digits + Definition carry_full_opt (us : list Z) : list Z := Eval cbv [proj1_sig carry_full_opt_sig] in proj1_sig (carry_full_opt_sig us). - Definition carry_full_opt_correct us : carry_full_opt us = carry_full us := - proj2_sig (carry_full_opt_sig us). + Definition carry_full_opt_correct us + : length us = length limb_widths + -> carry_full_opt us = carry_full us + := proj2_sig (carry_full_opt_sig us). Definition carry_full_opt_cps_sig {T} - (f : digits -> T) - (us : digits) - : { d : T | d = f (carry_full us) }. + (f : list Z -> T) + (us : list Z) + : { d : T | length us = length limb_widths + -> d = f (carry_full us) }. Proof. - eexists. - rewrite <- carry_full_opt_correct. + eexists; intros. + rewrite <- carry_full_opt_correct by auto. cbv beta iota delta [carry_full_opt]. - rewrite carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds). + rewrite carry_sequence_opt_cps_correct by (auto || apply full_carry_chain_bounds). match goal with |- ?LHS = ?f (?g (carry_sequence ?is ?us)) => change (LHS = (fun x => f (g x)) (carry_sequence is us)) end. - rewrite <-carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds). + rewrite <-carry_sequence_opt_cps_correct by (auto || apply full_carry_chain_bounds). reflexivity. Defined. - Definition carry_full_opt_cps {T} (f : digits -> T) (us : digits) : T + Definition carry_full_opt_cps {T} (f : list Z -> T) (us : list Z) : T := Eval cbv [proj1_sig carry_full_opt_cps_sig] in proj1_sig (carry_full_opt_cps_sig f us). - Definition carry_full_opt_cps_correct {T} us (f : digits -> T) : - carry_full_opt_cps f us = f (carry_full us) := - proj2_sig (carry_full_opt_cps_sig f us). + Definition carry_full_opt_cps_correct {T} us (f : list Z -> T) + : length us = length limb_widths + -> carry_full_opt_cps f us = f (carry_full us) + := proj2_sig (carry_full_opt_cps_sig f us). End Carries. @@ -844,44 +866,28 @@ Section Canonicalization. {int_width} (preconditions : freezePreconditions prm int_width). Local Notation digits := (tuple Z (length limb_widths)). - Definition encodeZ_opt := Eval compute in Pow2Base.encodeZ. - - Definition modulus_digits_opt_sig : - { b : list Z | b = modulus_digits }. - Proof. - eexists. - cbv beta iota delta [modulus_digits]. - change Pow2Base.encodeZ with encodeZ_opt. - reflexivity. - Defined. - - Definition modulus_digits_opt : list Z - := Eval cbv [proj1_sig modulus_digits_opt_sig] in proj1_sig (modulus_digits_opt_sig). - - Definition modulus_digits_opt_correct - : modulus_digits_opt = modulus_digits - := proj2_sig (modulus_digits_opt_sig). - Definition carry_full_3_opt_cps_sig - {T} (f : digits -> T) - (us : digits) - : { d : T | d = f (carry_full (carry_full (carry_full us))) }. + {T} (f : list Z -> T) + (us : list Z) + : { d : T | length us = length limb_widths + -> d = f (carry_full (carry_full (carry_full us))) }. Proof. eexists. transitivity (carry_full_opt_cps c_ (carry_full_opt_cps c_ (carry_full_opt_cps c_ f)) us). Focus 2. { - rewrite !carry_full_opt_cps_correct by assumption; reflexivity. + rewrite !carry_full_opt_cps_correct; repeat (autorewrite with distr_length; rewrite ?length_carry_full; auto). } Unfocus. reflexivity. Defined. - Definition carry_full_3_opt_cps {T} (f : digits -> T) (us : digits) : T + Definition carry_full_3_opt_cps {T} (f : list Z -> T) (us : list Z) : T := Eval cbv [proj1_sig carry_full_3_opt_cps_sig] in proj1_sig (carry_full_3_opt_cps_sig f us). - Definition carry_full_3_opt_cps_correct {T} (f : digits -> T) us : - carry_full_3_opt_cps f us = f (carry_full (carry_full (carry_full us))) := - proj2_sig (carry_full_3_opt_cps_sig f us). + Definition carry_full_3_opt_cps_correct {T} (f : list Z -> T) us + : length us = length limb_widths + -> carry_full_3_opt_cps f us = f (carry_full (carry_full (carry_full us))) + := proj2_sig (carry_full_3_opt_cps_sig f us). Definition conditional_subtract_modulus_opt_sig (f : list Z) (cond : bool) : { g | g = conditional_subtract_modulus f cond}. @@ -894,7 +900,6 @@ Section Canonicalization. let RHSf := match (eval pattern (if cond then a else b) in RHS) with ?RHSf _ => RHSf end in change (LHS = Let_In (if cond then a else b) RHSf) end. cbv [map2 map]. - change modulus_digits with modulus_digits_opt. change @max_ones with max_ones_opt. reflexivity. Defined. @@ -906,35 +911,32 @@ Section Canonicalization. : conditional_subtract_modulus_opt f cond = conditional_subtract_modulus f cond := Eval cbv [proj2_sig conditional_subtract_modulus_opt_sig] in proj2_sig (conditional_subtract_modulus_opt_sig f cond). - Definition freeze_opt_sig (us : digits) : - { b : digits | b = freeze us }. + Definition freeze_opt_sig (us : list Z) : + { b : list Z | length us = length limb_widths + -> b = ModularBaseSystemList.freeze us }. Proof. eexists. - cbv [freeze]. - rewrite <-from_list_default_eq with (d := 0%Z). - change (@from_list_default Z) with (@from_list_default_opt Z). + cbv [ModularBaseSystemList.freeze]. rewrite <-conditional_subtract_modulus_opt_correct. - let LHS := match goal with |- ?LHS = ?RHS => LHS end in - let RHS := match goal with |- ?LHS = ?RHS => RHS end in - let RHSf := match (eval pattern (to_list (length limb_widths) (carry_full (carry_full (carry_full us)))) in RHS) with ?RHSf _ => RHSf end in - change (LHS = Let_In (to_list (length limb_widths) (carry_full (carry_full (carry_full us)))) RHSf). + intros. let LHS := match goal with |- ?LHS = ?RHS => LHS end in let RHS := match goal with |- ?LHS = ?RHS => RHS end in let RHSf := match (eval pattern (carry_full (carry_full (carry_full us))) in RHS) with ?RHSf _ => RHSf end in - rewrite <-carry_full_3_opt_cps_correct with (f := RHSf). + rewrite <-carry_full_3_opt_cps_correct with (f := RHSf) by auto. cbv beta iota delta [ge_modulus ge_modulus']. change length with length_opt. - change (nth_default 0 modulus_digits) with (nth_default_opt 0 modulus_digits_opt). + change nth_default with @nth_default_opt. change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt. change minus with minus_opt. reflexivity. Defined. - Definition freeze_opt (us : digits) : digits + Definition freeze_opt (us : list Z) : list Z := Eval cbv beta iota delta [proj1_sig freeze_opt_sig] in proj1_sig (freeze_opt_sig us). Definition freeze_opt_correct us - : freeze_opt us = freeze us + : length us = length limb_widths + -> freeze_opt us = ModularBaseSystemList.freeze us := proj2_sig (freeze_opt_sig us). End Canonicalization. @@ -947,21 +949,6 @@ Section SquareRoots. Context (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_) (one_ : digits) (one_subst : one = one_). - Definition eqb_opt_sig (us vs : digits) : - { b | b = ModularBaseSystem.eqb us vs}. - Proof. - eexists; cbv [ModularBaseSystem.eqb]. - cbv [fieldwiseb fieldwiseb']. - erewrite <-!freeze_opt_correct by eassumption. - reflexivity. - Defined. - - Definition eqb_opt us vs := Eval cbv [proj1_sig eqb_opt_sig] in proj1_sig (eqb_opt_sig us vs). - - Definition eqb_opt_correct us vs - : eqb_opt us vs = ModularBaseSystem.eqb us vs - := Eval cbv [proj2_sig eqb_opt_sig] in proj2_sig (eqb_opt_sig us vs). - (* TODO : where should this lemma go? Alternatively, is there a standard-library tactic/lemma for this? *) Lemma if_equiv : forall {A} (eqA : A -> A -> Prop) (x0 x1 : bool) y0 y1 z0 z1, @@ -1004,6 +991,8 @@ Section SquareRoots. etransitivity. Focus 2. { apply if_equiv. { + etransitivity. + Focus 2. { apply eqb_Proper; [ | reflexivity ]. transitivity (carry_mul_opt k_ c_ (pow_opt k_ c_ one_ us chain) (pow_opt k_ c_ one_ us chain)); [ reflexivity | ]. cbv [eq]. @@ -1011,6 +1000,11 @@ Section SquareRoots. rewrite carry_mul_rep by reflexivity. rewrite mul_rep by reflexivity. f_equal; apply pow_opt_correct; auto. + } Unfocus. + cbv [ModularBaseSystem.eqb freeze]. + rewrite <-!from_list_default_eq with (d := 0). + erewrite <-!freeze_opt_correct by eauto using length_to_list. + reflexivity. } { apply pow_opt_correct; auto. } { @@ -1024,7 +1018,6 @@ Section SquareRoots. reflexivity. } } Unfocus. - rewrite <-eqb_opt_correct. rewrite k_subst, c_subst, one_subst. let LHS := match goal with |- eq ?LHS ?RHS => LHS end in let RHS := match goal with |- eq ?LHS ?RHS => RHS end in diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index 3b04eda2f..561c1ae81 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -447,16 +447,6 @@ Section CarryProofs. apply IHis; auto using in_cons. Qed. - Lemma carry_full_preserves_rep : forall us x, - rep us x -> rep (carry_full us) x. - Proof. - unfold carry_full; intros. - apply carry_sequence_rep; auto. - unfold full_carry_chain; apply make_chain_lt. - Qed. - - Opaque carry_full. - Context `{cc : CarryChain limb_widths}. Lemma carry_mul_rep : forall us vs x y, rep us x -> rep vs y -> |