diff options
author | Andres Erbsen <andreser@mit.edu> | 2017-04-06 15:47:12 -0400 |
---|---|---|
committer | Andres Erbsen <andreser@mit.edu> | 2017-04-06 15:47:12 -0400 |
commit | f8cc64c7ca411828cac5cad2958959b0d779d683 (patch) | |
tree | 1c74a6bf6bc92522db7451e3c0dd748c57a2ece3 | |
parent | 33b1e92e1a71f284461e0c8d1d22b1d28b29bd7a (diff) |
start removing BaseSystem
31 files changed, 11 insertions, 5125 deletions
diff --git a/_CoqProject b/_CoqProject index de6c1cd42..d545cc6b2 100644 --- a/_CoqProject +++ b/_CoqProject @@ -3,8 +3,6 @@ Bedrock/Nomega.v Bedrock/Word.v src/Algebra.v -src/BaseSystem.v -src/BaseSystemProofs.v src/EdDSARepChange.v src/Karatsuba.v src/MontgomeryCurve.v @@ -14,7 +12,6 @@ src/MontgomeryXProofs.v src/MxDHRepChange.v src/NewBaseSystem.v src/SaturatedBaseSystem.v -src/Testbit.v src/Algebra/Field.v src/Algebra/Field_test.v src/Algebra/Group.v @@ -24,9 +21,14 @@ src/Algebra/Ring.v src/Algebra/ScalarMult.v src/BoundedArithmetic/ArchitectureToZLike.v src/BoundedArithmetic/ArchitectureToZLikeProofs.v +src/BoundedArithmetic/BaseSystem.v +src/BoundedArithmetic/BaseSystemProofs.v +src/BoundedArithmetic/CaseUtil.v src/BoundedArithmetic/Eta.v src/BoundedArithmetic/Interface.v src/BoundedArithmetic/InterfaceProofs.v +src/BoundedArithmetic/Pow2Base.v +src/BoundedArithmetic/Pow2BaseProofs.v src/BoundedArithmetic/StripCF.v src/BoundedArithmetic/Double/Core.v src/BoundedArithmetic/Double/Proofs/BitwiseOr.v @@ -46,24 +48,9 @@ src/CompleteEdwardsCurve/ExtendedCoordinates.v src/CompleteEdwardsCurve/Pre.v src/Experiments/ExtrHaskellNats.v src/Experiments/GenericFieldPow.v -src/ModularArithmetic/Conversion.v -src/ModularArithmetic/ExtPow2BaseMulProofs.v -src/ModularArithmetic/ExtendedBaseVector.v src/ModularArithmetic/ModularArithmeticTheorems.v -src/ModularArithmetic/ModularBaseSystem.v -src/ModularArithmetic/ModularBaseSystemList.v -src/ModularArithmetic/ModularBaseSystemListProofs.v -src/ModularArithmetic/ModularBaseSystemListZOperations.v -src/ModularArithmetic/ModularBaseSystemListZOperationsProofs.v -src/ModularArithmetic/ModularBaseSystemOpt.v -src/ModularArithmetic/ModularBaseSystemProofs.v -src/ModularArithmetic/ModularBaseSystemWord.v -src/ModularArithmetic/Pow2Base.v -src/ModularArithmetic/Pow2BaseProofs.v src/ModularArithmetic/Pre.v src/ModularArithmetic/PrimeFieldTheorems.v -src/ModularArithmetic/PseudoMersenneBaseParamProofs.v -src/ModularArithmetic/PseudoMersenneBaseParams.v src/ModularArithmetic/ZBounded.v src/ModularArithmetic/ZBoundedZ.v src/ModularArithmetic/BarrettReduction/Z.v @@ -201,8 +188,6 @@ src/Spec/ModularArithmetic.v src/Spec/MontgomeryCurve.v src/Spec/MxDH.v src/Spec/WeierstrassCurve.v -src/Specific/GF1305.v -src/Specific/GF25519.v src/Specific/IntegrationTest.v src/Specific/NewBaseSystemTest.v src/Specific/SC25519.v @@ -217,7 +202,6 @@ src/Util/AutoRewrite.v src/Util/Bool.v src/Util/BoundedWord.v src/Util/CPSUtil.v -src/Util/CaseUtil.v src/Util/ChangeInAll.v src/Util/Curry.v src/Util/Decidable.v diff --git a/src/BaseSystem.v b/src/BoundedArithmetic/BaseSystem.v index 5d48c0977..5d48c0977 100644 --- a/src/BaseSystem.v +++ b/src/BoundedArithmetic/BaseSystem.v diff --git a/src/BaseSystemProofs.v b/src/BoundedArithmetic/BaseSystemProofs.v index 409d8b7db..409d8b7db 100644 --- a/src/BaseSystemProofs.v +++ b/src/BoundedArithmetic/BaseSystemProofs.v diff --git a/src/Util/CaseUtil.v b/src/BoundedArithmetic/CaseUtil.v index 2d1ab6c58..2d1ab6c58 100644 --- a/src/Util/CaseUtil.v +++ b/src/BoundedArithmetic/CaseUtil.v diff --git a/src/BoundedArithmetic/Double/Core.v b/src/BoundedArithmetic/Double/Core.v index 82b450e76..6b6726f77 100644 --- a/src/BoundedArithmetic/Double/Core.v +++ b/src/BoundedArithmetic/Double/Core.v @@ -2,7 +2,7 @@ Require Import Coq.ZArith.ZArith. Require Import Crypto.BoundedArithmetic.Interface. Require Import Crypto.BoundedArithmetic.InterfaceProofs. -Require Import Crypto.ModularArithmetic.Pow2Base. +Require Import Crypto.BoundedArithmetic.Pow2Base. Require Import Crypto.Util.Tuple. Require Import Crypto.Util.ListUtil. Require Import Crypto.Util.Notations. diff --git a/src/BoundedArithmetic/Double/Proofs/Decode.v b/src/BoundedArithmetic/Double/Proofs/Decode.v index e3d57bdfc..a84e84acf 100644 --- a/src/BoundedArithmetic/Double/Proofs/Decode.v +++ b/src/BoundedArithmetic/Double/Proofs/Decode.v @@ -2,8 +2,8 @@ Require Import Coq.ZArith.ZArith Coq.Lists.List Coq.micromega.Psatz. Require Import Crypto.BoundedArithmetic.Interface. Require Import Crypto.BoundedArithmetic.InterfaceProofs. Require Import Crypto.BaseSystem. -Require Import Crypto.ModularArithmetic.Pow2Base. -Require Import Crypto.ModularArithmetic.Pow2BaseProofs. +Require Import Crypto.BoundedArithmetic.Pow2Base. +Require Import Crypto.BoundedArithmetic.Pow2BaseProofs. Require Import Crypto.BoundedArithmetic.Double.Core. Require Import Crypto.Util.Tuple. Require Import Crypto.Util.ZUtil. diff --git a/src/ModularArithmetic/Pow2Base.v b/src/BoundedArithmetic/Pow2Base.v index 0175018f8..0175018f8 100644 --- a/src/ModularArithmetic/Pow2Base.v +++ b/src/BoundedArithmetic/Pow2Base.v diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/BoundedArithmetic/Pow2BaseProofs.v index 7a5bb4255..fd92db37d 100644 --- a/src/ModularArithmetic/Pow2BaseProofs.v +++ b/src/BoundedArithmetic/Pow2BaseProofs.v @@ -8,11 +8,11 @@ Require Import Crypto.Util.Tactics.SpecializeBy. Require Import Crypto.Util.Tactics.BreakMatch. Require Import Crypto.Util.Tactics.UniquePose. Require Import Crypto.Util.Tactics.RewriteHyp. -Require Import Crypto.ModularArithmetic.Pow2Base Crypto.BaseSystemProofs. +Require Import Crypto.BoundedArithmetic.Pow2Base. +Require Import Crypto.BoundedArithmetic.BaseSystemProofs. Require Import Crypto.Util.Notations. Require Export Crypto.Util.Bool. Require Export Crypto.Util.FixCoqMistakes. -Require Crypto.BaseSystem. Local Open Scope Z_scope. Create HintDb simpl_add_to_nth discriminated. diff --git a/src/ModularArithmetic/Conversion.v b/src/ModularArithmetic/Conversion.v deleted file mode 100644 index 3e8436f43..000000000 --- a/src/ModularArithmetic/Conversion.v +++ /dev/null @@ -1,318 +0,0 @@ -Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.micromega.Psatz. -Require Import Coq.Numbers.Natural.Peano.NPeano. -Require Import Coq.Lists.List. -Require Import Coq.funind.Recdef. -Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil. -Require Import Crypto.Tactics.VerdiTactics. -Require Import Crypto.Util.Tactics.UniquePose. -Require Import Crypto.Util.Tactics.SpecializeBy. -Require Import Crypto.Util.Tactics.SubstLet. -Require Import Crypto.Util.Tactics.Forward. -Require Import Crypto.ModularArithmetic.Pow2Base. -Require Import Crypto.ModularArithmetic.Pow2BaseProofs Crypto.BaseSystemProofs. -Require Import Crypto.Util.Notations. -Require Export Crypto.Util.FixCoqMistakes. -Require Crypto.BaseSystem. -Local Open Scope Z_scope. - -Section ConversionHelper. - Local Hint Resolve in_eq in_cons. - - (* concatenates first n bits of a with all bits of b *) - Definition concat_bits n a b := Z.lor (Z.pow2_mod a n) (b << n). - - Lemma concat_bits_spec : forall a b n i, 0 <= n -> - Z.testbit (concat_bits n a b) i = - if Z_lt_dec i n then Z.testbit a i else Z.testbit b (i - n). - Proof. - repeat match goal with - | |- _ => progress cbv [concat_bits]; intros - | |- _ => progress autorewrite with Ztestbit - | |- _ => rewrite Z.testbit_pow2_mod by omega - | |- _ => rewrite Z.testbit_neg_r by omega - | |- _ => break_if - | |- appcontext [Z.testbit (?a << ?b) ?i] => destruct (Z_le_dec 0 i) - | |- (?a || ?b)%bool = ?a => replace b with false - | |- _ => reflexivity - end. - Qed. - - Definition update_by_concat_bits num_low_bits bits x := concat_bits num_low_bits x bits. - -End ConversionHelper. - -Section Conversion. - Context {limb_widthsA} (limb_widthsA_nonneg : forall w, In w limb_widthsA -> 0 <= w) - {limb_widthsB} (limb_widthsB_nonneg : forall w, In w limb_widthsB -> 0 <= w). - Local Notation bitsIn lw := (sum_firstn lw (length lw)). - Context (bits_fit : bitsIn limb_widthsA <= bitsIn limb_widthsB). - Local Notation decodeA := (BaseSystem.decode (base_from_limb_widths limb_widthsA)). - Local Notation decodeB := (BaseSystem.decode (base_from_limb_widths limb_widthsB)). - Local Notation "u # i" := (nth_default 0 u i). - Local Hint Resolve in_eq in_cons nth_default_limb_widths_nonneg sum_firstn_limb_widths_nonneg Nat2Z.is_nonneg. - Local Opaque bounded. - - Function convert' inp i out - {measure (fun x => Z.to_nat ((bitsIn limb_widthsA) - Z.of_nat x)) i}:= - if Z_le_dec (bitsIn limb_widthsA) (Z.of_nat i) - then out - else - let digitA := digit_index limb_widthsA (Z.of_nat i) in - let digitB := digit_index limb_widthsB (Z.of_nat i) in - let indexA := bit_index limb_widthsA (Z.of_nat i) in - let indexB := bit_index limb_widthsB (Z.of_nat i) in - let dist := Z.min ((limb_widthsA # digitA) - indexA) ((limb_widthsB # digitB) - indexB) in - let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in - convert' inp (i + Z.to_nat dist)%nat (update_nth digitB (update_by_concat_bits indexB bitsA) out). - Proof. - generalize limb_widthsA_nonneg; intros _. (* don't drop this from the proof in 8.4 *) - generalize limb_widthsB_nonneg; intros _. (* don't drop this from the proof in 8.4 *) - repeat match goal with - | |- _ => progress intros - | |- appcontext [bit_index (Z.of_nat ?i)] => - unique pose proof (Nat2Z.is_nonneg i) - | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] => - unique pose proof (bit_index_not_done lw i) - | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] => - unique assert (0 <= i < bitsIn lw -> i + ((lw # digit_index lw i) - bit_index lw i) <= bitsIn lw) by auto using rem_bits_in_digit_le_rem_bits - | |- _ => rewrite Z2Nat.id - | |- _ => rewrite Nat2Z.inj_add - | |- (Z.to_nat _ < Z.to_nat _)%nat => apply Z2Nat.inj_lt - | |- (?a - _ < ?a - _) => apply Z.sub_lt_mono_l - | |- appcontext [Z.min ?a ?b] => unique assert (0 < Z.min a b) by (specialize_by lia; lia) - | |- _ => lia - end. - Defined. - - Definition convert'_invariant inp i out := - length out = length limb_widthsB - /\ bounded limb_widthsB out - /\ Z.of_nat i <= bitsIn limb_widthsA - /\ forall n, Z.testbit (decodeB out) n = if Z_lt_dec n (Z.of_nat i) then Z.testbit (decodeA inp) n else false. - - Ltac subst_lia := subst_let; subst; lia. - - Lemma convert'_bounded_step : forall inp i out, - bounded limb_widthsB out -> - let digitA := digit_index limb_widthsA (Z.of_nat i) in - let digitB := digit_index limb_widthsB (Z.of_nat i) in - let indexA := bit_index limb_widthsA (Z.of_nat i) in - let indexB := bit_index limb_widthsB (Z.of_nat i) in - let dist := Z.min ((limb_widthsA # digitA) - indexA) - ((limb_widthsB # digitB) - indexB) in - let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in - 0 < dist -> - bounded limb_widthsB (update_nth digitB (update_by_concat_bits indexB bitsA) out). - Proof using limb_widthsB_nonneg. - repeat match goal with - | |- _ => progress intros - | |- _ => progress autorewrite with Ztestbit - | |- _ => rewrite update_nth_nth_default_full - | |- _ => rewrite Z.testbit_pow2_mod - | |- _ => break_if - | |- _ => progress cbv [update_by_concat_bits]; - rewrite concat_bits_spec by (apply bit_index_nonneg; auto using Nat2Z.is_nonneg) - | |- bounded _ _ => apply pow2_mod_bounded_iff - | |- Z.pow2_mod _ _ = _ => apply Z.bits_inj' - | |- false = Z.testbit _ _ => symmetry - | x := _ |- Z.testbit ?x _ = _ => subst x - | |- Z.testbit _ _ = false => eapply testbit_bounded_high; eauto; lia - | |- _ => solve [auto] - | |- _ => subst_lia - end. - Qed. - - Lemma convert'_index_step : forall inp i out, - bounded limb_widthsB out -> - let digitA := digit_index limb_widthsA (Z.of_nat i) in - let digitB := digit_index limb_widthsB (Z.of_nat i) in - let indexA := bit_index limb_widthsA (Z.of_nat i) in - let indexB := bit_index limb_widthsB (Z.of_nat i) in - let dist := Z.min ((limb_widthsA # digitA) - indexA) - ((limb_widthsB # digitB) - indexB) in - let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in - 0 < dist -> - Z.of_nat i < bitsIn limb_widthsA -> - Z.of_nat i + dist <= bitsIn limb_widthsA. - Proof using limb_widthsA_nonneg. - pose proof (rem_bits_in_digit_le_rem_bits limb_widthsA). - pose proof (rem_bits_in_digit_le_rem_bits limb_widthsA). - repeat match goal with - | |- _ => progress intros - | H : forall x : Z, In x ?lw -> x = ?y, H0 : 0 < ?y |- _ => - unique pose proof (uniform_limb_widths_nonneg H0 lw H) - | |- _ => progress specialize_by assumption - | H : _ /\ _ |- _ => destruct H - | |- _ => break_if - | |- _ => split - | a := digit_index _ ?i, H : forall x, 0 <= x < bitsIn _ -> _ |- _ => specialize (H i); forward H - | |- _ => subst_lia - | |- _ => apply bit_index_pos_iff; auto - | |- _ => apply Nat2Z.is_nonneg - end. - Qed. - - Lemma convert'_invariant_step : forall inp i out, - length inp = length limb_widthsA -> - bounded limb_widthsA inp -> - convert'_invariant inp i out -> - let digitA := digit_index limb_widthsA (Z.of_nat i) in - let digitB := digit_index limb_widthsB (Z.of_nat i) in - let indexA := bit_index limb_widthsA (Z.of_nat i) in - let indexB := bit_index limb_widthsB (Z.of_nat i) in - let dist := Z.min ((limb_widthsA # digitA) - indexA) - ((limb_widthsB # digitB) - indexB) in - let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in - 0 < dist -> - Z.of_nat i < bitsIn limb_widthsA -> - convert'_invariant inp (i + Z.to_nat dist)%nat - (update_nth digitB (update_by_concat_bits indexB bitsA) out). - Proof using Type*. - Time - repeat match goal with - | |- _ => progress intros; cbv [convert'_invariant] in * - | |- _ => progress autorewrite with Ztestbit - | H : forall x, In x ?lw -> 0 <= x |- appcontext[digit_index ?lw ?i] => - unique pose proof (digit_index_lt_length lw H i) - | |- _ => rewrite Nat2Z.inj_add - | |- _ => rewrite Z2Nat.id in * - | H : forall n, Z.testbit (decodeB _) n = _ |- Z.testbit (decodeB _) ?n = _ => - specialize (H n) - | H0 : ?n < ?i, H1 : ?n < ?i + ?d, - H : Z.testbit (decodeB _) ?n = Z.testbit (decodeA _) ?n |- _ = Z.testbit (decodeA _) ?n => - rewrite <-H - | H : _ /\ _ |- _ => destruct H - | |- _ => break_if - | |- _ => split - | |- _ => rewrite testbit_decode_full - | |- _ => rewrite update_nth_nth_default_full - | |- _ => rewrite nth_default_out_of_bounds by omega - | H : ~ (0 <= ?n ) |- appcontext[Z.testbit ?a ?n] => rewrite (Z.testbit_neg_r a n) by omega - | |- _ => progress cbv [update_by_concat_bits]; - rewrite concat_bits_spec by (apply bit_index_nonneg; auto using Nat2Z.is_nonneg) - | |- _ => solve [distr_length] - | |- _ => eapply convert'_bounded_step; solve [auto] - | |- _ => etransitivity; [ | eapply convert'_index_step]; subst_let; eauto; lia - | H : digit_index limb_widthsB ?i = digit_index limb_widthsB ?j |- _ => - unique assert (digit_index limb_widthsA i = digit_index limb_widthsA j) by - (symmetry; apply same_digit; assumption || lia); - pose proof (same_digit_bit_index_sub limb_widthsA j i) as X; - forward X; [ | lia | lia | lia ] - | d := digit_index ?lw ?j, - H : digit_index ?lw ?i <> ?d |- _ => - exfalso; apply H; symmetry; apply same_digit; assumption || subst_lia - | d := digit_index ?lw ?j, - H : digit_index ?lw ?i = ?d |- _ => - let X := fresh "H" in - ((pose proof (same_digit_bit_index_sub lw i j) as X; - forward X; [ subst_let | subst_lia | lia | lia ]) || - (pose proof (same_digit_bit_index_sub lw j i) as X; - forward X; [ subst_let | subst_lia | lia | lia ])) - | |- Z.testbit _ (bit_index ?lw _ - bit_index ?lw ?i + _) = false => - apply (@testbit_bounded_high limb_widthsA); auto; - rewrite (same_digit_bit_index_sub) by subst_lia; - rewrite <-(split_index_eqn limb_widthsA i) at 2 by lia - | |- ?lw # ?b <= ?a - ((sum_firstn ?lw ?b) + ?c) + ?c => replace (a - (sum_firstn lw b + c) + c) with (a - sum_firstn lw b) by ring; apply Z.le_add_le_sub_r - | |- (?lw # ?n) + sum_firstn ?lw ?n <= _ => - rewrite <-sum_firstn_succ_default; transitivity (bitsIn lw); [ | lia]; - apply sum_firstn_prefix_le; auto; lia - | |- _ => lia - | |- _ => assumption - | |- _ => solve [auto] - | |- _ => rewrite <-testbit_decode by (assumption || lia || auto); assumption - | |- _ => repeat (f_equal; try congruence); lia - end. - Qed. - - Lemma convert'_invariant_holds : forall inp i out, - length inp = length limb_widthsA -> - bounded limb_widthsA inp -> - convert'_invariant inp i out -> - convert'_invariant inp (Z.to_nat (bitsIn limb_widthsA)) (convert' inp i out). - Proof using Type. - intros until 2; functional induction (convert' inp i out); - repeat match goal with - | |- _ => progress intros - | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] => - unique pose proof (bit_index_not_done lw i) - | H : convert'_invariant _ _ _ |- convert'_invariant _ _ (convert' _ _ _) => - eapply convert'_invariant_step in H; solve [auto; specialize_by lia; lia] - | H : convert'_invariant _ _ ?out |- convert'_invariant _ _ ?out => progress cbv [convert'_invariant] in * - | H : _ /\ _ |- _ => destruct H - | |- _ => rewrite Z2Nat.id - | |- _ => split - | |- _ => assumption - | |- _ => lia - | |- _ => solve [eauto] - | |- _ => replace (bitsIn limb_widthsA) with (Z.of_nat i) by (apply Z.le_antisymm; assumption) - end. - Qed. - - Definition convert us := convert' us 0 (BaseSystem.zeros (length limb_widthsB)). - - Lemma convert_correct : forall us, length us = length limb_widthsA -> - bounded limb_widthsA us -> - decodeA us = decodeB (convert us). - Proof using Type. - repeat match goal with - | |- _ => progress intros - | |- _ => progress cbv [convert convert'_invariant] in * - | |- _ => progress change (Z.of_nat 0) with 0 in * - | |- _ => progress rewrite ?length_zeros, ?zeros_rep, ?Z.testbit_0_l - | H : length _ = length limb_widthsA |- _ => rewrite H - | |- _ => rewrite Z.testbit_neg_r by omega - | |- _ => rewrite nth_default_zeros - | |- _ => break_if - | |- _ => split - | H : _ /\ _ |- _ => destruct H - | H : forall n, Z.testbit ?x n = _ |- _ = ?x => apply Z.bits_inj'; intros; rewrite H - | |- _ = decodeB (convert' ?a ?b ?c) => edestruct (convert'_invariant_holds a b c) - | |- _ => apply testbit_decode_high - | |- _ => assumption - | |- _ => reflexivity - | |- _ => lia - | |- _ => solve [auto using sum_firstn_limb_widths_nonneg] - | |- _ => solve [apply nth_default_preserves_properties; auto; lia] - | |- _ => rewrite Z2Nat.id in * - | |- bounded _ _ => apply bounded_iff - | |- 0 < 2 ^ _ => zero_bounds - end. - Qed. - - (* This is part of convert'_invariant, but proving it separately strips preconditions *) - Lemma convert'_bounded : forall inp i out, - bounded limb_widthsB out -> - bounded limb_widthsB (convert' inp i out). - Proof using Type. - intros; functional induction (convert' inp i out); auto. - apply IHl. - apply convert'_bounded_step; auto. - clear IHl. - pose proof (bit_index_not_done limb_widthsA (Z.of_nat i)). - pose proof (bit_index_not_done limb_widthsB (Z.of_nat i)). - specialize_by lia. - lia. - Qed. - - Lemma convert_bounded : forall us, bounded limb_widthsB (convert us). - Proof using Type. - intros; apply convert'_bounded. - apply bounded_iff; intros. - rewrite nth_default_zeros. - split; zero_bounds. - Qed. - - (* This is part of convert'_invariant, but proving it separately strips preconditions *) - Lemma length_convert' : forall inp i out, - length (convert' inp i out) = length out. - Proof using Type. - intros; functional induction (convert' inp i out); distr_length. - Qed. - - Lemma length_convert : forall us, length (convert us) = length limb_widthsB. - Proof using Type. - cbv [convert]; intros. - rewrite length_convert', length_zeros. - reflexivity. - Qed. -End Conversion. diff --git a/src/ModularArithmetic/ExtPow2BaseMulProofs.v b/src/ModularArithmetic/ExtPow2BaseMulProofs.v deleted file mode 100644 index 38e9cf634..000000000 --- a/src/ModularArithmetic/ExtPow2BaseMulProofs.v +++ /dev/null @@ -1,34 +0,0 @@ -Require Import Coq.ZArith.ZArith Coq.Lists.List. -Require Import Crypto.BaseSystem. -Require Import Crypto.BaseSystemProofs. -Require Import Crypto.ModularArithmetic.Pow2Base. -Require Import Crypto.ModularArithmetic.Pow2BaseProofs. -Require Import Crypto.ModularArithmetic.ExtendedBaseVector. -Require Import Crypto.Util.ListUtil. - -Local Open Scope Z_scope. - -Section ext_mul. - Context (limb_widths : list Z) - (limb_widths_nonnegative : forall x, In x limb_widths -> 0 <= x). - Local Notation k := (sum_firstn limb_widths (length limb_widths)). - Local Notation base := (base_from_limb_widths limb_widths). - Context (bv : BaseVector base) - (limb_widths_match_modulus : forall i j, - (i < length limb_widths)%nat -> - (j < length limb_widths)%nat -> - (i + j >= length limb_widths)%nat -> - let w_sum := sum_firstn limb_widths in - k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j). - - Local Hint Resolve firstn_us_base_ext_base ExtBaseVector bv. - - Lemma mul_rep_extended : forall (us vs : BaseSystem.digits), - (length us <= length base)%nat -> - (length vs <= length base)%nat -> - (BaseSystem.decode base us) * (BaseSystem.decode base vs) = BaseSystem.decode (ext_base limb_widths) (BaseSystem.mul (ext_base limb_widths) us vs). - Proof using Type*. - intros; apply mul_rep_two_base; auto; - distr_length. - Qed. -End ext_mul. diff --git a/src/ModularArithmetic/ExtendedBaseVector.v b/src/ModularArithmetic/ExtendedBaseVector.v deleted file mode 100644 index 2236461ce..000000000 --- a/src/ModularArithmetic/ExtendedBaseVector.v +++ /dev/null @@ -1,200 +0,0 @@ -Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith. -Require Import Coq.Lists.List. -Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.Tactics.VerdiTactics. -Require Import Crypto.ModularArithmetic.Pow2Base. -Require Import Crypto.ModularArithmetic.Pow2BaseProofs. -Require Import Crypto.BaseSystemProofs. -Require Crypto.BaseSystem. -Local Open Scope Z_scope. - -Section ExtendedBaseVector. - Context (limb_widths : list Z) - (limb_widths_nonnegative : forall x, In x limb_widths -> 0 <= x). - Local Notation k := (sum_firstn limb_widths (length limb_widths)). - Local Notation base := (base_from_limb_widths limb_widths). - - (* This section defines a new BaseVector that has double the length of the BaseVector - * used to construct [params]. The coefficients of the new vector are as follows: - * - * ext_base[i] = if (i < length base) then base[i] else 2^k * base[i] - * - * The purpose of this construction is that it allows us to multiply numbers expressed - * using [base], obtaining a number expressed using [ext_base]. (Numbers are "expressed" as - * vectors of digits; the value of a digit vector is obtained by doing a dot product with - * the base vector.) So if x, y are digit vectors: - * - * (x \dot base) * (y \dot base) = (z \dot ext_base) - * - * Then we can separate z into its first and second halves: - * - * (z \dot ext_base) = (z1 \dot base) + (2 ^ k) * (z2 \dot base) - * - * Now, if we want to reduce the product modulo 2 ^ k - c: - * - * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + (2 ^ k) * (z2 \dot base) mod (2^k-c) - * (z \dot ext_base) mod (2^k-c)= (z1 \dot base) + c * (z2 \dot base) mod (2^k-c) - * - * This sum may be short enough to express using base; if not, we can reduce again. - *) - Definition ext_limb_widths := limb_widths ++ limb_widths. - Definition ext_base := base_from_limb_widths ext_limb_widths. - Lemma ext_base_alt : ext_base = base ++ (map (Z.mul (2^k)) base). - Proof using Type*. - unfold ext_base, ext_limb_widths. - rewrite base_from_limb_widths_app by auto. - rewrite two_p_equiv. - reflexivity. - Qed. - - Lemma ext_base_positive : forall b, In b ext_base -> b > 0. - Proof using Type*. - apply base_positive; unfold ext_limb_widths. - intros ? H. apply in_app_or in H; destruct H; auto. - Qed. - - Lemma b0_1 : forall x, nth_default x base 0 = 1 -> nth_default x ext_base 0 = 1. - Proof using Type*. - intros. rewrite ext_base_alt, nth_default_app. - destruct base; assumption. - Qed. - - Lemma map_nth_default_base_high : forall n, (n < (length base))%nat -> - nth_default 0 (map (Z.mul (2 ^ k)) base) n = - (2 ^ k) * (nth_default 0 base n). - Proof using Type. - intros. - erewrite map_nth_default; auto. - Qed. - - Lemma ext_limb_widths_nonneg - (limb_widths_nonneg : forall w : Z, In w limb_widths -> 0 <= w) - : forall w : Z, In w ext_limb_widths -> 0 <= w. - Proof using Type*. - unfold ext_limb_widths; setoid_rewrite in_app_iff. - intros ? [?|?]; auto. - Qed. - - Lemma ext_limb_widths_upper_bound - : upper_bound ext_limb_widths = upper_bound limb_widths * upper_bound limb_widths. - Proof using Type*. - unfold ext_limb_widths. - autorewrite with push_upper_bound; reflexivity. - Qed. - - Section base_good. - Context (two_k_nonzero : 2^k <> 0) - (base_good : forall i j, (i+j < length base)%nat -> - let b := nth_default 0 base in - let r := (b i * b j) / b (i+j)%nat in - b i * b j = r * b (i+j)%nat) - (limb_widths_match_modulus : forall i j, - (i < length limb_widths)%nat -> - (j < length limb_widths)%nat -> - (i + j >= length limb_widths)%nat -> - let w_sum := sum_firstn limb_widths in - k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j). - - Lemma base_good_over_boundary - : forall (i : nat) - (l : (i < length base)%nat) - (j' : nat) - (Hj': (i + j' < length base)%nat), - 2 ^ k * (nth_default 0 base i * nth_default 0 base j') = - (2 ^ k * (nth_default 0 base i * nth_default 0 base j')) - / (2 ^ k * nth_default 0 base (i + j')) * - (2 ^ k * nth_default 0 base (i + j')). - Proof using base_good two_k_nonzero. - clear limb_widths_match_modulus. - intros. - remember (nth_default 0 base) as b. - rewrite Zdiv_mult_cancel_l by (exact two_k_nonzero). - replace (b i * b j' / b (i + j')%nat * (2 ^ k * b (i + j')%nat)) - with ((2 ^ k * (b (i + j')%nat * (b i * b j' / b (i + j')%nat)))) by ring. - rewrite Z.mul_cancel_l by (exact two_k_nonzero). - replace (b (i + j')%nat * (b i * b j' / b (i + j')%nat)) - with ((b i * b j' / b (i + j')%nat) * b (i + j')%nat) by ring. - subst b. - apply (base_good i j'); omega. - Qed. - - Lemma ext_base_good : - forall i j, (i+j < length ext_base)%nat -> - let b := nth_default 0 ext_base in - let r := (b i * b j) / b (i+j)%nat in - b i * b j = r * b (i+j)%nat. - Proof using Type*. - intros. - subst b. subst r. - rewrite ext_base_alt in *. - rewrite app_length in H; rewrite map_length in H. - repeat rewrite nth_default_app. - repeat break_if; try omega. - { (* i < length base, j < length base, i + j < length base *) - auto using BaseSystem.base_good. - } { (* i < length base, j < length base, i + j >= length base *) - rewrite (map_nth_default _ _ _ _ 0) by omega. - apply base_matches_modulus; auto using limb_widths_nonnegative, limb_widths_match_modulus; - distr_length. - assumption. - } { (* i < length base, j >= length base, i + j >= length base *) - do 2 rewrite map_nth_default_base_high by omega. - remember (j - length base)%nat as j'. - replace (i + j - length base)%nat with (i + j')%nat by omega. - replace (nth_default 0 base i * (2 ^ k * nth_default 0 base j')) - with (2 ^ k * (nth_default 0 base i * nth_default 0 base j')) - by ring. - eapply base_good_over_boundary; eauto; omega. - } { (* i >= length base, j < length base, i + j >= length base *) - do 2 rewrite map_nth_default_base_high by omega. - remember (i - length base)%nat as i'. - replace (i + j - length base)%nat with (j + i')%nat by omega. - replace (2 ^ k * nth_default 0 base i' * nth_default 0 base j) - with (2 ^ k * (nth_default 0 base j * nth_default 0 base i')) - by ring. - eapply base_good_over_boundary; eauto; omega. - } - Qed. - End base_good. - - Lemma extended_base_length: - length ext_base = (length base + length base)%nat. - Proof using Type. - clear limb_widths_nonnegative. - unfold ext_base, ext_limb_widths; autorewrite with distr_length; reflexivity. - Qed. - - Lemma firstn_us_base_ext_base : forall (us : BaseSystem.digits), - (length us <= length base)%nat - -> firstn (length us) base = firstn (length us) ext_base. - Proof using Type*. - rewrite ext_base_alt; intros. - rewrite firstn_app_inleft; auto; omega. - Qed. - - Lemma decode_short : forall (us : BaseSystem.digits), - (length us <= length base)%nat -> - BaseSystem.decode base us = BaseSystem.decode ext_base us. - Proof using Type*. auto using decode_short_initial, firstn_us_base_ext_base. Qed. - - Section BaseVector. - Context {bv : BaseSystem.BaseVector base} - (limb_widths_match_modulus : forall i j, - (i < length limb_widths)%nat -> - (j < length limb_widths)%nat -> - (i + j >= length limb_widths)%nat -> - let w_sum := sum_firstn limb_widths in - k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j). - - Instance ExtBaseVector : BaseSystem.BaseVector ext_base := - { base_positive := ext_base_positive; - b0_1 x := b0_1 x (BaseSystem.b0_1 _); - base_good := ext_base_good (two_sum_firstn_limb_widths_nonzero limb_widths_nonnegative _) BaseSystem.base_good limb_widths_match_modulus }. - End BaseVector. -End ExtendedBaseVector. - -Hint Rewrite @extended_base_length : distr_length. -Hint Resolve ext_limb_widths_nonneg : znonzero. -Hint Rewrite @ext_limb_widths_upper_bound using solve [ eauto with znonzero ] : push_upper_bound. -Hint Rewrite <- @ext_limb_widths_upper_bound using solve [ eauto with znonzero ] : pull_upper_bound. diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v deleted file mode 100644 index 0e09386f5..000000000 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ /dev/null @@ -1,124 +0,0 @@ -Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith. -Require Import Coq.Lists.List. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.BaseSystem. -Require Import Crypto.BaseSystemProofs. -Require Import Crypto.ModularArithmetic.ExtendedBaseVector. -Require Import Crypto.ModularArithmetic.Pow2Base. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. -Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations. -Require Import Crypto.ModularArithmetic.ModularBaseSystemList. -Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs. -Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. -Require Import Crypto.Util.Tuple. -Require Import Crypto.Util.AdditionChainExponentiation. -Require Import Crypto.Util.Notations. -Require Import Crypto.Tactics.VerdiTactics. -Local Open Scope Z_scope. - -Section ModularBaseSystem. - Context `{prm :PseudoMersenneBaseParams}. - Local Notation base := (base_from_limb_widths limb_widths). - Local Notation digits := (tuple Z (length limb_widths)). - Local Arguments to_list {_ _} _. - Local Arguments from_list {_ _} _ _. - Local Arguments length_to_list {_ _ _}. - Local Notation "[[ u ]]" := (to_list u). - - Definition decode (us : digits) : F modulus := decode [[us]]. - - Definition encode (x : F modulus) : digits := from_list (encode x) length_encode. - - Definition add (us vs : digits) : digits := from_list (add [[us]] [[vs]]) - (add_same_length _ _ _ length_to_list length_to_list). - - Definition mul (us vs : digits) : digits := from_list (mul [[us]] [[vs]]) - (length_mul length_to_list length_to_list). - - Definition sub (modulus_multiple: digits) - (modulus_multiple_correct : decode modulus_multiple = 0%F) - (us vs : digits) : digits := - from_list (sub [[modulus_multiple]] [[us]] [[vs]]) - (length_sub length_to_list length_to_list length_to_list). - - Definition zero : digits := encode (F.of_Z _ 0). - - Definition one : digits := encode (F.of_Z _ 1). - - Definition opp (modulus_multiple : digits) - (modulus_multiple_correct : decode modulus_multiple = 0%F) - (x : digits) : - digits := sub modulus_multiple modulus_multiple_correct zero x. - - Definition pow (x : digits) (chain : list (nat * nat)) : digits := - fold_chain one mul chain (x :: nil). - - Definition inv (chain : list (nat * nat)) - (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus - 2)) - (x : digits) : digits := pow x chain. - - (* Placeholder *) - Definition div (x y : digits) : digits := encode (F.div (decode x) (decode y)). - - Definition carry_ (carry_chain : list nat) (us : digits) : digits := - from_list (carry_sequence carry_chain [[us]]) (length_carry_sequence length_to_list). - - Definition carry_add (carry_chain : list nat) (us vs : digits) : digits := - carry_ carry_chain (add us vs). - Definition carry_mul (carry_chain : list nat) (us vs : digits) : digits := - carry_ carry_chain (mul us vs). - Definition carry_sub (carry_chain : list nat) (modulus_multiple: digits) - (modulus_multiple_correct : decode modulus_multiple = 0%F) - (us vs : digits) : digits := - carry_ carry_chain (sub modulus_multiple modulus_multiple_correct us vs). - Definition carry_opp (carry_chain : list nat) (modulus_multiple : digits) - (modulus_multiple_correct : decode modulus_multiple = 0%F) - (x : digits) : digits := - carry_sub carry_chain modulus_multiple modulus_multiple_correct zero x. - - Definition rep (us : digits) (x : F modulus) := decode us = x. - Local Notation "u ~= x" := (rep u x). - Local Hint Unfold rep. - - Definition eq (x y : digits) : Prop := decode x = decode y. - - Definition freeze int_width (x : digits) : digits := - from_list (freeze int_width [[x]]) (length_freeze length_to_list). - - Definition eqb int_width (x y : digits) : bool := fieldwiseb Z.eqb (freeze int_width x) (freeze int_width y). - - (* Note : both of the following square root definitions will produce garbage output if the input is - not square mod [modulus]. The caller should either provably only call them with square input, - or test that the output squared is in fact equal to the input and case split. *) - Definition sqrt_3mod4 (chain : list (nat * nat)) - (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus / 4 + 1)) - (x : digits) : digits := pow x chain. - - Definition sqrt_5mod8 int_width powx powx_squared (chain : list (nat * nat)) - (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus / 8 + 1)) - (sqrt_minus1 x : digits) : digits := - if eqb int_width powx_squared x then powx else mul sqrt_minus1 powx. - - Import Morphisms. - Global Instance eq_Equivalence : Equivalence eq. - Proof using Type. - split; cbv [eq]; repeat intro; congruence. - Qed. - - Definition select int_width (b : Z) (x y : digits) := - add (map (Z.land (neg int_width b)) x) - (map (Z.land (neg int_width (Z.lxor b 1))) x). - - Context {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x) - (bits_eq : sum_firstn limb_widths (length limb_widths) = - sum_firstn target_widths (length target_widths)). - Local Notation target_digits := (tuple Z (length target_widths)). - - Definition pack (x : digits) : target_digits := - from_list (pack target_widths_nonneg bits_eq [[x]]) length_pack. - - Definition unpack (x : target_digits) : digits := - from_list (unpack target_widths_nonneg bits_eq [[x]]) length_unpack. - -End ModularBaseSystem. diff --git a/src/ModularArithmetic/ModularBaseSystemList.v b/src/ModularArithmetic/ModularBaseSystemList.v deleted file mode 100644 index 8cce5481c..000000000 --- a/src/ModularArithmetic/ModularBaseSystemList.v +++ /dev/null @@ -1,90 +0,0 @@ -Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith. -Require Import Coq.Numbers.Natural.Peano.NPeano. -Require Import Coq.Lists.List. -Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.BaseSystem. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. -Require Import Crypto.ModularArithmetic.ExtendedBaseVector. -Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations. -Require Import Crypto.Tactics.VerdiTactics. -Require Import Crypto.Util.LetIn. -Require Import Crypto.Util.Notations. -Require Import Crypto.ModularArithmetic.Pow2Base. -Require Import Crypto.ModularArithmetic.Conversion. -Local Open Scope Z_scope. - -Section Defs. - Context `{prm :PseudoMersenneBaseParams} (modulus_multiple : digits). - Local Notation base := (base_from_limb_widths limb_widths). - Local Notation "u [ i ]" := (nth_default 0 u i). - - Definition decode (us : digits) := F.of_Z modulus (BaseSystem.decode base us). - - Definition encode (x : F modulus) := encodeZ limb_widths (F.to_Z x). - - (* Converts from length of extended base to length of base by reduction modulo M.*) - Definition reduce (us : digits) : digits := - let high := skipn (length limb_widths) us in - let low := firstn (length limb_widths) us in - let wrap := map (Z.mul c) high in - BaseSystem.add low wrap. - - Definition mul (us vs : digits) := reduce (BaseSystem.mul (ext_base limb_widths) us vs). - - (* In order to subtract without underflowing, we add a multiple of the modulus first. *) - Definition sub (us vs : digits) := BaseSystem.sub (add modulus_multiple us) vs. - - (* [carry_and_reduce] multiplies the carried value by c, and, if carrying - from index [i] in a list [us], adds the value to the digit with index - [(S i) mod (length us)] *) - Definition carry_and_reduce := - carry_gen limb_widths (fun ci => c * ci) (fun Si => (Si mod (length limb_widths))%nat). - - Definition carry i : digits -> digits := - if eq_nat_dec i (pred (length limb_widths)) - then carry_and_reduce i - else carry_simple limb_widths i. - - Definition carry_sequence is (us : digits) : digits := fold_right carry us is. - - Definition carry_full : digits -> digits := carry_sequence (full_carry_chain limb_widths). - - Definition modulus_digits := encodeZ limb_widths modulus. - - (* Constant-time comparison with modulus; only works if all digits of [us] - are less than 2 ^ their respective limb width. *) - Fixpoint ge_modulus' {A} (f : Z -> A) us (result : Z) i := - dlet r := result in - match i return A with - | O => - dlet x := (cmovl (modulus_digits [0]) (us [0]) r 0) in f x - | S i' => - ge_modulus' f us (cmovne (modulus_digits [i]) (us [i]) r 0) i' - end. - - Definition ge_modulus us := ge_modulus' id us 1 (length limb_widths - 1)%nat. - - Definition conditional_subtract_modulus int_width (us : digits) (cond : Z) := - (* [and_term] is all ones if us' is full, so the subtractions subtract q overall. - Otherwise, it's all zeroes, and the subtractions do nothing. *) - map2 (fun x y => x - y) us (map (Z.land (neg int_width cond)) modulus_digits). - - Definition freeze int_width (us : digits) : digits := - let us' := carry_full (carry_full (carry_full us)) in - conditional_subtract_modulus int_width us' (ge_modulus us'). - - Context {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x) - (bits_eq : sum_firstn limb_widths (length limb_widths) = - sum_firstn target_widths (length target_widths)). - - Definition pack := @convert limb_widths limb_widths_nonneg - target_widths target_widths_nonneg - (Z.eq_le_incl _ _ bits_eq). - - Definition unpack := @convert target_widths target_widths_nonneg - limb_widths limb_widths_nonneg - (Z.eq_le_incl _ _ (Z.eq_sym bits_eq)). - -End Defs. diff --git a/src/ModularArithmetic/ModularBaseSystemListProofs.v b/src/ModularArithmetic/ModularBaseSystemListProofs.v deleted file mode 100644 index 8d749dfdd..000000000 --- a/src/ModularArithmetic/ModularBaseSystemListProofs.v +++ /dev/null @@ -1,539 +0,0 @@ -Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith. -Require Import Coq.Numbers.Natural.Peano.NPeano. -Require Import Coq.Lists.List. -Require Import Crypto.Tactics.VerdiTactics. -Require Import Crypto.BaseSystem. -Require Import Crypto.BaseSystemProofs. -Require Import Crypto.ModularArithmetic.Conversion. -Require Import Crypto.ModularArithmetic.Pow2Base. -Require Import Crypto.ModularArithmetic.Pow2BaseProofs. -Require Import Crypto.ModularArithmetic.ExtendedBaseVector. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. -Require Import Crypto.Util.Tactics.UniquePose. -Require Import Crypto.Util.Tactics.SpecializeBy. -Require Import Crypto.Util.ListUtil. -Require Import Crypto.Util.ZUtil. -Require Import Crypto.Util.Notations. - -Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations. -Require Import Crypto.ModularArithmetic.ModularBaseSystemList. -Local Open Scope Z_scope. - -Section LengthProofs. - Context `{prm :PseudoMersenneBaseParams}. - Local Notation base := (base_from_limb_widths limb_widths). - - Lemma length_encode {x} : length (encode x) = length limb_widths. - Proof using Type. - cbv [encode encodeZ]; intros. - rewrite encode'_spec; - auto using encode'_length, limb_widths_nonneg, Nat.eq_le_incl, base_from_limb_widths_length. - Qed. - - Lemma length_reduce : forall us, - (length limb_widths <= length us <= length (ext_base limb_widths))%nat -> - (length (reduce us) = length limb_widths)%nat. - Proof using Type. - rewrite extended_base_length. - unfold reduce; intros. - rewrite add_length_exact. - pose proof (@base_from_limb_widths_length limb_widths). - rewrite map_length, firstn_length, skipn_length, Min.min_l, Max.max_l; - omega. - Qed. - - Lemma length_mul {u v} : - length u = length limb_widths - -> length v = length limb_widths - -> length (mul u v) = length limb_widths. - Proof using Type. - cbv [mul]; intros. - apply length_reduce. - destruct u; try congruence. - + rewrite @nil_length0 in *; omega. - + rewrite mul_length_exact, extended_base_length, base_from_limb_widths_length; try omega; - repeat match goal with - | |- _ => progress intros - | |- nth_default _ (ext_base _) 0 = 1 => apply b0_1 - | x := nth_default _ (ext_base _) |- _ => apply ext_base_good - | x := nth_default _ base |- _ => apply base_good - | x := nth_default _ base |- _ => apply limb_widths_good - | |- 2 ^ _ <> 0 => apply Z.pow_nonzero - | |- _ => solve [apply BaseSystem.b0_1] - | |- _ => solve [auto using limb_widths_nonneg, sum_firstn_limb_widths_nonneg, limb_widths_match_modulus] - | |- _ => omega - | |- _ => congruence - end. - Qed. - - Section Sub. - Context {mm : list Z} (mm_length : length mm = length limb_widths). - - Lemma length_sub {u v} : - length u = length limb_widths - -> length v = length limb_widths - -> length (sub mm u v) = length limb_widths. - Proof using Type*. - cbv [sub]; intros. - rewrite sub_length, add_length_exact. - repeat rewrite Max.max_r; omega. - Qed. - End Sub. - - Lemma length_carry_and_reduce {us}: forall i, length (carry_and_reduce i us) = length us. - Proof using Type. intros; unfold carry_and_reduce; autorewrite with distr_length; reflexivity. Qed. - Hint Rewrite @length_carry_and_reduce : distr_length. - - Lemma length_carry {u i} : - length u = length limb_widths - -> length (carry i u) = length limb_widths. - Proof using Type. intros; unfold carry; break_if; autorewrite with distr_length; omega. Qed. - Hint Rewrite @length_carry : distr_length. - - Lemma length_carry_sequence {u i} : - length u = length limb_widths - -> length (carry_sequence i u) = length limb_widths. - Proof using Type. - induction i; intros; unfold carry_sequence; - simpl; autorewrite with distr_length; auto. Qed. - Hint Rewrite @length_carry_sequence : distr_length. - - Lemma length_carry_full {u} : - length u = length limb_widths - -> length (carry_full u) = length limb_widths. - Proof using Type. intros; unfold carry_full; autorewrite with distr_length; congruence. Qed. - Hint Rewrite @length_carry_full : distr_length. - - Lemma length_modulus_digits : length modulus_digits = length limb_widths. - Proof using Type. - intros; unfold modulus_digits, encodeZ. - rewrite encode'_spec, encode'_length; - auto using encode'_length, limb_widths_nonneg, Nat.eq_le_incl, base_from_limb_widths_length. - Qed. - Hint Rewrite @length_modulus_digits : distr_length. - - Lemma length_conditional_subtract_modulus {int_width u cond} : - length u = length limb_widths - -> length (conditional_subtract_modulus int_width u cond) = length limb_widths. - Proof using Type. - intros; unfold conditional_subtract_modulus. - rewrite map2_length, map_length, length_modulus_digits. - apply Min.min_case; omega. - Qed. - Hint Rewrite @length_conditional_subtract_modulus : distr_length. - - Lemma length_freeze {int_width u} : - length u = length limb_widths - -> length (freeze int_width u) = length limb_widths. - Proof using Type. - intros; unfold freeze; repeat autorewrite with distr_length; congruence. - Qed. - - Lemma length_pack : forall {target_widths} - {target_widths_nonneg : forall x, In x target_widths -> 0 <= x} - {pf us}, - length (pack target_widths_nonneg pf us) = length target_widths. - Proof using Type. - cbv [pack]; intros. - apply length_convert. - Qed. - - Lemma length_unpack : forall {target_widths} - {target_widths_nonneg : forall x, In x target_widths -> 0 <= x} - {pf us}, - length (unpack target_widths_nonneg pf us) = length limb_widths. - Proof using Type. - cbv [pack]; intros. - apply length_convert. - Qed. - -End LengthProofs. - -Section ModulusDigitsProofs. - Context `{prm :PseudoMersenneBaseParams} - (c_upper_bound : c - 1 < 2 ^ nth_default 0 limb_widths 0). - Local Notation base := (base_from_limb_widths limb_widths). - Local Hint Resolve sum_firstn_limb_widths_nonneg. - Local Hint Resolve limb_widths_nonneg. - - Lemma decode_modulus_digits : decode' base modulus_digits = modulus. - Proof using Type. - cbv [modulus_digits]. - pose proof c_pos. pose proof modulus_pos. - rewrite encodeZ_spec by eauto using limb_widths_nonnil, limb_widths_good. - apply Z.mod_small. - cbv [upper_bound]. fold k. - assert (Z.pos modulus = 2 ^ k - c) by (cbv [c]; ring). - omega. - Qed. - - Lemma bounded_modulus_digits : bounded limb_widths modulus_digits. - Proof using Type. - apply bounded_encodeZ; auto using limb_widths_nonneg. - pose proof modulus_pos; omega. - Qed. - - Lemma modulus_digits_ones : forall i, (0 < i < length limb_widths)%nat -> - nth_default 0 modulus_digits i = Z.ones (nth_default 0 limb_widths i). - Proof using Type*. - repeat match goal with - | |- _ => progress (cbv [BaseSystem.decode]; intros) - | |- _ => progress autorewrite with Ztestbit - | |- _ => unique pose proof c_pos - | |- _ => unique pose proof modulus_pos - | |- _ => unique assert (Z.pos modulus = 2 ^ k - c) by (cbv [c]; ring) - | |- _ => break_if - | |- _ => rewrite decode_modulus_digits - | |- _ => rewrite Z.testbit_pow2_mod - by eauto using nth_default_limb_widths_nonneg - | |- _ => rewrite Z.ones_spec by eauto using nth_default_limb_widths_nonneg - | |- _ => erewrite digit_select - by (eauto; apply bounded_encodeZ; eauto; omega) - | |- Z.testbit (2 ^ k - c) _ = _ => - rewrite Z.testbit_sub_pow2 by (try omega; cbv [k]; - pose proof (sum_firstn_prefix_le limb_widths (S i) (length limb_widths)); - specialize_by (eauto || omega); - rewrite sum_firstn_succ_default in *; split; zero_bounds; eauto) - | |- Z.pow2_mod _ _ = Z.ones _ => apply Z.bits_inj' - | |- Z.testbit (Z.pos modulus) ?i = true => transitivity (Z.testbit (2 ^ k - c) i) - | |- _ => congruence - end. - - replace (c - 1) with ((c - 1) mod 2 ^ nth_default 0 limb_widths 0) by (apply Z.mod_small; omega). - rewrite Z.mod_pow2_bits_high; auto. - pose proof (sum_firstn_prefix_le limb_widths 1 i). - specialize_by (eauto || omega). - rewrite !sum_firstn_succ_default, !sum_firstn_0 in *. - split; zero_bounds; eauto using nth_default_limb_widths_nonneg. - Qed. - - Lemma bounded_le_modulus_digits : forall us i, length us = length limb_widths -> - bounded limb_widths us -> (0 < i < length limb_widths)%nat -> - nth_default 0 us i <= nth_default 0 modulus_digits i. - Proof using Type*. - intros until 0; rewrite bounded_iff; intros. - rewrite modulus_digits_ones by omega. - specialize (H0 i). - rewrite Z.ones_equiv. - omega. - Qed. - -End ModulusDigitsProofs. - -Section ModulusComparisonProofs. - Context `{prm :PseudoMersenneBaseParams} - (c_upper_bound : c - 1 < 2 ^ nth_default 0 limb_widths 0). - Local Notation base := (base_from_limb_widths limb_widths). - Local Hint Resolve sum_firstn_limb_widths_nonneg. - Local Hint Resolve limb_widths_nonneg. - - Fixpoint compare' us vs i := - match i with - | O => Eq - | S i' => if Z_eq_dec (nth_default 0 us i') (nth_default 0 vs i') - then compare' us vs i' - else Z.compare (nth_default 0 us i') (nth_default 0 vs i') - end. - - (* Lexicographically compare two vectors of equal length, starting from the END of the list - (in our context, this is the most significant end). NOT constant time. *) - Definition compare us vs := compare' us vs (length us). - - Lemma decode_firstn_compare' : forall us vs i, - (i <= length limb_widths)%nat -> - length us = length limb_widths -> bounded limb_widths us -> - length vs = length limb_widths -> bounded limb_widths vs -> - (Z.compare (decode' base (firstn i us)) (decode' base (firstn i vs)) - = compare' us vs i). - Proof using Type. - induction i; - repeat match goal with - | |- _ => progress intros - | |- _ => progress (simpl compare') - | |- _ => progress specialize_by (assumption || omega) - | |- _ => rewrite sum_firstn_0 - | |- _ => rewrite set_higher - | |- _ => rewrite nth_default_base by eauto - | |- _ => rewrite firstn_length, Min.min_l by omega - | |- _ => rewrite firstn_O - | |- _ => rewrite firstn_succ with (d := 0) by omega - | |- _ => rewrite Z.compare_add_shiftl by - (eauto || (rewrite decode_firstn_pow2_mod, Z.pow2_mod_pow2_mod, Z.min_id by - (eauto || omega); reflexivity)) - | |- appcontext[2 ^ ?x * ?y] => replace (2 ^ x * y) with (y << x) by - (rewrite (Z.mul_comm (2 ^ x)); apply Z.shiftl_mul_pow2; eauto) - | |- _ => tauto - | |- _ => split - | |- _ => break_if - end. - Qed. - - Lemma decode_compare' : forall us vs, - length us = length limb_widths -> bounded limb_widths us -> - length vs = length limb_widths -> bounded limb_widths vs -> - (Z.compare (decode' base us) (decode' base vs) - = compare' us vs (length limb_widths)). - Proof using Type. - intros. - rewrite <-decode_firstn_compare' by (auto || omega). - rewrite !firstn_all by auto. - reflexivity. - Qed. - - Lemma ge_modulus'_0 : forall {A} f us i, - ge_modulus' (A := A) f us 0 i = f 0. - Proof using Type. - induction i; intros; simpl; cbv [cmovne cmovl]; break_if; auto. - Qed. - - Lemma ge_modulus'_01 : forall {A} f us i b, - (b = 0 \/ b = 1) -> - (ge_modulus' (A := A) f us b i = f 0 \/ ge_modulus' (A := A) f us b i = f 1). - Proof using Type. - induction i; intros; - try intuition (subst; cbv [ge_modulus' LetIn.Let_In cmovl cmovne]; break_if; tauto). - simpl; cbv [LetIn.Let_In cmovl cmovne]. - break_if; apply IHi; tauto. - Qed. - - Lemma ge_modulus_01 : forall us, - (ge_modulus us = 0 \/ ge_modulus us = 1). - Proof using Type. - cbv [ge_modulus]; intros; apply ge_modulus'_01; tauto. - Qed. - - Lemma ge_modulus'_true_digitwise : forall us, - length us = length limb_widths -> - forall i, (i < length us)%nat -> ge_modulus' id us 1 i = 1 -> - forall j, (j <= i)%nat -> - nth_default 0 modulus_digits j <= nth_default 0 us j. - Proof using Type. - induction i; - repeat match goal with - | |- _ => progress intros; simpl in * - | |- _ => progress cbv [LetIn.Let_In cmovne cmovl] in * - | |- _ =>erewrite (ge_modulus'_0 (@id Z)) in * - | H : (?x <= 0)%nat |- _ => progress replace x with 0%nat in * by omega - | |- _ => break_if - | |- _ => discriminate - | |- _ => solve [rewrite ?Z.leb_le, ?Z.eqb_eq in *; omega] - end. - destruct (le_dec j i). - + apply IHi; auto; omega. - + replace j with (S i) in * by omega; rewrite Z.eqb_eq in *; try omega. - Qed. - - Lemma ge_modulus'_compare' : forall us, length us = length limb_widths -> bounded limb_widths us -> - forall i, (i < length limb_widths)%nat -> - (ge_modulus' id us 1 i = 0 <-> compare' us modulus_digits (S i) = Lt). - Proof using Type*. - induction i; - repeat match goal with - | |- _ => progress (intros; cbv [LetIn.Let_In id cmovne cmovl]) - | |- _ => progress (simpl compare' in * ) - | |- _ => progress specialize_by omega - | |- _ => (progress rewrite ?Z.compare_eq_iff, - ?Z.compare_gt_iff, ?Z.compare_lt_iff in * ) - | |- appcontext[ge_modulus' _ _ _ 0] => - cbv [ge_modulus'] - | |- appcontext[ge_modulus' _ _ _ (S _)] => - unfold ge_modulus'; fold (ge_modulus' (@id Z)) - | |- _ => break_if - | |- _ => rewrite Nat.sub_0_r - | |- _ => rewrite (ge_modulus'_0 (@id Z)) - | |- _ => rewrite Bool.andb_true_r - | |- _ => rewrite Z.leb_compare; break_match - | |- _ => rewrite Z.eqb_compare; break_match - | |- _ => (rewrite Z.leb_le in * ) - | |- _ => (rewrite Z.leb_gt in * ) - | |- _ => (rewrite Z.eqb_eq in * ) - | |- _ => (rewrite Z.eqb_neq in * ) - | |- _ => split; (congruence || omega) - | |- _ => assumption - end; - pose proof (bounded_le_modulus_digits c_upper_bound us (S i)); - specialize_by (auto || omega); split; (congruence || omega). - Qed. - - Lemma ge_modulus_spec : forall u, length u = length limb_widths -> - bounded limb_widths u -> - (ge_modulus u = 0 <-> 0 <= BaseSystem.decode base u < modulus). - Proof using Type*. - cbv [ge_modulus]; intros. - assert (0 < length limb_widths)%nat - by (pose proof limb_widths_nonnil; destruct limb_widths; - distr_length; omega || congruence). - rewrite ge_modulus'_compare' by (auto || omega). - replace (S (length limb_widths - 1)) with (length limb_widths) by omega. - rewrite <-decode_compare' - by (try (apply length_modulus_digits || apply bounded_encodeZ); eauto; - pose proof modulus_pos; omega). - rewrite Z.compare_lt_iff. - rewrite decode_modulus_digits. - repeat (split; intros; eauto using decode_nonneg). - cbv [BaseSystem.decode] in *. omega. - Qed. - -End ModulusComparisonProofs. - -Section ConditionalSubtractModulusProofs. - Context `{prm :PseudoMersenneBaseParams} - (* B is machine integer width (e.g. 32, 64) *) - {B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B) - (c_upper_bound : c - 1 < 2 ^ nth_default 0 limb_widths 0) - (lt_1_length_limb_widths : (1 < length limb_widths)%nat). - Local Notation base := (base_from_limb_widths limb_widths). - Local Hint Resolve sum_firstn_limb_widths_nonneg. - Local Hint Resolve limb_widths_nonneg. - Local Hint Resolve length_modulus_digits. - - Lemma map2_sub_eq : forall us vs, length us = length vs -> - map2 (fun x y => x - y) us vs = BaseSystem.sub us vs. - Proof using lt_1_length_limb_widths. - induction us; destruct vs; boring; try omega. - Qed. - - (* TODO : ListUtil *) - Lemma map_id_strong : forall {A} f (xs : list A), - (forall x, In x xs -> f x = x) -> map f xs = xs. - Proof using Type. - induction xs; intros; auto. - simpl; f_equal; auto using in_eq, in_cons. - Qed. - - Lemma bounded_digit_fits : forall us, - length us = length limb_widths -> bounded limb_widths us -> - forall x, In x us -> 0 <= x < 2 ^ B. - Proof using B_compat c_upper_bound lt_1_length_limb_widths. - intros. - let i := fresh "i" in - match goal with H : In ?x ?us, Hb : bounded _ _ |- _ => - apply In_nth with (d := 0) in H; destruct H as [i [? ?] ]; - rewrite bounded_iff in Hb; specialize (Hb i); - assert (2 ^ nth i limb_widths 0 <= 2 ^ B) by - (apply Z.pow_le_mono_r; try apply B_compat, nth_In; omega) end. - rewrite !nth_default_eq in *. - omega. - Qed. - - Lemma map_land_max_ones : forall us, - length us = length limb_widths -> - bounded limb_widths us -> map (Z.land (Z.ones B)) us = us. - Proof using Type*. - repeat match goal with - | |- _ => progress intros - | |- _ => apply map_id_strong - | |- appcontext[Z.ones ?n &' ?x] => rewrite (Z.land_comm _ x); - rewrite Z.land_ones by omega - | |- _ => apply Z.mod_small - | |- _ => solve [eauto using bounded_digit_fits] - end. - Qed. - - Lemma map_land_zero : forall us, map (Z.land 0) us = zeros (length us). - Proof using Type. - induction us; boring. - Qed. - - Hint Rewrite @length_modulus_digits @length_zeros : distr_length. - Lemma conditional_subtract_modulus_spec : forall u cond - (cond_01 : cond = 0 \/ cond = 1), - length u = length limb_widths -> - BaseSystem.decode base (conditional_subtract_modulus B u cond) = - BaseSystem.decode base u - cond * modulus. - Proof using Type*. - repeat match goal with - | |- _ => progress (cbv [conditional_subtract_modulus neg]; intros) - | |- _ => destruct cond_01; subst - | |- _ => break_if - | |- _ => rewrite map_land_max_ones by auto using bounded_modulus_digits - | |- _ => rewrite map_land_zero - | |- _ => rewrite map2_sub_eq by distr_length - | |- _ => rewrite sub_rep by auto - | |- _ => rewrite zeros_rep - | |- _ => rewrite decode_modulus_digits by auto - | |- _ => f_equal; ring - | |- _ => discriminate - end. - Qed. - - Lemma conditional_subtract_modulus_preserves_bounded : forall u, - length u = length limb_widths -> - bounded limb_widths u -> - bounded limb_widths (conditional_subtract_modulus B u (ge_modulus u)). - Proof using Type*. - repeat match goal with - | |- _ => progress (cbv [conditional_subtract_modulus neg]; intros) - | |- _ => unique pose proof bounded_modulus_digits - | |- _ => rewrite map_land_max_ones by auto using bounded_modulus_digits - | |- _ => rewrite map_land_zero - | |- _ => rewrite length_modulus_digits in * - | |- _ => rewrite length_zeros in * - | |- _ => rewrite Min.min_l in * by omega - | |- _ => rewrite nth_default_zeros - | |- _ => rewrite nth_default_map2 with (d1 := 0) (d2 := 0) - | |- _ => break_if - | |- bounded _ _ => apply bounded_iff - | |- 0 <= 0 < _ => split; zero_bounds; eauto using nth_default_limb_widths_nonneg - end; - repeat match goal with - | H : bounded _ ?x |- appcontext[nth_default 0 ?x ?i] => - rewrite bounded_iff in H; specialize (H i) - | |- _ => omega - end. - cbv [ge_modulus] in Heqb. - rewrite Z.eqb_eq in *. - apply ge_modulus'_true_digitwise with (j := i) in Heqb; auto; omega. - Qed. - - Lemma bounded_mul2_modulus : forall u, length u = length limb_widths -> - bounded limb_widths u -> ge_modulus u = 1 -> - modulus <= BaseSystem.decode base u < 2 * modulus. - Proof using c_upper_bound lt_1_length_limb_widths. - intros. - pose proof (@decode_upper_bound _ limb_widths_nonneg u). - specialize_by auto. - cbv [upper_bound] in *. - fold k in *. - assert (Z.pos modulus = 2 ^ k - c) by (cbv [c]; ring). - destruct (Z_le_dec modulus (BaseSystem.decode base u)). - + split; try omega. - apply Z.lt_le_trans with (m := 2 ^ k); try omega. - assert (2 * c <= 2 ^ k); try omega. - transitivity (2 ^ (nth_default 0 limb_widths 0 + 1)); - try (rewrite Z.pow_add_r, ?Z.pow_1_r; - eauto using nth_default_limb_widths_nonneg; omega). - apply Z.pow_le_mono_r; try omega. - unfold k. - pose proof (sum_firstn_prefix_le limb_widths 2 (length limb_widths)). - specialize_by (eauto || omega). - etransitivity; try eassumption. - rewrite !sum_firstn_succ_default, sum_firstn_0. - assert (0 < nth_default 0 limb_widths 1); try omega. - apply limb_widths_pos. - rewrite nth_default_eq. - apply nth_In. - omega. - + assert (0 <= BaseSystem.decode base u < modulus) as Hlt_modulus by omega. - apply ge_modulus_spec in Hlt_modulus; auto. - congruence. - Qed. - - Lemma conditional_subtract_lt_modulus : forall u, - length u = length limb_widths -> - bounded limb_widths u -> - ge_modulus (conditional_subtract_modulus B u (ge_modulus u)) = 0. - Proof using Type*. - intros. - rewrite ge_modulus_spec by auto using length_conditional_subtract_modulus, conditional_subtract_modulus_preserves_bounded. - pose proof (ge_modulus_01 u) as Hgm01. - rewrite conditional_subtract_modulus_spec by auto. - destruct Hgm01 as [Hgm0 | Hgm1]; rewrite ?Hgm0, ?Hgm1. - + apply ge_modulus_spec in Hgm0; auto. - omega. - + pose proof (bounded_mul2_modulus u); specialize_by auto. - omega. - Qed. -End ConditionalSubtractModulusProofs. diff --git a/src/ModularArithmetic/ModularBaseSystemListZOperations.v b/src/ModularArithmetic/ModularBaseSystemListZOperations.v deleted file mode 100644 index 5b39f1066..000000000 --- a/src/ModularArithmetic/ModularBaseSystemListZOperations.v +++ /dev/null @@ -1,60 +0,0 @@ -(** * Definitions of some basic operations on ℤ used in ModularBaseSystemList *) -(** We separate these out so that we can depend on them in other files - without waiting for ModularBaseSystemList to build. *) -Require Import Coq.ZArith.ZArith. -Require Import Bedrock.Word. -Require Import Crypto.Util.FixedWordSizes. -Require Import Crypto.Util.Tuple. - -Definition cmovl (x y r1 r2 : Z) := if Z.leb x y then r1 else r2. -Definition cmovne (x y r1 r2 : Z) := if Z.eqb x y then r1 else r2. - -(* analagous to NEG assembly instruction on an integer that is 0 or 1: - neg 1 = 2^64 - 1 (on 64-bit; 2^32-1 on 32-bit, etc.) - neg 0 = 0 *) -Definition neg (int_width : Z) (b : Z) := if Z.eqb b 1 then Z.ones int_width else 0%Z. - -Definition wcmovl_gen {sz} x y r1 r2 - := @ZToWord_gen sz (cmovl (@wordToZ_gen sz x) (@wordToZ_gen sz y) (@wordToZ_gen sz r1) (@wordToZ_gen sz r2)). -Definition wcmovne_gen {sz} x y r1 r2 - := @ZToWord_gen sz (cmovne (@wordToZ_gen sz x) (@wordToZ_gen sz y) (@wordToZ_gen sz r1) (@wordToZ_gen sz r2)). -Definition wneg_gen {sz} (int_width : Z) b - := @ZToWord_gen sz (neg int_width (@wordToZ_gen sz b)). - -Definition wcmovl32 x y r1 r2 := ZToWord32 (cmovl (word32ToZ x) (word32ToZ y) (word32ToZ r1) (word32ToZ r2)). -Definition wcmovne32 x y r1 r2 := ZToWord32 (cmovne (word32ToZ x) (word32ToZ y) (word32ToZ r1) (word32ToZ r2)). -Definition wneg32 (int_width : Z) b := ZToWord32 (neg int_width (word32ToZ b)). - -Definition wcmovl64 x y r1 r2 := ZToWord64 (cmovl (word64ToZ x) (word64ToZ y) (word64ToZ r1) (word64ToZ r2)). -Definition wcmovne64 x y r1 r2 := ZToWord64 (cmovne (word64ToZ x) (word64ToZ y) (word64ToZ r1) (word64ToZ r2)). -Definition wneg64 (int_width : Z) b := ZToWord64 (neg int_width (word64ToZ b)). - -Definition wcmovl128 x y r1 r2 := ZToWord128 (cmovl (word128ToZ x) (word128ToZ y) (word128ToZ r1) (word128ToZ r2)). -Definition wcmovne128 x y r1 r2 := ZToWord128 (cmovne (word128ToZ x) (word128ToZ y) (word128ToZ r1) (word128ToZ r2)). -Definition wneg128 (int_width : Z) b := ZToWord128 (neg int_width (word128ToZ b)). - -Definition wcmovl {logsz} - := word_case_dep (T:=fun _ word => word -> word -> word -> word -> word) - logsz wcmovl32 wcmovl64 wcmovl128 (fun _ => @wcmovl_gen _). -Definition wcmovne {logsz} - := word_case_dep (T:=fun _ word => word -> word -> word -> word -> word) - logsz wcmovne32 wcmovne64 wcmovne128 (fun _ => @wcmovne_gen _). -Definition wneg {logsz} - := word_case_dep (T:=fun _ word => Z -> word -> word) - logsz wneg32 wneg64 wneg128 (fun _ => @wneg_gen _). - -Hint Unfold wcmovl wcmovne wneg : fixed_size_constants. - -(** After unfolding [wneg], [wcmovl], [wcmovne], this tactic adjusts - the unfolded form to allow processing by - [FixedWordSizesEquality.fixed_size_op_to_word] *) -Ltac adjust_mbs_wops := - change wcmovl32 with (@wcmovl_gen 32) in *; - change wcmovl64 with (@wcmovl_gen 64) in *; - change wcmovl128 with (@wcmovl_gen 128) in *; - change wcmovne32 with (@wcmovne_gen 32) in *; - change wcmovne64 with (@wcmovne_gen 64) in *; - change wcmovne128 with (@wcmovne_gen 128) in *; - change wneg32 with (@wneg_gen 32) in *; - change wneg64 with (@wneg_gen 64) in *; - change wneg128 with (@wneg_gen 128) in *. diff --git a/src/ModularArithmetic/ModularBaseSystemListZOperationsProofs.v b/src/ModularArithmetic/ModularBaseSystemListZOperationsProofs.v deleted file mode 100644 index eb310f0f8..000000000 --- a/src/ModularArithmetic/ModularBaseSystemListZOperationsProofs.v +++ /dev/null @@ -1,29 +0,0 @@ -Require Import Coq.ZArith.ZArith. -Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations. -Require Import Crypto.Util.Tuple. -Require Import Crypto.Util.ZUtil. -Require Import Crypto.Util.Tactics.BreakMatch. - -Local Open Scope Z_scope. - -Lemma neg_nonneg : forall x y, 0 <= x -> 0 <= ModularBaseSystemListZOperations.neg x y. -Proof. - unfold neg; intros; break_match; auto with zarith. -Qed. -Hint Resolve neg_nonneg : zarith. - -Lemma neg_upperbound : forall x y, 0 <= x -> ModularBaseSystemListZOperations.neg x y <= Z.ones x. -Proof. - unfold neg; intros; break_match; auto with zarith. -Qed. -Hint Resolve neg_upperbound : zarith. - -Lemma neg_range : forall x y, 0 <= x -> - 0 <= neg x y < 2 ^ x. -Proof. - intros. - split; auto using neg_nonneg. - eapply Z.le_lt_trans; eauto using neg_upperbound. - rewrite Z.ones_equiv. - omega. -Qed. diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v deleted file mode 100644 index 0a240568b..000000000 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ /dev/null @@ -1,1094 +0,0 @@ -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. -Require Import Crypto.ModularArithmetic.ExtendedBaseVector. -Require Import Crypto.ModularArithmetic.Conversion. -Require Import Crypto.ModularArithmetic.Pow2Base. -Require Import Crypto.ModularArithmetic.Pow2BaseProofs. -Require Import Crypto.BaseSystem. -Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations. -Require Import Crypto.ModularArithmetic.ModularBaseSystemList. -Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs. -Require Import Crypto.ModularArithmetic.ModularBaseSystem. -Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. -Require Import Coq.Lists.List. -Require Import Crypto.Util.Tuple. -Require Import Crypto.Util.LetIn. -Require Import Crypto.Util.AdditionChainExponentiation. -Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil Crypto.Util.CaseUtil. -Import ListNotations. -Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. -Require Import Coq.Numbers.Natural.Peano.NPeano. -Require Import Coq.QArith.QArith Coq.QArith.Qround. -Require Import Crypto.Tactics.VerdiTactics. -Require Export Crypto.Util.FixCoqMistakes. -Local Open Scope Z. - -(* Computed versions of some functions. *) - -Definition plus_opt := Eval compute in plus. - -Definition Z_add_opt := Eval compute in Z.add. -Definition Z_sub_opt := Eval compute in Z.sub. -Definition Z_mul_opt := Eval compute in Z.mul. -Definition Z_div_opt := Eval compute in Z.div. -Definition Z_pow_opt := Eval compute in Z.pow. -Definition Z_opp_opt := Eval compute in Z.opp. -Definition Z_min_opt := Eval compute in Z.min. -Definition Z_ones_opt := Eval compute in Z.ones. -Definition Z_of_nat_opt := Eval compute in Z.of_nat. -Definition Z_le_dec_opt := Eval compute in Z_le_dec. -Definition Z_lt_dec_opt := Eval compute in Z_lt_dec. -Definition Z_shiftl_opt := Eval compute in Z.shiftl. -Definition Z_shiftl_by_opt := Eval compute in Z.shiftl_by. - -Definition nth_default_opt {A} := Eval compute in @nth_default A. -Definition set_nth_opt {A} := Eval compute in @set_nth A. -Definition update_nth_opt {A} := Eval compute in @update_nth A. -Definition map_opt {A B} := Eval compute in @List.map A B. -Definition full_carry_chain_opt := Eval compute in @Pow2Base.full_carry_chain. -Definition length_opt := Eval compute in length. -Definition base_from_limb_widths_opt := Eval compute in @Pow2Base.base_from_limb_widths. -Definition minus_opt := Eval compute in minus. -Definition from_list_default_opt {A} := Eval compute in (@from_list_default A). -Definition sum_firstn_opt {A} := Eval compute in (@sum_firstn A). -Definition zeros_opt := Eval compute in (@zeros). -Definition bit_index_opt := Eval compute in bit_index. -Definition digit_index_opt := Eval compute in digit_index. - -(* Some automation that comes in handy when constructing base parameters *) -Ltac opt_step := - match goal with - | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ] - => refine (_ : match e with nil => _ | _ => _ end = _); - destruct e - end. - -Definition limb_widths_from_len_step loop len k := - (fun i prev => - match i with - | O => nil - | S i' => let x := (if (Z.eqb ((k * Z.of_nat (len - i + 1)) mod (Z.of_nat len)) 0) - then (k * Z.of_nat (len - i + 1)) / Z.of_nat len - else (k * Z.of_nat (len - i + 1)) / Z.of_nat len + 1)in - x - prev:: (loop i' x) - end). -Definition limb_widths_from_len len k := - (fix loop i prev := limb_widths_from_len_step loop len k i prev) len 0. - -Definition brute_force_indices0 lw : bool - := List.fold_right - andb true - (List.map - (fun i - => List.fold_right - andb true - (List.map - (fun j - => sum_firstn lw (i + j) <=? sum_firstn lw i + sum_firstn lw j) - (seq 0 (length lw - i)))) - (seq 0 (length lw))). - -Lemma brute_force_indices_correct0 lw - : brute_force_indices0 lw = true -> forall i j : nat, - (i + j < length lw)%nat -> sum_firstn lw (i + j) <= sum_firstn lw i + sum_firstn lw j. -Proof. - unfold brute_force_indices0. - progress repeat setoid_rewrite fold_right_andb_true_map_iff. - setoid_rewrite in_seq. - setoid_rewrite Z.leb_le. - eauto with omega. -Qed. - -Definition brute_force_indices1 lw : bool - := List.fold_right - andb true - (List.map - (fun i - => List.fold_right - andb true - (List.map - (fun j - => let w_sum := sum_firstn lw in - sum_firstn lw (length lw) + w_sum (i + j - length lw)%nat <=? w_sum i + w_sum j) - (seq (length lw - i) (length lw - (length lw - i))))) - (seq 1 (length lw - 1))). - -Lemma brute_force_indices_correct1 lw - : brute_force_indices1 lw = true -> forall i j : nat, - (i < length lw)%nat -> - (j < length lw)%nat -> - (i + j >= length lw)%nat -> - let w_sum := sum_firstn lw in - sum_firstn lw (length lw) + w_sum (i + j - length lw)%nat <= w_sum i + w_sum j. -Proof. - unfold brute_force_indices1. - progress repeat setoid_rewrite fold_right_andb_true_map_iff. - setoid_rewrite in_seq. - setoid_rewrite Z.leb_le. - eauto with omega. -Qed. - -Ltac construct_params prime_modulus len k := - let lwv := (eval cbv in (limb_widths_from_len len k)) in - let lw := fresh "lw" in pose lwv as lw; - eapply Build_PseudoMersenneBaseParams with (limb_widths := lw); - [ abstract (apply fold_right_and_True_forall_In_iff; simpl; repeat (split; [omega |]); auto) - | abstract (cbv; congruence) - | abstract (refine (@brute_force_indices_correct0 lw _); vm_cast_no_check (eq_refl true)) - | abstract apply prime_modulus - | abstract (cbv; congruence) - | abstract (refine (@brute_force_indices_correct1 lw _); vm_cast_no_check (eq_refl true))]. - -Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits := - match limb_widths with - | nil => nil - | x :: tail => - 2 ^ (x + 1) - (2 * c) :: List.map (fun w => 2 ^ (w + 1) - 2) tail - end. - -Ltac compute_preconditions := - cbv; intros; repeat match goal with H : _ \/ _ |- _ => - destruct H; subst; [ congruence | ] end; (congruence || omega). - -Ltac subst_precondition := match goal with - | [H : ?P, H' : ?P -> _ |- _] => specialize (H' H); clear H -end. - -Ltac kill_precondition H := - forward H; [abstract (try exact eq_refl; clear; cbv; intros; repeat break_or_hyp; intuition)|]; - subst_precondition. - -Section Carries. - Context `{prm : PseudoMersenneBaseParams} - (* allows caller to precompute k and c *) - (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_). - Local Notation base := (Pow2Base.base_from_limb_widths limb_widths). - Local Notation digits := (tuple Z (length limb_widths)). - - Definition carry_gen_opt_sig fc fi i us - : { d : list Z | (0 <= fi (S (fi i)) < length us)%nat -> - d = carry_gen limb_widths fc fi i us}. - Proof. - eexists; intros. - cbv beta iota delta [carry_gen carry_single Z.pow2_mod]. - rewrite add_to_nth_set_nth. - change @nth_default with @nth_default_opt in *. - change @set_nth with @set_nth_opt in *. - change Z.ones with Z_ones_opt. - rewrite set_nth_nth_default by assumption. - rewrite <- @beq_nat_eq_nat_dec. - reflexivity. - Defined. - - Definition carry_gen_opt fc fi i us := Eval cbv [proj1_sig carry_gen_opt_sig] in - proj1_sig (carry_gen_opt_sig fc fi i us). - - Definition carry_gen_opt_correct fc fi i us - : (0 <= fi (S (fi i)) < length us)%nat -> - carry_gen_opt fc fi i us = carry_gen limb_widths fc fi i us - := proj2_sig (carry_gen_opt_sig fc fi i us). - - Definition carry_opt_sig - (i : nat) (b : list Z) - : { d : list Z | (length b = length limb_widths) - -> (i < length limb_widths)%nat - -> d = carry i b }. - Proof. - eexists ; intros. - cbv [carry]. - rewrite <-pull_app_if_sumbool. - cbv beta delta - [carry carry_and_reduce carry_simple]. - lazymatch goal with - | [ |- _ = (if ?br then ?c else ?d) ] - => let x := fresh "x" in let y := fresh "y" in evar (x:list Z); evar (y:list Z); transitivity (if br then x else y); subst x; subst y - end. - Focus 2. { - cbv zeta. - break_if; rewrite <-carry_gen_opt_correct by (omega || - (replace (length b) with (length limb_widths) by congruence; - apply Nat.mod_bound_pos; omega)); reflexivity. - } Unfocus. - rewrite c_subst. - rewrite <- @beq_nat_eq_nat_dec. - cbv [carry_gen_opt]. - reflexivity. - Defined. - - Definition carry_opt is us := Eval cbv [proj1_sig carry_opt_sig] in - proj1_sig (carry_opt_sig is us). - - Definition carry_opt_correct i us - : length us = length limb_widths - -> (i < length limb_widths)%nat - -> carry_opt i us = carry i us - := proj2_sig (carry_opt_sig i us). - - Definition carry_sequence_opt_sig (is : list nat) (us : list Z) - : { b : list Z | (length us = length limb_widths) - -> (forall i, In i is -> i < length limb_widths)%nat - -> b = carry_sequence is us }. - Proof. - eexists. intros H. - cbv [carry_sequence]. - transitivity (fold_right carry_opt us is). - Focus 2. - { induction is; [ reflexivity | ]. - simpl; rewrite IHis, carry_opt_correct. - - reflexivity. - - fold (carry_sequence is us). auto using length_carry_sequence. - - auto using in_eq. - - intros. auto using in_cons. - } - Unfocus. - reflexivity. - Defined. - - Definition carry_sequence_opt is us := Eval cbv [proj1_sig carry_sequence_opt_sig] in - proj1_sig (carry_sequence_opt_sig is us). - - Definition carry_sequence_opt_correct is us - : (length us = length limb_widths) - -> (forall i, In i is -> i < length limb_widths)%nat - -> carry_sequence_opt is us = carry_sequence is us - := proj2_sig (carry_sequence_opt_sig is us). - - Definition carry_gen_opt_cps_sig - {T} fc fi - (i : nat) - (f : list Z -> T) - (b : list Z) - : { d : T | (0 <= fi (S (fi i)) < length b)%nat -> d = f (carry_gen limb_widths fc fi i b) }. - Proof. - eexists. intros H. - rewrite <-carry_gen_opt_correct by assumption. - cbv beta iota delta [carry_gen_opt]. - match goal with |- appcontext[?a &' Z_ones_opt _] => - let LHS := match goal with |- ?LHS = ?RHS => LHS end in - let RHS := match goal with |- ?LHS = ?RHS => RHS end in - let RHSf := match (eval pattern (a) in RHS) with ?RHSf _ => RHSf end in - change (LHS = Let_In (a) RHSf) end. - reflexivity. - Defined. - - Definition carry_gen_opt_cps {T} fc fi i f b - := Eval cbv beta iota delta [proj1_sig carry_gen_opt_cps_sig] in - proj1_sig (@carry_gen_opt_cps_sig T fc fi i f b). - - Definition carry_gen_opt_cps_correct {T} fc fi i f b : - (0 <= fi (S (fi i)) < length b)%nat -> - @carry_gen_opt_cps T fc fi i f b = f (carry_gen limb_widths fc fi i b) - := proj2_sig (carry_gen_opt_cps_sig fc fi i f b). - - Definition carry_opt_cps_sig - {T} - (i : nat) - (f : list Z -> T) - (b : list Z) - : { d : T | (length b = length limb_widths) - -> (i < length limb_widths)%nat - -> d = f (carry i b) }. - Proof. - eexists. intros. - cbv beta delta - [carry carry_and_reduce carry_simple]. - rewrite <-pull_app_if_sumbool. - lazymatch goal with - | [ |- _ = ?f (if ?br then ?c else ?d) ] - => let x := fresh "x" in let y := fresh "y" in evar (x:T); evar (y:T); transitivity (if br then x else y); subst x; subst y - end. - Focus 2. { - cbv zeta. - break_if; rewrite <-carry_gen_opt_cps_correct by (omega || - (replace (length b) with (length limb_widths) by congruence; - apply Nat.mod_bound_pos; omega)); reflexivity. - } Unfocus. - rewrite c_subst. - rewrite <- @beq_nat_eq_nat_dec. - reflexivity. - Defined. - - Definition carry_opt_cps {T} i f b - := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b). - - Definition carry_opt_cps_correct {T} i f b : - (length b = length limb_widths) - -> (i < length limb_widths)%nat - -> @carry_opt_cps T i f b = f (carry i b) - := proj2_sig (carry_opt_cps_sig i f b). - - Definition carry_sequence_opt_cps_sig {T} (is : list nat) (us : list Z) - (f : list Z -> T) - : { b : T | (length us = length limb_widths) - -> (forall i, In i is -> i < length limb_widths)%nat - -> b = f (carry_sequence is us) }. - Proof. - eexists. - cbv [carry_sequence]. - transitivity (fold_right carry_opt_cps f (List.rev is) us). - Focus 2. - { - assert (forall i, In i (rev is) -> i < length limb_widths)%nat as Hr. { - subst. intros. rewrite <- in_rev in *. auto. } - remember (rev is) as ris eqn:Heq. - rewrite <- (rev_involutive is), <- Heq in H0 |- *. - clear H0 Heq is. - rewrite fold_left_rev_right. - revert H. revert us; induction ris; [ reflexivity | ]; intros. - { simpl. - rewrite <- IHris; clear IHris; - [|intros; apply Hr; right; assumption|auto using length_carry]. - rewrite carry_opt_cps_correct; [reflexivity|congruence|]. - apply Hr; left; reflexivity. - } } - Unfocus. - cbv [carry_opt_cps]. - reflexivity. - Defined. - - Definition carry_sequence_opt_cps {T} is us (f : list Z -> T) := - Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in - proj1_sig (carry_sequence_opt_cps_sig is us f). - - Definition carry_sequence_opt_cps_correct {T} is us (f : list Z -> T) - : (length us = length limb_widths) - -> (forall i, In i is -> i < length limb_widths)%nat - -> carry_sequence_opt_cps is us f = f (carry_sequence is us) - := proj2_sig (carry_sequence_opt_cps_sig is us f). - - Lemma full_carry_chain_bounds : forall i, In i (Pow2Base.full_carry_chain limb_widths) -> - (i < length limb_widths)%nat. - Proof. - unfold Pow2Base.full_carry_chain; intros. - apply Pow2BaseProofs.make_chain_lt; auto. - Qed. - - Definition carry_full_opt_sig (us : list Z) : - { b : list Z | (length us = length limb_widths) - -> b = carry_full us }. - Proof. - eexists; cbv [carry_full]; intros. - match goal with |- ?LHS = ?RHS => change (LHS = id RHS) end. - rewrite <-carry_sequence_opt_cps_correct with (f := id) by (auto; apply full_carry_chain_bounds). - change @Pow2Base.full_carry_chain with full_carry_chain_opt. - reflexivity. - Defined. - - Definition carry_full_opt (us : list Z) : list Z - := Eval cbv [proj1_sig carry_full_opt_sig] in proj1_sig (carry_full_opt_sig us). - - Definition carry_full_opt_correct us - : length us = length limb_widths - -> carry_full_opt us = carry_full us - := proj2_sig (carry_full_opt_sig us). - - Definition carry_full_opt_cps_sig - {T} - (f : list Z -> T) - (us : list Z) - : { d : T | length us = length limb_widths - -> d = f (carry_full us) }. - Proof. - eexists; intros. - rewrite <- carry_full_opt_correct by auto. - cbv beta iota delta [carry_full_opt]. - rewrite carry_sequence_opt_cps_correct by (auto || apply full_carry_chain_bounds). - match goal with |- ?LHS = ?f (?g (carry_sequence ?is ?us)) => - change (LHS = (fun x => f (g x)) (carry_sequence is us)) end. - rewrite <-carry_sequence_opt_cps_correct by (auto || apply full_carry_chain_bounds). - reflexivity. - Defined. - - Definition carry_full_opt_cps {T} (f : list Z -> T) (us : list Z) : T - := Eval cbv [proj1_sig carry_full_opt_cps_sig] in proj1_sig (carry_full_opt_cps_sig f us). - - Definition carry_full_opt_cps_correct {T} us (f : list Z -> T) - : length us = length limb_widths - -> carry_full_opt_cps f us = f (carry_full us) - := proj2_sig (carry_full_opt_cps_sig f us). - -End Carries. - -Section CarryChain. - Context `{prm : PseudoMersenneBaseParams} {cc : CarryChain limb_widths}. - Local Notation digits := (tuple Z (length limb_widths)). - - Definition carry__opt_sig {T} (f : digits -> T) (us : digits) - : { x | x = f (carry_ carry_chain us) }. - Proof. - eexists. - cbv [carry_]. - rewrite <- from_list_default_eq with (d := 0%Z). - change @from_list_default with @from_list_default_opt. - erewrite <-carry_sequence_opt_cps_correct by eauto using carry_chain_valid, length_to_list. - cbv [carry_sequence_opt_cps]. - reflexivity. - Defined. - - Definition carry__opt_cps {T} (f:digits -> T) (us : digits) : T - := Eval cbv [proj1_sig carry__opt_sig] in proj1_sig (carry__opt_sig f us). - - Definition carry__opt_cps_correct {T} (f:digits -> T) (us : digits) - : carry__opt_cps f us = f (carry_ carry_chain us) - := proj2_sig (carry__opt_sig f us). -End CarryChain. - -Section Addition. - Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient} {cc : CarryChain limb_widths}. - Local Notation digits := (tuple Z (length limb_widths)). - - Definition add_opt_sig (us vs : digits) : { b : digits | b = add us vs }. - Proof. - eexists. - reflexivity. - Defined. - - Definition add_opt (us vs : digits) : digits - := Eval cbv [proj1_sig add_opt_sig] in proj1_sig (add_opt_sig us vs). - - Definition add_opt_correct us vs - : add_opt us vs = add us vs - := proj2_sig (add_opt_sig us vs). - - Definition carry_add_opt_sig {T} (f:digits -> T) - (us vs : digits) : { x | x = f (carry_add carry_chain us vs) }. - Proof. - eexists. - cbv [carry_add]. - rewrite <-carry__opt_cps_correct, <-add_opt_correct. - cbv [carry_sequence_opt_cps carry__opt_cps add_opt add]. - rewrite to_list_from_list. - reflexivity. - Defined. - - Definition carry_add_opt_cps {T} (f:digits -> T) (us vs : digits) : T - := Eval cbv [proj1_sig carry_add_opt_sig] in proj1_sig (carry_add_opt_sig f us vs). - - Definition carry_add_opt_cps_correct {T} (f:digits -> T) (us vs : digits) - : carry_add_opt_cps f us vs = f (carry_add carry_chain us vs) - := proj2_sig (carry_add_opt_sig f us vs). - - Definition carry_add_opt := carry_add_opt_cps id. - - Definition carry_add_opt_correct (us vs : digits) - : carry_add_opt us vs = carry_add carry_chain us vs := - carry_add_opt_cps_correct id us vs. -End Addition. - -Section Subtraction. - Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient} {cc : CarryChain limb_widths}. - Local Notation digits := (tuple Z (length limb_widths)). - - Definition sub_opt_sig (us vs : digits) : { b : digits | b = sub coeff coeff_mod us vs }. - Proof. - eexists. - cbv [BaseSystem.add ModularBaseSystem.sub BaseSystem.sub]. - reflexivity. - Defined. - - Definition sub_opt (us vs : digits) : digits - := Eval cbv [proj1_sig sub_opt_sig] in proj1_sig (sub_opt_sig us vs). - - Definition sub_opt_correct us vs - : sub_opt us vs = sub coeff coeff_mod us vs - := proj2_sig (sub_opt_sig us vs). - - Definition carry_sub_opt_sig {T} (f:digits -> T) - (us vs : digits) : { x | x = f (carry_sub carry_chain coeff coeff_mod us vs) }. - Proof. - eexists. - cbv [carry_sub]. - rewrite <-carry__opt_cps_correct, <-sub_opt_correct. - cbv [carry_sequence_opt_cps carry__opt_cps sub_opt]. - rewrite to_list_from_list. - reflexivity. - Defined. - - Definition carry_sub_opt_cps {T} (f:digits -> T) (us vs : digits) : T - := Eval cbv [proj1_sig carry_sub_opt_sig] in proj1_sig (carry_sub_opt_sig f us vs). - - Definition carry_sub_opt_cps_correct {T} (f:digits -> T) (us vs : digits) - : carry_sub_opt_cps f us vs = f (carry_sub carry_chain coeff coeff_mod us vs) - := proj2_sig (carry_sub_opt_sig f us vs). - - Definition carry_sub_opt := carry_sub_opt_cps id. - - Definition carry_sub_opt_correct (us vs : digits) - : carry_sub_opt us vs = carry_sub carry_chain coeff coeff_mod us vs := - carry_sub_opt_cps_correct id us vs. - - Definition opp_opt_sig (us : digits) : { b : digits | b = opp coeff coeff_mod us }. - Proof. - eexists. - cbv [opp]. - rewrite <-sub_opt_correct. - reflexivity. - Defined. - - Definition opp_opt (us : digits) : digits - := Eval cbv [proj1_sig opp_opt_sig] in proj1_sig (opp_opt_sig us). - - Definition opp_opt_correct us - : opp_opt us = opp coeff coeff_mod us - := proj2_sig (opp_opt_sig us). - - Definition carry_opp_opt_sig (us : digits) : { b : digits | b = carry_opp carry_chain coeff coeff_mod us }. - Proof. - eexists. - cbv [carry_opp]. - rewrite <-carry_sub_opt_correct. - reflexivity. - Defined. - - Definition carry_opp_opt (us : digits) : digits - := Eval cbv [proj1_sig carry_opp_opt_sig] in proj1_sig (carry_opp_opt_sig us). - - Definition carry_opp_opt_correct us - : carry_opp_opt us = carry_opp carry_chain coeff coeff_mod us - := proj2_sig (carry_opp_opt_sig us). - -End Subtraction. - -Section Multiplication. - Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient} {cc : CarryChain limb_widths} - (* allows caller to precompute k and c *) - (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_). - Local Notation digits := (tuple Z (length limb_widths)). - - Definition mul_bi'_step - (mul_bi' : nat -> list Z -> list Z -> list Z) - (i : nat) (vsr : list Z) (bs : list Z) - : list Z - := match vsr with - | [] => [] - | v :: vsr' => (v * crosscoef bs i (length vsr'))%Z :: mul_bi' i vsr' bs - end. - - Definition mul_bi'_opt_step_sig - (mul_bi' : nat -> list Z -> list Z -> list Z) - (i : nat) (vsr : list Z) (bs : list Z) - : { l : list Z | l = mul_bi'_step mul_bi' i vsr bs }. - Proof. - eexists. - cbv [mul_bi'_step]. - opt_step. - { reflexivity. } - { cbv [crosscoef]. - change Z.div with Z_div_opt. - change Z.mul with Z_mul_opt at 2. - change @nth_default with @nth_default_opt. - reflexivity. } - Defined. - - Definition mul_bi'_opt_step - (mul_bi' : nat -> list Z -> list Z -> list Z) - (i : nat) (vsr : list Z) (bs : list Z) - : list Z - := Eval cbv [proj1_sig mul_bi'_opt_step_sig] in - proj1_sig (mul_bi'_opt_step_sig mul_bi' i vsr bs). - - Fixpoint mul_bi'_opt - (i : nat) (vsr : list Z) (bs : list Z) {struct vsr} - : list Z - := mul_bi'_opt_step mul_bi'_opt i vsr bs. - - Definition mul_bi'_opt_correct - (i : nat) (vsr : list Z) (bs : list Z) - : mul_bi'_opt i vsr bs = mul_bi' bs i vsr. - Proof using Type. - revert i; induction vsr as [|vsr vsrs IHvsr]; intros. - { reflexivity. } - { simpl mul_bi'. - rewrite <- IHvsr; clear IHvsr. - unfold mul_bi'_opt, mul_bi'_opt_step. - apply f_equal2; [ | reflexivity ]. - cbv [crosscoef]. - change Z.div with Z_div_opt. - change Z.mul with Z_mul_opt at 2. - change @nth_default with @nth_default_opt. - reflexivity. } - Qed. - - Definition mul'_step - (mul' : list Z -> list Z -> list Z -> list Z) - (usr vs : list Z) (bs : list Z) - : list Z - := match usr with - | [] => [] - | u :: usr' => BaseSystem.add (mul_each u (mul_bi bs (length usr') vs)) (mul' usr' vs bs) - end. - - Lemma map_zeros : forall a n l, - List.map (Z.mul a) (zeros n ++ l) = zeros n ++ List.map (Z.mul a) l. - Proof using prm. - induction n; simpl; [ reflexivity | intros; apply f_equal2; [ omega | congruence ] ]. - Qed. - - Definition mul'_opt_step_sig - (mul' : list Z -> list Z -> list Z -> list Z) - (usr vs : list Z) (bs : list Z) - : { d : list Z | d = mul'_step mul' usr vs bs }. - Proof. - eexists. - cbv [mul'_step]. - match goal with - | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ] - => refine (_ : match e with nil => _ | _ => _ end = _); - destruct e - end. - { reflexivity. } - { cbv [mul_each mul_bi]. - rewrite <- mul_bi'_opt_correct. - rewrite map_zeros. - change @List.map with @map_opt. - cbv [zeros]. - reflexivity. } - Defined. - - Definition mul'_opt_step - (mul' : list Z -> list Z -> list Z -> list Z) - (usr vs : list Z) (bs : list Z) - : list Z - := Eval cbv [proj1_sig mul'_opt_step_sig] in proj1_sig (mul'_opt_step_sig mul' usr vs bs). - - Fixpoint mul'_opt - (usr vs : list Z) (bs : list Z) - : list Z - := mul'_opt_step mul'_opt usr vs bs. - - Definition mul'_opt_correct - (usr vs : list Z) (bs : list Z) - : mul'_opt usr vs bs = mul' bs usr vs. - Proof using prm. - revert vs; induction usr as [|usr usrs IHusr]; intros. - { reflexivity. } - { simpl. - rewrite <- IHusr; clear IHusr. - apply f_equal2; [ | reflexivity ]. - cbv [mul_each mul_bi]. - rewrite map_zeros. - rewrite <- mul_bi'_opt_correct. - cbv [zeros]. - reflexivity. } - Qed. - - Definition mul_opt_sig (us vs : digits) : { b : digits | b = mul us vs }. - Proof. - eexists. - cbv [mul ModularBaseSystemList.mul BaseSystem.mul mul_each mul_bi mul_bi' zeros reduce]. - rewrite <- from_list_default_eq with (d := 0%Z). - change (@from_list_default Z) with (@from_list_default_opt Z). - apply f_equal. - rewrite ext_base_alt by auto using limb_widths_pos with zarith. - rewrite <- mul'_opt_correct. - change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt. - rewrite Z.map_shiftl by apply k_nonneg. - rewrite c_subst. - fold k; rewrite k_subst. - change @List.map with @map_opt. - change @Z.shiftl_by with @Z_shiftl_by_opt. - reflexivity. - Defined. - - Definition mul_opt (us vs : digits) : digits - := Eval cbv [proj1_sig mul_opt_sig] in proj1_sig (mul_opt_sig us vs). - - Definition mul_opt_correct us vs - : mul_opt us vs = mul us vs - := proj2_sig (mul_opt_sig us vs). - - Definition carry_mul_opt_sig {T} (f:digits -> T) - (us vs : digits) : { x | x = f (carry_mul carry_chain us vs) }. - Proof. - eexists. - cbv [carry_mul]. - rewrite <-carry__opt_cps_correct, <-mul_opt_correct. - cbv [carry_sequence_opt_cps carry__opt_cps mul_opt]. - erewrite from_list_default_eq. - rewrite to_list_from_list. - reflexivity. - Grab Existential Variables. - rewrite mul'_opt_correct. - distr_length. - assert (0 < length limb_widths)%nat by (pose proof limb_widths_nonnil; destruct limb_widths; congruence || simpl; omega). - rewrite Min.min_l; break_match; try omega. - rewrite Max.max_l; omega. - Defined. - - Definition carry_mul_opt_cps {T} (f:digits -> T) (us vs : digits) : T - := Eval cbv [proj1_sig carry_mul_opt_sig] in proj1_sig (carry_mul_opt_sig f us vs). - - Definition carry_mul_opt_cps_correct {T} (f:digits -> T) (us vs : digits) - : carry_mul_opt_cps f us vs = f (carry_mul carry_chain us vs) - := proj2_sig (carry_mul_opt_sig f us vs). - - Definition carry_mul_opt := carry_mul_opt_cps id. - - Definition carry_mul_opt_correct (us vs : digits) - : carry_mul_opt us vs = carry_mul carry_chain us vs := - carry_mul_opt_cps_correct id us vs. - -End Multiplication. - -Import Morphisms. -Global Instance Proper_fold_chain {T} {Teq} {Teq_Equivalence : Equivalence Teq} - : Proper (Logic.eq - ==> (fun f g => forall x1 x2 y1 y2 : T, Teq x1 x2 -> Teq y1 y2 -> Teq (f x1 y1) (g x2 y2)) - ==> Logic.eq - ==> SetoidList.eqlistA Teq - ==> Teq) fold_chain. -Proof. - do 9 intro. - subst; induction y1; repeat intro; - unfold fold_chain; fold @fold_chain. - + inversion H; assumption || reflexivity. - + destruct a. - apply IHy1. - econstructor; try assumption. - apply H0; eapply Proper_nth_default; eauto; reflexivity. -Qed. - -Section PowInv. - Context `{prm : PseudoMersenneBaseParams} - (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_) - {cc : CarryChain limb_widths}. - Local Notation digits := (tuple Z (length limb_widths)). - Context (one_ : digits) (one_subst : one = one_). - - Fixpoint fold_chain_opt {T} (id : T) op chain acc := - match chain with - | [] => match acc with - | [] => id - | ret :: _ => ret - end - | (i, j) :: chain' => - Let_In (op (nth_default id acc i) (nth_default id acc j)) - (fun ijx => fold_chain_opt id op chain' (ijx :: acc)) - end. - - Lemma fold_chain_opt_correct : forall {T} (id : T) op chain acc, - fold_chain_opt id op chain acc = fold_chain id op chain acc. - Proof using Type. - reflexivity. - Qed. - - Definition pow_opt_sig x chain : - {y | eq y (ModularBaseSystem.pow x chain)}. - Proof. - eexists. - cbv beta iota delta [ModularBaseSystem.pow]. - transitivity (fold_chain one_ (carry_mul_opt k_ c_) chain [x]). - Focus 2. { - apply Proper_fold_chain; auto; try reflexivity. - cbv [eq]; intros. - rewrite carry_mul_opt_correct by assumption. - rewrite carry_mul_rep, mul_rep by reflexivity. - congruence. - } Unfocus. - rewrite <-fold_chain_opt_correct. - reflexivity. - Defined. - - Definition pow_opt x chain : digits - := Eval cbv [proj1_sig pow_opt_sig] in (proj1_sig (pow_opt_sig x chain)). - - Definition pow_opt_correct x chain - : eq (pow_opt x chain) (ModularBaseSystem.pow x chain) - := Eval cbv [proj2_sig pow_opt_sig] in (proj2_sig (pow_opt_sig x chain)). - - Context {ec : ExponentiationChain (modulus - 2)}. - - Definition inv_opt_sig x: - {y | eq y (inv chain chain_correct x)}. - Proof. - eexists. - cbv [inv]. - rewrite <-pow_opt_correct. - reflexivity. - Defined. - - Definition inv_opt x : digits - := Eval cbv [proj1_sig inv_opt_sig] in (proj1_sig (inv_opt_sig x)). - - Definition inv_opt_correct x - : eq (inv_opt x) (inv chain chain_correct x) - := Eval cbv [proj2_sig inv_opt_sig] in (proj2_sig (inv_opt_sig x)). -End PowInv. - -Section Conversion. - - Definition convert'_opt_sig {lwA lwB} - (nonnegA : forall x, In x lwA -> 0 <= x) - (nonnegB : forall x, In x lwB -> 0 <= x) - bits_fit inp i out : - { y | y = convert' nonnegA nonnegB bits_fit inp i out}. - Proof. - eexists. - rewrite convert'_equation. - change sum_firstn with @sum_firstn_opt. - change length with length_opt. - change Z_le_dec with Z_le_dec_opt. - change Z.of_nat with Z_of_nat_opt. - change digit_index with digit_index_opt. - change bit_index with bit_index_opt. - change Z.min with Z_min_opt. - change (nth_default 0 lwA) with (nth_default_opt 0 lwA). - change (nth_default 0 lwB) with (nth_default_opt 0 lwB). - cbv [update_by_concat_bits concat_bits Z.pow2_mod]. - change Z.ones with Z_ones_opt. - change @update_nth with @update_nth_opt. - change plus with plus_opt. - change Z.sub with Z_sub_opt. - reflexivity. - Defined. - - Definition convert'_opt {lwA lwB} - (nonnegA : forall x, In x lwA -> 0 <= x) - (nonnegB : forall x, In x lwB -> 0 <= x) - bits_fit inp i out := - Eval cbv [proj1_sig convert'_opt_sig] in - proj1_sig (convert'_opt_sig nonnegA nonnegB bits_fit inp i out). - - Definition convert'_opt_correct {lwA lwB} - (nonnegA : forall x, In x lwA -> 0 <= x) - (nonnegB : forall x, In x lwB -> 0 <= x) - bits_fit inp i out : - convert'_opt nonnegA nonnegB bits_fit inp i out = convert' nonnegA nonnegB bits_fit inp i out := - Eval cbv [proj2_sig convert'_opt_sig] in - proj2_sig (convert'_opt_sig nonnegA nonnegB bits_fit inp i out). - - Context {modulus} (prm : PseudoMersenneBaseParams modulus) - {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x) (bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn target_widths (length target_widths)). - Local Notation digits := (tuple Z (length limb_widths)). - Local Notation target_digits := (tuple Z (length target_widths)). - - Definition pack_opt_sig (x : digits) : { y | y = pack target_widths_nonneg bits_eq x}. - Proof. - eexists. - cbv [pack]. - rewrite <- from_list_default_eq with (d := 0%Z). - change @from_list_default with @from_list_default_opt. - cbv [ModularBaseSystemList.pack convert]. - change length with length_opt. - change sum_firstn with @sum_firstn_opt. - change zeros with zeros_opt. - reflexivity. - Defined. - - Definition pack_opt (x : digits) : target_digits := - Eval cbv [proj1_sig pack_opt_sig] in proj1_sig (pack_opt_sig x). - - Definition pack_correct (x : digits) : - pack_opt x = pack target_widths_nonneg bits_eq x - := Eval cbv [proj2_sig pack_opt_sig] in proj2_sig (pack_opt_sig x). - - Definition unpack_opt_sig (x : target_digits) : { y | y = unpack target_widths_nonneg bits_eq x}. - Proof. - eexists. - cbv [unpack]. - rewrite <- from_list_default_eq with (d := 0%Z). - change @from_list_default with @from_list_default_opt. - cbv [ModularBaseSystemList.unpack convert]. - change length with length_opt. - change sum_firstn with @sum_firstn_opt. - change zeros with zeros_opt. - reflexivity. - Defined. - - Definition unpack_opt (x : target_digits) : digits := - Eval cbv [proj1_sig unpack_opt_sig] in proj1_sig (unpack_opt_sig x). - - Definition unpack_correct (x : target_digits) : - unpack_opt x = unpack target_widths_nonneg bits_eq x - := Eval cbv [proj2_sig unpack_opt_sig] in proj2_sig (unpack_opt_sig x). - -End Conversion. - -Local Hint Resolve lt_1_length_limb_widths int_width_pos B_pos B_compat - c_reduce1 c_reduce2. - -Section Canonicalization. - Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient} - (* allows caller to precompute k and c *) - (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_) - {int_width freeze_input_bound} - (preconditions : FreezePreconditions freeze_input_bound int_width). - Local Notation digits := (tuple Z (length limb_widths)). - - Definition carry_full_3_opt_sig - (us : list Z) - : { d : list Z | length us = length limb_widths - -> d = carry_full (carry_full (carry_full us)) }. - Proof. - eexists. - transitivity (carry_full_opt_cps c_ (carry_full_opt_cps c_ (carry_full_opt c_)) us). - Focus 2. { - rewrite !carry_full_opt_cps_correct; try rewrite carry_full_opt_correct; repeat (autorewrite with distr_length; rewrite ?length_carry_full; auto). - } - Unfocus. - reflexivity. - Defined. - - Definition carry_full_3_opt (us : list Z) : list Z - := Eval cbv [proj1_sig carry_full_3_opt_sig] in proj1_sig (carry_full_3_opt_sig us). - - Definition carry_full_3_opt_correct us - : length us = length limb_widths - -> carry_full_3_opt us = carry_full (carry_full (carry_full us)) - := proj2_sig (carry_full_3_opt_sig us). - - Lemma ge_modulus'_cps : forall {A} (f : Z -> A) (us : list Z) i b, - f (ge_modulus' id us b i) = ge_modulus' f us b i. - Proof using Type. - induction i; intros; simpl; cbv [Let_In cmovl cmovne]; break_if; try reflexivity; - apply IHi. - Qed. - - Definition ge_modulus_opt_sig (us : list Z) : - { a : Z | a = ge_modulus us}. - Proof. - eexists. - cbv [ge_modulus ge_modulus']. - change length with length_opt. - change nth_default with @nth_default_opt. - change minus with minus_opt. - reflexivity. - Defined. - - Definition ge_modulus_opt us : Z - := Eval cbv [proj1_sig ge_modulus_opt_sig] in proj1_sig (ge_modulus_opt_sig us). - - Definition ge_modulus_opt_correct us : - ge_modulus_opt us= ge_modulus us - := Eval cbv [proj2_sig ge_modulus_opt_sig] in proj2_sig (ge_modulus_opt_sig us). - - Definition conditional_subtract_modulus_opt_sig (f : list Z): - { g | g = conditional_subtract_modulus int_width f (ge_modulus f) }. - Proof. - eexists. - cbv [conditional_subtract_modulus]. - let LHS := match goal with |- ?LHS = ?RHS => LHS end in - let RHS := match goal with |- ?LHS = ?RHS => RHS end in - let RHSf := match (eval pattern (neg int_width (ge_modulus f)) in RHS) with ?RHSf _ => RHSf end in - change (LHS = Let_In (neg int_width (ge_modulus f)) RHSf). - cbv [ge_modulus]. - rewrite ge_modulus'_cps. - cbv beta iota delta [ge_modulus ge_modulus']. - change length with length_opt. - change nth_default with @nth_default_opt. - change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt. - change minus with minus_opt. - reflexivity. - Defined. - - Definition conditional_subtract_modulus_opt f : list Z - := Eval cbv [proj1_sig conditional_subtract_modulus_opt_sig] in proj1_sig (conditional_subtract_modulus_opt_sig f). - - Definition conditional_subtract_modulus_opt_correct f - : conditional_subtract_modulus_opt f = conditional_subtract_modulus int_width f (ge_modulus f) - := Eval cbv [proj2_sig conditional_subtract_modulus_opt_sig] in proj2_sig (conditional_subtract_modulus_opt_sig f). - - - Definition freeze_opt_sig (us : list Z) : - { b : list Z | length us = length limb_widths - -> b = ModularBaseSystemList.freeze int_width us }. - Proof. - eexists. - cbv [ModularBaseSystemList.freeze]. - rewrite <-conditional_subtract_modulus_opt_correct. - intros. - rewrite <-carry_full_3_opt_correct by auto. - let LHS := match goal with |- ?LHS = ?RHS => LHS end in - let RHS := match goal with |- ?LHS = ?RHS => RHS end in - let RHSf := match (eval pattern (carry_full_3_opt us) in RHS) with ?RHSf _ => RHSf end in - change (LHS = Let_In (carry_full_3_opt us) RHSf). - reflexivity. - Defined. - - Definition freeze_opt (us : list Z) : list Z - := Eval cbv beta iota delta [proj1_sig freeze_opt_sig] in proj1_sig (freeze_opt_sig us). - - Definition freeze_opt_correct us - : length us = length limb_widths - -> freeze_opt us = ModularBaseSystemList.freeze int_width us - := proj2_sig (freeze_opt_sig us). - -End Canonicalization. - -Section SquareRoots. - Context `{prm : PseudoMersenneBaseParams}. - Context {cc : CarryChain limb_widths}. - Local Notation digits := (tuple Z (length limb_widths)). - (* allows caller to precompute k and c *) - Context (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_) - (one_ : digits) (one_subst : one = one_). - - (* TODO : where should this lemma go? Alternatively, is there a standard-library - tactic/lemma for this? *) - Lemma if_equiv : forall {A} (eqA : A -> A -> Prop) (x0 x1 : bool) y0 y1 z0 z1, - x0 = x1 -> eqA y0 y1 -> eqA z0 z1 -> - eqA (if x0 then y0 else z0) (if x1 then y1 else z1). - Proof using Type. - intros; repeat break_if; congruence. - Qed. - - Section SquareRoot3mod4. - Context {ec : ExponentiationChain (modulus / 4 + 1)}. - - Definition sqrt_3mod4_opt_sig (us : digits) : - { vs : digits | eq vs (sqrt_3mod4 chain chain_correct us)}. - Proof. - eexists; cbv [sqrt_3mod4]. - apply @pow_opt_correct; eassumption. - Defined. - - Definition sqrt_3mod4_opt us := Eval cbv [proj1_sig sqrt_3mod4_opt_sig] in - proj1_sig (sqrt_3mod4_opt_sig us). - - Definition sqrt_3mod4_opt_correct us - : eq (sqrt_3mod4_opt us) (sqrt_3mod4 chain chain_correct us) - := Eval cbv [proj2_sig sqrt_3mod4_opt_sig] in proj2_sig (sqrt_3mod4_opt_sig us). - - End SquareRoot3mod4. - - Section SquareRoot5mod8. - Context {ec : ExponentiationChain (modulus / 8 + 1)}. - Context (sqrt_m1 : digits) (sqrt_m1_correct : rep (mul sqrt_m1 sqrt_m1) (F.opp 1%F)). - Context {int_width freeze_input_bound} - (preconditions : FreezePreconditions freeze_input_bound int_width). - - Definition sqrt_5mod8_opt_sig (powx powx_squared us : digits) : - { vs : digits | - eq vs (sqrt_5mod8 int_width powx powx_squared chain chain_correct sqrt_m1 us)}. - Proof. - cbv [sqrt_5mod8]. - match goal with - |- appcontext[(if ?P then ?t else mul ?a ?b)] => - assert (eq (carry_mul_opt k_ c_ a b) (mul a b)) - by (rewrite carry_mul_opt_correct by auto; - cbv [eq]; rewrite carry_mul_rep, mul_rep; reflexivity) - end. - let RHS := match goal with |- {vs | eq vs ?RHS} => RHS end in - let RHSf := match (eval pattern powx in RHS) with ?RHSf _ => RHSf end in - change ({vs | eq vs (Let_In powx RHSf)}). - match goal with - | H : eq (?g powx) (?f powx) - |- {vs | eq vs (Let_In powx (fun x => if ?P then x else ?f x))} => - exists (Let_In powx (fun x => if P then x else g x)) - end. - break_if; try reflexivity. - cbv [Let_In]. - auto. - Defined. - - Definition sqrt_5mod8_opt powx powx_squared us := Eval cbv [proj1_sig sqrt_5mod8_opt_sig] in - proj1_sig (sqrt_5mod8_opt_sig powx powx_squared us). - - Definition sqrt_5mod8_opt_correct powx powx_squared us - : eq (sqrt_5mod8_opt powx powx_squared us) (ModularBaseSystem.sqrt_5mod8 int_width _ _ chain chain_correct sqrt_m1 us) - := Eval cbv [proj2_sig sqrt_5mod8_opt_sig] in proj2_sig (sqrt_5mod8_opt_sig powx powx_squared us). - - End SquareRoot5mod8. - -End SquareRoots. diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v deleted file mode 100644 index 9b22187bd..000000000 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ /dev/null @@ -1,1145 +0,0 @@ -Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith. -Require Import Coq.Numbers.Natural.Peano.NPeano. -Require Import Coq.Lists.List. -Require Import Crypto.Algebra. -Require Import Crypto.BaseSystem. -Require Import Crypto.BaseSystemProofs. -Require Import Crypto.ModularArithmetic.ExtendedBaseVector. -Require Import Crypto.ModularArithmetic.Pow2Base. -Require Import Crypto.ModularArithmetic.Pow2BaseProofs. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. -Require Import Crypto.ModularArithmetic.ModularBaseSystemList. -Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs. -Require Import Crypto.ModularArithmetic.ModularBaseSystem. -Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil Crypto.Util.NatUtil. -Require Import Crypto.Util.AdditionChainExponentiation. -Require Import Crypto.Util.Tuple. -Require Import Crypto.Util.Tactics.UniquePose. -Require Import Crypto.Util.Tactics.BreakMatch. -Require Import Crypto.Util.Tactics.SpecializeBy. -Require Import Crypto.Util.Notations. -Require Export Crypto.Util.FixCoqMistakes. -Local Open Scope Z_scope. - -Local Opaque add_to_nth carry_simple. - -Class CarryChain (limb_widths : list Z) := - { - carry_chain : list nat; - carry_chain_valid : forall i, In i carry_chain -> (i < length limb_widths)%nat - }. - - Class SubtractionCoefficient {m : positive} {prm : PseudoMersenneBaseParams m} := { - coeff : tuple Z (length limb_widths); - coeff_mod: decode coeff = 0%F - }. - - Class ExponentiationChain {m : positive} {prm : PseudoMersenneBaseParams m} (exp : Z) := { - chain : list (nat * nat); - chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N exp - }. - - -Section FieldOperationProofs. - Context `{prm :PseudoMersenneBaseParams}. - - Local Arguments to_list {_ _} _. - Local Arguments from_list {_ _} _ _. - - Local Hint Unfold decode. - Local Notation "u ~= x" := (rep u x). - Local Notation digits := (tuple Z (length limb_widths)). - Local Hint Resolve (@limb_widths_nonneg _ prm) sum_firstn_limb_widths_nonneg. - - Local Hint Resolve log_cap_nonneg. - Local Hint Resolve base_from_limb_widths_length. - Local Notation base := (base_from_limb_widths limb_widths). - Local Notation log_cap i := (nth_default 0 limb_widths i). - - Local Hint Unfold rep decode ModularBaseSystemList.decode. - - Lemma rep_decode : forall us x, us ~= x -> decode us = x. - Proof using Type. - autounfold; intuition. - Qed. - - Lemma decode_rep : forall us, rep us (decode us). - Proof using Type. - cbv [rep]; auto. - Qed. - - Lemma encode_eq : forall x : F modulus, - ModularBaseSystemList.encode x = BaseSystem.encode base (F.to_Z x) (2 ^ k). - Proof using Type. - cbv [ModularBaseSystemList.encode BaseSystem.encode encodeZ]; intros. - rewrite base_from_limb_widths_length. - apply encode'_spec; auto using Nat.eq_le_incl. - Qed. - - Lemma encode_rep : forall x : F modulus, encode x ~= x. - Proof using Type. - autounfold; cbv [encode]; intros. - rewrite to_list_from_list; autounfold. - rewrite encode_eq, encode_rep. - + apply F.of_Z_to_Z. - + apply bv. - + rewrite <-F.mod_to_Z. - match goal with |- appcontext [?a mod (Z.pos modulus)] => - pose proof (Z.mod_pos_bound a modulus modulus_pos) end. - pose proof lt_modulus_2k. - omega. - + eauto using base_upper_bound_compatible, limb_widths_nonneg. - Qed. - - Lemma bounded_encode : forall x, bounded limb_widths (to_list (encode x)). - Proof using Type. - intros. - cbv [encode]; rewrite to_list_from_list. - cbv [ModularBaseSystemList.encode]. - apply bounded_encodeZ; auto. - apply F.to_Z_range. - pose proof prime_modulus; prime_bound. - Qed. - - Lemma encode_range : forall x, - 0 <= BaseSystem.decode base (to_list (encode x)) < modulus. - Proof. - cbv [encode]; intros. - rewrite to_list_from_list. - rewrite encode_eq. - rewrite BaseSystemProofs.encode_rep; auto using F.to_Z_range, modulus_pos, bv. - + pose proof (F.to_Z_range x modulus_pos). - replace (2 ^ k) with (modulus + c) by (cbv[c]; ring). - pose proof c_pos; omega. - + apply base_upper_bound_compatible; auto. - Qed. - - Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> - add u v ~= (x+y)%F. - Proof using Type. - autounfold; cbv [add]; intros. - rewrite to_list_from_list; autounfold. - rewrite add_rep, F.of_Z_add. - f_equal; assumption. - Qed. - - Lemma eq_rep_iff : forall u v, (eq u v <-> u ~= decode v). - Proof using Type. - reflexivity. - Qed. - - Lemma eq_dec : forall x y, Decidable.Decidable (eq x y). - Proof using Type. - intros. - destruct (F.eq_dec (decode x) (decode y)); [ left | right ]; congruence. - Qed. - - Lemma modular_base_system_add_monoid : @monoid digits eq add zero. - Proof using Type. - repeat match goal with - | |- _ => progress intro - | |- _ => cbv [zero]; rewrite encode_rep - | |- _ digits eq add => econstructor - | |- _ digits eq add _ => econstructor - | |- (_ + _)%F = decode (add ?a ?b) => rewrite (add_rep a b) by (try apply add_rep; reflexivity) - | |- eq _ _ => apply eq_rep_iff - | |- add _ _ ~= _ => apply add_rep - | |- decode (add _ _) = _ => apply add_rep - | |- add _ _ ~= decode _ => etransitivity - | x : digits |- ?x ~= _ => reflexivity - | |- _ => apply associative - | |- _ => apply left_identity - | |- _ => apply right_identity - | |- _ => solve [eauto using eq_Equivalence, eq_dec] - | |- _ => congruence - end. - Qed. - - Local Hint Resolve firstn_us_base_ext_base bv ExtBaseVector limb_widths_match_modulus. - Local Hint Extern 1 => apply limb_widths_match_modulus. - - Lemma reduce_rep : forall us, - BaseSystem.decode base (reduce us) mod modulus = - BaseSystem.decode (ext_base limb_widths) us mod modulus. - Proof. - cbv [reduce]; intros. - rewrite extended_shiftadd, base_from_limb_widths_length, pseudomersenne_add, BaseSystemProofs.add_rep. - change (List.map (Z.mul c)) with (BaseSystem.mul_each c). - rewrite mul_each_rep; auto. - Qed. - - Lemma mul_rep : forall u v x y, u ~= x -> v ~= y -> mul u v ~= (x*y)%F. - Proof using Type. - autounfold in *; unfold ModularBaseSystem.mul in *. - intuition idtac; subst. - rewrite to_list_from_list. - cbv [ModularBaseSystemList.mul ModularBaseSystemList.decode]. - rewrite F.of_Z_mod, reduce_rep, <-F.of_Z_mod. - pose proof (@base_from_limb_widths_length limb_widths). - rewrite @mul_rep by (eauto using ExtBaseVector || rewrite extended_base_length, !length_to_list; omega). - rewrite 2decode_short by (rewrite ?base_from_limb_widths_length; - auto using Nat.eq_le_incl, length_to_list with omega). - apply F.of_Z_mul. - Qed. - - Lemma modular_base_system_mul_monoid : @monoid digits eq mul one. - Proof using Type. - repeat match goal with - | |- _ => progress intro - | |- _ => cbv [one]; rewrite encode_rep - | |- _ digits eq mul => econstructor - | |- _ digits eq mul _ => econstructor - | |- (_ * _)%F = decode (mul ?a ?b) => rewrite (mul_rep a b) by (try apply mul_rep; reflexivity) - | |- eq _ _ => apply eq_rep_iff - | |- mul _ _ ~= _ => apply mul_rep - | |- decode (mul _ _) = _ => apply mul_rep - | |- mul _ _ ~= decode _ => etransitivity - | x : digits |- ?x ~= _ => reflexivity - | |- _ => apply associative - | |- _ => apply left_identity - | |- _ => apply right_identity - | |- _ => solve [eauto using eq_Equivalence, eq_dec] - | |- _ => congruence - end. - Qed. - - Lemma Fdecode_decode_mod : forall us x, - decode us = x -> BaseSystem.decode base (to_list us) mod modulus = F.to_Z x. - Proof using Type. - autounfold; intros. - rewrite <-H. - apply F.to_Z_of_Z. - Qed. - - Lemma sub_rep : forall mm pf u v x y, u ~= x -> v ~= y -> - ModularBaseSystem.sub mm pf u v ~= (x-y)%F. - Proof. - autounfold; cbv [sub]; intros. - rewrite to_list_from_list; autounfold. - cbv [ModularBaseSystemList.sub]. - rewrite BaseSystemProofs.sub_rep, BaseSystemProofs.add_rep. - rewrite F.of_Z_sub, F.of_Z_add, F.of_Z_mod. - apply Fdecode_decode_mod in pf; cbv [BaseSystem.decode] in *. - rewrite pf. rewrite Algebra.left_identity. - f_equal; assumption. - Qed. - - Lemma opp_rep : forall mm pf u x, u ~= x -> opp mm pf u ~= F.opp x. - Proof using Type. - cbv [opp rep]; intros. - rewrite sub_rep by (apply encode_rep || eassumption). - apply F.eq_to_Z_iff. - rewrite F.to_Z_opp. - rewrite <-Z.sub_0_l. - pose proof @F.of_Z_sub. - transitivity (F.to_Z (F.of_Z modulus (0 - F.to_Z x))); - [ rewrite F.of_Z_sub, F.of_Z_to_Z; reflexivity | ]. - rewrite F.to_Z_of_Z. reflexivity. - Qed. - - Section PowInv. - Context (modulus_gt_2 : 2 < modulus). - - Lemma scalarmult_rep : forall u x n, u ~= x -> - (@ScalarMult.scalarmult_ref digits mul one n u) ~= (x ^ (N.of_nat n))%F. - Proof using Type. - induction n; intros. - + cbv [N.to_nat ScalarMult.scalarmult_ref]. rewrite F.pow_0_r. - apply encode_rep. - + unfold ScalarMult.scalarmult_ref. - fold (@ScalarMult.scalarmult_ref digits mul one). - rewrite Nnat.Nat2N.inj_succ, <-N.add_1_l, F.pow_add_r, F.pow_1_r. - apply mul_rep; auto. - Qed. - - Lemma pow_rep : forall chain u x, u ~= x -> - pow u chain ~= F.pow x (fold_chain 0%N N.add chain (1%N :: nil)). - Proof using Type. - cbv [pow rep]; intros. - erewrite (@fold_chain_exp _ _ _ _ modular_base_system_mul_monoid) - by (apply @ScalarMult.scalarmult_ref_is_scalarmult; apply modular_base_system_mul_monoid). - etransitivity; [ apply scalarmult_rep; eassumption | ]. - rewrite Nnat.N2Nat.id. - reflexivity. - Qed. - - Lemma inv_rep : forall chain pf u x, u ~= x -> - inv chain pf u ~= F.inv x. - Proof using modulus_gt_2. - cbv [inv]; intros. - rewrite (@F.Fq_inv_fermat _ prime_modulus modulus_gt_2). - etransitivity; [ apply pow_rep; eassumption | ]. - congruence. - Qed. - - End PowInv. - - - Import Morphisms. - - Global Instance encode_Proper : Proper (Logic.eq ==> eq) encode. - Proof using Type. - repeat intro; cbv [eq]. - rewrite !encode_rep. assumption. - Qed. - - Global Instance add_Proper : Proper (eq ==> eq ==> eq) add. - Proof using Type. - repeat intro. - cbv beta delta [eq] in *. - erewrite !add_rep; cbv [rep] in *; try reflexivity; assumption. - Qed. - - Global Instance sub_Proper mm mm_correct - : Proper (eq ==> eq ==> eq) (sub mm mm_correct). - Proof using Type. - repeat intro. - cbv beta delta [eq] in *. - erewrite !sub_rep; cbv [rep] in *; try reflexivity; assumption. - Qed. - - Global Instance opp_Proper mm mm_correct - : Proper (eq ==> eq) (opp mm mm_correct). - Proof using Type. - cbv [opp]; repeat intro. - apply sub_Proper; assumption || reflexivity. - Qed. - - Global Instance mul_Proper : Proper (eq ==> eq ==> eq) mul. - Proof using Type. - repeat intro. - cbv beta delta [eq] in *. - erewrite !mul_rep; cbv [rep] in *; try reflexivity; assumption. - Qed. - - Global Instance pow_Proper : Proper (eq ==> Logic.eq ==> eq) pow. - Proof using Type. - repeat intro. - cbv beta delta [eq] in *. - erewrite !pow_rep; cbv [rep] in *; subst; try reflexivity. - congruence. - Qed. - - Global Instance inv_Proper chain chain_correct : Proper (eq ==> eq) (inv chain chain_correct). - Proof using Type. - cbv [inv]; repeat intro. - apply pow_Proper; assumption || reflexivity. - Qed. - - Global Instance div_Proper : Proper (eq ==> eq ==> eq) div. - Proof using Type. - cbv [div]; repeat intro; congruence. - Qed. - - Section FieldProofs. - Context (modulus_gt_2 : 2 < modulus) - {sc : SubtractionCoefficient} - {ec : ExponentiationChain (modulus - 2)}. - - Lemma _zero_neq_one : not (eq zero one). - Proof using Type. - cbv [eq zero one]; erewrite !encode_rep. - pose proof (@F.field_modulo modulus prime_modulus). - apply zero_neq_one. - Qed. - - Lemma modular_base_system_field : - @field digits eq zero one (opp coeff coeff_mod) add (sub coeff coeff_mod) mul (inv chain chain_correct) div. - Proof using modulus_gt_2. - eapply (Field.isomorphism_to_subfield_field (phi := decode) (fieldR := @F.field_modulo modulus prime_modulus)). - Grab Existential Variables. - + intros; eapply encode_rep. - + intros; eapply encode_rep. - + intros; eapply encode_rep. - + intros; eapply inv_rep; auto. - + intros; eapply mul_rep; auto. - + intros; eapply sub_rep; auto using coeff_mod. - + intros; eapply add_rep; auto. - + intros; eapply opp_rep; auto using coeff_mod. - + eapply _zero_neq_one. - + trivial. - Qed. -End FieldProofs. - -End FieldOperationProofs. -Opaque encode add mul sub inv pow. - -Section CarryProofs. - Context `{prm : PseudoMersenneBaseParams}. - Local Notation base := (base_from_limb_widths limb_widths). - Local Notation log_cap i := (nth_default 0 limb_widths i). - Local Notation "u ~= x" := (rep u x). - Local Hint Resolve (@limb_widths_nonneg _ prm) sum_firstn_limb_widths_nonneg. - Local Hint Resolve log_cap_nonneg. - - Lemma base_length_lt_pred : (pred (length base) < length base)%nat. - Proof using Type. - pose proof limb_widths_nonnil; rewrite base_from_limb_widths_length. - destruct limb_widths; congruence || distr_length. - Qed. - Hint Resolve base_length_lt_pred. - - Definition carry_done us := forall i, (i < length base)%nat -> - 0 <= nth_default 0 us i /\ Z.shiftr (nth_default 0 us i) (log_cap i) = 0. - - Lemma carry_done_bounds : forall us, (length us = length base) -> - (carry_done us <-> forall i, 0 <= nth_default 0 us i < 2 ^ log_cap i). - Proof using Type. - intros ? ?; unfold carry_done; split; [ intros Hcarry_done i | intros Hbounds i i_lt ]. - + destruct (lt_dec i (length base)) as [i_lt | i_nlt]. - - specialize (Hcarry_done i i_lt). - split; [ intuition | ]. - destruct Hcarry_done as [Hnth_nonneg Hshiftr_0]. - apply Z.shiftr_eq_0_iff in Hshiftr_0. - destruct Hshiftr_0 as [nth_0 | [] ]; [ rewrite nth_0; zero_bounds | ]. - apply Z.log2_lt_pow2; auto. - - rewrite nth_default_out_of_bounds by omega. - split; zero_bounds. - + specialize (Hbounds i). - split; [ intuition | ]. - destruct Hbounds as [nth_nonneg nth_lt_pow2]. - apply Z.shiftr_eq_0_iff. - apply Z.le_lteq in nth_nonneg; destruct nth_nonneg; try solve [left; auto]. - right; split; auto. - apply Z.log2_lt_pow2; auto. - Qed. - - Lemma carry_decode_eq_reduce : forall us, - (length us = length limb_widths) -> - BaseSystem.decode base (carry_and_reduce (pred (length limb_widths)) us) mod modulus - = BaseSystem.decode base us mod modulus. - Proof using Type. - cbv [carry_and_reduce]; intros. - rewrite carry_gen_decode_eq; auto. - distr_length. - assert (0 < length limb_widths)%nat by (pose proof limb_widths_nonnil; - destruct limb_widths; distr_length; congruence). - break_match; repeat rewrite ?pred_mod, ?Nat.succ_pred,?Nat.mod_same in * by omega; - try omega. - rewrite !nth_default_base by (auto || destruct (length limb_widths); auto). - rewrite sum_firstn_0. - autorewrite with zsimplify. - match goal with |- appcontext[2 ^ ?a * ?b * 2 ^ ?c] => - replace (2 ^ a * b * 2 ^ c) with (2 ^ (a + c) * b) end. - { rewrite <-sum_firstn_succ by (apply nth_error_Some_nth_default; destruct (length limb_widths); auto). - rewrite Nat.succ_pred by omega. - remember (pred (length limb_widths)) as pred_len. - fold k. - rewrite <-Z.mul_sub_distr_r. - replace (c - 2 ^ k) with (modulus * -1) by (cbv [c]; ring). - rewrite <-Z.mul_assoc. - apply Z.mod_add_l'. - pose proof prime_modulus. Z.prime_bound. } - { rewrite Z.pow_add_r; auto using log_cap_nonneg, sum_firstn_limb_widths_nonneg. - rewrite <-!Z.mul_assoc. - apply Z.mul_cancel_l; try ring. - apply Z.pow_nonzero; (omega || auto using log_cap_nonneg). } - Qed. - - Lemma carry_rep : forall i us x, - (length us = length limb_widths)%nat -> - (i < length limb_widths)%nat -> - forall pf1 pf2, - from_list _ us pf1 ~= x -> from_list _ (carry i us) pf2 ~= x. - Proof using Type. - cbv [carry rep decode]; intros. - rewrite to_list_from_list. - pose proof carry_decode_eq_reduce. pose proof (@carry_simple_decode_eq limb_widths). - - specialize_by eauto. - cbv [ModularBaseSystemList.carry]. - break_match; subst; eauto. - apply F.eq_of_Z_iff. - rewrite to_list_from_list. - apply carry_decode_eq_reduce. auto. - cbv [ModularBaseSystemList.decode]. - apply F.eq_of_Z_iff. - rewrite to_list_from_list, carry_simple_decode_eq; try omega; distr_length; auto. - Qed. - Hint Resolve carry_rep. - - Lemma decode_mod_Fdecode : forall u, length u = length limb_widths -> - BaseSystem.decode base u mod modulus= F.to_Z (decode (from_list_default 0 _ u)). - Proof using Type. - intros. - rewrite <-(to_list_from_list _ u) with (pf := H). - erewrite Fdecode_decode_mod by reflexivity. - rewrite to_list_from_list. - rewrite from_list_default_eq with (pf := H). - reflexivity. - Qed. - - Lemma carry_sequence_rep : forall is us x, - (forall i, In i is -> (i < length limb_widths)%nat) -> - us ~= x -> forall pf, from_list _ (carry_sequence is (to_list _ us)) pf ~= x. - Proof using Type. - induction is; intros. - + cbv [carry_sequence fold_right]. rewrite from_list_to_list. assumption. - + simpl. apply carry_rep with (pf1 := length_carry_sequence (length_to_list us)); - auto using length_carry_sequence, length_to_list, in_eq. - apply IHis; auto using in_cons. - Qed. - - Context `{cc : CarryChain limb_widths}. - Lemma carry_mul_rep : forall us vs x y, - rep us x -> rep vs y -> - rep (carry_mul carry_chain us vs) (x * y)%F. - Proof using Type. - cbv [carry_mul]; intros; apply carry_sequence_rep; - auto using carry_chain_valid, mul_rep. - Qed. - - Lemma carry_sub_rep : forall coeff coeff_mod a b, - eq - (carry_sub carry_chain coeff coeff_mod a b) - (sub coeff coeff_mod a b). - Proof using Type. - cbv [carry_sub carry_]; intros. - eapply carry_sequence_rep; auto using carry_chain_valid. - reflexivity. - Qed. - - Lemma carry_add_rep : forall a b, - eq (carry_add carry_chain a b) (add a b). - Proof using Type. - cbv [carry_add carry_]; intros. - eapply carry_sequence_rep; auto using carry_chain_valid. - reflexivity. - Qed. - - Lemma carry_opp_rep : forall coeff coeff_mod a, - eq - (carry_opp carry_chain coeff coeff_mod a) - (opp coeff coeff_mod a). - Proof using Type. - cbv [carry_opp opp]; intros. - apply carry_sub_rep. - Qed. - -End CarryProofs. - -Hint Rewrite @length_carry_and_reduce @length_carry : distr_length. - -Class FreezePreconditions `{prm : PseudoMersenneBaseParams} B int_width := - { - lt_1_length_limb_widths : (1 < length limb_widths)%nat; - int_width_pos : 0 < int_width; - B_le_int_width : B <= int_width; - B_pos : 0 < B; - B_compat : forall w, In w limb_widths -> w < B; - (* on the first reduce step, we add at most one bit of width to the first digit *) - c_reduce1 : c * ((2 ^ B) >> nth_default 0 limb_widths (pred (length limb_widths))) <= 2 ^ (nth_default 0 limb_widths 0); - (* on the second reduce step, we add at most one bit of width to the first digit, - and leave room to carry c one more time after the highest bit is carried *) - c_reduce2 : c <= 2 ^ (nth_default 0 limb_widths 0) - c - }. - -Section CanonicalizationProofs. - Context `{freeze_pre : FreezePreconditions}. - Local Notation base := (base_from_limb_widths limb_widths). - Local Notation digits := (tuple Z (length limb_widths)). - Local Hint Resolve (@limb_widths_nonneg _ prm) sum_firstn_limb_widths_nonneg. - Local Hint Resolve log_cap_nonneg. - Local Notation "u [ i ]" := (nth_default 0 u i). - Local Notation "u {{ i }}" := (carry_sequence (make_chain i) u) (at level 30). (* Can't rely on [Reserved Notation]: https://coq.inria.fr/bugs/show_bug.cgi?id=4970 *) - - Lemma nth_default_carry_and_reduce_full : forall n i us, - (carry_and_reduce i us) [n] - = if lt_dec n (length us) - then - (if eq_nat_dec n (i mod length limb_widths) - then Z.pow2_mod (us [n]) (limb_widths [n]) - else us [n]) + - if eq_nat_dec n (S (i mod length limb_widths) mod length limb_widths) - then c * (us [i mod length limb_widths]) >> (limb_widths [i mod length limb_widths]) - else 0 - else 0. - Proof using Type. - cbv [carry_and_reduce]; intros. - autorewrite with push_nth_default. - reflexivity. - Qed. - Hint Rewrite @nth_default_carry_and_reduce_full : push_nth_default. - - Lemma nth_default_carry_full : forall n i us, - length us = length limb_widths -> - (carry i us) [n] - = if lt_dec n (length us) - then - if eq_nat_dec i (pred (length limb_widths)) - then (if eq_nat_dec n i - then Z.pow2_mod (us [n]) (limb_widths [n]) - else us [n]) + - if eq_nat_dec n 0 - then c * ((us [i]) >> (limb_widths [i])) - else 0 - else if eq_nat_dec n i - then Z.pow2_mod (us [n]) (limb_widths [n]) - else us [n] + - if eq_nat_dec n (S i) - then (us [i]) >> (limb_widths [i]) - else 0 - else 0. - Proof using Type*. - intros. - cbv [carry]. - break_innermost_match_step. - + subst i. - pose proof lt_1_length_limb_widths. - autorewrite with push_nth_default natsimplify. - destruct (eq_nat_dec (length limb_widths) (length us)); congruence. - + autorewrite with push_nth_default; reflexivity. - Qed. - Hint Rewrite @nth_default_carry_full : push_nth_default. - - Lemma nth_default_carry_sequence_make_chain_full : forall i n us, - length us = length limb_widths -> - (i <= length limb_widths)%nat -> - us {{ i }} [n] - = if lt_dec n (length limb_widths) - then - if eq_nat_dec i 0 - then nth_default 0 us n - else - if lt_dec i (length limb_widths) - then - if lt_dec n i - then - if eq_nat_dec n (pred i) - then Z.pow2_mod (us {{ pred i }} [n]) (limb_widths [n]) - else us{{ pred i }} [n] - else us{{ pred i}} [n] + - (if eq_nat_dec n i - then (us{{ pred i}} [pred i]) >> (limb_widths [pred i]) - else 0) - else - if lt_dec n (pred i) - then us {{ pred i }} [n] + - (if eq_nat_dec n 0 - then c * (us{{ pred i}} [pred i]) >> (limb_widths [pred i]) - else 0) - else Z.pow2_mod (us {{ pred i }} [n]) (limb_widths [n]) - else 0. - Proof using Type*. - induction i; intros; cbv [carry_sequence]. - + cbv [pred make_chain fold_right]. - break_match; subst; omega || reflexivity || auto using Z.add_0_r. - apply nth_default_out_of_bounds. omega. - + replace (make_chain (S i)) with (i :: make_chain i) by reflexivity. - rewrite fold_right_cons. - pose proof lt_1_length_limb_widths. - autorewrite with push_nth_default natsimplify; - rewrite ?Nat.pred_succ; fold (carry_sequence (make_chain i) us); - rewrite length_carry_sequence; auto. - repeat (break_innermost_match_step; try omega). - Qed. - - Lemma nth_default_carry : forall i us, - length us = length limb_widths -> - (i < length us)%nat -> - nth_default 0 (carry i us) i - = Z.pow2_mod (us [i]) (limb_widths [i]). - Proof using Type*. - intros; pose proof lt_1_length_limb_widths; autorewrite with push_nth_default natsimplify; break_match; omega. - Qed. - Hint Rewrite @nth_default_carry using (omega || distr_length; omega) : push_nth_default. - - Lemma pow_limb_widths_gt_1 : forall i, (i < length limb_widths)%nat -> - 1 < 2 ^ limb_widths [i]. - Proof using Type. - intros. - apply Z.pow_gt_1; try omega. - apply nth_default_preserves_properties_length_dep; intros; try omega. - auto using limb_widths_pos. - Qed. - - Lemma carry_sequence_nil_l : forall us, carry_sequence nil us = us. - Proof using Type. - reflexivity. - Qed. - - Ltac bound_during_loop := - repeat match goal with - | |- _ => progress (intros; subst) - | |- _ => unique pose proof lt_1_length_limb_widths - | |- _ => unique pose proof c_reduce2 - | |- _ => break_innermost_match_step; try omega - | |- _ => break_innermost_match_hyps_step; try omega - | |- _ => progress simpl pred in * - | |- _ => progress rewrite ?Z.add_0_r, ?Z.sub_0_r in * - | |- _ => rewrite nth_default_out_of_bounds by omega - | |- _ => rewrite nth_default_carry_sequence_make_chain_full by auto - | H : forall n, 0 <= _ [n] < _ |- appcontext [ _ [?n] ] => pose proof (H (pred n)); specialize (H n) - | H : forall n, (n < ?m)%nat -> 0 <= _ [n] < _ |- appcontext [ _ [?n] ] => pose proof (H (pred n)); specialize (H n); specialize_by omega - | |- appcontext [make_chain 0] => simpl make_chain; rewrite carry_sequence_nil_l - | |- 0 <= ?a + c * ?b < 2 * ?d => unique assert (c * b <= d); - [ | solve [pose proof c_pos; rewrite <-Z.add_diag; split; zero_bounds] ] - | |- c * (?e >> (limb_widths[?i])) <= ?b => - pose proof (Z.shiftr_le e (2 ^ B) (limb_widths [i])); specialize_by (auto || omega); - replace (limb_widths [i]) with (limb_widths [pred (length limb_widths)]) in * by (f_equal; omega); - etransitivity; [ | apply c_reduce1]; apply Z.mul_le_mono_pos_l; try apply c_pos; omega - | H : 0 <= _ < ?b - (?c >> ?d) |- 0 <= _ + (?e >> ?d) < ?b => - pose proof (Z.shiftr_le e c d); specialize_by (auto || omega); solve [split; zero_bounds] - | IH : forall n, _ -> 0 <= ?u {{ ?i }} [n] < _ - |- 0 <= ?u {{ ?i }} [?n] < _ => specialize (IH n) - | IH : forall n, _ -> 0 <= ?u {{ ?i }} [n] < _ - |- appcontext [(?u {{ ?i }} [?n]) >> _] => pose proof (IH 0%nat); pose proof (IH (S n)); specialize (IH n); specialize_by omega - | H : 0 <= ?a < 2 ^ ?n + ?x |- appcontext [?a >> ?n] => - assert (x < 2 ^ n) by (omega || auto using pow_limb_widths_gt_1); - unique assert (0 <= a < 2 * 2 ^ n) by omega - | H : 0 <= ?a < 2 * 2 ^ ?n |- appcontext [?a >> ?n] => - pose proof c_pos; - apply Z.lt_mul_2_pow_2_shiftr in H; (break_innermost_match_step || break_innermost_match_hyps_step); rewrite H; omega - | H : 0 <= ?a < 2 ^ ?n |- appcontext [?a >> ?n] => - pose proof c_pos; - apply Z.lt_pow_2_shiftr in H; rewrite H; omega - | |- 0 <= Z.pow2_mod _ _ < c => - rewrite Z.pow2_mod_spec, Z.lt_mul_2_mod_sub; auto; omega - | |- _ => apply Z.pow2_mod_pos_bound, limb_widths_pos, nth_default_preserves_properties_length_dep; [tauto | omega] - | |- 0 <= 0 < _ => solve[split; zero_bounds] - | |- _ => omega - end. - - Lemma bound_during_first_loop : forall us, - length us = length limb_widths -> - (forall n, (n < length limb_widths)%nat -> 0 <= us [n] < 2 ^ B - if eq_nat_dec n 0 then 0 else ((2 ^ B) >> (limb_widths [pred n]))) -> - forall i n, - (i <= length limb_widths)%nat -> - 0 <= us{{i}}[n] < if eq_nat_dec i 0 then us[n] + 1 else - if lt_dec i (length limb_widths) - then - if lt_dec n i - then 2 ^ (limb_widths [n]) - else if eq_nat_dec n i - then 2 ^ B - else us[n] + 1 - else - if eq_nat_dec n 0 - then 2 * 2 ^ limb_widths [n] - else 2 ^ limb_widths [n]. - Proof using Type*. - induction i; bound_during_loop. - Qed. - - Lemma bound_after_loop_length_preconditions : forall us (Hlength : length us = length limb_widths) - {bound bound' bound'' : list Z -> nat -> Z} - {X Y : list Z -> nat -> nat -> Z} f, - (forall us, length us = length limb_widths - -> (forall n, (n < length limb_widths)%nat -> 0 <= us [n] < bound' us n) - -> forall i n, (i <= length limb_widths)%nat - -> 0 <= us{{i}}[n] < if eq_nat_dec i 0 then X us i n else - if lt_dec i (length limb_widths) - then Y us i n - else bound'' us n) -> - ((forall n, (n < length limb_widths)%nat -> 0 <= us [n] < bound us n) - -> forall n, (n < length limb_widths)%nat -> 0 <= (f us) [n] < bound' (f us) n) -> - length (f us) = length limb_widths -> - (forall n, (n < length limb_widths)%nat -> 0 <= us [n] < bound us n) - -> forall n, (n < length limb_widths)%nat -> 0 <= (carry_full (f us)) [n] < bound'' (f us) n. - Proof using Type*. - pose proof lt_1_length_limb_widths. - cbv [carry_full full_carry_chain]; intros ? ? ? ? ? ? ? ? Hloop Hfbound Hflength Hbound n. - specialize (Hfbound Hbound). - specialize (Hloop (f us) Hflength Hfbound (length limb_widths) n). - specialize_by omega. - repeat (omega || break_innermost_match_step || break_innermost_match_hyps_step). - Qed. - - Lemma bound_after_loop : forall us (Hlength : length us = length limb_widths) - {bound bound' bound'' : list Z -> nat -> Z} - {X Y : list Z -> nat -> nat -> Z} f, - (forall us, length us = length limb_widths - -> (forall n, 0 <= us [n] < bound' us n) - -> forall i n, (i <= length limb_widths)%nat - -> 0 <= us{{i}}[n] < if eq_nat_dec i 0 then X us i n else - if lt_dec i (length limb_widths) - then Y us i n - else bound'' us n) -> - ((forall n, (n < length limb_widths)%nat -> 0 <= us [n] < bound us n) - -> forall n, 0 <= (f us) [n] < bound' (f us) n) - -> length (f us) = length limb_widths - -> (forall n, (n < length limb_widths)%nat -> 0 <= us [n] < bound us n) - -> forall n, 0 <= (carry_full (f us)) [n] < bound'' (f us) n. - Proof using Type*. - pose proof lt_1_length_limb_widths. - cbv [carry_full full_carry_chain]; intros ? ? ? ? ? ? ? ? Hloop Hfbound Hflength Hbound n. - specialize (Hfbound Hbound). - specialize (Hloop (f us) Hflength Hfbound (length limb_widths) n). - specialize_by omega. - repeat (omega || break_innermost_match_step || break_innermost_match_hyps_step). - Qed. - - Lemma bound_after_first_loop_pre : forall us, - length us = length limb_widths -> - (forall n, (n < length limb_widths)%nat -> 0 <= us [n] < 2 ^ B - if eq_nat_dec n 0 then 0 else ((2 ^ B) >> (limb_widths [pred n]))) -> - forall n, (n < length limb_widths)%nat -> - 0 <= (carry_full us)[n] < - if eq_nat_dec n 0 - then 2 * 2 ^ limb_widths [n] - else 2 ^ limb_widths [n]. - Proof using Type*. - intros ? ?. - apply (bound_after_loop_length_preconditions us H id bound_during_first_loop); auto. - Qed. - - Lemma bound_after_first_loop : forall us, - length us = length limb_widths -> - (forall n, (n < length limb_widths)%nat -> 0 <= us [n] < 2 ^ B - if eq_nat_dec n 0 then 0 else ((2 ^ B) >> (limb_widths [pred n]))) -> - forall n, - 0 <= (carry_full us)[n] < - if eq_nat_dec n 0 - then 2 * 2 ^ limb_widths [n] - else 2 ^ limb_widths [n]. - Proof using Type*. - intros. - destruct (lt_dec n (length limb_widths)); - auto using bound_after_first_loop_pre. - rewrite !nth_default_out_of_bounds by (rewrite ?length_carry_full; omega). - autorewrite with zsimplify. - rewrite Z.pow_0_r. - break_innermost_match_step; omega. - Qed. - - Lemma bound_during_second_loop : forall us, - length us = length limb_widths -> - (forall n, 0 <= us [n] < if eq_nat_dec n 0 then 2 * 2 ^ limb_widths [n] else 2 ^ limb_widths [n]) -> - forall i n, - (i <= length limb_widths)%nat -> - 0 <= us{{i}}[n] < if eq_nat_dec i 0 then us[n] + 1 else - if lt_dec i (length limb_widths) - then - if lt_dec n i - then 2 ^ (limb_widths [n]) - else if eq_nat_dec n i - then 2 * 2 ^ limb_widths [n] - else us[n] + 1 - else - if eq_nat_dec n 0 - then 2 ^ limb_widths [n] + c - else 2 ^ limb_widths [n]. - Proof using Type*. - induction i; bound_during_loop. - Qed. - - Lemma bound_after_second_loop : forall us, - length us = length limb_widths -> - (forall n, (n < length limb_widths)%nat -> 0 <= us [n] < 2 ^ B - if eq_nat_dec n 0 then 0 else ((2 ^ B) >> (limb_widths [pred n]))) -> - forall n, - 0 <= (carry_full (carry_full us)) [n] < - if eq_nat_dec n 0 - then 2 ^ limb_widths [n] + c - else 2 ^ limb_widths [n]. - Proof using Type*. - intros ? ?; apply (bound_after_loop us H carry_full bound_during_second_loop); - auto using length_carry_full, bound_after_first_loop. - Qed. - - Lemma bound_during_third_loop : forall us, - length us = length limb_widths -> - (forall n, 0 <= us [n] < if eq_nat_dec n 0 then 2 ^ limb_widths [n] + c else 2 ^ limb_widths [n]) -> - forall i n, - (i <= length limb_widths)%nat -> - 0 <= us{{i}}[n] < if eq_nat_dec i 0 then us[n] + 1 else - if lt_dec i (length limb_widths) - then - if Z_lt_dec (us [0]) (2 ^ limb_widths [0]) - then - 2 ^ limb_widths [n] - else - if eq_nat_dec n 0 - then c - else - if lt_dec n i - then 2 ^ limb_widths [n] - else if eq_nat_dec n i - then 2 ^ limb_widths [n] + 1 - else us[n] + 1 - else - 2 ^ limb_widths [n]. - Proof using Type*. - induction i; bound_during_loop. - Qed. - - Lemma bound_after_third_loop : forall us, - length us = length limb_widths -> - (forall n, (n < length limb_widths)%nat -> 0 <= us [n] < 2 ^ B - if eq_nat_dec n 0 then 0 else ((2 ^ B) >> (limb_widths [pred n]))) -> - forall n, - 0 <= (carry_full (carry_full (carry_full us))) [n] < 2 ^ limb_widths [n]. - Proof using Type*. - intros ? ?. - apply (bound_after_loop us H (fun x => carry_full (carry_full x)) bound_during_third_loop); - auto using length_carry_full, bound_after_second_loop. - Qed. - - Local Notation initial_bounds u := - (forall n : nat, (n < length limb_widths)%nat -> - 0 <= to_list (length limb_widths) u [n] < - 2 ^ B - - (if eq_nat_dec n 0 - then 0 - else (2 ^ B) >> (limb_widths [pred n]))). - Local Notation minimal_rep u := ((bounded limb_widths (to_list (length limb_widths) u)) - /\ (ge_modulus (to_list _ u) = 0)). - - Lemma decode_bitwise_eq_iff : forall u v, minimal_rep u -> minimal_rep v -> - (fieldwise Logic.eq u v <-> - decode_bitwise limb_widths (to_list _ u) = decode_bitwise limb_widths (to_list _ v)). - Proof using Type. - intros. - rewrite !decode_bitwise_spec by (tauto || auto using length_to_list). - rewrite fieldwise_to_list_iff. - split; intros. - + apply decode_Proper; auto. - + apply Forall2_forall_iff with (d := 0); intros; repeat rewrite @length_to_list in *; auto. - erewrite digit_select with (us := to_list _ u) by intuition eauto. - erewrite digit_select with (us := to_list _ v) by intuition eauto. - rewrite H1; reflexivity. - Qed. - - Lemma c_upper_bound : c - 1 < 2 ^ limb_widths[0]. - Proof using Type*. - pose proof c_reduce2. pose proof c_pos. - omega. - Qed. - Hint Resolve c_upper_bound. - - Lemma minimal_rep_encode : forall x, minimal_rep (encode x). - Proof using Type*. - split; intros; auto using bounded_encode. - apply ge_modulus_spec; auto using bounded_encode, length_to_list. - apply encode_range. - Qed. - - Lemma encode_minimal_rep : forall u x, rep u x -> minimal_rep u -> - fieldwise Logic.eq u (encode x). - Proof using Type*. - intros. - apply decode_bitwise_eq_iff; auto using minimal_rep_encode. - rewrite !decode_bitwise_spec by (intuition auto; distr_length; try apply minimal_rep_encode). - apply Fdecode_decode_mod in H. - pose proof (Fdecode_decode_mod _ _ (encode_rep x)). - rewrite Z.mod_small in H by (apply ge_modulus_spec; distr_length; intuition auto). - rewrite Z.mod_small in H1 by (apply ge_modulus_spec; distr_length; auto using c_upper_bound; apply minimal_rep_encode). - congruence. - Qed. - - Lemma bounded_canonical : forall u v x y, rep u x -> rep v y -> - minimal_rep u -> minimal_rep v -> - (x = y <-> fieldwise Logic.eq u v). - Proof using Type*. - intros. - eapply encode_minimal_rep in H1; eauto. - eapply encode_minimal_rep in H2; eauto. - split; intros; subst. - + etransitivity; eauto; symmetry; eauto. - + assert (fieldwise Logic.eq (encode x) (encode y)) by - (transitivity u; [symmetry; eauto | ]; transitivity v; eauto). - apply decode_bitwise_eq_iff in H4; try apply minimal_rep_encode. - rewrite !decode_bitwise_spec in H4 by (auto; distr_length; apply minimal_rep_encode). - apply F.eq_to_Z_iff. - erewrite <-!Fdecode_decode_mod by eapply encode_rep. - congruence. - Qed. - - Lemma int_width_compat : forall x, In x limb_widths -> x < int_width. - Proof using Type*. - intros. apply B_compat in H. - eapply Z.lt_le_trans; eauto using B_le_int_width. - Qed. - - Lemma minimal_rep_freeze : forall u, initial_bounds u -> - minimal_rep (freeze int_width u). - Proof using Type*. - repeat match goal with - | |- _ => progress (cbv [freeze ModularBaseSystemList.freeze]) - | |- _ => progress intros - | |- minimal_rep _ => split - | |- _ => rewrite to_list_from_list - | |- _ => apply bound_after_third_loop - | |- _ => apply conditional_subtract_lt_modulus - | |- _ => apply conditional_subtract_modulus_preserves_bounded - | |- bounded _ (carry_full _) => apply bounded_iff - | |- _ => solve [auto using Z.lt_le_incl, int_width_pos, int_width_compat, lt_1_length_limb_widths, length_carry_full, length_to_list] - end. - Qed. - - Lemma freeze_decode : forall u, - BaseSystem.decode base (to_list _ (freeze int_width u)) mod modulus = - BaseSystem.decode base (to_list _ u) mod modulus. - Proof using Type*. - repeat match goal with - | |- _ => progress cbv [freeze ModularBaseSystemList.freeze] - | |- _ => progress intros - | |- _ => rewrite <-Z.add_opp_r, <-Z.mul_opp_l - | |- _ => rewrite Z.mod_add by (pose proof prime_modulus; prime_bound) - | |- _ => rewrite to_list_from_list - | |- _ => rewrite conditional_subtract_modulus_spec by - (auto using Z.lt_le_incl, int_width_pos, int_width_compat, lt_1_length_limb_widths, length_carry_full, length_to_list, ge_modulus_01) - end. - rewrite !decode_mod_Fdecode by auto using length_carry_full, length_to_list. - cbv [carry_full]. - apply F.eq_to_Z_iff. - rewrite <-@to_list_from_list with (pf := length_carry_sequence (length_carry_sequence (length_to_list _))). - rewrite from_list_default_eq with (pf := length_carry_sequence (length_to_list _)). - rewrite carry_sequence_rep; try reflexivity; try apply make_chain_lt. - cbv [rep]. - rewrite <-from_list_default_eq with (d := 0). - erewrite <-to_list_from_list with (pf := length_carry_sequence (length_to_list _)). - rewrite from_list_default_eq with (pf := length_carry_sequence (length_to_list _)). - rewrite carry_sequence_rep; try reflexivity; try apply make_chain_lt. - cbv [rep]. - rewrite carry_sequence_rep; try reflexivity; try apply make_chain_lt. - rewrite from_list_default_eq with (pf := length_to_list _). - rewrite from_list_to_list; reflexivity. - Qed. - - Lemma freeze_rep : forall u x, rep u x -> rep (freeze int_width u) x. - Proof using Type*. - cbv [rep]; intros. - apply F.eq_to_Z_iff. - erewrite <-!Fdecode_decode_mod by eauto. - apply freeze_decode. - Qed. - - Lemma freeze_canonical : forall u v x y, rep u x -> rep v y -> - initial_bounds u -> - initial_bounds v -> - (x = y <-> fieldwise Logic.eq (freeze int_width u) (freeze int_width v)). - Proof using Type*. - intros; apply bounded_canonical; auto using freeze_rep, minimal_rep_freeze. - Qed. - -End CanonicalizationProofs. - -Section SquareRootProofs. - Context `{freeze_pre : FreezePreconditions}. - Local Notation "u ~= x" := (rep u x). - Local Notation digits := (tuple Z (length limb_widths)). - Local Notation base := (base_from_limb_widths limb_widths). - Local Hint Resolve (@limb_widths_nonneg _ prm) sum_firstn_limb_widths_nonneg. - - Definition freeze_input_bounds n := - (2 ^ B - - (if eq_nat_dec n 0 - then 0 - else (2 ^ B) >> (nth_default 0 limb_widths (pred n)))). - Definition bounded_by u bounds := - (forall n : nat, (n < length limb_widths)%nat -> - 0 <= nth_default 0 (to_list (length limb_widths) u) n < bounds n). - - Lemma eqb_true_iff : forall u v x y, - bounded_by u freeze_input_bounds -> bounded_by v freeze_input_bounds -> - u ~= x -> v ~= y -> (x = y <-> eqb int_width u v = true). - Proof using Type*. - cbv [eqb freeze_input_bounds]. intros. - rewrite fieldwiseb_fieldwise by (apply Z.eqb_eq). - eauto using freeze_canonical. - Qed. - - Lemma eqb_false_iff : forall u v x y, - bounded_by u freeze_input_bounds -> bounded_by v freeze_input_bounds -> - u ~= x -> v ~= y -> (x <> y <-> eqb int_width u v = false). - Proof using Type*. - intros. - case_eq (eqb int_width u v). - + rewrite <-eqb_true_iff by eassumption; split; intros; - congruence || contradiction. - + split; intros; auto. - intro Hfalse_eq; - rewrite (eqb_true_iff u v) in Hfalse_eq by eassumption. - congruence. - Qed. - - Section Sqrt3mod4. - Context (modulus_3mod4 : modulus mod 4 = 3). - Context {ec : ExponentiationChain (modulus / 4 + 1)}. - - Lemma sqrt_3mod4_correct : forall u x, u ~= x -> - (sqrt_3mod4 chain chain_correct u) ~= F.sqrt_3mod4 x. - Proof using Type. - repeat match goal with - | |- _ => progress (cbv [sqrt_3mod4 F.sqrt_3mod4]; intros) - | |- _ => rewrite @F.pow_2_r in * - | |- _ => rewrite eqb_correct in * by eassumption - | |- _ => rewrite <-chain_correct; apply pow_rep; eassumption - end. - Qed. - End Sqrt3mod4. - - Section Sqrt5mod8. - Context (modulus_5mod8 : modulus mod 8 = 5). - Context {ec : ExponentiationChain (modulus / 8 + 1)}. - Context (sqrt_m1 : digits) (sqrt_m1_correct : mul sqrt_m1 sqrt_m1 ~= F.opp 1%F). - - Lemma sqrt_5mod8_correct : forall u x powx powx_squared, u ~= x -> - bounded_by u freeze_input_bounds -> - bounded_by powx_squared freeze_input_bounds -> - ModularBaseSystem.eq powx (pow u chain) -> - ModularBaseSystem.eq powx_squared (mul powx powx) -> - (sqrt_5mod8 int_width powx powx_squared chain chain_correct sqrt_m1 u) ~= F.sqrt_5mod8 (decode sqrt_m1) x. - Proof using freeze_pre. - cbv [sqrt_5mod8 F.sqrt_5mod8]. - intros. - repeat match goal with - | |- _ => progress (cbv [sqrt_5mod8 F.sqrt_5mod8]; intros) - | |- _ => rewrite @F.pow_2_r in * - | |- _ => rewrite eqb_correct in * by eassumption - | |- (if eqb _ ?a ?b then _ else _) ~= - (if dec (?c = _) then _ else _) => - assert (a ~= c) by - (cbv [rep]; rewrite <-chain_correct, <-pow_rep, <-mul_rep; - eassumption); break_innermost_match - | |- _ => apply mul_rep; try reflexivity; - rewrite <-chain_correct, <-pow_rep; eassumption - | |- _ => rewrite <-chain_correct, <-pow_rep; eassumption - | H : eqb _ ?a ?b = true, H1 : ?b ~= ?y, H2 : ?a ~= ?x |- _ => - rewrite <-(eqb_true_iff a b x y) in H by eassumption - | H : eqb _ ?a ?b = false, H1 : ?b ~= ?y, H2 : ?a ~= ?x |- _ => - rewrite <-(eqb_false_iff a b x y) in H by eassumption - | |- _ => congruence - end. - Qed. - End Sqrt5mod8. - -End SquareRootProofs. - -Section ConversionProofs. - Context `{prm :PseudoMersenneBaseParams}. - Context {target_widths} - (target_widths_nonneg : forall x, In x target_widths -> 0 <= x) - (bits_eq : sum_firstn limb_widths (length limb_widths) = - sum_firstn target_widths (length target_widths)). - Local Notation target_base := (base_from_limb_widths target_widths). - - Lemma pack_rep : forall w, - bounded limb_widths (to_list _ w) -> - bounded target_widths (to_list _ w) -> - rep w (F.of_Z modulus - (BaseSystem.decode - target_base - (to_list _ (pack target_widths_nonneg bits_eq w)))). - Proof using Type. - intros; cbv [pack ModularBaseSystemList.pack rep]. - rewrite Tuple.to_list_from_list. - apply F.eq_to_Z_iff. - rewrite F.to_Z_of_Z. - rewrite <-Conversion.convert_correct; auto using length_to_list. - Qed. - - Lemma unpack_rep : forall w, - bounded target_widths (to_list _ w) -> - rep (unpack target_widths_nonneg bits_eq w) - (F.of_Z modulus (BaseSystem.decode target_base (to_list _ w))). - Proof using Type. - intros; cbv [unpack ModularBaseSystemList.unpack rep]. - apply F.eq_to_Z_iff. - rewrite <-from_list_default_eq with (d := 0). - rewrite <-decode_mod_Fdecode by apply Conversion.length_convert. - rewrite F.to_Z_of_Z. - rewrite <-Conversion.convert_correct; auto using length_to_list. - Qed. - - -End ConversionProofs. diff --git a/src/ModularArithmetic/ModularBaseSystemWord.v b/src/ModularArithmetic/ModularBaseSystemWord.v deleted file mode 100644 index 9283bfb30..000000000 --- a/src/ModularArithmetic/ModularBaseSystemWord.v +++ /dev/null @@ -1,23 +0,0 @@ -Require Import Coq.ZArith.ZArith. -Require Import Coq.Lists.List. -Require Import Crypto.Util.Tuple. -Require Import Crypto.Util.Notations. -Require Import Bedrock.Word. -Local Open Scope Z_scope. - -Section conditional_subtract_modulus. - Context {int_width num_limbs : nat}. - Local Notation limb := (word int_width). - Local Notation digits := (tuple limb num_limbs). - Local Notation zero := (natToWord int_width 0). - Local Notation one := (natToWord int_width 1). - Local Notation "u [ i ]" := (nth_default zero u i). - Context (modulus : digits). - Context (ge_modulusW : digits -> limb) (negW : limb -> limb). - - Definition conditional_subtract_modulusW (us : digits) := - (* [and_term] is all ones if us' is full, so the subtractions subtract q overall. - Otherwise, it's all zeroes, and the subtractions do nothing. *) - map2 (fun x y => wminus x y) us (map (wand (negW (ge_modulusW us))) modulus). - -End conditional_subtract_modulus.
\ No newline at end of file diff --git a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v deleted file mode 100644 index 85ed920a2..000000000 --- a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v +++ /dev/null @@ -1,99 +0,0 @@ -Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith. -Require Import Coq.Lists.List. -Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. -Require Import Crypto.ModularArithmetic.ExtendedBaseVector. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.Tactics.VerdiTactics. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. -Require Import Crypto.BaseSystem. -Require Import Crypto.BaseSystemProofs. -Require Import Crypto.ModularArithmetic.Pow2Base. -Require Import Crypto.ModularArithmetic.Pow2BaseProofs. -Require Crypto.BaseSystem. -Local Open Scope Z_scope. - -Section PseudoMersenneBaseParamProofs. - Context `{prm : PseudoMersenneBaseParams}. - Local Notation base := (base_from_limb_widths limb_widths). - - Lemma limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w. - Proof using Type. auto using Z.lt_le_incl, limb_widths_pos. Qed. - - Lemma k_nonneg : 0 <= k. - Proof using Type. apply sum_firstn_limb_widths_nonneg, limb_widths_nonneg. Qed. - - Lemma lt_modulus_2k : modulus < 2 ^ k. - Proof using Type. - replace (2 ^ k) with (modulus + c) by (unfold c; ring). - pose proof c_pos; omega. - Qed. Hint Resolve lt_modulus_2k. - - Lemma modulus_pos : 0 < modulus. - Proof using Type*. - pose proof (NumTheoryUtil.lt_1_p _ prime_modulus); omega. - Qed. Hint Resolve modulus_pos. - - Lemma modulus_nonzero : Z.pos modulus <> 0. - Proof using Type*. - - pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega. - Qed. - - (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *) - Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus. - Proof using Type. - intros. - replace (2^k) with ((2^k - c) + c) by ring. - rewrite Z.mul_add_distr_r, Zplus_mod. - unfold c. - rewrite Z.sub_sub_distr, Z.sub_diag. - rewrite Z.mul_comm, Z.mod_add_l; auto using modulus_nonzero. - rewrite <- Zplus_mod; auto. - Qed. - - Lemma pseudomersenne_add': forall x y0 y1 z, (z - x + ((2^k) * y0 * y1)) mod modulus = (c * y0 * y1 - x + z) mod modulus. - Proof using Type. - intros; rewrite <- !Z.add_opp_r, <- !Z.mul_assoc, pseudomersenne_add; apply f_equal2; omega. - Qed. - - Lemma extended_shiftadd: forall (us : digits), - decode (ext_base limb_widths) us = - decode base (firstn (length base) us) - + (2 ^ k * decode base (skipn (length base) us)). - Proof using Type. - intros. - unfold decode; rewrite <- mul_each_rep. - rewrite ext_base_alt by apply limb_widths_nonneg. - fold k; fold (mul_each (2 ^ k) base). - rewrite base_mul_app. - rewrite <- mul_each_rep; auto. - Qed. - - Global Instance bv : BaseSystem.BaseVector base := { - base_positive := base_positive limb_widths_nonneg; - b0_1 := fun x => b0_1 x limb_widths_nonnil; - base_good := base_good limb_widths_nonneg limb_widths_good - }. - - Lemma nth_default_base_positive : forall i, (i < length base)%nat -> - nth_default 0 base i > 0. - Proof using Type. - intros. - pose proof (nth_error_length_exists_value _ _ H). - destruct H0. - pose proof (nth_error_value_In _ _ _ H0). - pose proof (BaseSystem.base_positive _ H1). - unfold nth_default. - rewrite H0; auto. - Qed. - - Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat -> - nth_default 0 base (S i) = nth_default 0 base i * - (nth_default 0 base (S i) / nth_default 0 base i). - Proof using Type. - intros. - apply Z_div_exact_2; try (apply nth_default_base_positive; omega). - apply base_succ; distr_length; eauto using limb_widths_nonneg. - Qed. - -End PseudoMersenneBaseParamProofs. diff --git a/src/ModularArithmetic/PseudoMersenneBaseParams.v b/src/ModularArithmetic/PseudoMersenneBaseParams.v deleted file mode 100644 index 6f6fd6556..000000000 --- a/src/ModularArithmetic/PseudoMersenneBaseParams.v +++ /dev/null @@ -1,24 +0,0 @@ -Require Import Coq.ZArith.ZArith. -Require Import Coq.Lists.List. -Require Import Crypto.Util.ListUtil. -Require Crypto.BaseSystem. -Local Open Scope Z_scope. - -Class PseudoMersenneBaseParams (modulus : positive) := { - limb_widths : list Z; - limb_widths_pos : forall w, In w limb_widths -> 0 < w; - limb_widths_nonnil : limb_widths <> nil; - limb_widths_good : forall i j, (i + j < length limb_widths)%nat -> - sum_firstn limb_widths (i + j) <= - sum_firstn limb_widths i + sum_firstn limb_widths j; - prime_modulus : Znumtheory.prime (Z.pos modulus); - k := sum_firstn limb_widths (length limb_widths); - c := 2 ^ k - (Z.pos modulus); - c_pos : 0 < c; - limb_widths_match_modulus : forall i j, - (i < length limb_widths)%nat -> - (j < length limb_widths)%nat -> - (i + j >= length limb_widths)%nat -> - let w_sum := sum_firstn limb_widths in - k + w_sum (i + j - length limb_widths)%nat <= w_sum i + w_sum j -}. diff --git a/src/Reflection/Z/ArithmeticSimplifier.v b/src/Reflection/Z/ArithmeticSimplifier.v index 4e5d6126e..821ef8459 100644 --- a/src/Reflection/Z/ArithmeticSimplifier.v +++ b/src/Reflection/Z/ArithmeticSimplifier.v @@ -175,9 +175,6 @@ Section language. | Lor _ _ _ as opc | OpConst _ _ as opc | Opp _ _ as opc - | Neg _ _ _ as opc - | Cmovne _ _ _ _ _ as opc - | Cmovle _ _ _ _ _ as opc => Op opc end. End with_var. diff --git a/src/Reflection/Z/Bounds/Interpretation.v b/src/Reflection/Z/Bounds/Interpretation.v index 3d6d65c98..69670bee0 100644 --- a/src/Reflection/Z/Bounds/Interpretation.v +++ b/src/Reflection/Z/Bounds/Interpretation.v @@ -140,9 +140,6 @@ Module Import Bounds. | Land _ _ T => fun xy => land (bit_width_of_base_type T) (fst xy) (snd xy) | Lor _ _ T => fun xy => lor (bit_width_of_base_type T) (fst xy) (snd xy) | Opp _ T => fun x => opp (bit_width_of_base_type T) x - | Neg _ T int_width => fun x => neg (bit_width_of_base_type T) int_width x - | Cmovne _ _ _ _ T => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovne (bit_width_of_base_type T) x y z w - | Cmovle _ _ _ _ T => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovle (bit_width_of_base_type T) x y z w end%bounds. Definition of_Z (z : Z) : t := ZToZRange z. diff --git a/src/Reflection/Z/Bounds/InterpretationLemmas.v b/src/Reflection/Z/Bounds/InterpretationLemmas.v index 0c7791a2f..11a6ea91e 100644 --- a/src/Reflection/Z/Bounds/InterpretationLemmas.v +++ b/src/Reflection/Z/Bounds/InterpretationLemmas.v @@ -241,20 +241,6 @@ Proof. | word_arith_t ]. Qed. -Local Ltac t_special_case_op_step := - first [ fin_t - | progress intros - | progress subst - | progress simpl in * - | progress split_andb - | progress Zarith_t_step - | specializer_t_step - | rewriter_t - | progress break_t_step - | progress split_min_max - | progress cbv [Bounds.neg' Bounds.cmovne' Bounds.cmovle' ModularBaseSystemListZOperations.neg ModularBaseSystemListZOperations.cmovne ModularBaseSystemListZOperations.cmovl] ]. -Local Ltac t_special_case_op := repeat t_special_case_op_step. - Local Arguments Z.pow : simpl never. Local Arguments Z.add !_ !_. Local Existing Instances Z.add_le_Proper Z.log2_up_le_Proper Z.pow_Zpos_le_Proper Z.sub_le_eq_Proper. @@ -310,9 +296,6 @@ Proof. | progress simpl in * | progress split_min_max | omega ]. } - { t_special_case_op. } - { t_special_case_op. } - { t_special_case_op. } Admitted. Local Arguments lift_op : simpl never. diff --git a/src/Reflection/Z/Reify.v b/src/Reflection/Z/Reify.v index 0573501b7..439a1df8c 100644 --- a/src/Reflection/Z/Reify.v +++ b/src/Reflection/Z/Reify.v @@ -1,5 +1,4 @@ Require Import Coq.ZArith.ZArith. -Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations. Require Import Crypto.Reflection.InputSyntax. Require Import Crypto.Reflection.Z.Syntax. Require Import Crypto.Reflection.Z.Syntax.Equality. @@ -23,14 +22,6 @@ Ltac base_reify_op op op_head extra ::= | @Z.land => constr:(reify_op op op_head 2 (Land TZ TZ TZ)) | @Z.lor => constr:(reify_op op op_head 2 (Lor TZ TZ TZ)) | @Z.opp => constr:(reify_op op op_head 1 (Opp TZ TZ)) - | @ModularBaseSystemListZOperations.cmovne => constr:(reify_op op op_head 4 (Cmovne TZ TZ TZ TZ TZ)) - | @ModularBaseSystemListZOperations.cmovl => constr:(reify_op op op_head 4 (Cmovle TZ TZ TZ TZ TZ)) - | @ModularBaseSystemListZOperations.neg - => lazymatch extra with - | @ModularBaseSystemListZOperations.neg ?int_width _ - => constr:(reify_op op op_head 1 (Neg TZ TZ int_width)) - | _ => fail 100 "Anomaly: In Reflection.Z.base_reify_op: head is neg but body is wrong:" extra - end end. Ltac base_reify_type T ::= lazymatch T with diff --git a/src/Reflection/Z/Syntax.v b/src/Reflection/Z/Syntax.v index 58c7de6e6..58c55bc06 100644 --- a/src/Reflection/Z/Syntax.v +++ b/src/Reflection/Z/Syntax.v @@ -4,7 +4,6 @@ Require Import Bedrock.Word. Require Import Crypto.Reflection.SmartMap. Require Import Crypto.Reflection.Syntax. Require Import Crypto.Reflection.TypeUtil. -Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations. Require Import Crypto.Util.FixedWordSizes. Require Import Crypto.Util.Option. Require Import Crypto.Util.NatUtil. (* for nat_beq for equality schemes *) @@ -27,9 +26,7 @@ Inductive op : flat_type base_type -> flat_type base_type -> Type := | Land T1 T2 Tout : op (Tbase T1 * Tbase T2) (Tbase Tout) | Lor T1 T2 Tout : op (Tbase T1 * Tbase T2) (Tbase Tout) | Opp T Tout : op (Tbase T) (Tbase Tout) -| Neg T Tout (int_width : Z) : op (Tbase T) (Tbase Tout) -| Cmovne T1 T2 T3 T4 Tout : op (Tbase T1 * Tbase T2 * Tbase T3 * Tbase T4) (Tbase Tout) -| Cmovle T1 T2 T3 T4 Tout : op (Tbase T1 * Tbase T2 * Tbase T3 * Tbase T4) (Tbase Tout). +. Definition interp_base_type (v : base_type) : Type := match v with @@ -81,9 +78,6 @@ Definition Zinterp_op src dst (f : op src dst) | Land _ _ _ => fun xy => Z.land (fst xy) (snd xy) | Lor _ _ _ => fun xy => Z.lor (fst xy) (snd xy) | Opp _ _ => fun x => Z.opp x - | Neg _ _ int_width => fun x => ModularBaseSystemListZOperations.neg int_width x - | Cmovne _ _ _ _ _ => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovne x y z w - | Cmovle _ _ _ _ _ => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovl x y z w end%Z. Definition interp_op src dst (f : op src dst) : interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst diff --git a/src/Reflection/Z/Syntax/Equality.v b/src/Reflection/Z/Syntax/Equality.v index 2862859b7..17822d7ec 100644 --- a/src/Reflection/Z/Syntax/Equality.v +++ b/src/Reflection/Z/Syntax/Equality.v @@ -41,11 +41,6 @@ Definition op_beq_hetero {t1 tR t1' tR'} (f : op t1 tR) (g : op t1' tR') : bool => base_type_beq T1 T1' && base_type_beq T2 T2' && base_type_beq Tout Tout' | Opp Tin Tout, Opp Tin' Tout' => base_type_beq Tin Tin' && base_type_beq Tout Tout' - | Cmovne T1 T2 T3 T4 Tout, Cmovne T1' T2' T3' T4' Tout' - | Cmovle T1 T2 T3 T4 Tout, Cmovle T1' T2' T3' T4' Tout' - => base_type_beq T1 T1' && base_type_beq T2 T2' && base_type_beq T3 T3' && base_type_beq T4 T4' && base_type_beq Tout Tout' - | Neg Tin Tout n, Neg Tin' Tout' m - => base_type_beq Tin Tin' && base_type_beq Tout Tout' && Z.eqb n m | OpConst _ _, _ | Add _ _ _, _ | Sub _ _ _, _ @@ -55,9 +50,6 @@ Definition op_beq_hetero {t1 tR t1' tR'} (f : op t1 tR) (g : op t1' tR') : bool | Land _ _ _, _ | Lor _ _ _, _ | Opp _ _, _ - | Neg _ _ _, _ - | Cmovne _ _ _ _ _, _ - | Cmovle _ _ _ _ _, _ => false end%bool. diff --git a/src/Reflection/Z/Syntax/Util.v b/src/Reflection/Z/Syntax/Util.v index 72b08d6cf..b5862c72f 100644 --- a/src/Reflection/Z/Syntax/Util.v +++ b/src/Reflection/Z/Syntax/Util.v @@ -59,9 +59,6 @@ Definition genericize_op {var' src dst} (opc : op src dst) {f} | Land _ _ _ => fun _ _ => Land _ _ _ | Lor _ _ _ => fun _ _ => Lor _ _ _ | Opp _ _ => fun _ _ => Opp _ _ - | Neg _ _ int_width => fun _ _ => Neg _ _ int_width - | Cmovne _ _ _ _ _ => fun _ _ => Cmovne _ _ _ _ _ - | Cmovle _ _ _ _ _ => fun _ _ => Cmovle _ _ _ _ _ end. Lemma cast_const_id {t} v diff --git a/src/Specific/GF1305.v b/src/Specific/GF1305.v deleted file mode 100644 index 6ddc12ee5..000000000 --- a/src/Specific/GF1305.v +++ /dev/null @@ -1,404 +0,0 @@ -Require Import Crypto.BaseSystem. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. -Require Import Crypto.ModularArithmetic.ModularBaseSystem. -Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. -Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. -Require Import Crypto.Util.Tuple. -Require Import Coq.Lists.List Crypto.Util.ListUtil. -Require Import Crypto.Tactics.VerdiTactics. -Require Import Crypto.Util.ZUtil. -Require Import Crypto.Util.LetIn. -Require Import Crypto.Util.Notations. -Require Import Crypto.Util.Decidable. -Require Import Crypto.Algebra. -Import ListNotations. -Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. -Local Open Scope Z. - -(* BEGIN precomputation. *) - -Definition modulus : positive := (2^130 - 5)%positive. -Lemma prime_modulus : prime modulus. Admitted. -Definition int_width := 32%Z. - -Instance params1305 : PseudoMersenneBaseParams modulus. - construct_params prime_modulus 5%nat 130. -Defined. - -Definition fe1305 := Eval compute in (tuple Z (length limb_widths)). - -Definition mul2modulus : fe1305 := - Eval compute in (from_list_default 0%Z (length limb_widths) (construct_mul2modulus params1305)). - -Instance subCoeff : @SubtractionCoefficient modulus params1305. - apply Build_SubtractionCoefficient with (coeff := mul2modulus). - vm_decide. -Defined. - -Instance carryChain : CarryChain limb_widths. - apply Build_CarryChain with (carry_chain := ([0;1;2;3;4;0])%nat). - intros; - repeat (destruct H as [|H]; [subst; vm_compute; repeat constructor | ]). - contradiction H. -Defined. - -Definition freezePreconditions1305 : FreezePreconditions int_width int_width. -Proof. - constructor; compute_preconditions. -Defined. -(* Wire format for [pack] and [unpack] *) -Definition wire_widths := Eval compute in (repeat 32 4 ++ 2 :: nil). - -Definition wire_digits := Eval compute in (tuple Z (length wire_widths)). - -Lemma wire_widths_nonneg : forall w, In w wire_widths -> 0 <= w. -Proof. - intros. - repeat (destruct H as [|H]; [subst; vm_compute; congruence | ]). - contradiction H. -Qed. - -Lemma bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn wire_widths (length wire_widths). -Proof. reflexivity. Qed. - -Lemma modulus_gt_2 : 2 < modulus. Proof. cbv; congruence. Qed. - -(* Temporarily, we'll use addition chains equivalent to double-and-add. This is pending - finding the real, more optimal chains from previous work. *) -Fixpoint pow2Chain'' p (pow2_index acc_index : nat) chain_acc : list (nat * nat) := - match p with - | xI p' => pow2Chain'' p' 1 0 - (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) - | xO p' => pow2Chain'' p' 0 (S acc_index) - (chain_acc ++ (pow2_index, pow2_index)::nil) - | xH => (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) - end. - -Fixpoint pow2Chain' p index := - match p with - | xI p' => pow2Chain'' p' 0 0 (repeat (0,0)%nat index) - | xO p' => pow2Chain' p' (S index) - | xH => repeat (0,0)%nat index - end. - -Definition pow2_chain p := - match p with - | xH => nil - | _ => pow2Chain' p 0 - end. - -Definition invChain := Eval compute in pow2_chain (Z.to_pos (modulus - 2)). - -Instance inv_ec : ExponentiationChain (modulus - 2). - apply Build_ExponentiationChain with (chain := invChain). - reflexivity. -Defined. - -(* Note : use caution copying square root code to other primes. The (modulus / 4 + 1) chains are - for primes that are 3 mod 4; if the prime is 5 mod 8 then use (modulus / 8 + 1). *) -Definition sqrtChain := Eval compute in pow2_chain (Z.to_pos (modulus / 4 + 1)). - -Instance sqrt_ec : ExponentiationChain (modulus / 4 + 1). - apply Build_ExponentiationChain with (chain := sqrtChain). - reflexivity. -Defined. - -Arguments chain {_ _ _} _. - -(* END precomputation *) - -(* Precompute k, c, zero, and one *) -Definition k_ := Eval compute in k. -Definition c_ := Eval compute in c. -Definition one_ := Eval compute in one. -Definition zero_ := Eval compute in zero. -Definition k_subst : k = k_ := eq_refl k_. -Definition c_subst : c = c_ := eq_refl c_. -Definition one_subst : one = one_ := eq_refl one_. -Definition zero_subst : zero = zero_ := eq_refl zero_. - -Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In. - -Definition app_5 {T} (f : fe1305) (P : fe1305 -> T) : T. -Proof. - cbv [fe1305] in *. - set (f0 := f). - repeat (let g := fresh "g" in destruct f as [f g]). - apply P. - apply f0. -Defined. - -Definition app_5_correct {T} f (P : fe1305 -> T) : app_5 f P = P f. -Proof. - intros. - cbv [fe1305] in *. - repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. - reflexivity. -Qed. - -Definition appify2 {T} (op : fe1305 -> fe1305 -> T) (f g : fe1305) := - app_5 f (fun f0 => (app_5 g (fun g0 => op f0 g0))). - -Lemma appify2_correct : forall {T} op f g, @appify2 T op f g = op f g. -Proof. - intros. cbv [appify2]. - etransitivity; apply app_5_correct. -Qed. - -Definition add_sig (f g : fe1305) : - { fg : fe1305 | fg = add_opt f g}. -Proof. - eexists. - rewrite <-(@appify2_correct fe1305). - cbv. - reflexivity. -Defined. - -Definition add (f g : fe1305) : fe1305 := - Eval cbv beta iota delta [proj1_sig add_sig] in - proj1_sig (add_sig f g). - -Definition add_correct (f g : fe1305) - : add f g = add_opt f g := - Eval cbv beta iota delta [proj1_sig add_sig] in - proj2_sig (add_sig f g). - -Definition sub_sig (f g : fe1305) : - { fg : fe1305 | fg = sub_opt f g}. -Proof. - eexists. - rewrite <-(@appify2_correct fe1305). - cbv. - reflexivity. -Defined. - -Definition sub (f g : fe1305) : fe1305 := - Eval cbv beta iota delta [proj1_sig sub_sig] in - proj1_sig (sub_sig f g). - -Definition sub_correct (f g : fe1305) - : sub f g = sub_opt f g := - Eval cbv beta iota delta [proj1_sig sub_sig] in - proj2_sig (sub_sig f g). - -(* For multiplication, we add another layer of definition so that we can - rewrite under the [let] binders. *) -Definition mul_simpl_sig (f g : fe1305) : - { fg : fe1305 | fg = carry_mul_opt k_ c_ f g}. -Proof. - cbv [fe1305] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - eexists. - cbv. - autorewrite with zsimplify. - reflexivity. -Defined. - -Definition mul_simpl (f g : fe1305) : fe1305 := - Eval cbv beta iota delta [proj1_sig mul_simpl_sig] in - proj1_sig (mul_simpl_sig f g). - -Definition mul_simpl_correct (f g : fe1305) - : mul_simpl f g = carry_mul_opt k_ c_ f g := - Eval cbv beta iota delta [proj1_sig mul_simpl_sig] in - proj2_sig (mul_simpl_sig f g). - -Definition mul_sig (f g : fe1305) : - { fg : fe1305 | fg = carry_mul_opt k_ c_ f g}. -Proof. - eexists. - rewrite <-mul_simpl_correct. - rewrite <-(@appify2_correct fe1305). - cbv. - reflexivity. -Defined. - -Definition mul (f g : fe1305) : fe1305 := - Eval cbv beta iota delta [proj1_sig mul_sig] in - proj1_sig (mul_sig f g). - -Definition mul_correct (f g : fe1305) - : mul f g = carry_mul_opt k_ c_ f g := - Eval cbv beta iota delta [proj2_sig add_sig] in - proj2_sig (mul_sig f g). - -Definition opp_sig (f : fe1305) : - { g : fe1305 | g = opp_opt f }. -Proof. - eexists. - cbv [opp_opt]. - rewrite <-sub_correct. - rewrite zero_subst. - cbv [sub]. - reflexivity. -Defined. - -Definition opp (f : fe1305) : fe1305 - := Eval cbv beta iota delta [proj1_sig opp_sig] in proj1_sig (opp_sig f). - -Definition opp_correct (f : fe1305) - : opp f = opp_opt f - := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (opp_sig f). - -Definition pow (f : fe1305) chain := fold_chain_opt one_ mul chain [f]. - -Lemma pow_correct (f : fe1305) : forall chain, pow f chain = pow_opt k_ c_ one_ f chain. -Proof. - cbv [pow pow_opt]; intros. - rewrite !fold_chain_opt_correct. - apply Proper_fold_chain; try reflexivity. - intros; subst; apply mul_correct. -Qed. - -Definition inv_sig (f : fe1305) : - { g : fe1305 | g = inv_opt k_ c_ one_ f }. -Proof. - eexists; cbv [inv_opt]. - rewrite <-pow_correct. - cbv - [mul]. - reflexivity. -Defined. - -Definition inv (f : fe1305) : fe1305 - := Eval cbv beta iota delta [proj1_sig inv_sig] in proj1_sig (inv_sig f). - -Definition inv_correct (f : fe1305) - : inv f = inv_opt k_ c_ one_ f - := Eval cbv beta iota delta [proj2_sig inv_sig] in proj2_sig (inv_sig f). - -Definition mbs_field := modular_base_system_field modulus_gt_2. - -Import Morphisms. - -Lemma field1305 : @field fe1305 eq zero one opp add sub mul inv div. -Proof. - pose proof (Equivalence_Reflexive : Reflexive eq). - eapply (Field.equivalent_operations_field (fieldR := mbs_field)). - Grab Existential Variables. - + reflexivity. - + reflexivity. - + reflexivity. - + intros; rewrite mul_correct. - rewrite carry_mul_opt_correct by auto using k_subst, c_subst. - cbv [eq]. - rewrite carry_mul_rep by reflexivity. - rewrite mul_rep; reflexivity. - + intros; rewrite sub_correct, sub_opt_correct; reflexivity. - + intros; rewrite add_correct, add_opt_correct; reflexivity. - + intros; rewrite inv_correct, inv_opt_correct; reflexivity. - + intros; rewrite opp_correct, opp_opt_correct; reflexivity. -Qed. - -Lemma homomorphism_F1305 : - @Ring.is_homomorphism - (F modulus) Logic.eq F.one F.add F.mul - fe1305 eq one add mul encode. -Proof. - econstructor. - + econstructor; [ | apply encode_Proper]. - intros; cbv [eq]. - rewrite add_correct, add_opt_correct, add_rep; apply encode_rep. - + intros; cbv [eq]. - rewrite mul_correct, carry_mul_opt_correct, carry_mul_rep - by auto using k_subst, c_subst, encode_rep. - apply encode_rep. - + reflexivity. -Qed. - -Definition pack_simpl_sig (f : fe1305) : - { f' | f' = pack_opt params1305 wire_widths_nonneg bits_eq f }. -Proof. - cbv [fe1305] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - eexists. - cbv [pack_opt]. - repeat ( - rewrite <-convert'_opt_correct; - cbv - [from_list_default_opt Conversion.convert']; - repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r). - cbv [from_list_default_opt]. - reflexivity. -Defined. - -Definition pack_simpl (f : fe1305) := - Eval cbv beta iota delta [proj1_sig pack_simpl_sig] in - proj1_sig (pack_simpl_sig f). - -Definition pack_simpl_correct (f : fe1305) - : pack_simpl f = pack_opt params1305 wire_widths_nonneg bits_eq f - := Eval cbv beta iota delta [proj2_sig pack_simpl_sig] in proj2_sig (pack_simpl_sig f). - -Definition pack_sig (f : fe1305) : - { f' | f' = pack_opt params1305 wire_widths_nonneg bits_eq f }. -Proof. - eexists. - rewrite <-pack_simpl_correct. - rewrite <-(@app_5_correct wire_digits). - cbv. - reflexivity. -Defined. - -Definition pack (f : fe1305) : wire_digits := - Eval cbv beta iota delta [proj1_sig pack_sig] in proj1_sig (pack_sig f). - -Definition pack_correct (f : fe1305) - : pack f = pack_opt params1305 wire_widths_nonneg bits_eq f - := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (pack_sig f). - -Definition unpack_simpl_sig (f : wire_digits) : - { f' | f' = unpack_opt params1305 wire_widths_nonneg bits_eq f }. -Proof. - cbv [wire_digits] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - eexists. - cbv [unpack_opt]. - repeat ( - rewrite <-convert'_opt_correct; - cbv - [from_list_default_opt Conversion.convert']; - repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r). - cbv [from_list_default_opt]. - reflexivity. -Defined. - -Definition unpack_simpl (f : wire_digits) : fe1305 := - Eval cbv beta iota delta [proj1_sig unpack_simpl_sig] in - proj1_sig (unpack_simpl_sig f). - -Definition unpack_simpl_correct (f : wire_digits) - : unpack_simpl f = unpack_opt params1305 wire_widths_nonneg bits_eq f - := Eval cbv beta iota delta [proj2_sig unpack_simpl_sig] in proj2_sig (unpack_simpl_sig f). - -Definition unpack_sig (f : wire_digits) : - { f' | f' = unpack_opt params1305 wire_widths_nonneg bits_eq f }. -Proof. - eexists. - rewrite <-unpack_simpl_correct. - rewrite <-(@app_5_correct fe1305). - cbv. - reflexivity. -Defined. - -Definition unpack (f : wire_digits) : fe1305 := - Eval cbv beta iota delta [proj1_sig unpack_sig] in proj1_sig (unpack_sig f). - -Definition unpack_correct (f : wire_digits) - : unpack f = unpack_opt params1305 wire_widths_nonneg bits_eq f - := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (unpack_sig f). - -Definition sqrt_sig (f : fe1305) : - { g | g = sqrt_3mod4_opt k_ c_ one_ f}. -Proof. - eexists; cbv [sqrt_3mod4_opt]. - rewrite <-pow_correct. - cbv - [mul]. - reflexivity. -Defined. - -Definition sqrt (f : fe1305) : fe1305 := - Eval cbv beta iota delta [proj1_sig sqrt_sig] in proj1_sig (sqrt_sig f). - -Definition sqrt_correct (f : fe1305) - : sqrt f = sqrt_3mod4_opt k_ c_ one_ f - := Eval cbv beta iota delta [proj2_sig sqrt_sig] in proj2_sig (sqrt_sig f). diff --git a/src/Specific/GF25519.v b/src/Specific/GF25519.v deleted file mode 100644 index 1a74de889..000000000 --- a/src/Specific/GF25519.v +++ /dev/null @@ -1,785 +0,0 @@ -Require Import Crypto.BaseSystem. -Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. -Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. -Require Import Crypto.ModularArithmetic.ModularBaseSystem. -Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. -Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. -Require Import Crypto.Util.Tuple. -Require Import Coq.Lists.List Crypto.Util.ListUtil. -Require Import Crypto.Tactics.VerdiTactics. -Require Import Crypto.Util.ZUtil. -Require Import Crypto.Util.Tactics.SetEvars. -Require Import Crypto.Util.Tactics.SubstEvars. -Require Import Crypto.Util.Tactics.DestructHead. -Require Import Crypto.Util.LetIn. -Require Import Crypto.Util.Tower. -Require Import Crypto.Util.Notations. -Require Import Crypto.Util.Decidable. -Require Import Crypto.Algebra. -Require Crypto.Spec.Ed25519. -Import ListNotations. -Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. -Local Open Scope Z. - -(* BEGIN precomputation. *) - -Definition modulus : positive := Eval compute in (2^255 - 19)%positive. -Definition prime_modulus : prime modulus := Crypto.Spec.Ed25519.prime_q. -Definition int_width := 64%Z. -Definition freeze_input_bound := 32%Z. - -Instance params25519 : PseudoMersenneBaseParams modulus. - construct_params prime_modulus 10%nat 255. -Defined. - -Definition length_fe25519 := Eval compute in length limb_widths. -Definition fe25519 := Eval compute in (tuple Z length_fe25519). - -Definition mul2modulus : fe25519 := - Eval compute in (from_list_default 0%Z (length limb_widths) (construct_mul2modulus params25519)). - -Instance subCoeff : SubtractionCoefficient. - apply Build_SubtractionCoefficient with (coeff := mul2modulus). - vm_decide. -Defined. - -Instance carryChain : CarryChain limb_widths. - apply Build_CarryChain with (carry_chain := (rev [0;1;2;3;4;5;6;7;8;9;0;1])%nat). - intros. - repeat (destruct H as [|H]; [subst; vm_compute; repeat constructor | ]). - contradiction H. -Defined. - -Definition freezePreconditions25519 : FreezePreconditions freeze_input_bound int_width. -Proof. - constructor; compute_preconditions. -Defined. - -(* Wire format for [pack] and [unpack] *) -Definition wire_widths := Eval compute in (repeat 32 7 ++ 31 :: nil). - -Definition wire_digits := Eval compute in (tuple Z (length wire_widths)). - -Lemma wire_widths_nonneg : forall w, In w wire_widths -> 0 <= w. -Proof. - intros. - repeat (destruct H as [|H]; [subst; vm_compute; congruence | ]). - contradiction H. -Qed. - -Lemma bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn wire_widths (length wire_widths). -Proof. - reflexivity. -Qed. - -Lemma modulus_gt_2 : 2 < modulus. Proof. cbv; congruence. Qed. - -(* Temporarily, we'll use addition chains equivalent to double-and-add. This is pending - finding the real, more optimal chains from previous work. *) -Fixpoint pow2Chain'' p (pow2_index acc_index : nat) chain_acc : list (nat * nat) := - match p with - | xI p' => pow2Chain'' p' 1 0 - (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) - | xO p' => pow2Chain'' p' 0 (S acc_index) - (chain_acc ++ (pow2_index, pow2_index)::nil) - | xH => (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) - end. - -Fixpoint pow2Chain' p index := - match p with - | xI p' => pow2Chain'' p' 0 0 (repeat (0,0)%nat index) - | xO p' => pow2Chain' p' (S index) - | xH => repeat (0,0)%nat index - end. - -Definition pow2_chain p := - match p with - | xH => nil - | _ => pow2Chain' p 0 - end. - -(* From Daniel Bernstein's "ref" implementation (Public Domain) *) -Definition invChain := [(0, 0); (0, 0); (0, 0); (0, 3); (0, 3); (0, 0); (0, 2); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 5); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 10); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 20); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 42); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 50); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 100); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 202); (0, 0); (0, 0); (0, 0); (0, 0); (0, 0); (0, 259)]%nat. - -Instance inv_ec : ExponentiationChain (modulus - 2). - apply Build_ExponentiationChain with (chain := invChain); vm_decide_no_check. -Defined. - -(* Note : use caution copying square root code to other primes. The (modulus / 8 + 1) chains are - for primes that are 5 mod 8; if the prime is 3 mod 4 then use (modulus / 4 + 1). *) -Definition sqrtChain := Eval compute in pow2_chain (Z.to_pos (modulus / 8 + 1)). - -Instance sqrt_ec : ExponentiationChain (modulus / 8 + 1). - apply Build_ExponentiationChain with (chain := sqrtChain). - reflexivity. -Defined. - -Arguments chain {_ _ _} _. - -(* END precomputation *) - -(* Precompute constants *) -Definition k_ := Eval compute in k. -Definition k_subst : k = k_ := eq_refl k_. - -Definition c_ := Eval compute in c. -Definition c_subst : c = c_ := eq_refl c_. - -Definition one_ := Eval compute in one. -Definition one_subst : one = one_ := eq_refl one_. - -Definition zero_ := Eval compute in zero. -Definition zero_subst : zero = zero_ := eq_refl zero_. - -Definition modulus_digits_ := Eval compute in ModularBaseSystemList.modulus_digits. -Definition modulus_digits_subst : ModularBaseSystemList.modulus_digits = modulus_digits_ := eq_refl modulus_digits_. - -Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb Z.leb ModularBaseSystemListZOperations.neg ModularBaseSystemListZOperations.cmovl ModularBaseSystemListZOperations.cmovne. - -Definition app_7 {T} (f : wire_digits) (P : wire_digits -> T) : T. -Proof. - cbv [wire_digits] in *. - set (f0 := f). - repeat (let g := fresh "g" in destruct f as [f g]). - apply P. - apply f0. -Defined. - -Definition app_7_correct {T} f (P : wire_digits -> T) : app_7 f P = P f. -Proof. - intros. - cbv [wire_digits] in *. - repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. - reflexivity. -Qed. - -Definition app_10 {T} (f : fe25519) (P : fe25519 -> T) : T. -Proof. - cbv [fe25519] in *. - set (f0 := f). - repeat (let g := fresh "g" in destruct f as [f g]). - apply P. - apply f0. -Defined. - -Definition app_10_correct {T} f (P : fe25519 -> T) : app_10 f P = P f. -Proof. - intros. - cbv [fe25519] in *. - repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. - reflexivity. -Qed. - -Definition appify2 {T} (op : fe25519 -> fe25519 -> T) (f g : fe25519) := - app_10 f (fun f0 => (app_10 g (fun g0 => op f0 g0))). - -Lemma appify2_correct : forall {T} op f g, @appify2 T op f g = op f g. -Proof. - intros. cbv [appify2]. - etransitivity; apply app_10_correct. -Qed. - -Definition appify9 {T} (op : fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> T) (x0 x1 x2 x3 x4 x5 x6 x7 x8 : fe25519) := - app_10 x0 (fun x0' => - app_10 x1 (fun x1' => - app_10 x2 (fun x2' => - app_10 x3 (fun x3' => - app_10 x4 (fun x4' => - app_10 x5 (fun x5' => - app_10 x6 (fun x6' => - app_10 x7 (fun x7' => - app_10 x8 (fun x8' => - op x0' x1' x2' x3' x4' x5' x6' x7' x8'))))))))). - -Lemma appify9_correct : forall {T} op x0 x1 x2 x3 x4 x5 x6 x7 x8, - @appify9 T op x0 x1 x2 x3 x4 x5 x6 x7 x8 = op x0 x1 x2 x3 x4 x5 x6 x7 x8. -Proof. - intros. cbv [appify9]. - repeat (etransitivity; [ apply app_10_correct | ]); reflexivity. -Qed. - -Definition uncurry_unop_fe25519 {T} (op : fe25519 -> T) - := Eval compute in Tuple.uncurry (n:=length_fe25519) op. -Definition curry_unop_fe25519 {T} op : fe25519 -> T - := Eval compute in fun f => app_10 f (Tuple.curry (n:=length_fe25519) op). - -Fixpoint uncurry_n_op_fe25519 {T} n - : forall (op : Tower.tower_nd fe25519 T n), - Tower.tower_nd Z T (n * length_fe25519) - := match n - return (forall (op : Tower.tower_nd fe25519 T n), - Tower.tower_nd Z T (n * length_fe25519)) - with - | O => fun x => x - | S n' => fun f => uncurry_unop_fe25519 (fun x => @uncurry_n_op_fe25519 _ n' (f x)) - end. - -Definition uncurry_binop_fe25519 {T} (op : fe25519 -> fe25519 -> T) - := Eval compute in uncurry_n_op_fe25519 2 op. -Definition curry_binop_fe25519 {T} op : fe25519 -> fe25519 -> T - := Eval compute in appify2 (fun f => curry_unop_fe25519 (curry_unop_fe25519 op f)). - -Definition uncurry_unop_wire_digits {T} (op : wire_digits -> T) - := Eval compute in Tuple.uncurry (n:=length wire_widths) op. -Definition curry_unop_wire_digits {T} op : wire_digits -> T - := Eval compute in fun f => app_7 f (Tuple.curry (n:=length wire_widths) op). - -Definition uncurry_9op_fe25519 {T} (op : fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> T) - := Eval compute in uncurry_n_op_fe25519 9 op. -Definition curry_9op_fe25519 {T} op : fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> fe25519 -> T - := Eval compute in - appify9 (fun x0 x1 x2 x3 x4 x5 x6 x7 x8 - => curry_unop_fe25519 (curry_unop_fe25519 (curry_unop_fe25519 (curry_unop_fe25519 (curry_unop_fe25519 (curry_unop_fe25519 (curry_unop_fe25519 (curry_unop_fe25519 (curry_unop_fe25519 op x0) x1) x2) x3) x4) x5) x6) x7) x8). - -Definition add_sig (f g : fe25519) : - { fg : fe25519 | fg = add_opt f g}. -Proof. - eexists. - rewrite <-(@appify2_correct fe25519). - cbv. - reflexivity. -Defined. - -Definition add (f g : fe25519) : fe25519 := - Eval cbv beta iota delta [proj1_sig add_sig] in - proj1_sig (add_sig f g). - -Definition add_correct (f g : fe25519) - : add f g = add_opt f g := - Eval cbv beta iota delta [proj1_sig add_sig] in - proj2_sig (add_sig f g). - -Definition carry_add_sig (f g : fe25519) : - { fg : fe25519 | fg = carry_add_opt f g}. -Proof. - eexists. - rewrite <-(@appify2_correct fe25519). - cbv. - autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. (* FIXME: The speed of this rewrite depends on the fact that we have 10 limbs; there are some lemmas in [zsimplify_Z_to_pos] which are specific to 10. *) - autorewrite with zsimplify_Z_to_pos; cbv. - reflexivity. -Defined. - -Definition carry_add (f g : fe25519) : fe25519 := - Eval cbv beta iota delta [proj1_sig carry_add_sig] in - proj1_sig (carry_add_sig f g). - -Definition carry_add_correct (f g : fe25519) - : carry_add f g = carry_add_opt f g := - Eval cbv beta iota delta [proj1_sig carry_add_sig] in - proj2_sig (carry_add_sig f g). - -Definition sub_sig (f g : fe25519) : - { fg : fe25519 | fg = sub_opt f g}. -Proof. - eexists. - rewrite <-(@appify2_correct fe25519). - cbv. - reflexivity. -Defined. - -Definition sub (f g : fe25519) : fe25519 := - Eval cbv beta iota delta [proj1_sig sub_sig] in - proj1_sig (sub_sig f g). - -Definition sub_correct (f g : fe25519) - : sub f g = sub_opt f g := - Eval cbv beta iota delta [proj1_sig sub_sig] in - proj2_sig (sub_sig f g). - -Definition carry_sub_sig (f g : fe25519) : - { fg : fe25519 | fg = carry_sub_opt f g}. -Proof. - eexists. - rewrite <-(@appify2_correct fe25519). - cbv. - autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. (* FIXME: The speed of this rewrite depends on the fact that we have 10 limbs; there are some lemmas in [zsimplify_Z_to_pos] which are specific to 10. *) - autorewrite with zsimplify_Z_to_pos; cbv. - reflexivity. -Defined. - -Definition carry_sub (f g : fe25519) : fe25519 := - Eval cbv beta iota delta [proj1_sig carry_sub_sig] in - proj1_sig (carry_sub_sig f g). - -Definition carry_sub_correct (f g : fe25519) - : carry_sub f g = carry_sub_opt f g := - Eval cbv beta iota delta [proj1_sig carry_sub_sig] in - proj2_sig (carry_sub_sig f g). - -(* For multiplication, we add another layer of definition so that we can - rewrite under the [let] binders. *) -Definition mul_simpl_sig (f g : fe25519) : - { fg : fe25519 | fg = carry_mul_opt k_ c_ f g}. -Proof. - cbv [fe25519] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - eexists. - cbv. (* N.B. The slow part of this is computing with [Z_div_opt]. - It would be much faster if we could take advantage of - the form of [base_from_limb_widths] when doing - division, so we could do subtraction instead. *) - autorewrite with zsimplify_fast. - reflexivity. -Defined. - -Definition mul_simpl (f g : fe25519) : fe25519 := - Eval cbv beta iota delta [proj1_sig mul_simpl_sig] in - let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in - let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9) := g in - proj1_sig (mul_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) - (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9)). - -Definition mul_simpl_correct (f g : fe25519) - : mul_simpl f g = carry_mul_opt k_ c_ f g. -Proof. - pose proof (proj2_sig (mul_simpl_sig f g)). - cbv [fe25519] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - assumption. -Qed. - -Definition mul_sig (f g : fe25519) : - { fg : fe25519 | fg = carry_mul_opt k_ c_ f g}. -Proof. - eexists. - rewrite <-mul_simpl_correct. - rewrite <-(@appify2_correct fe25519). - cbv. - autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. (* FIXME: The speed of this rewrite depends on the fact that we have 10 limbs; there are some lemmas in [zsimplify_Z_to_pos] which are specific to 10. *) - autorewrite with zsimplify_Z_to_pos; cbv. - reflexivity. -Defined. - -Definition mul (f g : fe25519) : fe25519 := - Eval cbv beta iota delta [proj1_sig mul_sig] in - proj1_sig (mul_sig f g). - -Definition mul_correct (f g : fe25519) - : mul f g = carry_mul_opt k_ c_ f g := - Eval cbv beta iota delta [proj1_sig add_sig] in - proj2_sig (mul_sig f g). - -Definition opp_sig (f : fe25519) : - { g : fe25519 | g = opp_opt f }. -Proof. - eexists. - cbv [opp_opt]. - rewrite <-sub_correct. - rewrite zero_subst. - cbv [sub]. - reflexivity. -Defined. - -Definition opp (f : fe25519) : fe25519 - := Eval cbv beta iota delta [proj1_sig opp_sig] in proj1_sig (opp_sig f). - -Definition opp_correct (f : fe25519) - : opp f = opp_opt f - := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (opp_sig f). - -Definition carry_opp_sig (f : fe25519) : - { g : fe25519 | g = carry_opp_opt f }. -Proof. - eexists. - cbv [carry_opp_opt]. - rewrite <-carry_sub_correct. - rewrite zero_subst. - cbv [carry_sub]. - reflexivity. -Defined. - -Definition carry_opp (f : fe25519) : fe25519 - := Eval cbv beta iota delta [proj1_sig carry_opp_sig] in proj1_sig (carry_opp_sig f). - -Definition carry_opp_correct (f : fe25519) - : carry_opp f = carry_opp_opt f - := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (carry_opp_sig f). - -Definition pow (f : fe25519) chain := fold_chain_opt one_ mul chain [f]. - -Lemma pow_correct (f : fe25519) : forall chain, pow f chain = pow_opt k_ c_ one_ f chain. -Proof. - cbv [pow pow_opt]; intros. - rewrite !fold_chain_opt_correct. - apply Proper_fold_chain; try reflexivity. - intros; subst; apply mul_correct. -Qed. - -(* Now that we have [pow], we can compute sqrt of -1 for use - in sqrt function (this is not needed unless the prime is - 5 mod 8) *) -Local Transparent Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb andb. - -Definition sqrt_m1 := Eval vm_compute in (pow (encode (F.of_Z _ 2)) (pow2_chain (Z.to_pos ((modulus - 1) / 4)))). - -Lemma sqrt_m1_correct : rep (mul sqrt_m1 sqrt_m1) (F.opp 1%F). -Proof. - cbv [rep]. - apply F.eq_to_Z_iff. - vm_compute. - reflexivity. -Qed. - -Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb andb. - -Definition inv_sig (f : fe25519) : - { g : fe25519 | g = inv_opt k_ c_ one_ f }. -Proof. - eexists; cbv [inv_opt]. - rewrite <-pow_correct. - cbv - [mul]. - reflexivity. -Defined. - -Definition inv (f : fe25519) : fe25519 - := Eval cbv beta iota delta [proj1_sig inv_sig] in proj1_sig (inv_sig f). - -Definition inv_correct (f : fe25519) - : inv f = inv_opt k_ c_ one_ f - := Eval cbv beta iota delta [proj2_sig inv_sig] in proj2_sig (inv_sig f). - -Definition mbs_field := modular_base_system_field modulus_gt_2. - -Import Morphisms. - -Local Existing Instance prime_modulus. - -Lemma field25519_and_homomorphisms - : @field fe25519 eq zero_ one_ opp add sub mul inv div - /\ @Ring.is_homomorphism - (F modulus) Logic.eq F.one F.add F.mul - fe25519 eq one_ add mul encode - /\ @Ring.is_homomorphism - fe25519 eq one_ add mul - (F modulus) Logic.eq F.one F.add F.mul - decode. -Proof. - eapply @Field.field_and_homomorphism_from_redundant_representation. - { exact (F.field_modulo _). } - { apply encode_rep. } - { reflexivity. } - { reflexivity. } - { reflexivity. } - { intros; rewrite opp_correct, opp_opt_correct; apply opp_rep; reflexivity. } - { intros; rewrite add_correct, add_opt_correct; apply add_rep; reflexivity. } - { intros; rewrite sub_correct, sub_opt_correct; apply sub_rep; reflexivity. } - { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } - { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } - { intros; apply encode_rep. } -Qed. - -Definition field25519 : @field fe25519 eq zero_ one_ opp add sub mul inv div := proj1 field25519_and_homomorphisms. - -Lemma carry_field25519_and_homomorphisms - : @field fe25519 eq zero_ one_ carry_opp carry_add carry_sub mul inv div - /\ @Ring.is_homomorphism - (F modulus) Logic.eq F.one F.add F.mul - fe25519 eq one_ carry_add mul encode - /\ @Ring.is_homomorphism - fe25519 eq one_ carry_add mul - (F modulus) Logic.eq F.one F.add F.mul - decode. -Proof. - eapply @Field.field_and_homomorphism_from_redundant_representation. - { exact (F.field_modulo _). } - { apply encode_rep. } - { reflexivity. } - { reflexivity. } - { reflexivity. } - { intros; rewrite carry_opp_correct, carry_opp_opt_correct, carry_opp_rep; apply opp_rep; reflexivity. } - { intros; rewrite carry_add_correct, carry_add_opt_correct, carry_add_rep; apply add_rep; reflexivity. } - { intros; rewrite carry_sub_correct, carry_sub_opt_correct, carry_sub_rep; apply sub_rep; reflexivity. } - { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } - { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } - { intros; apply encode_rep. } -Qed. - -Definition carry_field25519 : @field fe25519 eq zero_ one_ carry_opp carry_add carry_sub mul inv div := proj1 carry_field25519_and_homomorphisms. - -Lemma homomorphism_F25519_encode - : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe25519 eq one add mul encode. -Proof. apply field25519_and_homomorphisms. Qed. - -Lemma homomorphism_F25519_decode - : @Ring.is_homomorphism fe25519 eq one add mul (F modulus) Logic.eq F.one F.add F.mul decode. -Proof. apply field25519_and_homomorphisms. Qed. - - -Lemma homomorphism_carry_F25519_encode - : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe25519 eq one carry_add mul encode. -Proof. apply carry_field25519_and_homomorphisms. Qed. - -Lemma homomorphism_carry_F25519_decode - : @Ring.is_homomorphism fe25519 eq one carry_add mul (F modulus) Logic.eq F.one F.add F.mul decode. -Proof. apply carry_field25519_and_homomorphisms. Qed. - -Definition ge_modulus_sig (f : fe25519) : - { b : Z | b = ge_modulus_opt (to_list 10 f) }. -Proof. - cbv [fe25519] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - eexists; cbv [ge_modulus_opt]. - rewrite !modulus_digits_subst. - cbv. - reflexivity. -Defined. - -Definition ge_modulus (f : fe25519) : Z := - Eval cbv beta iota delta [proj1_sig ge_modulus_sig] in - let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in - proj1_sig (ge_modulus_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9)). - -Definition ge_modulus_correct (f : fe25519) : - ge_modulus f = ge_modulus_opt (to_list 10 f). -Proof. - pose proof (proj2_sig (ge_modulus_sig f)). - cbv [fe25519] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - assumption. -Defined. - -Definition prefreeze_sig (f : fe25519) : - { f' : fe25519 | f' = from_list_default 0 10 (carry_full_3_opt c_ (to_list 10 f)) }. -Proof. - cbv [fe25519] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - eexists. - cbv - [from_list_default]. - (* TODO(jgross,jadep): use Reflective linearization here? *) - repeat ( - set_evars; rewrite app_Let_In_nd; subst_evars; - eapply Proper_Let_In_nd_changebody; [reflexivity|intro]). - cbv [from_list_default from_list_default']. - reflexivity. -Defined. - -Definition prefreeze (f : fe25519) : fe25519 := - Eval cbv beta iota delta [proj1_sig prefreeze_sig] in - let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in - proj1_sig (prefreeze_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9)). - -Definition prefreeze_correct (f : fe25519) - : prefreeze f = from_list_default 0 10 (carry_full_3_opt c_ (to_list 10 f)). -Proof. - pose proof (proj2_sig (prefreeze_sig f)). - cbv [fe25519] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - assumption. -Defined. - -Definition postfreeze_sig (f : fe25519) : - { f' : fe25519 | f' = from_list_default 0 10 (conditional_subtract_modulus_opt (int_width := int_width) (to_list 10 f)) }. -Proof. - cbv [fe25519] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - eexists; cbv [freeze_opt int_width]. - cbv [to_list to_list']. - cbv [conditional_subtract_modulus_opt]. - rewrite !modulus_digits_subst. - cbv - [from_list_default]. - (* TODO(jgross,jadep): use Reflective linearization here? *) - repeat ( - set_evars; rewrite app_Let_In_nd; subst_evars; - eapply Proper_Let_In_nd_changebody; [reflexivity|intro]). - cbv [from_list_default from_list_default']. - reflexivity. -Defined. - -Definition postfreeze (f : fe25519) : fe25519 := - Eval cbv beta iota delta [proj1_sig postfreeze_sig] in - let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in - proj1_sig (postfreeze_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9)). - -Definition postfreeze_correct (f : fe25519) - : postfreeze f = from_list_default 0 10 (conditional_subtract_modulus_opt (int_width := int_width) (to_list 10 f)). -Proof. - pose proof (proj2_sig (postfreeze_sig f)). - cbv [fe25519] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - assumption. -Defined. - -Definition freeze (f : fe25519) : fe25519 := - dlet x := prefreeze f in - postfreeze x. - -Local Transparent Let_In. -Definition freeze_correct (f : fe25519) - : freeze f = from_list_default 0 10 (freeze_opt (int_width := int_width) c_ (to_list 10 f)). -Proof. - cbv [freeze_opt freeze Let_In]. - rewrite prefreeze_correct. - rewrite postfreeze_correct. - match goal with - |- appcontext [to_list _ (from_list_default _ ?n ?xs)] => - assert (length xs = n) as pf; [ | rewrite from_list_default_eq with (pf0 := pf) ] end. - { rewrite carry_full_3_opt_correct; repeat rewrite ModularBaseSystemListProofs.length_carry_full; auto using length_to_list. } - rewrite to_list_from_list. - reflexivity. -Qed. -Local Opaque Let_In. - -Definition fieldwiseb_sig (f g : fe25519) : - { b | b = @fieldwiseb Z Z 10 Z.eqb f g }. -Proof. - cbv [fe25519] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - eexists. - cbv. - reflexivity. -Defined. - -Definition fieldwiseb (f g : fe25519) : bool - := Eval cbv beta iota delta [proj1_sig fieldwiseb_sig] in - let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in - let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9) := g in - proj1_sig (fieldwiseb_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) - (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9)). - -Lemma fieldwiseb_correct (f g : fe25519) - : fieldwiseb f g = @Tuple.fieldwiseb Z Z 10 Z.eqb f g. -Proof. - set (f' := f); set (g' := g). - hnf in f, g; destruct_head' prod. - exact (proj2_sig (fieldwiseb_sig f' g')). -Qed. - -Definition eqb_sig (f g : fe25519) : - { b | b = eqb int_width f g }. -Proof. - cbv [eqb]. - cbv [fe25519] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - eexists. - cbv [ModularBaseSystem.freeze int_width]. - rewrite <-!from_list_default_eq with (d := 0). - rewrite <-!(freeze_opt_correct c_) by auto using length_to_list. - rewrite <-!freeze_correct. - rewrite <-fieldwiseb_correct. - reflexivity. -Defined. - -Definition eqb (f g : fe25519) : bool - := Eval cbv beta iota delta [proj1_sig eqb_sig] in - let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in - let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9) := g in - proj1_sig (eqb_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) - (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9)). - -Lemma eqb_correct (f g : fe25519) - : eqb f g = ModularBaseSystem.eqb int_width f g. -Proof. - set (f' := f); set (g' := g). - hnf in f, g; destruct_head' prod. - exact (proj2_sig (eqb_sig f' g')). -Qed. - -Definition sqrt_sig (powf powf_squared f : fe25519) : - { f' : fe25519 | f' = sqrt_5mod8_opt (int_width := int_width) k_ c_ sqrt_m1 powf powf_squared f}. -Proof. - eexists. - cbv [sqrt_5mod8_opt int_width]. - apply Proper_Let_In_nd_changebody; [reflexivity|intro]. - set_evars. rewrite <-!mul_correct, <-eqb_correct. subst_evars. - reflexivity. -Defined. - -Definition sqrt (powf powf_squared f : fe25519) : fe25519 - := Eval cbv beta iota delta [proj1_sig sqrt_sig] in proj1_sig (sqrt_sig powf powf_squared f). - -Definition sqrt_correct (powf powf_squared f : fe25519) - : sqrt powf powf_squared f = sqrt_5mod8_opt k_ c_ sqrt_m1 powf powf_squared f - := Eval cbv beta iota delta [proj2_sig sqrt_sig] in proj2_sig (sqrt_sig powf powf_squared f). - -Definition pack_simpl_sig (f : fe25519) : - { f' | f' = pack_opt params25519 wire_widths_nonneg bits_eq f }. -Proof. - cbv [fe25519] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - eexists. - cbv [pack_opt]. - repeat (rewrite <-convert'_opt_correct; - cbv - [from_list_default_opt Conversion.convert']). - repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. - cbv [from_list_default_opt]. - reflexivity. -Defined. - -Definition pack_simpl (f : fe25519) := - Eval cbv beta iota delta [proj1_sig pack_simpl_sig] in - let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in - proj1_sig (pack_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9)). - -Definition pack_simpl_correct (f : fe25519) - : pack_simpl f = pack_opt params25519 wire_widths_nonneg bits_eq f. -Proof. - pose proof (proj2_sig (pack_simpl_sig f)). - cbv [fe25519] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - assumption. -Qed. - -Definition pack_sig (f : fe25519) : - { f' | f' = pack_opt params25519 wire_widths_nonneg bits_eq f }. -Proof. - eexists. - rewrite <-pack_simpl_correct. - rewrite <-(@app_10_correct wire_digits). - cbv. - reflexivity. -Defined. - -Definition pack (f : fe25519) : wire_digits := - Eval cbv beta iota delta [proj1_sig pack_sig] in proj1_sig (pack_sig f). - -Definition pack_correct (f : fe25519) - : pack f = pack_opt params25519 wire_widths_nonneg bits_eq f - := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (pack_sig f). - -Definition unpack_simpl_sig (f : wire_digits) : - { f' | f' = unpack_opt params25519 wire_widths_nonneg bits_eq f }. -Proof. - cbv [wire_digits] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - eexists. - cbv [unpack_opt]. - repeat ( - rewrite <-convert'_opt_correct; - cbv - [from_list_default_opt Conversion.convert']). - repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. - cbv [from_list_default_opt]. - reflexivity. -Defined. - -Definition unpack_simpl (f : wire_digits) : fe25519 := - Eval cbv beta iota delta [proj1_sig unpack_simpl_sig] in - let '(f0, f1, f2, f3, f4, f5, f6, f7) := f in - proj1_sig (unpack_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7)). - -Definition unpack_simpl_correct (f : wire_digits) - : unpack_simpl f = unpack_opt params25519 wire_widths_nonneg bits_eq f. -Proof. - pose proof (proj2_sig (unpack_simpl_sig f)). - cbv [wire_digits] in *. - repeat match goal with p : (_ * Z)%type |- _ => destruct p end. - assumption. -Qed. - -Definition unpack_sig (f : wire_digits) : - { f' | f' = unpack_opt params25519 wire_widths_nonneg bits_eq f }. -Proof. - eexists. - rewrite <-unpack_simpl_correct. - rewrite <-(@app_7_correct fe25519). - cbv. - reflexivity. -Defined. - -Definition unpack (f : wire_digits) : fe25519 := - Eval cbv beta iota delta [proj1_sig unpack_sig] in proj1_sig (unpack_sig f). - -Definition unpack_correct (f : wire_digits) - : unpack f = unpack_opt params25519 wire_widths_nonneg bits_eq f - := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (unpack_sig f). diff --git a/src/Testbit.v b/src/Testbit.v deleted file mode 100644 index 1da2c33e0..000000000 --- a/src/Testbit.v +++ /dev/null @@ -1,81 +0,0 @@ -Require Import Coq.Lists.List. -Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil. -Require Import Crypto.BaseSystem Crypto.BaseSystemProofs. -Require Import Crypto.ModularArithmetic.Pow2Base Crypto.ModularArithmetic.Pow2BaseProofs. -Require Import Coq.ZArith.ZArith Coq.ZArith.Zdiv. -Require Import Coq.omega.Omega Coq.Numbers.Natural.Peano.NPeano. -Require Import Coq.micromega.Psatz. -Require Import Crypto.Util.Tactics.UniquePose. -Require Coq.Arith.Arith. -Import Nat. -Local Open Scope Z. - -Section Testbit. - Context {width : Z} (limb_width_pos : 0 < width). - Context (limb_widths : list Z) (limb_widths_nonnil : limb_widths <> nil) - (limb_widths_uniform : forall w, In w limb_widths -> w = width). - Local Notation base := (base_from_limb_widths limb_widths). - - Definition testbit (us : list Z) (n : nat) := - Z.testbit (nth_default 0 us (n / (Z.to_nat width))) (Z.of_nat (n mod Z.to_nat width)%nat). - - Ltac zify_nat_hyp := - repeat match goal with - | H : ~ (_ < _)%nat |- _ => rewrite nlt_ge in H - | H : ~ (_ <= _)%nat |- _ => rewrite nle_gt in H - | H : ~ (_ > _)%nat |- _ => apply not_gt in H - | H : ~ (_ >= _)%nat |- _ => apply not_ge in H - | H : (_ < _)%nat |- _ => unique pose proof (proj1 (Nat2Z.inj_lt _ _) H) - | H : (_ <= _)%nat |- _ => unique pose proof (proj1 (Nat2Z.inj_le _ _) H) - | H : (_ > _)%nat |- _ => unique pose proof (proj1 (Nat2Z.inj_gt _ _) H) - | H : (_ >= _)%nat |- _ => unique pose proof (proj1 (Nat2Z.inj_ge _ _) H) - | H : ~ (_ = _ :> nat) |- _ => unique pose proof (fun x => H (Nat2Z.inj _ _ x)) - | H : (_ = _ :> nat) |- _ => unique pose proof (proj2 (Nat2Z.inj_iff _ _) H) - end. - - Lemma testbit_spec' : forall a b us, (0 <= b < width) -> - bounded limb_widths us -> (length us = length limb_widths)%nat -> - Z.testbit (nth_default 0 us a) b = Z.testbit (decode base us) (Z.of_nat a * width + b). - Proof using limb_width_pos limb_widths_uniform. - repeat match goal with - | |- _ => progress intros - | |- _ => progress autorewrite with push_nth_default Ztestbit zsimplify in * - | |- _ => progress change (Z.of_nat 0) with 0 in * - | [ H : In ?x ?ls, H' : forall x', In x' ?ls -> x' = _ |- _ ] - => is_var x; apply H' in H - | |- _ => rewrite Nat2Z.inj_succ, Z.mul_succ_l - | |- _ => rewrite nth_default_out_of_bounds by omega - | |- _ => rewrite nth_default_uniform_base by omega - | |- false = Z.testbit (decode _ _) _ => rewrite testbit_decode_high - | |- _ => rewrite (@sum_firstn_uniform_base width) by (eassumption || omega) - | |- _ => rewrite sum_firstn_succ_default - | |- Z.testbit (nth_default _ _ ?x) _ = Z.testbit (decode _ _) _ => - destruct (lt_dec x (length limb_widths)); - [ erewrite testbit_decode_digit_select with (i := x); eauto | ] - | |- _ => reflexivity - | |- _ => assumption - | |- _ => zify_nat_hyp; omega - | |- ?a * ?b <= ?c * ?b + ?d => transitivity (c * b); [ | omega ] - | |- ?a * ?b <= ?c * ?b => apply Z.mul_le_mono_pos_r - | |- _ => solve [auto] - | |- _ => solve [eapply uniform_limb_widths_nonneg; eauto] - end. - Qed. - - Hint Rewrite div_add_l' mod_add_l mod_add_l' mod_div_eq0 add_0_r mod_mod : nat_mod_div. - - Lemma testbit_spec : forall n us, (length us = length limb_widths)%nat -> - bounded limb_widths us -> - testbit us n = Z.testbit (BaseSystem.decode base us) (Z.of_nat n). - Proof using limb_width_pos limb_widths_uniform. - cbv [testbit]; intros. - pose proof limb_width_pos as limb_width_pos_nat. - rewrite Z2Nat.inj_lt in limb_width_pos_nat by omega. - rewrite (Nat.div_mod n (Z.to_nat width)) by omega. - autorewrite with nat_mod_div; try omega. - rewrite testbit_spec' by (rewrite ?mod_Zmod, ?Z2Nat.id; try apply Z.mod_pos_bound; omega || auto). - f_equal. - rewrite Nat2Z.inj_add, Nat2Z.inj_mul, Z2Nat.id; ring || omega. - Qed. - -End Testbit. |