aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/ModularArithmetic/ModularBaseSystem.v9
-rw-r--r--src/ModularArithmetic/ModularBaseSystemProofs.v223
-rw-r--r--src/ModularArithmetic/PseudoMersenneBaseParamProofs.v70
3 files changed, 194 insertions, 108 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v
index e771e7eb4..2f264fa6c 100644
--- a/src/ModularArithmetic/ModularBaseSystem.v
+++ b/src/ModularArithmetic/ModularBaseSystem.v
@@ -11,6 +11,7 @@ Require Import Crypto.ModularArithmetic.ModularBaseSystemList.
Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs.
Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
Require Import Crypto.Util.Tuple.
+Require Import Crypto.Util.AdditionChainExponentiation.
Require Import Crypto.Util.Notations.
Require Import Crypto.Tactics.VerdiTactics.
Local Open Scope Z_scope.
@@ -45,8 +46,12 @@ Section ModularBaseSystem.
(* Placeholder *)
Definition opp (x : digits) : digits := encode (F.opp (decode x)).
- (* Placeholder *)
- Definition inv (x : digits) : digits := encode (F.inv (decode x)).
+ Definition pow (x : digits) (chain : list (nat * nat)) : digits :=
+ fold_chain one mul chain (x :: nil).
+
+ Definition inv (chain : list (nat * nat))
+ (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus - 2))
+ (x : digits) : digits := pow x chain.
(* Placeholder *)
Definition div (x y : digits) : digits := encode (F.div (decode x) (decode y)).
diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v
index 115f04c92..4543cde2e 100644
--- a/src/ModularArithmetic/ModularBaseSystemProofs.v
+++ b/src/ModularArithmetic/ModularBaseSystemProofs.v
@@ -2,6 +2,7 @@ Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith.
Require Import Coq.Numbers.Natural.Peano.NPeano.
Require Import Coq.Lists.List.
Require Import Crypto.Tactics.VerdiTactics.
+Require Import Crypto.Algebra.
Require Import Crypto.BaseSystem.
Require Import Crypto.BaseSystemProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
@@ -14,6 +15,7 @@ Require Import Crypto.ModularArithmetic.ModularBaseSystemList.
Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs.
Require Import Crypto.ModularArithmetic.ModularBaseSystem.
Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil Crypto.Util.NatUtil.
+Require Import Crypto.Util.AdditionChainExponentiation.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.Tactics.
Require Import Crypto.Util.Notations.
@@ -28,7 +30,7 @@ Class CarryChain (limb_widths : list Z) :=
carry_chain_valid : forall i, In i carry_chain -> (i < length limb_widths)%nat
}.
-Section PseudoMersenneProofs.
+Section FieldOperationProofs.
Context `{prm :PseudoMersenneBaseParams}.
Local Arguments to_list {_ _} _.
@@ -40,6 +42,7 @@ Section PseudoMersenneProofs.
Local Hint Resolve (@limb_widths_nonneg _ prm) sum_firstn_limb_widths_nonneg.
Local Hint Resolve log_cap_nonneg.
+ Local Hint Resolve base_from_limb_widths_length.
Local Notation base := (base_from_limb_widths limb_widths).
Local Notation log_cap i := (nth_default 0 limb_widths i).
@@ -55,39 +58,12 @@ Section PseudoMersenneProofs.
cbv [rep]; auto.
Qed.
- Lemma lt_modulus_2k : modulus < 2 ^ k.
- Proof.
- replace (2 ^ k) with (modulus + c) by (unfold c; ring).
- pose proof c_pos; omega.
- Qed. Hint Resolve lt_modulus_2k.
-
- Lemma modulus_pos : 0 < modulus.
- Proof.
- pose proof (NumTheoryUtil.lt_1_p _ prime_modulus); omega.
- Qed. Hint Resolve modulus_pos.
-
- (** TODO(jadep, from jgross): The abstraction barrier of
- [base]/[limb_widths] is repeatedly broken in the following
- proofs. This lemma should almost never be needed, but removing
- it breaks everything. (And using [distr_length] is too much of
- a sledgehammer, and demolishes the abstraction barrier that's
- currently merely in pieces.) *)
- Lemma base_length : length base = length limb_widths.
- Proof. distr_length. Qed.
-
- Lemma base_length_nonzero : length base <> 0%nat.
- Proof.
- distr_length.
- pose proof limb_widths_nonnil.
- destruct limb_widths; simpl in *; congruence.
- Qed.
-
Lemma encode_eq : forall x : F modulus,
ModularBaseSystemList.encode x = BaseSystem.encode base (F.to_Z x) (2 ^ k).
Proof.
cbv [ModularBaseSystemList.encode BaseSystem.encode encodeZ]; intros.
- rewrite base_length.
- apply encode'_spec; auto using Nat.eq_le_incl, base_length.
+ rewrite base_from_limb_widths_length.
+ apply encode'_spec; auto using Nat.eq_le_incl.
Qed.
Lemma encode_rep : forall x : F modulus, encode x ~= x.
@@ -114,51 +90,47 @@ Section PseudoMersenneProofs.
f_equal; assumption.
Qed.
- Local Hint Resolve firstn_us_base_ext_base bv ExtBaseVector limb_widths_match_modulus.
- Local Hint Extern 1 => apply limb_widths_match_modulus.
-
- Lemma modulus_nonzero : modulus <> 0.
- pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega.
+ Lemma eq_rep_iff : forall u v, (eq u v <-> u ~= decode v).
+ Proof.
+ reflexivity.
Qed.
- (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *)
- Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus.
+ Lemma eq_dec : forall x y, Decidable.Decidable (eq x y).
Proof.
intros.
- replace (2^k) with ((2^k - c) + c) by ring.
- rewrite Z.mul_add_distr_r, Zplus_mod.
- unfold c.
- rewrite Z.sub_sub_distr, Z.sub_diag.
- simpl.
- rewrite Z.mul_comm, Z.mod_add_l; auto using modulus_nonzero.
- rewrite <- Zplus_mod; auto.
+ destruct (F.eq_dec (decode x) (decode y)); [ left | right ]; congruence.
Qed.
- Lemma pseudomersenne_add': forall x y0 y1 z, (z - x + ((2^k) * y0 * y1)) mod modulus = (c * y0 * y1 - x + z) mod modulus.
+ Lemma modular_base_system_add_monoid : @monoid digits eq add zero.
Proof.
- intros; rewrite <- !Z.add_opp_r, <- !Z.mul_assoc, pseudomersenne_add; apply f_equal2; omega.
+ repeat match goal with
+ | |- _ => progress intro
+ | |- _ => cbv [zero]; rewrite encode_rep
+ | |- _ digits eq add => econstructor
+ | |- _ digits eq add _ => econstructor
+ | |- (_ + _)%F = decode (add ?a ?b) => rewrite (add_rep a b) by (try apply add_rep; reflexivity)
+ | |- eq _ _ => apply eq_rep_iff
+ | |- add _ _ ~= _ => apply add_rep
+ | |- decode (add _ _) = _ => apply add_rep
+ | |- add _ _ ~= decode _ => etransitivity
+ | x : digits |- ?x ~= _ => reflexivity
+ | |- _ => apply associative
+ | |- _ => apply left_identity
+ | |- _ => apply right_identity
+ | |- _ => solve [eauto using eq_Equivalence, eq_dec]
+ | |- _ => congruence
+ end.
Qed.
- Lemma extended_shiftadd: forall (us : BaseSystem.digits),
- BaseSystem.decode (ext_base limb_widths) us =
- BaseSystem.decode base (firstn (length base) us)
- + (2^k * BaseSystem.decode base (skipn (length base) us)).
- Proof.
- intros.
- unfold BaseSystem.decode; rewrite <- mul_each_rep.
- rewrite ext_base_alt by auto.
- fold k.
- replace (map (Z.mul (2 ^ k)) base) with (BaseSystem.mul_each (2 ^ k) base) by auto.
- rewrite base_mul_app.
- rewrite <- mul_each_rep; auto.
- Qed.
+ Local Hint Resolve firstn_us_base_ext_base bv ExtBaseVector limb_widths_match_modulus.
+ Local Hint Extern 1 => apply limb_widths_match_modulus.
Lemma reduce_rep : forall us,
BaseSystem.decode base (reduce us) mod modulus =
BaseSystem.decode (ext_base limb_widths) us mod modulus.
Proof.
cbv [reduce]; intros.
- rewrite extended_shiftadd, base_length, pseudomersenne_add, BaseSystemProofs.add_rep.
+ rewrite extended_shiftadd, base_from_limb_widths_length, pseudomersenne_add, BaseSystemProofs.add_rep.
change (map (Z.mul c)) with (BaseSystem.mul_each c).
rewrite mul_each_rep; auto.
Qed.
@@ -177,25 +149,25 @@ Section PseudoMersenneProofs.
apply F.of_Z_mul.
Qed.
- Lemma nth_default_base_positive : forall i, (i < length base)%nat ->
- nth_default 0 base i > 0.
+ Lemma modular_base_system_mul_monoid : @monoid digits eq mul one.
Proof.
- intros.
- pose proof (nth_error_length_exists_value _ _ H).
- destruct H0.
- pose proof (nth_error_value_In _ _ _ H0).
- pose proof (BaseSystem.base_positive _ H1).
- unfold nth_default.
- rewrite H0; auto.
- Qed.
-
- Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat ->
- nth_default 0 base (S i) = nth_default 0 base i *
- (nth_default 0 base (S i) / nth_default 0 base i).
- Proof.
- intros.
- apply Z_div_exact_2; try (apply nth_default_base_positive; omega).
- apply base_succ; distr_length; eauto.
+ repeat match goal with
+ | |- _ => progress intro
+ | |- _ => cbv [one]; rewrite encode_rep
+ | |- _ digits eq mul => econstructor
+ | |- _ digits eq mul _ => econstructor
+ | |- (_ * _)%F = decode (mul ?a ?b) => rewrite (mul_rep a b) by (try apply mul_rep; reflexivity)
+ | |- eq _ _ => apply eq_rep_iff
+ | |- mul _ _ ~= _ => apply mul_rep
+ | |- decode (mul _ _) = _ => apply mul_rep
+ | |- mul _ _ ~= decode _ => etransitivity
+ | x : digits |- ?x ~= _ => reflexivity
+ | |- _ => apply associative
+ | |- _ => apply left_identity
+ | |- _ => apply right_identity
+ | |- _ => solve [eauto using eq_Equivalence, eq_dec]
+ | |- _ => congruence
+ end.
Qed.
Lemma Fdecode_decode_mod : forall us x,
@@ -206,31 +178,7 @@ Section PseudoMersenneProofs.
apply F.to_Z_of_Z.
Qed.
- Definition carry_done us := forall i, (i < length base)%nat ->
- 0 <= nth_default 0 us i /\ Z.shiftr (nth_default 0 us i) (log_cap i) = 0.
-
- Lemma carry_done_bounds : forall us, (length us = length base) ->
- (carry_done us <-> forall i, 0 <= nth_default 0 us i < 2 ^ log_cap i).
- Proof.
- intros ? ?; unfold carry_done; split; [ intros Hcarry_done i | intros Hbounds i i_lt ].
- + destruct (lt_dec i (length base)) as [i_lt | i_nlt].
- - specialize (Hcarry_done i i_lt).
- split; [ intuition | ].
- destruct Hcarry_done as [Hnth_nonneg Hshiftr_0].
- apply Z.shiftr_eq_0_iff in Hshiftr_0.
- destruct Hshiftr_0 as [nth_0 | [] ]; [ rewrite nth_0; zero_bounds | ].
- apply Z.log2_lt_pow2; auto.
- - rewrite nth_default_out_of_bounds by omega.
- split; zero_bounds.
- + specialize (Hbounds i).
- split; [ intuition | ].
- destruct Hbounds as [nth_nonneg nth_lt_pow2].
- apply Z.shiftr_eq_0_iff.
- apply Z.le_lteq in nth_nonneg; destruct nth_nonneg; try solve [left; auto].
- right; split; auto.
- apply Z.log2_lt_pow2; auto.
- Qed.
-
+ Section Subtraction.
Context (mm : digits) (mm_spec : decode mm = 0%F).
Lemma sub_rep : forall u v x y, u ~= x -> v ~= y ->
@@ -245,9 +193,45 @@ Section PseudoMersenneProofs.
rewrite mm_spec. rewrite Algebra.left_identity.
f_equal; assumption.
Qed.
+ End Subtraction.
+
+ Section PowInv.
+ Context (modulus_gt_2 : 2 < modulus).
-End PseudoMersenneProofs.
-Opaque encode add mul sub.
+ Lemma scalarmult_rep : forall u x n, u ~= x ->
+ (@ScalarMult.scalarmult_ref digits mul one n u) ~= (x ^ (N.of_nat n))%F.
+ Proof.
+ induction n; intros.
+ + cbv [N.to_nat ScalarMult.scalarmult_ref]. rewrite F.pow_0_r.
+ apply encode_rep.
+ + unfold ScalarMult.scalarmult_ref.
+ fold (@ScalarMult.scalarmult_ref digits mul one).
+ rewrite Nnat.Nat2N.inj_succ, <-N.add_1_l, F.pow_add_r, F.pow_1_r.
+ apply mul_rep; auto.
+ Qed.
+
+ Lemma pow_rep : forall chain u x, u ~= x ->
+ pow u chain ~= F.pow x (fold_chain 0%N N.add chain (1%N :: nil)).
+ Proof.
+ cbv [pow rep]; intros.
+ erewrite (@fold_chain_exp _ _ _ _ modular_base_system_mul_monoid)
+ by (apply @ScalarMult.scalarmult_ref_is_scalarmult; apply modular_base_system_mul_monoid).
+ etransitivity; [ apply scalarmult_rep; eassumption | ].
+ rewrite Nnat.N2Nat.id.
+ reflexivity.
+ Qed.
+
+ Lemma inv_rep : forall chain pf u x, u ~= x ->
+ inv chain pf u ~= F.inv x.
+ Proof.
+ cbv [inv]; intros.
+ rewrite (@F.Fq_inv_fermat _ prime_modulus modulus_gt_2).
+ etransitivity; [ apply pow_rep; eassumption | ].
+ congruence.
+ Qed.
+ End PowInv.
+End FieldOperationProofs.
+Opaque encode add mul sub inv pow.
Section CarryProofs.
Context `{prm : PseudoMersenneBaseParams}.
@@ -255,13 +239,40 @@ Section CarryProofs.
Local Notation log_cap i := (nth_default 0 limb_widths i).
Local Notation "u ~= x" := (rep u x).
Local Hint Resolve (@limb_widths_nonneg _ prm) sum_firstn_limb_widths_nonneg.
+ Local Hint Resolve log_cap_nonneg.
Lemma base_length_lt_pred : (pred (length base) < length base)%nat.
Proof.
- pose proof base_length_nonzero; omega.
+ pose proof limb_widths_nonnil; rewrite base_from_limb_widths_length.
+ destruct limb_widths; congruence || distr_length.
Qed.
Hint Resolve base_length_lt_pred.
+ Definition carry_done us := forall i, (i < length base)%nat ->
+ 0 <= nth_default 0 us i /\ Z.shiftr (nth_default 0 us i) (log_cap i) = 0.
+
+ Lemma carry_done_bounds : forall us, (length us = length base) ->
+ (carry_done us <-> forall i, 0 <= nth_default 0 us i < 2 ^ log_cap i).
+ Proof.
+ intros ? ?; unfold carry_done; split; [ intros Hcarry_done i | intros Hbounds i i_lt ].
+ + destruct (lt_dec i (length base)) as [i_lt | i_nlt].
+ - specialize (Hcarry_done i i_lt).
+ split; [ intuition | ].
+ destruct Hcarry_done as [Hnth_nonneg Hshiftr_0].
+ apply Z.shiftr_eq_0_iff in Hshiftr_0.
+ destruct Hshiftr_0 as [nth_0 | [] ]; [ rewrite nth_0; zero_bounds | ].
+ apply Z.log2_lt_pow2; auto.
+ - rewrite nth_default_out_of_bounds by omega.
+ split; zero_bounds.
+ + specialize (Hbounds i).
+ split; [ intuition | ].
+ destruct Hbounds as [nth_nonneg nth_lt_pow2].
+ apply Z.shiftr_eq_0_iff.
+ apply Z.le_lteq in nth_nonneg; destruct nth_nonneg; try solve [left; auto].
+ right; split; auto.
+ apply Z.log2_lt_pow2; auto.
+ Qed.
+
Lemma carry_decode_eq_reduce : forall us,
(length us = length limb_widths) ->
BaseSystem.decode base (carry_and_reduce (pred (length limb_widths)) us) mod modulus
diff --git a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
index 14482fe5e..4b3af84e1 100644
--- a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
+++ b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
@@ -1,9 +1,12 @@
Require Import Zpower ZArith.
Require Import List.
Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil.
+Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import VerdiTactics.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
+Require Import Crypto.BaseSystem.
+Require Import Crypto.BaseSystemProofs.
Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.ModularArithmetic.Pow2BaseProofs.
Require Crypto.BaseSystem.
@@ -19,10 +22,77 @@ Section PseudoMersenneBaseParamProofs.
Lemma k_nonneg : 0 <= k.
Proof. apply sum_firstn_limb_widths_nonneg, limb_widths_nonneg. Qed.
+ Lemma lt_modulus_2k : modulus < 2 ^ k.
+ Proof.
+ replace (2 ^ k) with (modulus + c) by (unfold c; ring).
+ pose proof c_pos; omega.
+ Qed. Hint Resolve lt_modulus_2k.
+
+ Lemma modulus_pos : 0 < modulus.
+ Proof.
+ pose proof (NumTheoryUtil.lt_1_p _ prime_modulus); omega.
+ Qed. Hint Resolve modulus_pos.
+
+ Lemma modulus_nonzero : modulus <> 0.
+ pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega.
+ Qed.
+
+ (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *)
+ Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus.
+ Proof.
+ intros.
+ replace (2^k) with ((2^k - c) + c) by ring.
+ rewrite Z.mul_add_distr_r, Zplus_mod.
+ unfold c.
+ rewrite Z.sub_sub_distr, Z.sub_diag.
+ simpl.
+ rewrite Z.mul_comm, Z.mod_add_l; auto using modulus_nonzero.
+ rewrite <- Zplus_mod; auto.
+ Qed.
+
+ Lemma pseudomersenne_add': forall x y0 y1 z, (z - x + ((2^k) * y0 * y1)) mod modulus = (c * y0 * y1 - x + z) mod modulus.
+ Proof.
+ intros; rewrite <- !Z.add_opp_r, <- !Z.mul_assoc, pseudomersenne_add; apply f_equal2; omega.
+ Qed.
+
+ Lemma extended_shiftadd: forall (us : digits),
+ decode (ext_base limb_widths) us =
+ decode base (firstn (length base) us)
+ + (2 ^ k * decode base (skipn (length base) us)).
+ Proof.
+ intros.
+ unfold decode; rewrite <- mul_each_rep.
+ rewrite ext_base_alt by apply limb_widths_nonneg.
+ fold k; fold (mul_each (2 ^ k) base).
+ rewrite base_mul_app.
+ rewrite <- mul_each_rep; auto.
+ Qed.
+
Global Instance bv : BaseSystem.BaseVector base := {
base_positive := base_positive limb_widths_nonneg;
b0_1 := fun x => b0_1 x limb_widths_nonnil;
base_good := base_good limb_widths_nonneg limb_widths_good
}.
+ Lemma nth_default_base_positive : forall i, (i < length base)%nat ->
+ nth_default 0 base i > 0.
+ Proof.
+ intros.
+ pose proof (nth_error_length_exists_value _ _ H).
+ destruct H0.
+ pose proof (nth_error_value_In _ _ _ H0).
+ pose proof (BaseSystem.base_positive _ H1).
+ unfold nth_default.
+ rewrite H0; auto.
+ Qed.
+
+ Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat ->
+ nth_default 0 base (S i) = nth_default 0 base i *
+ (nth_default 0 base (S i) / nth_default 0 base i).
+ Proof.
+ intros.
+ apply Z_div_exact_2; try (apply nth_default_base_positive; omega).
+ apply base_succ; distr_length; eauto using limb_widths_nonneg.
+ Qed.
+
End PseudoMersenneBaseParamProofs.