diff options
Diffstat (limited to 'src/ModularArithmetic')
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystem.v | 14 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemList.v | 36 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemListProofs.v | 148 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemOpt.v | 52 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemProofs.v | 28 |
5 files changed, 167 insertions, 111 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index 5c0d143c2..615bd832b 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -82,10 +82,10 @@ Section ModularBaseSystem. Definition eq (x y : digits) : Prop := decode x = decode y. - Definition freeze (x : digits) : digits := - from_list (freeze [[x]]) (length_freeze length_to_list). + Definition freeze B (x : digits) : digits := + from_list (freeze B [[x]]) (length_freeze length_to_list). - Definition eqb (x y : digits) : bool := fieldwiseb Z.eqb (freeze x) (freeze y). + Definition eqb B (x y : digits) : bool := fieldwiseb Z.eqb (freeze B x) (freeze B y). (* Note : both of the following square root definitions will produce garbage output if the input is not square mod [modulus]. The caller should either provably only call them with square input, @@ -97,10 +97,10 @@ Section ModularBaseSystem. (* sqrt_5mod8 is parameterized over implementation of [mul] and [pow] because it relies on bounds-checking for these two functions, which is much easier for simplified implementations than the more generalized ones defined here. *) - Definition sqrt_5mod8 mul_ pow_ (chain : list (nat * nat)) + Definition sqrt_5mod8 B mul_ pow_ (chain : list (nat * nat)) (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus / 8 + 1)) (sqrt_minus1 x : digits) : digits := - let b := pow_ x chain in if eqb (mul_ b b) x then b else mul_ sqrt_minus1 b. + let b := pow_ x chain in if eqb B (mul_ b b) x then b else mul_ sqrt_minus1 b. Import Morphisms. Global Instance eq_Equivalence : Equivalence eq. @@ -108,6 +108,10 @@ Section ModularBaseSystem. split; cbv [eq]; repeat intro; congruence. Qed. + Definition select B (b : Z) (x y : digits) := + add (map (Z.land (neg B b)) x) + (map (Z.land (neg B (Z.lxor b 1))) x). + Context {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x) (bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn target_widths (length target_widths)). diff --git a/src/ModularArithmetic/ModularBaseSystemList.v b/src/ModularArithmetic/ModularBaseSystemList.v index 6d0848151..e64ed5d0f 100644 --- a/src/ModularArithmetic/ModularBaseSystemList.v +++ b/src/ModularArithmetic/ModularBaseSystemList.v @@ -8,6 +8,7 @@ Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. Require Import Crypto.ModularArithmetic.ExtendedBaseVector. Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.LetIn. Require Import Crypto.Util.Notations. Require Import Crypto.ModularArithmetic.Pow2Base. Require Import Crypto.ModularArithmetic.Conversion. @@ -51,28 +52,35 @@ Section Defs. Definition modulus_digits := encodeZ limb_widths modulus. - (* compute at compile time *) - Definition max_ones := Z.ones (fold_right Z.max 0 limb_widths). - (* Constant-time comparison with modulus; only works if all digits of [us] are less than 2 ^ their respective limb width. *) - Fixpoint ge_modulus' us acc i := - match i with - | O => andb (Z.leb (modulus_digits [0]) (us [0])) acc - | S i' => ge_modulus' us (andb (Z.eqb (modulus_digits [i]) (us [i])) acc) i' + Fixpoint ge_modulus' {A} (f : Z -> A) us (result : Z) i := + dlet r := result in + match i return A with + | O => dlet x := if Z.leb (modulus_digits [0]) (us [0]) + then r + else 0 in f x + | S i' => ge_modulus' f us + (if Z.eqb (modulus_digits [i]) (us [i]) + then r + else 0) i' end. - Definition ge_modulus us := ge_modulus' us true (length limb_widths - 1)%nat. + Definition ge_modulus us := ge_modulus' id us 1 (length limb_widths - 1)%nat. + + (* analagous to NEG assembly instruction on an integer that is 0 or 1: + neg 1 = 2^64 - 1 (on 64-bit; 2^32-1 on 32-bit, etc.) + neg 0 = 0 *) + Definition neg (int_width : Z) (b : Z) := if b =? 1 then Z.ones int_width else 0. - Definition conditional_subtract_modulus (us : digits) (cond : bool) := - let and_term := if cond then max_ones else 0 in + Definition conditional_subtract_modulus int_width (us : digits) (cond : Z) := (* [and_term] is all ones if us' is full, so the subtractions subtract q overall. Otherwise, it's all zeroes, and the subtractions do nothing. *) - map2 (fun x y => x - y) us (map (Z.land and_term) modulus_digits). + map2 (fun x y => x - y) us (map (Z.land (neg int_width cond)) modulus_digits). - Definition freeze (us : digits) : digits := + Definition freeze int_width (us : digits) : digits := let us' := carry_full (carry_full (carry_full us)) in - conditional_subtract_modulus us' (ge_modulus us'). + conditional_subtract_modulus int_width us' (ge_modulus us'). Context {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x) (bits_eq : sum_firstn limb_widths (length limb_widths) = @@ -86,4 +94,4 @@ Section Defs. limb_widths limb_widths_nonneg (Z.eq_le_incl _ _ (Z.eq_sym bits_eq)). -End Defs. +End Defs.
\ No newline at end of file diff --git a/src/ModularArithmetic/ModularBaseSystemListProofs.v b/src/ModularArithmetic/ModularBaseSystemListProofs.v index 93b39e89a..23ec30cf7 100644 --- a/src/ModularArithmetic/ModularBaseSystemListProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemListProofs.v @@ -110,10 +110,11 @@ Section LengthProofs. rewrite encode'_spec, encode'_length; auto using encode'_length, limb_widths_nonneg, Nat.eq_le_incl, base_from_limb_widths_length. Qed. + Hint Rewrite @length_modulus_digits : distr_length. - Lemma length_conditional_subtract_modulus {u cond} : + Lemma length_conditional_subtract_modulus {int_width u cond} : length u = length limb_widths - -> length (conditional_subtract_modulus u cond) = length limb_widths. + -> length (conditional_subtract_modulus int_width u cond) = length limb_widths. Proof. intros; unfold conditional_subtract_modulus. rewrite map2_length, map_length, length_modulus_digits. @@ -121,9 +122,9 @@ Section LengthProofs. Qed. Hint Rewrite @length_conditional_subtract_modulus : distr_length. - Lemma length_freeze {u} : + Lemma length_freeze {int_width u} : length u = length limb_widths - -> length (freeze u) = length limb_widths. + -> length (freeze int_width u) = length limb_widths. Proof. intros; unfold freeze; repeat autorewrite with distr_length; congruence. Qed. @@ -285,27 +286,41 @@ Section ModulusComparisonProofs. reflexivity. Qed. - Lemma ge_modulus'_false : forall us i, - ge_modulus' us false i = false. + Lemma ge_modulus'_0 : forall {A} f us i, + ge_modulus' (A := A) f us 0 i = f 0. Proof. - induction i; intros; simpl; rewrite Bool.andb_false_r; auto. + induction i; intros; simpl; break_if; auto. + Qed. + + Lemma ge_modulus'_01 : forall {A} f us i b, + (b = 0 \/ b = 1) -> + (ge_modulus' (A := A) f us b i = f 0 \/ ge_modulus' (A := A) f us b i = f 1). + Proof. + induction i; intros; + try intuition (subst; cbv [ge_modulus' LetIn.Let_In]; break_if; tauto). + simpl; cbv [LetIn.Let_In]. + break_if; apply IHi; tauto. + Qed. + + Lemma ge_modulus_01 : forall us, + (ge_modulus us = 0 \/ ge_modulus us = 1). + Proof. + cbv [ge_modulus]; intros; apply ge_modulus'_01; tauto. Qed. Lemma ge_modulus'_true_digitwise : forall us, length us = length limb_widths -> - forall i, (i < length us)%nat -> ge_modulus' us true i = true -> + forall i, (i < length us)%nat -> ge_modulus' id us 1 i = 1 -> forall j, (j <= i)%nat -> nth_default 0 modulus_digits j <= nth_default 0 us j. Proof. induction i; repeat match goal with | |- _ => progress intros; simpl in * - | |- _ => rewrite ge_modulus'_false in * + | |- _ => progress cbv [LetIn.Let_In] in * + | |- _ =>erewrite (ge_modulus'_0 (@id Z)) in * | H : (?x <= 0)%nat |- _ => progress replace x with 0%nat in * by omega - | H : (?b && true)%bool = true |- _ => let A:= fresh "H" in - rewrite Bool.andb_true_r in H; case_eq b; intro A; rewrite A in H - | H : ge_modulus' _ (?b && true)%bool _ = true |- _ => let A:= fresh "H" in - rewrite Bool.andb_true_r in H; case_eq b; intro A; rewrite A in H + | |- _ => break_if | |- _ => discriminate | |- _ => solve [rewrite ?Z.leb_le, ?Z.eqb_eq in *; omega] end. @@ -316,35 +331,39 @@ Section ModulusComparisonProofs. Lemma ge_modulus'_compare' : forall us, length us = length limb_widths -> bounded limb_widths us -> forall i, (i < length limb_widths)%nat -> - (ge_modulus' us true i = false <-> compare' us modulus_digits (S i) = Lt). + (ge_modulus' id us 1 i = 0 <-> compare' us modulus_digits (S i) = Lt). Proof. induction i; repeat match goal with - | |- _ => progress intros - | |- _ => progress (simpl compare' in *) + | |- _ => progress (intros; cbv [LetIn.Let_In id]) + | |- _ => progress (simpl compare' in * ) | |- _ => progress specialize_by omega - | |- _ => progress rewrite ?Z.compare_eq_iff, - ?Z.compare_gt_iff, ?Z.compare_lt_iff in * - | |- appcontext[ge_modulus' _ _ 0] => + | |- _ => (progress rewrite ?Z.compare_eq_iff, + ?Z.compare_gt_iff, ?Z.compare_lt_iff in * ) + | |- appcontext[ge_modulus' _ _ _ 0] => cbv [ge_modulus'] - | |- appcontext[ge_modulus' _ _ (S _)] => - unfold ge_modulus'; fold ge_modulus' + | |- appcontext[ge_modulus' _ _ _ (S _)] => + unfold ge_modulus'; fold (ge_modulus' (@id Z)) | |- _ => break_if | |- _ => rewrite Nat.sub_0_r - | |- _ => rewrite ge_modulus'_false + | |- _ => rewrite (ge_modulus'_0 (@id Z)) | |- _ => rewrite Bool.andb_true_r | |- _ => rewrite Z.leb_compare; break_match | |- _ => rewrite Z.eqb_compare; break_match + | |- _ => (rewrite Z.leb_le in * ) + | |- _ => (rewrite Z.leb_gt in * ) + | |- _ => (rewrite Z.eqb_eq in * ) + | |- _ => (rewrite Z.eqb_neq in * ) | |- _ => split; (congruence || omega) | |- _ => assumption - end; - pose proof (bounded_le_modulus_digits c_upper_bound us (S i)); - specialize_by (auto || omega); omega. + end; + pose proof (bounded_le_modulus_digits c_upper_bound us (S i)); + specialize_by (auto || omega); split; (congruence || omega). Qed. Lemma ge_modulus_spec : forall u, length u = length limb_widths -> bounded limb_widths u -> - (ge_modulus u = false <-> 0 <= BaseSystem.decode base u < modulus). + (ge_modulus u = 0 <-> 0 <= BaseSystem.decode base u < modulus). Proof. cbv [ge_modulus]; intros. assert (0 < length limb_widths)%nat @@ -365,6 +384,8 @@ End ModulusComparisonProofs. Section ConditionalSubtractModulusProofs. Context `{prm :PseudoMersenneBaseParams} + (* B is machine integer width (e.g. 32, 64) *) + {B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B) (c_upper_bound : c - 1 < 2 ^ nth_default 0 limb_widths 0) (lt_1_length_limb_widths : (1 < length limb_widths)%nat). Local Notation base := (base_from_limb_widths limb_widths). @@ -386,24 +407,33 @@ Section ConditionalSubtractModulusProofs. simpl; f_equal; auto using in_eq, in_cons. Qed. - Lemma map_land_max_ones : forall us, length us = length limb_widths -> - bounded limb_widths us -> map (Z.land max_ones) us = us. + Lemma bounded_digit_fits : forall us, + length us = length limb_widths -> bounded limb_widths us -> + forall x, In x us -> 0 <= x < 2 ^ B. Proof. - intros; apply map_id_strong; intros ? HIn. - rewrite Z.land_comm. - cbv [max_ones]. - rewrite Z.land_ones by apply Z.le_fold_right_max_initial. - apply Z.mod_small. - apply In_nth with (d := 0) in HIn. - destruct HIn as [i HIn]; destruct HIn; subst. - rewrite bounded_iff in H0. - specialize (H0 i). + intros. + let i := fresh "i" in + match goal with H : In ?x ?us, Hb : bounded _ _ |- _ => + apply In_nth with (d := 0) in H; destruct H as [i [? ?] ]; + rewrite bounded_iff in Hb; specialize (Hb i); + assert (2 ^ nth i limb_widths 0 <= 2 ^ B) by + (apply Z.pow_le_mono_r; try apply B_compat, nth_In; omega) end. rewrite !nth_default_eq in *. - split; try omega. - eapply Z.lt_le_trans; try intuition eassumption. - apply Z.pow_le_mono_r; try omega. - apply Z.le_fold_right_max; eauto. - apply nth_In. omega. + omega. + Qed. + + Lemma map_land_max_ones : forall us, + length us = length limb_widths -> + bounded limb_widths us -> map (Z.land (Z.ones B)) us = us. + Proof. + repeat match goal with + | |- _ => progress intros + | |- _ => apply map_id_strong + | |- appcontext[Z.ones ?n &' ?x] => rewrite (Z.land_comm _ x); + rewrite Z.land_ones by omega + | |- _ => apply Z.mod_small + | |- _ => solve [eauto using bounded_digit_fits] + end. Qed. Lemma map_land_zero : forall us, map (Z.land 0) us = zeros (length us). @@ -411,32 +441,35 @@ Section ConditionalSubtractModulusProofs. induction us; boring. Qed. - Lemma conditional_subtract_modulus_spec : forall u cond, + Hint Rewrite @length_modulus_digits @length_zeros : distr_length. + Lemma conditional_subtract_modulus_spec : forall u cond + (cond_01 : cond = 0 \/ cond = 1), length u = length limb_widths -> - BaseSystem.decode base (conditional_subtract_modulus u cond) = - BaseSystem.decode base u - (if cond then 1 else 0) * modulus. + BaseSystem.decode base (conditional_subtract_modulus B u cond) = + BaseSystem.decode base u - cond * modulus. Proof. repeat match goal with - | |- _ => progress (cbv [conditional_subtract_modulus]; intros) + | |- _ => progress (cbv [conditional_subtract_modulus neg]; intros) + | |- _ => destruct cond_01; subst | |- _ => break_if | |- _ => rewrite map_land_max_ones by auto using bounded_modulus_digits | |- _ => rewrite map_land_zero - | |- _ => rewrite map2_sub_eq - by (rewrite ?length_modulus_digits, ?length_zeros; congruence) + | |- _ => rewrite map2_sub_eq by distr_length | |- _ => rewrite sub_rep by auto | |- _ => rewrite zeros_rep | |- _ => rewrite decode_modulus_digits by auto | |- _ => f_equal; ring + | |- _ => discriminate end. Qed. Lemma conditional_subtract_modulus_preserves_bounded : forall u, length u = length limb_widths -> bounded limb_widths u -> - bounded limb_widths (conditional_subtract_modulus u (ge_modulus u)). + bounded limb_widths (conditional_subtract_modulus B u (ge_modulus u)). Proof. repeat match goal with - | |- _ => progress (cbv [conditional_subtract_modulus]; intros) + | |- _ => progress (cbv [conditional_subtract_modulus neg]; intros) | |- _ => unique pose proof bounded_modulus_digits | |- _ => rewrite map_land_max_ones by auto using bounded_modulus_digits | |- _ => rewrite map_land_zero @@ -455,11 +488,12 @@ Section ConditionalSubtractModulusProofs. | |- _ => omega end. cbv [ge_modulus] in Heqb. + rewrite Z.eqb_eq in *. apply ge_modulus'_true_digitwise with (j := i) in Heqb; auto; omega. Qed. Lemma bounded_mul2_modulus : forall u, length u = length limb_widths -> - bounded limb_widths u -> ge_modulus u = true -> + bounded limb_widths u -> ge_modulus u = 1 -> modulus <= BaseSystem.decode base u < 2 * modulus. Proof. intros. @@ -494,16 +528,16 @@ Section ConditionalSubtractModulusProofs. Lemma conditional_subtract_lt_modulus : forall u, length u = length limb_widths -> bounded limb_widths u -> - ge_modulus (conditional_subtract_modulus u (ge_modulus u)) = false. + ge_modulus (conditional_subtract_modulus B u (ge_modulus u)) = 0. Proof. intros. - rewrite ge_modulus_spec by auto using length_conditional_subtract_modulus, - conditional_subtract_modulus_preserves_bounded. + rewrite ge_modulus_spec by auto using length_conditional_subtract_modulus, conditional_subtract_modulus_preserves_bounded. + pose proof (ge_modulus_01 u) as Hgm01. rewrite conditional_subtract_modulus_spec by auto. - break_if. - + pose proof (bounded_mul2_modulus u); specialize_by auto. + destruct Hgm01 as [Hgm0 | Hgm1]; rewrite ?Hgm0, ?Hgm1. + + apply ge_modulus_spec in Hgm0; auto. omega. - + apply ge_modulus_spec in Heqb; auto. + + pose proof (bounded_mul2_modulus u); specialize_by auto. omega. Qed. End ConditionalSubtractModulusProofs.
\ No newline at end of file diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index ff1dd87dd..d0af83e11 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -49,7 +49,6 @@ Definition full_carry_chain_opt := Eval compute in @Pow2Base.full_carry_chain. Definition length_opt := Eval compute in length. Definition base_from_limb_widths_opt := Eval compute in @Pow2Base.base_from_limb_widths. Definition minus_opt := Eval compute in minus. -Definition max_ones_opt := Eval compute in @max_ones. Definition from_list_default_opt {A} := Eval compute in (@from_list_default A). Definition sum_firstn_opt {A} := Eval compute in (@sum_firstn A). Definition zeros_opt := Eval compute in (@zeros). @@ -954,18 +953,16 @@ Section Canonicalization. -> carry_full_3_opt_cps f us = f (carry_full (carry_full (carry_full us))) := proj2_sig (carry_full_3_opt_cps_sig f us). - Definition conditional_subtract_modulus_opt_sig (f : list Z) (cond : bool) : - { g | g = conditional_subtract_modulus f cond}. + Definition conditional_subtract_modulus_opt_sig (f : list Z) (cond : Z) : + { g | g = conditional_subtract_modulus int_width f cond}. Proof. eexists. cbv [conditional_subtract_modulus]. - match goal with |- appcontext[if cond then ?a else ?b] => let LHS := match goal with |- ?LHS = ?RHS => LHS end in let RHS := match goal with |- ?LHS = ?RHS => RHS end in - let RHSf := match (eval pattern (if cond then a else b) in RHS) with ?RHSf _ => RHSf end in - change (LHS = Let_In (if cond then a else b) RHSf) end. + let RHSf := match (eval pattern (neg int_width cond) in RHS) with ?RHSf _ => RHSf end in + change (LHS = Let_In (neg int_width cond) RHSf). cbv [map2 map]. - change @max_ones with max_ones_opt. reflexivity. Defined. @@ -973,39 +970,51 @@ Section Canonicalization. := Eval cbv [proj1_sig conditional_subtract_modulus_opt_sig] in proj1_sig (conditional_subtract_modulus_opt_sig f cond). Definition conditional_subtract_modulus_opt_correct f cond - : conditional_subtract_modulus_opt f cond = conditional_subtract_modulus f cond + : conditional_subtract_modulus_opt f cond = conditional_subtract_modulus int_width f cond := Eval cbv [proj2_sig conditional_subtract_modulus_opt_sig] in proj2_sig (conditional_subtract_modulus_opt_sig f cond). - Definition ge_modulus_opt_sig (us : list Z) : - { b : bool | b = ModularBaseSystemList.ge_modulus us }. + Lemma ge_modulus'_cps : forall {A} (f : Z -> A) (us : list Z) i b, + f (ge_modulus' id us b i) = ge_modulus' f us b i. + Proof. + induction i; intros; simpl; cbv [Let_In]; break_if; try reflexivity; + apply IHi. + Qed. +(* + Definition ge_modulus'_opt_sig {A} (f : Z -> A) (us : list Z) b i : + { a : A | a = ModularBaseSystemList.ge_modulus' f us b i}. Proof. eexists. - cbv beta iota delta [ge_modulus ge_modulus']. + cbv [ge_modulus ge_modulus']. change length with length_opt. change nth_default with @nth_default_opt. change minus with minus_opt. reflexivity. Defined. - Definition ge_modulus_opt us : bool - := Eval cbv [proj1_sig ge_modulus_opt_sig] in proj1_sig (ge_modulus_opt_sig us). + Definition ge_modulus'_opt {A} f us b i : Z + := Eval cbv [proj1_sig ge_modulus'_opt_sig] in proj1_sig (@ge_modulus'_opt_sig A f us b i). - Definition ge_modulus_opt_correct us : - ge_modulus_opt us = ge_modulus us - := Eval cbv [proj2_sig ge_modulus_opt_sig] in proj2_sig (ge_modulus_opt_sig us). + Definition ge_modulus'_opt_correct {A} f us : + @ge_modulus'_opt A f us b i = @ge_modulus' A f us b i + := Eval cbv [proj2_sig ge_modulus_opt_sig] in proj2_sig (@ge_modulus'_opt_sig A f us). +*) Definition freeze_opt_sig (us : list Z) : { b : list Z | length us = length limb_widths - -> b = ModularBaseSystemList.freeze us }. + -> b = ModularBaseSystemList.freeze int_width us }. Proof. eexists. cbv [ModularBaseSystemList.freeze]. rewrite <-conditional_subtract_modulus_opt_correct. intros. + cbv [ge_modulus]. + rewrite ge_modulus'_cps. let LHS := match goal with |- ?LHS = ?RHS => LHS end in let RHS := match goal with |- ?LHS = ?RHS => RHS end in let RHSf := match (eval pattern (carry_full (carry_full (carry_full us))) in RHS) with ?RHSf _ => RHSf end in rewrite <-carry_full_3_opt_cps_correct with (f := RHSf) by auto. + cbv [carry_full_3_opt_cps carry_full_opt_cps carry_sequence_opt_cps]. + cbv [ge_modulus']. cbv beta iota delta [ge_modulus ge_modulus']. change length with length_opt. change nth_default with @nth_default_opt. @@ -1019,7 +1028,7 @@ Section Canonicalization. Definition freeze_opt_correct us : length us = length limb_widths - -> freeze_opt us = ModularBaseSystemList.freeze us + -> freeze_opt us = ModularBaseSystemList.freeze int_width us := proj2_sig (freeze_opt_sig us). End Canonicalization. @@ -1061,15 +1070,16 @@ Section SquareRoots. End SquareRoot3mod4. Import Morphisms. - Global Instance eqb_Proper : Proper (eq ==> eq ==> Logic.eq) ModularBaseSystem.eqb. Admitted. + Global Instance eqb_Proper : Proper (Logic.eq ==> eq ==> eq ==> Logic.eq) ModularBaseSystem.eqb. Admitted. Section SquareRoot5mod8. Context {ec : ExponentiationChain (modulus / 8 + 1)}. Context (sqrt_m1 : digits) (sqrt_m1_correct : rep (mul sqrt_m1 sqrt_m1) (F.opp 1%F)). + Context {int_width} (preconditions : freezePreconditions prm int_width). Definition sqrt_5mod8_opt_sig (us : digits) : { vs : digits | - eq vs (sqrt_5mod8 (carry_mul_opt k_ c_) (pow_opt k_ c_ one_) chain chain_correct sqrt_m1 us)}. + eq vs (sqrt_5mod8 int_width (carry_mul_opt k_ c_) (pow_opt k_ c_ one_) chain chain_correct sqrt_m1 us)}. Proof. eexists; cbv [sqrt_5mod8]. let LHS := match goal with |- eq ?LHS ?RHS => LHS end in @@ -1083,7 +1093,7 @@ Section SquareRoots. proj1_sig (sqrt_5mod8_opt_sig us). Definition sqrt_5mod8_opt_correct us - : eq (sqrt_5mod8_opt us) (ModularBaseSystem.sqrt_5mod8 _ _ chain chain_correct sqrt_m1 us) + : eq (sqrt_5mod8_opt us) (ModularBaseSystem.sqrt_5mod8 int_width _ _ chain chain_correct sqrt_m1 us) := Eval cbv [proj2_sig sqrt_5mod8_opt_sig] in proj2_sig (sqrt_5mod8_opt_sig us). End SquareRoot5mod8. diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index f8ad0969d..9a07e8ec0 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -833,7 +833,7 @@ Section CanonicalizationProofs. then 0 else (2 ^ B) >> (limb_widths [pred n]))). Local Notation minimal_rep u := ((bounded limb_widths (to_list (length limb_widths) u)) - /\ (ge_modulus (to_list _ u) = false)). + /\ (ge_modulus (to_list _ u) = 0)). Lemma decode_bitwise_eq_iff : forall u v, minimal_rep u -> minimal_rep v -> (fieldwise Logic.eq u v <-> @@ -896,7 +896,7 @@ Section CanonicalizationProofs. Qed. Lemma minimal_rep_freeze : forall u, initial_bounds u -> - minimal_rep (freeze u). + minimal_rep (freeze B u). Proof. repeat match goal with | |- _ => progress (cbv [freeze ModularBaseSystemList.freeze]) @@ -907,12 +907,12 @@ Section CanonicalizationProofs. | |- _ => apply conditional_subtract_lt_modulus | |- _ => apply conditional_subtract_modulus_preserves_bounded | |- bounded _ (carry_full _) => apply bounded_iff - | |- _ => solve [auto using lt_1_length_limb_widths, length_carry_full, length_to_list] + | |- _ => solve [auto using B_pos, B_compat, lt_1_length_limb_widths, length_carry_full, length_to_list] end. Qed. Lemma freeze_decode : forall u, - BaseSystem.decode base (to_list _ (freeze u)) mod modulus = + BaseSystem.decode base (to_list _ (freeze B u)) mod modulus = BaseSystem.decode base (to_list _ u) mod modulus. Proof. repeat match goal with @@ -922,7 +922,7 @@ Section CanonicalizationProofs. | |- _ => rewrite Z.mod_add by (pose proof prime_modulus; prime_bound) | |- _ => rewrite to_list_from_list | |- _ => rewrite conditional_subtract_modulus_spec by - auto using lt_1_length_limb_widths, length_carry_full, length_to_list + auto using B_pos, B_compat, lt_1_length_limb_widths, length_carry_full, length_to_list, ge_modulus_01 end. rewrite !decode_mod_Fdecode by auto using length_carry_full, length_to_list. cbv [carry_full]. @@ -941,7 +941,7 @@ Section CanonicalizationProofs. rewrite from_list_to_list; reflexivity. Qed. - Lemma freeze_rep : forall u x, rep u x -> rep (freeze u) x. + Lemma freeze_rep : forall u x, rep u x -> rep (freeze B u) x. Proof. cbv [rep]; intros. apply F.eq_to_Z_iff. @@ -952,7 +952,7 @@ Section CanonicalizationProofs. Lemma freeze_canonical : forall u v x y, rep u x -> rep v y -> initial_bounds u -> initial_bounds v -> - (x = y <-> fieldwise Logic.eq (freeze u) (freeze v)). + (x = y <-> fieldwise Logic.eq (freeze B u) (freeze B v)). Proof. intros; apply bounded_canonical; auto using freeze_rep, minimal_rep_freeze. Qed. @@ -977,7 +977,7 @@ Section SquareRootProofs. Lemma eqb_true_iff : forall u v x y, bounded_by u freeze_input_bounds -> bounded_by v freeze_input_bounds -> - u ~= x -> v ~= y -> (x = y <-> eqb u v = true). + u ~= x -> v ~= y -> (x = y <-> eqb B u v = true). Proof. cbv [eqb freeze_input_bounds]. intros. rewrite fieldwiseb_fieldwise by (apply Z.eqb_eq). @@ -986,10 +986,10 @@ Section SquareRootProofs. Lemma eqb_false_iff : forall u v x y, bounded_by u freeze_input_bounds -> bounded_by v freeze_input_bounds -> - u ~= x -> v ~= y -> (x <> y <-> eqb u v = false). + u ~= x -> v ~= y -> (x <> y <-> eqb B u v = false). Proof. intros. - case_eq (eqb u v). + case_eq (eqb B u v). + rewrite <-eqb_true_iff by eassumption; split; intros; congruence || contradiction. + split; intros; auto. @@ -1032,24 +1032,24 @@ Section SquareRootProofs. Lemma sqrt_5mod8_correct : forall u x, u ~= x -> bounded_by u pow_input_bounds -> bounded_by u freeze_input_bounds -> - (sqrt_5mod8 mul_ pow_ chain chain_correct sqrt_m1 u) ~= F.sqrt_5mod8 (decode sqrt_m1) x. + (sqrt_5mod8 B mul_ pow_ chain chain_correct sqrt_m1 u) ~= F.sqrt_5mod8 (decode sqrt_m1) x. Proof. repeat match goal with | |- _ => progress (cbv [sqrt_5mod8 F.sqrt_5mod8]; intros) | |- _ => rewrite @F.pow_2_r in * | |- _ => rewrite eqb_correct in * by eassumption - | |- (if eqb ?a ?b then _ else _) ~= + | |- (if eqb _ ?a ?b then _ else _) ~= (if dec (?c = _) then _ else _) => assert (a ~= c); rewrite !mul_equiv, pow_equiv in *; repeat break_if | |- _ => apply mul_rep; try reflexivity; rewrite <-chain_correct; apply pow_rep; eassumption | |- _ => rewrite <-chain_correct; apply pow_rep; eassumption - | H : eqb ?a ?b = true |- _ => + | H : eqb _ ?a ?b = true |- _ => rewrite <-(eqb_true_iff a b) in H by (eassumption || rewrite <-mul_equiv, <-pow_equiv; apply mul_bounded, pow_bounded; auto) - | H : eqb ?a ?b = false |- _ => + | H : eqb _ ?a ?b = false |- _ => rewrite <-(eqb_false_iff a b) in H by (eassumption || rewrite <-mul_equiv, <-pow_equiv; apply mul_bounded, pow_bounded; auto) |