From b747ee7379b1529e8d356b3b6a2526ef7bfefff2 Mon Sep 17 00:00:00 2001 From: jadep Date: Tue, 23 Aug 2016 13:44:06 -0400 Subject: Defined real versions of [pow] and [inv] in ModularBaseSystem, replacing placeholders, and proved their correctness. In the process, reorganized early parts of ModularBaseSystemProofs.v by moving some lemmas to PseudoMersenneBaseParamProofs.v and introducing lemmas about the algebraic properties of ModularBaseSystem operations. --- src/ModularArithmetic/ModularBaseSystem.v | 9 +- src/ModularArithmetic/ModularBaseSystemProofs.v | 223 +++++++++++---------- .../PseudoMersenneBaseParamProofs.v | 70 +++++++ 3 files changed, 194 insertions(+), 108 deletions(-) (limited to 'src/ModularArithmetic') diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index e771e7eb4..2f264fa6c 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -11,6 +11,7 @@ Require Import Crypto.ModularArithmetic.ModularBaseSystemList. Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs. Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.AdditionChainExponentiation. Require Import Crypto.Util.Notations. Require Import Crypto.Tactics.VerdiTactics. Local Open Scope Z_scope. @@ -45,8 +46,12 @@ Section ModularBaseSystem. (* Placeholder *) Definition opp (x : digits) : digits := encode (F.opp (decode x)). - (* Placeholder *) - Definition inv (x : digits) : digits := encode (F.inv (decode x)). + Definition pow (x : digits) (chain : list (nat * nat)) : digits := + fold_chain one mul chain (x :: nil). + + Definition inv (chain : list (nat * nat)) + (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus - 2)) + (x : digits) : digits := pow x chain. (* Placeholder *) Definition div (x y : digits) : digits := encode (F.div (decode x) (decode y)). diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index 115f04c92..4543cde2e 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -2,6 +2,7 @@ Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith. Require Import Coq.Numbers.Natural.Peano.NPeano. Require Import Coq.Lists.List. Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Algebra. Require Import Crypto.BaseSystem. Require Import Crypto.BaseSystemProofs. Require Import Crypto.ModularArithmetic.ExtendedBaseVector. @@ -14,6 +15,7 @@ Require Import Crypto.ModularArithmetic.ModularBaseSystemList. Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs. Require Import Crypto.ModularArithmetic.ModularBaseSystem. Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil Crypto.Util.NatUtil. +Require Import Crypto.Util.AdditionChainExponentiation. Require Import Crypto.Util.Tuple. Require Import Crypto.Util.Tactics. Require Import Crypto.Util.Notations. @@ -28,7 +30,7 @@ Class CarryChain (limb_widths : list Z) := carry_chain_valid : forall i, In i carry_chain -> (i < length limb_widths)%nat }. -Section PseudoMersenneProofs. +Section FieldOperationProofs. Context `{prm :PseudoMersenneBaseParams}. Local Arguments to_list {_ _} _. @@ -40,6 +42,7 @@ Section PseudoMersenneProofs. Local Hint Resolve (@limb_widths_nonneg _ prm) sum_firstn_limb_widths_nonneg. Local Hint Resolve log_cap_nonneg. + Local Hint Resolve base_from_limb_widths_length. Local Notation base := (base_from_limb_widths limb_widths). Local Notation log_cap i := (nth_default 0 limb_widths i). @@ -55,39 +58,12 @@ Section PseudoMersenneProofs. cbv [rep]; auto. Qed. - Lemma lt_modulus_2k : modulus < 2 ^ k. - Proof. - replace (2 ^ k) with (modulus + c) by (unfold c; ring). - pose proof c_pos; omega. - Qed. Hint Resolve lt_modulus_2k. - - Lemma modulus_pos : 0 < modulus. - Proof. - pose proof (NumTheoryUtil.lt_1_p _ prime_modulus); omega. - Qed. Hint Resolve modulus_pos. - - (** TODO(jadep, from jgross): The abstraction barrier of - [base]/[limb_widths] is repeatedly broken in the following - proofs. This lemma should almost never be needed, but removing - it breaks everything. (And using [distr_length] is too much of - a sledgehammer, and demolishes the abstraction barrier that's - currently merely in pieces.) *) - Lemma base_length : length base = length limb_widths. - Proof. distr_length. Qed. - - Lemma base_length_nonzero : length base <> 0%nat. - Proof. - distr_length. - pose proof limb_widths_nonnil. - destruct limb_widths; simpl in *; congruence. - Qed. - Lemma encode_eq : forall x : F modulus, ModularBaseSystemList.encode x = BaseSystem.encode base (F.to_Z x) (2 ^ k). Proof. cbv [ModularBaseSystemList.encode BaseSystem.encode encodeZ]; intros. - rewrite base_length. - apply encode'_spec; auto using Nat.eq_le_incl, base_length. + rewrite base_from_limb_widths_length. + apply encode'_spec; auto using Nat.eq_le_incl. Qed. Lemma encode_rep : forall x : F modulus, encode x ~= x. @@ -114,51 +90,47 @@ Section PseudoMersenneProofs. f_equal; assumption. Qed. - Local Hint Resolve firstn_us_base_ext_base bv ExtBaseVector limb_widths_match_modulus. - Local Hint Extern 1 => apply limb_widths_match_modulus. - - Lemma modulus_nonzero : modulus <> 0. - pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega. + Lemma eq_rep_iff : forall u v, (eq u v <-> u ~= decode v). + Proof. + reflexivity. Qed. - (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *) - Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus. + Lemma eq_dec : forall x y, Decidable.Decidable (eq x y). Proof. intros. - replace (2^k) with ((2^k - c) + c) by ring. - rewrite Z.mul_add_distr_r, Zplus_mod. - unfold c. - rewrite Z.sub_sub_distr, Z.sub_diag. - simpl. - rewrite Z.mul_comm, Z.mod_add_l; auto using modulus_nonzero. - rewrite <- Zplus_mod; auto. + destruct (F.eq_dec (decode x) (decode y)); [ left | right ]; congruence. Qed. - Lemma pseudomersenne_add': forall x y0 y1 z, (z - x + ((2^k) * y0 * y1)) mod modulus = (c * y0 * y1 - x + z) mod modulus. + Lemma modular_base_system_add_monoid : @monoid digits eq add zero. Proof. - intros; rewrite <- !Z.add_opp_r, <- !Z.mul_assoc, pseudomersenne_add; apply f_equal2; omega. + repeat match goal with + | |- _ => progress intro + | |- _ => cbv [zero]; rewrite encode_rep + | |- _ digits eq add => econstructor + | |- _ digits eq add _ => econstructor + | |- (_ + _)%F = decode (add ?a ?b) => rewrite (add_rep a b) by (try apply add_rep; reflexivity) + | |- eq _ _ => apply eq_rep_iff + | |- add _ _ ~= _ => apply add_rep + | |- decode (add _ _) = _ => apply add_rep + | |- add _ _ ~= decode _ => etransitivity + | x : digits |- ?x ~= _ => reflexivity + | |- _ => apply associative + | |- _ => apply left_identity + | |- _ => apply right_identity + | |- _ => solve [eauto using eq_Equivalence, eq_dec] + | |- _ => congruence + end. Qed. - Lemma extended_shiftadd: forall (us : BaseSystem.digits), - BaseSystem.decode (ext_base limb_widths) us = - BaseSystem.decode base (firstn (length base) us) - + (2^k * BaseSystem.decode base (skipn (length base) us)). - Proof. - intros. - unfold BaseSystem.decode; rewrite <- mul_each_rep. - rewrite ext_base_alt by auto. - fold k. - replace (map (Z.mul (2 ^ k)) base) with (BaseSystem.mul_each (2 ^ k) base) by auto. - rewrite base_mul_app. - rewrite <- mul_each_rep; auto. - Qed. + Local Hint Resolve firstn_us_base_ext_base bv ExtBaseVector limb_widths_match_modulus. + Local Hint Extern 1 => apply limb_widths_match_modulus. Lemma reduce_rep : forall us, BaseSystem.decode base (reduce us) mod modulus = BaseSystem.decode (ext_base limb_widths) us mod modulus. Proof. cbv [reduce]; intros. - rewrite extended_shiftadd, base_length, pseudomersenne_add, BaseSystemProofs.add_rep. + rewrite extended_shiftadd, base_from_limb_widths_length, pseudomersenne_add, BaseSystemProofs.add_rep. change (map (Z.mul c)) with (BaseSystem.mul_each c). rewrite mul_each_rep; auto. Qed. @@ -177,25 +149,25 @@ Section PseudoMersenneProofs. apply F.of_Z_mul. Qed. - Lemma nth_default_base_positive : forall i, (i < length base)%nat -> - nth_default 0 base i > 0. + Lemma modular_base_system_mul_monoid : @monoid digits eq mul one. Proof. - intros. - pose proof (nth_error_length_exists_value _ _ H). - destruct H0. - pose proof (nth_error_value_In _ _ _ H0). - pose proof (BaseSystem.base_positive _ H1). - unfold nth_default. - rewrite H0; auto. - Qed. - - Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat -> - nth_default 0 base (S i) = nth_default 0 base i * - (nth_default 0 base (S i) / nth_default 0 base i). - Proof. - intros. - apply Z_div_exact_2; try (apply nth_default_base_positive; omega). - apply base_succ; distr_length; eauto. + repeat match goal with + | |- _ => progress intro + | |- _ => cbv [one]; rewrite encode_rep + | |- _ digits eq mul => econstructor + | |- _ digits eq mul _ => econstructor + | |- (_ * _)%F = decode (mul ?a ?b) => rewrite (mul_rep a b) by (try apply mul_rep; reflexivity) + | |- eq _ _ => apply eq_rep_iff + | |- mul _ _ ~= _ => apply mul_rep + | |- decode (mul _ _) = _ => apply mul_rep + | |- mul _ _ ~= decode _ => etransitivity + | x : digits |- ?x ~= _ => reflexivity + | |- _ => apply associative + | |- _ => apply left_identity + | |- _ => apply right_identity + | |- _ => solve [eauto using eq_Equivalence, eq_dec] + | |- _ => congruence + end. Qed. Lemma Fdecode_decode_mod : forall us x, @@ -206,31 +178,7 @@ Section PseudoMersenneProofs. apply F.to_Z_of_Z. Qed. - Definition carry_done us := forall i, (i < length base)%nat -> - 0 <= nth_default 0 us i /\ Z.shiftr (nth_default 0 us i) (log_cap i) = 0. - - Lemma carry_done_bounds : forall us, (length us = length base) -> - (carry_done us <-> forall i, 0 <= nth_default 0 us i < 2 ^ log_cap i). - Proof. - intros ? ?; unfold carry_done; split; [ intros Hcarry_done i | intros Hbounds i i_lt ]. - + destruct (lt_dec i (length base)) as [i_lt | i_nlt]. - - specialize (Hcarry_done i i_lt). - split; [ intuition | ]. - destruct Hcarry_done as [Hnth_nonneg Hshiftr_0]. - apply Z.shiftr_eq_0_iff in Hshiftr_0. - destruct Hshiftr_0 as [nth_0 | [] ]; [ rewrite nth_0; zero_bounds | ]. - apply Z.log2_lt_pow2; auto. - - rewrite nth_default_out_of_bounds by omega. - split; zero_bounds. - + specialize (Hbounds i). - split; [ intuition | ]. - destruct Hbounds as [nth_nonneg nth_lt_pow2]. - apply Z.shiftr_eq_0_iff. - apply Z.le_lteq in nth_nonneg; destruct nth_nonneg; try solve [left; auto]. - right; split; auto. - apply Z.log2_lt_pow2; auto. - Qed. - + Section Subtraction. Context (mm : digits) (mm_spec : decode mm = 0%F). Lemma sub_rep : forall u v x y, u ~= x -> v ~= y -> @@ -245,9 +193,45 @@ Section PseudoMersenneProofs. rewrite mm_spec. rewrite Algebra.left_identity. f_equal; assumption. Qed. + End Subtraction. + + Section PowInv. + Context (modulus_gt_2 : 2 < modulus). -End PseudoMersenneProofs. -Opaque encode add mul sub. + Lemma scalarmult_rep : forall u x n, u ~= x -> + (@ScalarMult.scalarmult_ref digits mul one n u) ~= (x ^ (N.of_nat n))%F. + Proof. + induction n; intros. + + cbv [N.to_nat ScalarMult.scalarmult_ref]. rewrite F.pow_0_r. + apply encode_rep. + + unfold ScalarMult.scalarmult_ref. + fold (@ScalarMult.scalarmult_ref digits mul one). + rewrite Nnat.Nat2N.inj_succ, <-N.add_1_l, F.pow_add_r, F.pow_1_r. + apply mul_rep; auto. + Qed. + + Lemma pow_rep : forall chain u x, u ~= x -> + pow u chain ~= F.pow x (fold_chain 0%N N.add chain (1%N :: nil)). + Proof. + cbv [pow rep]; intros. + erewrite (@fold_chain_exp _ _ _ _ modular_base_system_mul_monoid) + by (apply @ScalarMult.scalarmult_ref_is_scalarmult; apply modular_base_system_mul_monoid). + etransitivity; [ apply scalarmult_rep; eassumption | ]. + rewrite Nnat.N2Nat.id. + reflexivity. + Qed. + + Lemma inv_rep : forall chain pf u x, u ~= x -> + inv chain pf u ~= F.inv x. + Proof. + cbv [inv]; intros. + rewrite (@F.Fq_inv_fermat _ prime_modulus modulus_gt_2). + etransitivity; [ apply pow_rep; eassumption | ]. + congruence. + Qed. + End PowInv. +End FieldOperationProofs. +Opaque encode add mul sub inv pow. Section CarryProofs. Context `{prm : PseudoMersenneBaseParams}. @@ -255,13 +239,40 @@ Section CarryProofs. Local Notation log_cap i := (nth_default 0 limb_widths i). Local Notation "u ~= x" := (rep u x). Local Hint Resolve (@limb_widths_nonneg _ prm) sum_firstn_limb_widths_nonneg. + Local Hint Resolve log_cap_nonneg. Lemma base_length_lt_pred : (pred (length base) < length base)%nat. Proof. - pose proof base_length_nonzero; omega. + pose proof limb_widths_nonnil; rewrite base_from_limb_widths_length. + destruct limb_widths; congruence || distr_length. Qed. Hint Resolve base_length_lt_pred. + Definition carry_done us := forall i, (i < length base)%nat -> + 0 <= nth_default 0 us i /\ Z.shiftr (nth_default 0 us i) (log_cap i) = 0. + + Lemma carry_done_bounds : forall us, (length us = length base) -> + (carry_done us <-> forall i, 0 <= nth_default 0 us i < 2 ^ log_cap i). + Proof. + intros ? ?; unfold carry_done; split; [ intros Hcarry_done i | intros Hbounds i i_lt ]. + + destruct (lt_dec i (length base)) as [i_lt | i_nlt]. + - specialize (Hcarry_done i i_lt). + split; [ intuition | ]. + destruct Hcarry_done as [Hnth_nonneg Hshiftr_0]. + apply Z.shiftr_eq_0_iff in Hshiftr_0. + destruct Hshiftr_0 as [nth_0 | [] ]; [ rewrite nth_0; zero_bounds | ]. + apply Z.log2_lt_pow2; auto. + - rewrite nth_default_out_of_bounds by omega. + split; zero_bounds. + + specialize (Hbounds i). + split; [ intuition | ]. + destruct Hbounds as [nth_nonneg nth_lt_pow2]. + apply Z.shiftr_eq_0_iff. + apply Z.le_lteq in nth_nonneg; destruct nth_nonneg; try solve [left; auto]. + right; split; auto. + apply Z.log2_lt_pow2; auto. + Qed. + Lemma carry_decode_eq_reduce : forall us, (length us = length limb_widths) -> BaseSystem.decode base (carry_and_reduce (pred (length limb_widths)) us) mod modulus diff --git a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v index 14482fe5e..4b3af84e1 100644 --- a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v +++ b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v @@ -1,9 +1,12 @@ Require Import Zpower ZArith. Require Import List. Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. +Require Import Crypto.ModularArithmetic.ExtendedBaseVector. Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. Require Import VerdiTactics. Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.BaseSystem. +Require Import Crypto.BaseSystemProofs. Require Import Crypto.ModularArithmetic.Pow2Base. Require Import Crypto.ModularArithmetic.Pow2BaseProofs. Require Crypto.BaseSystem. @@ -19,10 +22,77 @@ Section PseudoMersenneBaseParamProofs. Lemma k_nonneg : 0 <= k. Proof. apply sum_firstn_limb_widths_nonneg, limb_widths_nonneg. Qed. + Lemma lt_modulus_2k : modulus < 2 ^ k. + Proof. + replace (2 ^ k) with (modulus + c) by (unfold c; ring). + pose proof c_pos; omega. + Qed. Hint Resolve lt_modulus_2k. + + Lemma modulus_pos : 0 < modulus. + Proof. + pose proof (NumTheoryUtil.lt_1_p _ prime_modulus); omega. + Qed. Hint Resolve modulus_pos. + + Lemma modulus_nonzero : modulus <> 0. + pose proof (Znumtheory.prime_ge_2 _ prime_modulus); omega. + Qed. + + (* a = r + s(2^k) = r + s(2^k - c + c) = r + s(2^k - c) + cs = r + cs *) + Lemma pseudomersenne_add: forall x y, (x + ((2^k) * y)) mod modulus = (x + (c * y)) mod modulus. + Proof. + intros. + replace (2^k) with ((2^k - c) + c) by ring. + rewrite Z.mul_add_distr_r, Zplus_mod. + unfold c. + rewrite Z.sub_sub_distr, Z.sub_diag. + simpl. + rewrite Z.mul_comm, Z.mod_add_l; auto using modulus_nonzero. + rewrite <- Zplus_mod; auto. + Qed. + + Lemma pseudomersenne_add': forall x y0 y1 z, (z - x + ((2^k) * y0 * y1)) mod modulus = (c * y0 * y1 - x + z) mod modulus. + Proof. + intros; rewrite <- !Z.add_opp_r, <- !Z.mul_assoc, pseudomersenne_add; apply f_equal2; omega. + Qed. + + Lemma extended_shiftadd: forall (us : digits), + decode (ext_base limb_widths) us = + decode base (firstn (length base) us) + + (2 ^ k * decode base (skipn (length base) us)). + Proof. + intros. + unfold decode; rewrite <- mul_each_rep. + rewrite ext_base_alt by apply limb_widths_nonneg. + fold k; fold (mul_each (2 ^ k) base). + rewrite base_mul_app. + rewrite <- mul_each_rep; auto. + Qed. + Global Instance bv : BaseSystem.BaseVector base := { base_positive := base_positive limb_widths_nonneg; b0_1 := fun x => b0_1 x limb_widths_nonnil; base_good := base_good limb_widths_nonneg limb_widths_good }. + Lemma nth_default_base_positive : forall i, (i < length base)%nat -> + nth_default 0 base i > 0. + Proof. + intros. + pose proof (nth_error_length_exists_value _ _ H). + destruct H0. + pose proof (nth_error_value_In _ _ _ H0). + pose proof (BaseSystem.base_positive _ H1). + unfold nth_default. + rewrite H0; auto. + Qed. + + Lemma base_succ_div_mult : forall i, ((S i) < length base)%nat -> + nth_default 0 base (S i) = nth_default 0 base i * + (nth_default 0 base (S i) / nth_default 0 base i). + Proof. + intros. + apply Z_div_exact_2; try (apply nth_default_base_positive; omega). + apply base_succ; distr_length; eauto using limb_widths_nonneg. + Qed. + End PseudoMersenneBaseParamProofs. -- cgit v1.2.3