aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2016-09-05 12:35:38 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2016-09-06 11:20:59 -0400
commitebb83ddb57aa8da5dbaae11de69c2fdc1a3e8c97 (patch)
treef595a933abd65fde2632e7929e3c341cceba9bd9 /src/ModularArithmetic
parentc00aa881d043c40f6dda4c304c28ef199064f143 (diff)
Pushed [freeze] through to GF25519 in preparation for defining [sqrt], cleaning up freeze-related organization and definitions along the way
Diffstat (limited to 'src/ModularArithmetic')
-rw-r--r--src/ModularArithmetic/ModularBaseSystem.v12
-rw-r--r--src/ModularArithmetic/ModularBaseSystemList.v4
-rw-r--r--src/ModularArithmetic/ModularBaseSystemListProofs.v8
-rw-r--r--src/ModularArithmetic/ModularBaseSystemOpt.v143
-rw-r--r--src/ModularArithmetic/ModularBaseSystemProofs.v10
5 files changed, 84 insertions, 93 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v
index b8256f5a8..71851ac5d 100644
--- a/src/ModularArithmetic/ModularBaseSystem.v
+++ b/src/ModularArithmetic/ModularBaseSystem.v
@@ -67,16 +67,11 @@ Section ModularBaseSystem.
Local Notation "u ~= x" := (rep u x).
Local Hint Unfold rep.
- Definition carry_full (us : digits) : digits := from_list (carry_full [[us]])
- (length_carry_full length_to_list).
-
- Definition freeze (us : digits) : digits :=
- let us' := carry_full (carry_full (carry_full us)) in
- from_list (conditional_subtract_modulus [[us']] (ge_modulus [[us']]))
- (length_conditional_subtract_modulus length_to_list).
-
Definition eq (x y : digits) : Prop := decode x = decode y.
+ Definition freeze (x : digits) : digits :=
+ from_list (freeze [[x]]) (length_freeze length_to_list).
+
Definition eqb (x y : digits) : bool := fieldwiseb Z.eqb (freeze x) (freeze y).
(* Note : both of the following square root definitions will produce garbage output if the input is
@@ -91,6 +86,7 @@ Section ModularBaseSystem.
(chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus / 4 + 1))
(x : digits) : digits := pow x chain.
+
Import Morphisms.
Global Instance eq_Equivalence : Equivalence eq.
Proof.
diff --git a/src/ModularArithmetic/ModularBaseSystemList.v b/src/ModularArithmetic/ModularBaseSystemList.v
index d117635cd..836c7644c 100644
--- a/src/ModularArithmetic/ModularBaseSystemList.v
+++ b/src/ModularArithmetic/ModularBaseSystemList.v
@@ -69,6 +69,10 @@ Section Defs.
Otherwise, it's all zeroes, and the subtractions do nothing. *)
map2 (fun x y => x - y) us (map (Z.land and_term) modulus_digits).
+ Definition freeze (us : digits) : digits :=
+ let us' := carry_full (carry_full (carry_full us)) in
+ conditional_subtract_modulus us' (ge_modulus us').
+
Context {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x)
(bits_eq : sum_firstn limb_widths (length limb_widths) =
sum_firstn target_widths (length target_widths)).
diff --git a/src/ModularArithmetic/ModularBaseSystemListProofs.v b/src/ModularArithmetic/ModularBaseSystemListProofs.v
index a12d88f9c..16699b8a2 100644
--- a/src/ModularArithmetic/ModularBaseSystemListProofs.v
+++ b/src/ModularArithmetic/ModularBaseSystemListProofs.v
@@ -116,6 +116,14 @@ Section LengthProofs.
rewrite map2_length, map_length, length_modulus_digits.
apply Min.min_case; omega.
Qed.
+ Hint Rewrite @length_conditional_subtract_modulus : distr_length.
+
+ Lemma length_freeze {u} :
+ length u = length limb_widths
+ -> length (freeze u) = length limb_widths.
+ Proof.
+ intros; unfold freeze; repeat autorewrite with distr_length; congruence.
+ Qed.
Lemma length_pack : forall {target_widths}
{target_widths_nonneg : forall x, In x target_widths -> 0 <= x}
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v
index cb79ff868..878b62abe 100644
--- a/src/ModularArithmetic/ModularBaseSystemOpt.v
+++ b/src/ModularArithmetic/ModularBaseSystemOpt.v
@@ -160,6 +160,23 @@ Ltac kill_precondition H :=
forward H; [abstract (try exact eq_refl; clear; cbv; intros; repeat break_or_hyp; intuition)|];
subst_precondition.
+Lemma Let_In_push : forall {A B C} (g : A -> B) (f : B -> C) x,
+ f (Let_In x g) = Let_In x (fun y => f (g y)).
+Proof.
+ intros.
+ cbv [Let_In].
+ reflexivity.
+Qed.
+
+Lemma Let_In_ext : forall {A B} (f g : A -> B) x,
+ (forall x, f x = g x) ->
+ Let_In x g = Let_In x f.
+Proof.
+ intros.
+ cbv [Let_In].
+ congruence.
+Qed.
+
Section Carries.
Context `{prm : PseudoMersenneBaseParams}
(* allows caller to precompute k and c *)
@@ -365,44 +382,49 @@ Section Carries.
apply Pow2BaseProofs.make_chain_lt; auto.
Qed.
- Definition carry_full_opt_sig (us : digits) : { b : digits | b = carry_full us }.
+ Definition carry_full_opt_sig (us : list Z) :
+ { b : list Z | (length us = length limb_widths)
+ -> b = carry_full us }.
Proof.
- eexists.
- cbv [carry_full ModularBaseSystemList.carry_full].
- rewrite <-from_list_default_eq with (d := 0).
- rewrite <-carry_sequence_opt_cps_correct by (rewrite ?length_to_list; auto; apply full_carry_chain_bounds).
+ eexists; cbv [carry_full]; intros.
+ match goal with |- ?LHS = ?RHS => change (LHS = id RHS) end.
+ rewrite <-carry_sequence_opt_cps_correct with (f := id) by (auto; apply full_carry_chain_bounds).
change @Pow2Base.full_carry_chain with full_carry_chain_opt.
reflexivity.
Defined.
- Definition carry_full_opt (us : digits) : digits
+ Definition carry_full_opt (us : list Z) : list Z
:= Eval cbv [proj1_sig carry_full_opt_sig] in proj1_sig (carry_full_opt_sig us).
- Definition carry_full_opt_correct us : carry_full_opt us = carry_full us :=
- proj2_sig (carry_full_opt_sig us).
+ Definition carry_full_opt_correct us
+ : length us = length limb_widths
+ -> carry_full_opt us = carry_full us
+ := proj2_sig (carry_full_opt_sig us).
Definition carry_full_opt_cps_sig
{T}
- (f : digits -> T)
- (us : digits)
- : { d : T | d = f (carry_full us) }.
+ (f : list Z -> T)
+ (us : list Z)
+ : { d : T | length us = length limb_widths
+ -> d = f (carry_full us) }.
Proof.
- eexists.
- rewrite <- carry_full_opt_correct.
+ eexists; intros.
+ rewrite <- carry_full_opt_correct by auto.
cbv beta iota delta [carry_full_opt].
- rewrite carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds).
+ rewrite carry_sequence_opt_cps_correct by (auto || apply full_carry_chain_bounds).
match goal with |- ?LHS = ?f (?g (carry_sequence ?is ?us)) =>
change (LHS = (fun x => f (g x)) (carry_sequence is us)) end.
- rewrite <-carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds).
+ rewrite <-carry_sequence_opt_cps_correct by (auto || apply full_carry_chain_bounds).
reflexivity.
Defined.
- Definition carry_full_opt_cps {T} (f : digits -> T) (us : digits) : T
+ Definition carry_full_opt_cps {T} (f : list Z -> T) (us : list Z) : T
:= Eval cbv [proj1_sig carry_full_opt_cps_sig] in proj1_sig (carry_full_opt_cps_sig f us).
- Definition carry_full_opt_cps_correct {T} us (f : digits -> T) :
- carry_full_opt_cps f us = f (carry_full us) :=
- proj2_sig (carry_full_opt_cps_sig f us).
+ Definition carry_full_opt_cps_correct {T} us (f : list Z -> T)
+ : length us = length limb_widths
+ -> carry_full_opt_cps f us = f (carry_full us)
+ := proj2_sig (carry_full_opt_cps_sig f us).
End Carries.
@@ -844,44 +866,28 @@ Section Canonicalization.
{int_width} (preconditions : freezePreconditions prm int_width).
Local Notation digits := (tuple Z (length limb_widths)).
- Definition encodeZ_opt := Eval compute in Pow2Base.encodeZ.
-
- Definition modulus_digits_opt_sig :
- { b : list Z | b = modulus_digits }.
- Proof.
- eexists.
- cbv beta iota delta [modulus_digits].
- change Pow2Base.encodeZ with encodeZ_opt.
- reflexivity.
- Defined.
-
- Definition modulus_digits_opt : list Z
- := Eval cbv [proj1_sig modulus_digits_opt_sig] in proj1_sig (modulus_digits_opt_sig).
-
- Definition modulus_digits_opt_correct
- : modulus_digits_opt = modulus_digits
- := proj2_sig (modulus_digits_opt_sig).
-
Definition carry_full_3_opt_cps_sig
- {T} (f : digits -> T)
- (us : digits)
- : { d : T | d = f (carry_full (carry_full (carry_full us))) }.
+ {T} (f : list Z -> T)
+ (us : list Z)
+ : { d : T | length us = length limb_widths
+ -> d = f (carry_full (carry_full (carry_full us))) }.
Proof.
eexists.
transitivity (carry_full_opt_cps c_ (carry_full_opt_cps c_ (carry_full_opt_cps c_ f)) us).
Focus 2. {
- rewrite !carry_full_opt_cps_correct by assumption; reflexivity.
+ rewrite !carry_full_opt_cps_correct; repeat (autorewrite with distr_length; rewrite ?length_carry_full; auto).
}
Unfocus.
reflexivity.
Defined.
- Definition carry_full_3_opt_cps {T} (f : digits -> T) (us : digits) : T
+ Definition carry_full_3_opt_cps {T} (f : list Z -> T) (us : list Z) : T
:= Eval cbv [proj1_sig carry_full_3_opt_cps_sig] in proj1_sig (carry_full_3_opt_cps_sig f us).
- Definition carry_full_3_opt_cps_correct {T} (f : digits -> T) us :
- carry_full_3_opt_cps f us = f (carry_full (carry_full (carry_full us))) :=
- proj2_sig (carry_full_3_opt_cps_sig f us).
+ Definition carry_full_3_opt_cps_correct {T} (f : list Z -> T) us
+ : length us = length limb_widths
+ -> carry_full_3_opt_cps f us = f (carry_full (carry_full (carry_full us)))
+ := proj2_sig (carry_full_3_opt_cps_sig f us).
Definition conditional_subtract_modulus_opt_sig (f : list Z) (cond : bool) :
{ g | g = conditional_subtract_modulus f cond}.
@@ -894,7 +900,6 @@ Section Canonicalization.
let RHSf := match (eval pattern (if cond then a else b) in RHS) with ?RHSf _ => RHSf end in
change (LHS = Let_In (if cond then a else b) RHSf) end.
cbv [map2 map].
- change modulus_digits with modulus_digits_opt.
change @max_ones with max_ones_opt.
reflexivity.
Defined.
@@ -906,35 +911,32 @@ Section Canonicalization.
: conditional_subtract_modulus_opt f cond = conditional_subtract_modulus f cond
:= Eval cbv [proj2_sig conditional_subtract_modulus_opt_sig] in proj2_sig (conditional_subtract_modulus_opt_sig f cond).
- Definition freeze_opt_sig (us : digits) :
- { b : digits | b = freeze us }.
+ Definition freeze_opt_sig (us : list Z) :
+ { b : list Z | length us = length limb_widths
+ -> b = ModularBaseSystemList.freeze us }.
Proof.
eexists.
- cbv [freeze].
- rewrite <-from_list_default_eq with (d := 0%Z).
- change (@from_list_default Z) with (@from_list_default_opt Z).
+ cbv [ModularBaseSystemList.freeze].
rewrite <-conditional_subtract_modulus_opt_correct.
- let LHS := match goal with |- ?LHS = ?RHS => LHS end in
- let RHS := match goal with |- ?LHS = ?RHS => RHS end in
- let RHSf := match (eval pattern (to_list (length limb_widths) (carry_full (carry_full (carry_full us)))) in RHS) with ?RHSf _ => RHSf end in
- change (LHS = Let_In (to_list (length limb_widths) (carry_full (carry_full (carry_full us)))) RHSf).
+ intros.
let LHS := match goal with |- ?LHS = ?RHS => LHS end in
let RHS := match goal with |- ?LHS = ?RHS => RHS end in
let RHSf := match (eval pattern (carry_full (carry_full (carry_full us))) in RHS) with ?RHSf _ => RHSf end in
- rewrite <-carry_full_3_opt_cps_correct with (f := RHSf).
+ rewrite <-carry_full_3_opt_cps_correct with (f := RHSf) by auto.
cbv beta iota delta [ge_modulus ge_modulus'].
change length with length_opt.
- change (nth_default 0 modulus_digits) with (nth_default_opt 0 modulus_digits_opt).
+ change nth_default with @nth_default_opt.
change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt.
change minus with minus_opt.
reflexivity.
Defined.
- Definition freeze_opt (us : digits) : digits
+ Definition freeze_opt (us : list Z) : list Z
:= Eval cbv beta iota delta [proj1_sig freeze_opt_sig] in proj1_sig (freeze_opt_sig us).
Definition freeze_opt_correct us
- : freeze_opt us = freeze us
+ : length us = length limb_widths
+ -> freeze_opt us = ModularBaseSystemList.freeze us
:= proj2_sig (freeze_opt_sig us).
End Canonicalization.
@@ -947,21 +949,6 @@ Section SquareRoots.
Context (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_)
(one_ : digits) (one_subst : one = one_).
- Definition eqb_opt_sig (us vs : digits) :
- { b | b = ModularBaseSystem.eqb us vs}.
- Proof.
- eexists; cbv [ModularBaseSystem.eqb].
- cbv [fieldwiseb fieldwiseb'].
- erewrite <-!freeze_opt_correct by eassumption.
- reflexivity.
- Defined.
-
- Definition eqb_opt us vs := Eval cbv [proj1_sig eqb_opt_sig] in proj1_sig (eqb_opt_sig us vs).
-
- Definition eqb_opt_correct us vs
- : eqb_opt us vs = ModularBaseSystem.eqb us vs
- := Eval cbv [proj2_sig eqb_opt_sig] in proj2_sig (eqb_opt_sig us vs).
-
(* TODO : where should this lemma go? Alternatively, is there a standard-library
tactic/lemma for this? *)
Lemma if_equiv : forall {A} (eqA : A -> A -> Prop) (x0 x1 : bool) y0 y1 z0 z1,
@@ -1004,6 +991,8 @@ Section SquareRoots.
etransitivity.
Focus 2. {
apply if_equiv. {
+ etransitivity.
+ Focus 2. {
apply eqb_Proper; [ | reflexivity ].
transitivity (carry_mul_opt k_ c_ (pow_opt k_ c_ one_ us chain) (pow_opt k_ c_ one_ us chain)); [ reflexivity | ].
cbv [eq].
@@ -1011,6 +1000,11 @@ Section SquareRoots.
rewrite carry_mul_rep by reflexivity.
rewrite mul_rep by reflexivity.
f_equal; apply pow_opt_correct; auto.
+ } Unfocus.
+ cbv [ModularBaseSystem.eqb freeze].
+ rewrite <-!from_list_default_eq with (d := 0).
+ erewrite <-!freeze_opt_correct by eauto using length_to_list.
+ reflexivity.
} {
apply pow_opt_correct; auto.
} {
@@ -1024,7 +1018,6 @@ Section SquareRoots.
reflexivity.
}
} Unfocus.
- rewrite <-eqb_opt_correct.
rewrite k_subst, c_subst, one_subst.
let LHS := match goal with |- eq ?LHS ?RHS => LHS end in
let RHS := match goal with |- eq ?LHS ?RHS => RHS end in
diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v
index 3b04eda2f..561c1ae81 100644
--- a/src/ModularArithmetic/ModularBaseSystemProofs.v
+++ b/src/ModularArithmetic/ModularBaseSystemProofs.v
@@ -447,16 +447,6 @@ Section CarryProofs.
apply IHis; auto using in_cons.
Qed.
- Lemma carry_full_preserves_rep : forall us x,
- rep us x -> rep (carry_full us) x.
- Proof.
- unfold carry_full; intros.
- apply carry_sequence_rep; auto.
- unfold full_carry_chain; apply make_chain_lt.
- Qed.
-
- Opaque carry_full.
-
Context `{cc : CarryChain limb_widths}.
Lemma carry_mul_rep : forall us vs x y,
rep us x -> rep vs y ->