aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2016-10-21 18:47:26 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2016-10-22 00:10:53 -0400
commit31d24dcb9e53cd21d619d403de8933b8fc451ed8 (patch)
treee40c363a60cd861847f686535af6bd8801fff62d /src/ModularArithmetic
parent1ec6ade7fa92912adffdb815eef5f6cac31ab078 (diff)
Modified [freeze] to be more reifyable
Diffstat (limited to 'src/ModularArithmetic')
-rw-r--r--src/ModularArithmetic/ModularBaseSystem.v14
-rw-r--r--src/ModularArithmetic/ModularBaseSystemList.v36
-rw-r--r--src/ModularArithmetic/ModularBaseSystemListProofs.v148
-rw-r--r--src/ModularArithmetic/ModularBaseSystemOpt.v52
-rw-r--r--src/ModularArithmetic/ModularBaseSystemProofs.v28
5 files changed, 167 insertions, 111 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v
index 5c0d143c2..615bd832b 100644
--- a/src/ModularArithmetic/ModularBaseSystem.v
+++ b/src/ModularArithmetic/ModularBaseSystem.v
@@ -82,10 +82,10 @@ Section ModularBaseSystem.
Definition eq (x y : digits) : Prop := decode x = decode y.
- Definition freeze (x : digits) : digits :=
- from_list (freeze [[x]]) (length_freeze length_to_list).
+ Definition freeze B (x : digits) : digits :=
+ from_list (freeze B [[x]]) (length_freeze length_to_list).
- Definition eqb (x y : digits) : bool := fieldwiseb Z.eqb (freeze x) (freeze y).
+ Definition eqb B (x y : digits) : bool := fieldwiseb Z.eqb (freeze B x) (freeze B y).
(* Note : both of the following square root definitions will produce garbage output if the input is
not square mod [modulus]. The caller should either provably only call them with square input,
@@ -97,10 +97,10 @@ Section ModularBaseSystem.
(* sqrt_5mod8 is parameterized over implementation of [mul] and [pow] because it relies on bounds-checking
for these two functions, which is much easier for simplified implementations than the more generalized
ones defined here. *)
- Definition sqrt_5mod8 mul_ pow_ (chain : list (nat * nat))
+ Definition sqrt_5mod8 B mul_ pow_ (chain : list (nat * nat))
(chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus / 8 + 1))
(sqrt_minus1 x : digits) : digits :=
- let b := pow_ x chain in if eqb (mul_ b b) x then b else mul_ sqrt_minus1 b.
+ let b := pow_ x chain in if eqb B (mul_ b b) x then b else mul_ sqrt_minus1 b.
Import Morphisms.
Global Instance eq_Equivalence : Equivalence eq.
@@ -108,6 +108,10 @@ Section ModularBaseSystem.
split; cbv [eq]; repeat intro; congruence.
Qed.
+ Definition select B (b : Z) (x y : digits) :=
+ add (map (Z.land (neg B b)) x)
+ (map (Z.land (neg B (Z.lxor b 1))) x).
+
Context {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x)
(bits_eq : sum_firstn limb_widths (length limb_widths) =
sum_firstn target_widths (length target_widths)).
diff --git a/src/ModularArithmetic/ModularBaseSystemList.v b/src/ModularArithmetic/ModularBaseSystemList.v
index 6d0848151..e64ed5d0f 100644
--- a/src/ModularArithmetic/ModularBaseSystemList.v
+++ b/src/ModularArithmetic/ModularBaseSystemList.v
@@ -8,6 +8,7 @@ Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
Require Import Crypto.Tactics.VerdiTactics.
+Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.Notations.
Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.ModularArithmetic.Conversion.
@@ -51,28 +52,35 @@ Section Defs.
Definition modulus_digits := encodeZ limb_widths modulus.
- (* compute at compile time *)
- Definition max_ones := Z.ones (fold_right Z.max 0 limb_widths).
-
(* Constant-time comparison with modulus; only works if all digits of [us]
are less than 2 ^ their respective limb width. *)
- Fixpoint ge_modulus' us acc i :=
- match i with
- | O => andb (Z.leb (modulus_digits [0]) (us [0])) acc
- | S i' => ge_modulus' us (andb (Z.eqb (modulus_digits [i]) (us [i])) acc) i'
+ Fixpoint ge_modulus' {A} (f : Z -> A) us (result : Z) i :=
+ dlet r := result in
+ match i return A with
+ | O => dlet x := if Z.leb (modulus_digits [0]) (us [0])
+ then r
+ else 0 in f x
+ | S i' => ge_modulus' f us
+ (if Z.eqb (modulus_digits [i]) (us [i])
+ then r
+ else 0) i'
end.
- Definition ge_modulus us := ge_modulus' us true (length limb_widths - 1)%nat.
+ Definition ge_modulus us := ge_modulus' id us 1 (length limb_widths - 1)%nat.
+
+ (* analagous to NEG assembly instruction on an integer that is 0 or 1:
+ neg 1 = 2^64 - 1 (on 64-bit; 2^32-1 on 32-bit, etc.)
+ neg 0 = 0 *)
+ Definition neg (int_width : Z) (b : Z) := if b =? 1 then Z.ones int_width else 0.
- Definition conditional_subtract_modulus (us : digits) (cond : bool) :=
- let and_term := if cond then max_ones else 0 in
+ Definition conditional_subtract_modulus int_width (us : digits) (cond : Z) :=
(* [and_term] is all ones if us' is full, so the subtractions subtract q overall.
Otherwise, it's all zeroes, and the subtractions do nothing. *)
- map2 (fun x y => x - y) us (map (Z.land and_term) modulus_digits).
+ map2 (fun x y => x - y) us (map (Z.land (neg int_width cond)) modulus_digits).
- Definition freeze (us : digits) : digits :=
+ Definition freeze int_width (us : digits) : digits :=
let us' := carry_full (carry_full (carry_full us)) in
- conditional_subtract_modulus us' (ge_modulus us').
+ conditional_subtract_modulus int_width us' (ge_modulus us').
Context {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x)
(bits_eq : sum_firstn limb_widths (length limb_widths) =
@@ -86,4 +94,4 @@ Section Defs.
limb_widths limb_widths_nonneg
(Z.eq_le_incl _ _ (Z.eq_sym bits_eq)).
-End Defs.
+End Defs. \ No newline at end of file
diff --git a/src/ModularArithmetic/ModularBaseSystemListProofs.v b/src/ModularArithmetic/ModularBaseSystemListProofs.v
index 93b39e89a..23ec30cf7 100644
--- a/src/ModularArithmetic/ModularBaseSystemListProofs.v
+++ b/src/ModularArithmetic/ModularBaseSystemListProofs.v
@@ -110,10 +110,11 @@ Section LengthProofs.
rewrite encode'_spec, encode'_length;
auto using encode'_length, limb_widths_nonneg, Nat.eq_le_incl, base_from_limb_widths_length.
Qed.
+ Hint Rewrite @length_modulus_digits : distr_length.
- Lemma length_conditional_subtract_modulus {u cond} :
+ Lemma length_conditional_subtract_modulus {int_width u cond} :
length u = length limb_widths
- -> length (conditional_subtract_modulus u cond) = length limb_widths.
+ -> length (conditional_subtract_modulus int_width u cond) = length limb_widths.
Proof.
intros; unfold conditional_subtract_modulus.
rewrite map2_length, map_length, length_modulus_digits.
@@ -121,9 +122,9 @@ Section LengthProofs.
Qed.
Hint Rewrite @length_conditional_subtract_modulus : distr_length.
- Lemma length_freeze {u} :
+ Lemma length_freeze {int_width u} :
length u = length limb_widths
- -> length (freeze u) = length limb_widths.
+ -> length (freeze int_width u) = length limb_widths.
Proof.
intros; unfold freeze; repeat autorewrite with distr_length; congruence.
Qed.
@@ -285,27 +286,41 @@ Section ModulusComparisonProofs.
reflexivity.
Qed.
- Lemma ge_modulus'_false : forall us i,
- ge_modulus' us false i = false.
+ Lemma ge_modulus'_0 : forall {A} f us i,
+ ge_modulus' (A := A) f us 0 i = f 0.
Proof.
- induction i; intros; simpl; rewrite Bool.andb_false_r; auto.
+ induction i; intros; simpl; break_if; auto.
+ Qed.
+
+ Lemma ge_modulus'_01 : forall {A} f us i b,
+ (b = 0 \/ b = 1) ->
+ (ge_modulus' (A := A) f us b i = f 0 \/ ge_modulus' (A := A) f us b i = f 1).
+ Proof.
+ induction i; intros;
+ try intuition (subst; cbv [ge_modulus' LetIn.Let_In]; break_if; tauto).
+ simpl; cbv [LetIn.Let_In].
+ break_if; apply IHi; tauto.
+ Qed.
+
+ Lemma ge_modulus_01 : forall us,
+ (ge_modulus us = 0 \/ ge_modulus us = 1).
+ Proof.
+ cbv [ge_modulus]; intros; apply ge_modulus'_01; tauto.
Qed.
Lemma ge_modulus'_true_digitwise : forall us,
length us = length limb_widths ->
- forall i, (i < length us)%nat -> ge_modulus' us true i = true ->
+ forall i, (i < length us)%nat -> ge_modulus' id us 1 i = 1 ->
forall j, (j <= i)%nat ->
nth_default 0 modulus_digits j <= nth_default 0 us j.
Proof.
induction i;
repeat match goal with
| |- _ => progress intros; simpl in *
- | |- _ => rewrite ge_modulus'_false in *
+ | |- _ => progress cbv [LetIn.Let_In] in *
+ | |- _ =>erewrite (ge_modulus'_0 (@id Z)) in *
| H : (?x <= 0)%nat |- _ => progress replace x with 0%nat in * by omega
- | H : (?b && true)%bool = true |- _ => let A:= fresh "H" in
- rewrite Bool.andb_true_r in H; case_eq b; intro A; rewrite A in H
- | H : ge_modulus' _ (?b && true)%bool _ = true |- _ => let A:= fresh "H" in
- rewrite Bool.andb_true_r in H; case_eq b; intro A; rewrite A in H
+ | |- _ => break_if
| |- _ => discriminate
| |- _ => solve [rewrite ?Z.leb_le, ?Z.eqb_eq in *; omega]
end.
@@ -316,35 +331,39 @@ Section ModulusComparisonProofs.
Lemma ge_modulus'_compare' : forall us, length us = length limb_widths -> bounded limb_widths us ->
forall i, (i < length limb_widths)%nat ->
- (ge_modulus' us true i = false <-> compare' us modulus_digits (S i) = Lt).
+ (ge_modulus' id us 1 i = 0 <-> compare' us modulus_digits (S i) = Lt).
Proof.
induction i;
repeat match goal with
- | |- _ => progress intros
- | |- _ => progress (simpl compare' in *)
+ | |- _ => progress (intros; cbv [LetIn.Let_In id])
+ | |- _ => progress (simpl compare' in * )
| |- _ => progress specialize_by omega
- | |- _ => progress rewrite ?Z.compare_eq_iff,
- ?Z.compare_gt_iff, ?Z.compare_lt_iff in *
- | |- appcontext[ge_modulus' _ _ 0] =>
+ | |- _ => (progress rewrite ?Z.compare_eq_iff,
+ ?Z.compare_gt_iff, ?Z.compare_lt_iff in * )
+ | |- appcontext[ge_modulus' _ _ _ 0] =>
cbv [ge_modulus']
- | |- appcontext[ge_modulus' _ _ (S _)] =>
- unfold ge_modulus'; fold ge_modulus'
+ | |- appcontext[ge_modulus' _ _ _ (S _)] =>
+ unfold ge_modulus'; fold (ge_modulus' (@id Z))
| |- _ => break_if
| |- _ => rewrite Nat.sub_0_r
- | |- _ => rewrite ge_modulus'_false
+ | |- _ => rewrite (ge_modulus'_0 (@id Z))
| |- _ => rewrite Bool.andb_true_r
| |- _ => rewrite Z.leb_compare; break_match
| |- _ => rewrite Z.eqb_compare; break_match
+ | |- _ => (rewrite Z.leb_le in * )
+ | |- _ => (rewrite Z.leb_gt in * )
+ | |- _ => (rewrite Z.eqb_eq in * )
+ | |- _ => (rewrite Z.eqb_neq in * )
| |- _ => split; (congruence || omega)
| |- _ => assumption
- end;
- pose proof (bounded_le_modulus_digits c_upper_bound us (S i));
- specialize_by (auto || omega); omega.
+ end;
+ pose proof (bounded_le_modulus_digits c_upper_bound us (S i));
+ specialize_by (auto || omega); split; (congruence || omega).
Qed.
Lemma ge_modulus_spec : forall u, length u = length limb_widths ->
bounded limb_widths u ->
- (ge_modulus u = false <-> 0 <= BaseSystem.decode base u < modulus).
+ (ge_modulus u = 0 <-> 0 <= BaseSystem.decode base u < modulus).
Proof.
cbv [ge_modulus]; intros.
assert (0 < length limb_widths)%nat
@@ -365,6 +384,8 @@ End ModulusComparisonProofs.
Section ConditionalSubtractModulusProofs.
Context `{prm :PseudoMersenneBaseParams}
+ (* B is machine integer width (e.g. 32, 64) *)
+ {B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B)
(c_upper_bound : c - 1 < 2 ^ nth_default 0 limb_widths 0)
(lt_1_length_limb_widths : (1 < length limb_widths)%nat).
Local Notation base := (base_from_limb_widths limb_widths).
@@ -386,24 +407,33 @@ Section ConditionalSubtractModulusProofs.
simpl; f_equal; auto using in_eq, in_cons.
Qed.
- Lemma map_land_max_ones : forall us, length us = length limb_widths ->
- bounded limb_widths us -> map (Z.land max_ones) us = us.
+ Lemma bounded_digit_fits : forall us,
+ length us = length limb_widths -> bounded limb_widths us ->
+ forall x, In x us -> 0 <= x < 2 ^ B.
Proof.
- intros; apply map_id_strong; intros ? HIn.
- rewrite Z.land_comm.
- cbv [max_ones].
- rewrite Z.land_ones by apply Z.le_fold_right_max_initial.
- apply Z.mod_small.
- apply In_nth with (d := 0) in HIn.
- destruct HIn as [i HIn]; destruct HIn; subst.
- rewrite bounded_iff in H0.
- specialize (H0 i).
+ intros.
+ let i := fresh "i" in
+ match goal with H : In ?x ?us, Hb : bounded _ _ |- _ =>
+ apply In_nth with (d := 0) in H; destruct H as [i [? ?] ];
+ rewrite bounded_iff in Hb; specialize (Hb i);
+ assert (2 ^ nth i limb_widths 0 <= 2 ^ B) by
+ (apply Z.pow_le_mono_r; try apply B_compat, nth_In; omega) end.
rewrite !nth_default_eq in *.
- split; try omega.
- eapply Z.lt_le_trans; try intuition eassumption.
- apply Z.pow_le_mono_r; try omega.
- apply Z.le_fold_right_max; eauto.
- apply nth_In. omega.
+ omega.
+ Qed.
+
+ Lemma map_land_max_ones : forall us,
+ length us = length limb_widths ->
+ bounded limb_widths us -> map (Z.land (Z.ones B)) us = us.
+ Proof.
+ repeat match goal with
+ | |- _ => progress intros
+ | |- _ => apply map_id_strong
+ | |- appcontext[Z.ones ?n &' ?x] => rewrite (Z.land_comm _ x);
+ rewrite Z.land_ones by omega
+ | |- _ => apply Z.mod_small
+ | |- _ => solve [eauto using bounded_digit_fits]
+ end.
Qed.
Lemma map_land_zero : forall us, map (Z.land 0) us = zeros (length us).
@@ -411,32 +441,35 @@ Section ConditionalSubtractModulusProofs.
induction us; boring.
Qed.
- Lemma conditional_subtract_modulus_spec : forall u cond,
+ Hint Rewrite @length_modulus_digits @length_zeros : distr_length.
+ Lemma conditional_subtract_modulus_spec : forall u cond
+ (cond_01 : cond = 0 \/ cond = 1),
length u = length limb_widths ->
- BaseSystem.decode base (conditional_subtract_modulus u cond) =
- BaseSystem.decode base u - (if cond then 1 else 0) * modulus.
+ BaseSystem.decode base (conditional_subtract_modulus B u cond) =
+ BaseSystem.decode base u - cond * modulus.
Proof.
repeat match goal with
- | |- _ => progress (cbv [conditional_subtract_modulus]; intros)
+ | |- _ => progress (cbv [conditional_subtract_modulus neg]; intros)
+ | |- _ => destruct cond_01; subst
| |- _ => break_if
| |- _ => rewrite map_land_max_ones by auto using bounded_modulus_digits
| |- _ => rewrite map_land_zero
- | |- _ => rewrite map2_sub_eq
- by (rewrite ?length_modulus_digits, ?length_zeros; congruence)
+ | |- _ => rewrite map2_sub_eq by distr_length
| |- _ => rewrite sub_rep by auto
| |- _ => rewrite zeros_rep
| |- _ => rewrite decode_modulus_digits by auto
| |- _ => f_equal; ring
+ | |- _ => discriminate
end.
Qed.
Lemma conditional_subtract_modulus_preserves_bounded : forall u,
length u = length limb_widths ->
bounded limb_widths u ->
- bounded limb_widths (conditional_subtract_modulus u (ge_modulus u)).
+ bounded limb_widths (conditional_subtract_modulus B u (ge_modulus u)).
Proof.
repeat match goal with
- | |- _ => progress (cbv [conditional_subtract_modulus]; intros)
+ | |- _ => progress (cbv [conditional_subtract_modulus neg]; intros)
| |- _ => unique pose proof bounded_modulus_digits
| |- _ => rewrite map_land_max_ones by auto using bounded_modulus_digits
| |- _ => rewrite map_land_zero
@@ -455,11 +488,12 @@ Section ConditionalSubtractModulusProofs.
| |- _ => omega
end.
cbv [ge_modulus] in Heqb.
+ rewrite Z.eqb_eq in *.
apply ge_modulus'_true_digitwise with (j := i) in Heqb; auto; omega.
Qed.
Lemma bounded_mul2_modulus : forall u, length u = length limb_widths ->
- bounded limb_widths u -> ge_modulus u = true ->
+ bounded limb_widths u -> ge_modulus u = 1 ->
modulus <= BaseSystem.decode base u < 2 * modulus.
Proof.
intros.
@@ -494,16 +528,16 @@ Section ConditionalSubtractModulusProofs.
Lemma conditional_subtract_lt_modulus : forall u,
length u = length limb_widths ->
bounded limb_widths u ->
- ge_modulus (conditional_subtract_modulus u (ge_modulus u)) = false.
+ ge_modulus (conditional_subtract_modulus B u (ge_modulus u)) = 0.
Proof.
intros.
- rewrite ge_modulus_spec by auto using length_conditional_subtract_modulus,
- conditional_subtract_modulus_preserves_bounded.
+ rewrite ge_modulus_spec by auto using length_conditional_subtract_modulus, conditional_subtract_modulus_preserves_bounded.
+ pose proof (ge_modulus_01 u) as Hgm01.
rewrite conditional_subtract_modulus_spec by auto.
- break_if.
- + pose proof (bounded_mul2_modulus u); specialize_by auto.
+ destruct Hgm01 as [Hgm0 | Hgm1]; rewrite ?Hgm0, ?Hgm1.
+ + apply ge_modulus_spec in Hgm0; auto.
omega.
- + apply ge_modulus_spec in Heqb; auto.
+ + pose proof (bounded_mul2_modulus u); specialize_by auto.
omega.
Qed.
End ConditionalSubtractModulusProofs. \ No newline at end of file
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v
index ff1dd87dd..d0af83e11 100644
--- a/src/ModularArithmetic/ModularBaseSystemOpt.v
+++ b/src/ModularArithmetic/ModularBaseSystemOpt.v
@@ -49,7 +49,6 @@ Definition full_carry_chain_opt := Eval compute in @Pow2Base.full_carry_chain.
Definition length_opt := Eval compute in length.
Definition base_from_limb_widths_opt := Eval compute in @Pow2Base.base_from_limb_widths.
Definition minus_opt := Eval compute in minus.
-Definition max_ones_opt := Eval compute in @max_ones.
Definition from_list_default_opt {A} := Eval compute in (@from_list_default A).
Definition sum_firstn_opt {A} := Eval compute in (@sum_firstn A).
Definition zeros_opt := Eval compute in (@zeros).
@@ -954,18 +953,16 @@ Section Canonicalization.
-> carry_full_3_opt_cps f us = f (carry_full (carry_full (carry_full us)))
:= proj2_sig (carry_full_3_opt_cps_sig f us).
- Definition conditional_subtract_modulus_opt_sig (f : list Z) (cond : bool) :
- { g | g = conditional_subtract_modulus f cond}.
+ Definition conditional_subtract_modulus_opt_sig (f : list Z) (cond : Z) :
+ { g | g = conditional_subtract_modulus int_width f cond}.
Proof.
eexists.
cbv [conditional_subtract_modulus].
- match goal with |- appcontext[if cond then ?a else ?b] =>
let LHS := match goal with |- ?LHS = ?RHS => LHS end in
let RHS := match goal with |- ?LHS = ?RHS => RHS end in
- let RHSf := match (eval pattern (if cond then a else b) in RHS) with ?RHSf _ => RHSf end in
- change (LHS = Let_In (if cond then a else b) RHSf) end.
+ let RHSf := match (eval pattern (neg int_width cond) in RHS) with ?RHSf _ => RHSf end in
+ change (LHS = Let_In (neg int_width cond) RHSf).
cbv [map2 map].
- change @max_ones with max_ones_opt.
reflexivity.
Defined.
@@ -973,39 +970,51 @@ Section Canonicalization.
:= Eval cbv [proj1_sig conditional_subtract_modulus_opt_sig] in proj1_sig (conditional_subtract_modulus_opt_sig f cond).
Definition conditional_subtract_modulus_opt_correct f cond
- : conditional_subtract_modulus_opt f cond = conditional_subtract_modulus f cond
+ : conditional_subtract_modulus_opt f cond = conditional_subtract_modulus int_width f cond
:= Eval cbv [proj2_sig conditional_subtract_modulus_opt_sig] in proj2_sig (conditional_subtract_modulus_opt_sig f cond).
- Definition ge_modulus_opt_sig (us : list Z) :
- { b : bool | b = ModularBaseSystemList.ge_modulus us }.
+ Lemma ge_modulus'_cps : forall {A} (f : Z -> A) (us : list Z) i b,
+ f (ge_modulus' id us b i) = ge_modulus' f us b i.
+ Proof.
+ induction i; intros; simpl; cbv [Let_In]; break_if; try reflexivity;
+ apply IHi.
+ Qed.
+(*
+ Definition ge_modulus'_opt_sig {A} (f : Z -> A) (us : list Z) b i :
+ { a : A | a = ModularBaseSystemList.ge_modulus' f us b i}.
Proof.
eexists.
- cbv beta iota delta [ge_modulus ge_modulus'].
+ cbv [ge_modulus ge_modulus'].
change length with length_opt.
change nth_default with @nth_default_opt.
change minus with minus_opt.
reflexivity.
Defined.
- Definition ge_modulus_opt us : bool
- := Eval cbv [proj1_sig ge_modulus_opt_sig] in proj1_sig (ge_modulus_opt_sig us).
+ Definition ge_modulus'_opt {A} f us b i : Z
+ := Eval cbv [proj1_sig ge_modulus'_opt_sig] in proj1_sig (@ge_modulus'_opt_sig A f us b i).
- Definition ge_modulus_opt_correct us :
- ge_modulus_opt us = ge_modulus us
- := Eval cbv [proj2_sig ge_modulus_opt_sig] in proj2_sig (ge_modulus_opt_sig us).
+ Definition ge_modulus'_opt_correct {A} f us :
+ @ge_modulus'_opt A f us b i = @ge_modulus' A f us b i
+ := Eval cbv [proj2_sig ge_modulus_opt_sig] in proj2_sig (@ge_modulus'_opt_sig A f us).
+*)
Definition freeze_opt_sig (us : list Z) :
{ b : list Z | length us = length limb_widths
- -> b = ModularBaseSystemList.freeze us }.
+ -> b = ModularBaseSystemList.freeze int_width us }.
Proof.
eexists.
cbv [ModularBaseSystemList.freeze].
rewrite <-conditional_subtract_modulus_opt_correct.
intros.
+ cbv [ge_modulus].
+ rewrite ge_modulus'_cps.
let LHS := match goal with |- ?LHS = ?RHS => LHS end in
let RHS := match goal with |- ?LHS = ?RHS => RHS end in
let RHSf := match (eval pattern (carry_full (carry_full (carry_full us))) in RHS) with ?RHSf _ => RHSf end in
rewrite <-carry_full_3_opt_cps_correct with (f := RHSf) by auto.
+ cbv [carry_full_3_opt_cps carry_full_opt_cps carry_sequence_opt_cps].
+ cbv [ge_modulus'].
cbv beta iota delta [ge_modulus ge_modulus'].
change length with length_opt.
change nth_default with @nth_default_opt.
@@ -1019,7 +1028,7 @@ Section Canonicalization.
Definition freeze_opt_correct us
: length us = length limb_widths
- -> freeze_opt us = ModularBaseSystemList.freeze us
+ -> freeze_opt us = ModularBaseSystemList.freeze int_width us
:= proj2_sig (freeze_opt_sig us).
End Canonicalization.
@@ -1061,15 +1070,16 @@ Section SquareRoots.
End SquareRoot3mod4.
Import Morphisms.
- Global Instance eqb_Proper : Proper (eq ==> eq ==> Logic.eq) ModularBaseSystem.eqb. Admitted.
+ Global Instance eqb_Proper : Proper (Logic.eq ==> eq ==> eq ==> Logic.eq) ModularBaseSystem.eqb. Admitted.
Section SquareRoot5mod8.
Context {ec : ExponentiationChain (modulus / 8 + 1)}.
Context (sqrt_m1 : digits) (sqrt_m1_correct : rep (mul sqrt_m1 sqrt_m1) (F.opp 1%F)).
+ Context {int_width} (preconditions : freezePreconditions prm int_width).
Definition sqrt_5mod8_opt_sig (us : digits) :
{ vs : digits |
- eq vs (sqrt_5mod8 (carry_mul_opt k_ c_) (pow_opt k_ c_ one_) chain chain_correct sqrt_m1 us)}.
+ eq vs (sqrt_5mod8 int_width (carry_mul_opt k_ c_) (pow_opt k_ c_ one_) chain chain_correct sqrt_m1 us)}.
Proof.
eexists; cbv [sqrt_5mod8].
let LHS := match goal with |- eq ?LHS ?RHS => LHS end in
@@ -1083,7 +1093,7 @@ Section SquareRoots.
proj1_sig (sqrt_5mod8_opt_sig us).
Definition sqrt_5mod8_opt_correct us
- : eq (sqrt_5mod8_opt us) (ModularBaseSystem.sqrt_5mod8 _ _ chain chain_correct sqrt_m1 us)
+ : eq (sqrt_5mod8_opt us) (ModularBaseSystem.sqrt_5mod8 int_width _ _ chain chain_correct sqrt_m1 us)
:= Eval cbv [proj2_sig sqrt_5mod8_opt_sig] in proj2_sig (sqrt_5mod8_opt_sig us).
End SquareRoot5mod8.
diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v
index f8ad0969d..9a07e8ec0 100644
--- a/src/ModularArithmetic/ModularBaseSystemProofs.v
+++ b/src/ModularArithmetic/ModularBaseSystemProofs.v
@@ -833,7 +833,7 @@ Section CanonicalizationProofs.
then 0
else (2 ^ B) >> (limb_widths [pred n]))).
Local Notation minimal_rep u := ((bounded limb_widths (to_list (length limb_widths) u))
- /\ (ge_modulus (to_list _ u) = false)).
+ /\ (ge_modulus (to_list _ u) = 0)).
Lemma decode_bitwise_eq_iff : forall u v, minimal_rep u -> minimal_rep v ->
(fieldwise Logic.eq u v <->
@@ -896,7 +896,7 @@ Section CanonicalizationProofs.
Qed.
Lemma minimal_rep_freeze : forall u, initial_bounds u ->
- minimal_rep (freeze u).
+ minimal_rep (freeze B u).
Proof.
repeat match goal with
| |- _ => progress (cbv [freeze ModularBaseSystemList.freeze])
@@ -907,12 +907,12 @@ Section CanonicalizationProofs.
| |- _ => apply conditional_subtract_lt_modulus
| |- _ => apply conditional_subtract_modulus_preserves_bounded
| |- bounded _ (carry_full _) => apply bounded_iff
- | |- _ => solve [auto using lt_1_length_limb_widths, length_carry_full, length_to_list]
+ | |- _ => solve [auto using B_pos, B_compat, lt_1_length_limb_widths, length_carry_full, length_to_list]
end.
Qed.
Lemma freeze_decode : forall u,
- BaseSystem.decode base (to_list _ (freeze u)) mod modulus =
+ BaseSystem.decode base (to_list _ (freeze B u)) mod modulus =
BaseSystem.decode base (to_list _ u) mod modulus.
Proof.
repeat match goal with
@@ -922,7 +922,7 @@ Section CanonicalizationProofs.
| |- _ => rewrite Z.mod_add by (pose proof prime_modulus; prime_bound)
| |- _ => rewrite to_list_from_list
| |- _ => rewrite conditional_subtract_modulus_spec by
- auto using lt_1_length_limb_widths, length_carry_full, length_to_list
+ auto using B_pos, B_compat, lt_1_length_limb_widths, length_carry_full, length_to_list, ge_modulus_01
end.
rewrite !decode_mod_Fdecode by auto using length_carry_full, length_to_list.
cbv [carry_full].
@@ -941,7 +941,7 @@ Section CanonicalizationProofs.
rewrite from_list_to_list; reflexivity.
Qed.
- Lemma freeze_rep : forall u x, rep u x -> rep (freeze u) x.
+ Lemma freeze_rep : forall u x, rep u x -> rep (freeze B u) x.
Proof.
cbv [rep]; intros.
apply F.eq_to_Z_iff.
@@ -952,7 +952,7 @@ Section CanonicalizationProofs.
Lemma freeze_canonical : forall u v x y, rep u x -> rep v y ->
initial_bounds u ->
initial_bounds v ->
- (x = y <-> fieldwise Logic.eq (freeze u) (freeze v)).
+ (x = y <-> fieldwise Logic.eq (freeze B u) (freeze B v)).
Proof.
intros; apply bounded_canonical; auto using freeze_rep, minimal_rep_freeze.
Qed.
@@ -977,7 +977,7 @@ Section SquareRootProofs.
Lemma eqb_true_iff : forall u v x y,
bounded_by u freeze_input_bounds -> bounded_by v freeze_input_bounds ->
- u ~= x -> v ~= y -> (x = y <-> eqb u v = true).
+ u ~= x -> v ~= y -> (x = y <-> eqb B u v = true).
Proof.
cbv [eqb freeze_input_bounds]. intros.
rewrite fieldwiseb_fieldwise by (apply Z.eqb_eq).
@@ -986,10 +986,10 @@ Section SquareRootProofs.
Lemma eqb_false_iff : forall u v x y,
bounded_by u freeze_input_bounds -> bounded_by v freeze_input_bounds ->
- u ~= x -> v ~= y -> (x <> y <-> eqb u v = false).
+ u ~= x -> v ~= y -> (x <> y <-> eqb B u v = false).
Proof.
intros.
- case_eq (eqb u v).
+ case_eq (eqb B u v).
+ rewrite <-eqb_true_iff by eassumption; split; intros;
congruence || contradiction.
+ split; intros; auto.
@@ -1032,24 +1032,24 @@ Section SquareRootProofs.
Lemma sqrt_5mod8_correct : forall u x, u ~= x ->
bounded_by u pow_input_bounds -> bounded_by u freeze_input_bounds ->
- (sqrt_5mod8 mul_ pow_ chain chain_correct sqrt_m1 u) ~= F.sqrt_5mod8 (decode sqrt_m1) x.
+ (sqrt_5mod8 B mul_ pow_ chain chain_correct sqrt_m1 u) ~= F.sqrt_5mod8 (decode sqrt_m1) x.
Proof.
repeat match goal with
| |- _ => progress (cbv [sqrt_5mod8 F.sqrt_5mod8]; intros)
| |- _ => rewrite @F.pow_2_r in *
| |- _ => rewrite eqb_correct in * by eassumption
- | |- (if eqb ?a ?b then _ else _) ~=
+ | |- (if eqb _ ?a ?b then _ else _) ~=
(if dec (?c = _) then _ else _) =>
assert (a ~= c); rewrite !mul_equiv, pow_equiv in *;
repeat break_if
| |- _ => apply mul_rep; try reflexivity;
rewrite <-chain_correct; apply pow_rep; eassumption
| |- _ => rewrite <-chain_correct; apply pow_rep; eassumption
- | H : eqb ?a ?b = true |- _ =>
+ | H : eqb _ ?a ?b = true |- _ =>
rewrite <-(eqb_true_iff a b) in H
by (eassumption || rewrite <-mul_equiv, <-pow_equiv;
apply mul_bounded, pow_bounded; auto)
- | H : eqb ?a ?b = false |- _ =>
+ | H : eqb _ ?a ?b = false |- _ =>
rewrite <-(eqb_false_iff a b) in H
by (eassumption || rewrite <-mul_equiv, <-pow_equiv;
apply mul_bounded, pow_bounded; auto)