From 3959bc9986391882b3b73acd25e0fba04cdebbd9 Mon Sep 17 00:00:00 2001 From: jadep Date: Sat, 17 Sep 2016 12:33:44 -0400 Subject: Partially flesh out [freeze] proofs; also parameterize [sqrt_5mod8] over implementations of [mul] and [pow] so bounds can be threaded through --- src/ModularArithmetic/ModularBaseSystem.v | 12 +- src/ModularArithmetic/ModularBaseSystemOpt.v | 29 +--- src/ModularArithmetic/ModularBaseSystemProofs.v | 214 ++++++++++++++++++++---- 3 files changed, 194 insertions(+), 61 deletions(-) (limited to 'src/ModularArithmetic') diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index 71851ac5d..1769f86c4 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -77,15 +77,17 @@ Section ModularBaseSystem. (* Note : both of the following square root definitions will produce garbage output if the input is not square mod [modulus]. The caller should either provably only call them with square input, or test that the output squared is in fact equal to the input and case split. *) - Definition sqrt_5mod8 (chain : list (nat * nat)) - (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus / 8 + 1)) - (sqrt_minus1 x : digits) : digits := - let b := pow x chain in if eqb (mul b b) x then b else mul sqrt_minus1 b. - Definition sqrt_3mod4 (chain : list (nat * nat)) (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus / 4 + 1)) (x : digits) : digits := pow x chain. + (* sqrt_5mod8 is parameterized over implementation of [mul] and [pow] because it relies on bounds-checking + for these two functions, which is much easier for simplified implementations than the more generalized + ones defined here. *) + Definition sqrt_5mod8 mul_ pow_ (chain : list (nat * nat)) + (chain_correct : fold_chain 0%N N.add chain (1%N :: nil) = Z.to_N (modulus / 8 + 1)) + (sqrt_minus1 x : digits) : digits := + let b := pow_ x chain in if eqb (mul_ b b) x then b else mul_ sqrt_minus1 b. Import Morphisms. Global Instance eq_Equivalence : Equivalence eq. diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index eda2a584d..6a3a4f7c2 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -965,33 +965,10 @@ Section SquareRoots. Context (sqrt_m1 : digits) (sqrt_m1_correct : rep (mul sqrt_m1 sqrt_m1) (F.opp 1%F)). Definition sqrt_5mod8_opt_sig (us : digits) : - { vs : digits | eq vs (sqrt_5mod8 chain chain_correct sqrt_m1 us)}. + { vs : digits | + eq vs (sqrt_5mod8 (carry_mul_opt k_ c_) (pow_opt k_ c_ one_) chain chain_correct sqrt_m1 us)}. Proof. eexists; cbv [sqrt_5mod8]. - etransitivity. - Focus 2. { - apply if_equiv. { - apply eqb_Proper; [ | reflexivity ]. - transitivity (carry_mul_opt k_ c_ (pow_opt k_ c_ one_ us chain) (pow_opt k_ c_ one_ us chain)); [ reflexivity | ]. - cbv [eq]. - rewrite carry_mul_opt_correct by eassumption. - rewrite carry_mul_rep by reflexivity. - rewrite mul_rep by reflexivity. - f_equal; apply pow_opt_correct; auto. - } { - apply pow_opt_correct; auto. - } { - match goal with |- eq _ (mul ?a (ModularBaseSystem.pow ?d ?e)) => - transitivity (carry_mul_opt k_ c_ a (pow_opt k_ c_ one_ us chain)); [ reflexivity | ] end. - cbv [eq]. - rewrite !mul_rep by reflexivity. - erewrite <-pow_opt_correct by eassumption. - rewrite <-carry_mul_rep by reflexivity. - erewrite <-carry_mul_opt_correct by eassumption. - reflexivity. - } - } Unfocus. - rewrite k_subst, c_subst, one_subst. let LHS := match goal with |- eq ?LHS ?RHS => LHS end in let RHS := match goal with |- eq ?LHS ?RHS => RHS end in let RHSf := match (eval pattern (pow_opt k_ c_ one_ us chain) in RHS) with ?RHSf _ => RHSf end in @@ -1003,7 +980,7 @@ Section SquareRoots. proj1_sig (sqrt_5mod8_opt_sig us). Definition sqrt_5mod8_opt_correct us - : eq (sqrt_5mod8_opt us) (ModularBaseSystem.sqrt_5mod8 chain chain_correct sqrt_m1 us) + : eq (sqrt_5mod8_opt us) (ModularBaseSystem.sqrt_5mod8 _ _ chain chain_correct sqrt_m1 us) := Eval cbv [proj2_sig sqrt_5mod8_opt_sig] in proj2_sig (sqrt_5mod8_opt_sig us). End SquareRoot5mod8. diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index 008a3bc6d..ae5db0fc2 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -177,6 +177,19 @@ Section FieldOperationProofs. pose proof prime_modulus; prime_bound. Qed. + Lemma encode_range : forall x, + 0 <= BaseSystem.decode base (to_list (encode x)) < modulus. + Proof. + cbv [encode]; intros. + rewrite to_list_from_list. + rewrite encode_eq. + rewrite BaseSystemProofs.encode_rep; auto using F.to_Z_range, modulus_pos, bv. + + pose proof (F.to_Z_range x modulus_pos). + replace (2 ^ k) with (modulus + c) by (cbv[c]; ring). + pose proof c_pos; omega. + + apply base_upper_bound_compatible; auto. + Qed. + Lemma add_rep : forall u v x y, u ~= x -> v ~= y -> add u v ~= (x+y)%F. Proof. @@ -521,6 +534,17 @@ Section CarryProofs. Qed. Hint Resolve carry_rep. + Lemma decode_mod_Fdecode : forall u, length u = length limb_widths -> + BaseSystem.decode base u mod modulus= F.to_Z (decode (from_list_default 0 _ u)). + Proof. + intros. + rewrite <-(to_list_from_list _ u) with (pf := H). + erewrite Fdecode_decode_mod by reflexivity. + rewrite to_list_from_list. + rewrite from_list_default_eq with (pf := H). + reflexivity. + Qed. + Lemma carry_sequence_rep : forall is us x, (forall i, In i is -> (i < length limb_widths)%nat) -> us ~= x -> forall pf, from_list _ (carry_sequence is (to_list _ us)) pf ~= x. @@ -849,26 +873,62 @@ Section CanonicalizationProofs. auto using length_carry_full, bound_after_second_loop. Qed. - (* TODO(jadep): - - Proof of correctness for [ge_modulus] and [cond_subtract_modulus] - - Proof of correctness for [freeze] - * freeze us = encode (decode us) - * decode us = x -> - canonicalized_BSToWord (freeze us)) = FToWord x + Lemma ge_modulus_spec : forall u, length u = length limb_widths -> + (ge_modulus u = false <-> 0 <= BaseSystem.decode base u < modulus). + Proof. + Admitted. - (where [canonicalized_BSToWord] uses bitwise operations to concatenate digits - in BaseSystem in canonical form, splitting along word capacities) - *) + Lemma conditional_subtract_modulus_spec : forall u cond, + length u = length limb_widths -> + BaseSystem.decode base (conditional_subtract_modulus u cond) = + BaseSystem.decode base u - (if cond then 1 else 0) * modulus. + Proof. + Admitted. + + Lemma conditional_subtract_modulus_preserves_bounded : forall u, + bounded limb_widths u -> + bounded limb_widths (conditional_subtract_modulus u (ge_modulus u)). + Proof. + Admitted. + + Lemma conditional_subtract_lt_modulus : forall u, + bounded limb_widths u -> + ge_modulus (conditional_subtract_modulus u (ge_modulus u)) = false. + Proof. + Admitted. (* bounded canonical -> freeze bounded -> freeze canonical *) - Import SetoidList. + Import SetoidList. (* TODO : move to Tuple *) Lemma fieldwise_to_list_iff : forall {T n} R (s t : tuple T n), - (fieldwise R s t <-> Forall2 R (to_list _ s) (to_list _ t)). - Admitted. + (fieldwise R s t <-> Forall2 R (to_list _ s) (to_list _ t)). + Proof. + induction n; split; intros. + + constructor. + + cbv [fieldwise]. auto. + + destruct n; cbv [tuple to_list fieldwise] in *. + - cbv [to_list']; auto. + - simpl in *. destruct s,t; cbv [fst snd] in *. + constructor; intuition auto. + apply IHn; auto. + + destruct n; cbv [tuple to_list fieldwise] in *. + - cbv [fieldwise']; auto. + cbv [to_list'] in *; inversion H; auto. + - simpl in *. destruct s,t; cbv [fst snd] in *. + inversion H; subst. + split; try assumption. + apply IHn; auto. + Qed. - (* convenience notation -- [bounded] for digits rather than lists *) + + Local Notation initial_bounds u := + (forall n : nat, + 0 <= to_list (length limb_widths) u [n] < + 2 ^ B - + (if PeanoNat.Nat.eq_dec n 0 + then 0 + else (2 ^ B) >> (limb_widths [Init.Nat.pred n]))). Local Notation minimal_rep u := ((bounded limb_widths (to_list (length limb_widths) u)) /\ (ge_modulus (to_list _ u) = false)). Import Morphisms. @@ -920,12 +980,11 @@ Section CanonicalizationProofs. Qed. Lemma minimal_rep_encode : forall x, minimal_rep (encode x). - Admitted. - - Lemma ge_modulus_spec : forall u, length u = length limb_widths -> - ge_modulus u = false -> - 0 <= BaseSystem.decode base u < modulus. - Admitted. + Proof. + split; intros; auto using bounded_encode. + apply ge_modulus_spec; auto using length_to_list. + apply encode_range. + Qed. Lemma encode_minimal_rep : forall u x, rep u x -> minimal_rep u -> fieldwise Logic.eq u (encode x). @@ -958,14 +1017,64 @@ Section CanonicalizationProofs. congruence. Qed. - Lemma minimal_rep_freeze : forall u, minimal_rep (freeze u). - Admitted. + Lemma minimal_rep_freeze : forall u, initial_bounds u -> + minimal_rep (freeze u). + Proof. + repeat match goal with + | |- _ => progress (cbv [freeze ModularBaseSystemList.freeze]) + | |- _ => progress intros + | |- minimal_rep _ => split + | |- _ => rewrite to_list_from_list + | |- _ => apply bound_after_third_loop + | |- _ => apply conditional_subtract_lt_modulus + | |- _ => apply conditional_subtract_modulus_preserves_bounded + | |- bounded _ (carry_full _) => apply bounded_iff + | |- _ => solve [auto using length_to_list] + end. + Qed. + + Lemma freeze_decode : forall u, + BaseSystem.decode base (to_list _ (freeze u)) mod modulus = + BaseSystem.decode base (to_list _ u) mod modulus. + Proof. + repeat match goal with + | |- _ => progress cbv [freeze ModularBaseSystemList.freeze] + | |- _ => progress intros + | |- _ => rewrite <-Z.add_opp_r, <-Z.mul_opp_l + | |- _ => rewrite Z.mod_add by (pose proof prime_modulus; prime_bound) + | |- _ => rewrite to_list_from_list + | |- _ => rewrite conditional_subtract_modulus_spec by + auto using length_carry_full, length_to_list + end. + rewrite !decode_mod_Fdecode by auto using length_carry_full, length_to_list. + cbv [carry_full]. + apply F.eq_to_Z_iff. + rewrite <-@to_list_from_list with (pf := length_carry_sequence (length_carry_sequence (length_to_list _))). + rewrite from_list_default_eq with (pf := length_carry_sequence (length_to_list _)). + rewrite carry_sequence_rep; try reflexivity; try apply make_chain_lt. + cbv [rep]. + rewrite <-from_list_default_eq with (d := 0). + erewrite <-to_list_from_list with (pf := length_carry_sequence (length_to_list _)). + rewrite from_list_default_eq with (pf := length_carry_sequence (length_to_list _)). + rewrite carry_sequence_rep; try reflexivity; try apply make_chain_lt. + cbv [rep]. + rewrite carry_sequence_rep; try reflexivity; try apply make_chain_lt. + rewrite from_list_default_eq with (pf := length_to_list _). + rewrite from_list_to_list; reflexivity. + Qed. Lemma freeze_rep : forall u x, rep u x -> rep (freeze u) x. - Admitted. + Proof. + cbv [rep]; intros. + apply F.eq_to_Z_iff. + erewrite <-!Fdecode_decode_mod by eauto. + apply freeze_decode. + Qed. Lemma freeze_canonical : forall u v x y, rep u x -> rep v y -> - (x = y <-> fieldwise Logic.eq (freeze u) (freeze v)). + initial_bounds u -> + initial_bounds v -> + (x = y <-> fieldwise Logic.eq (freeze u) (freeze v)). Proof. intros; apply bounded_canonical; auto using freeze_rep, minimal_rep_freeze. Qed. @@ -979,14 +1088,37 @@ Section SquareRootProofs. Local Notation base := (base_from_limb_widths limb_widths). Local Hint Resolve (@limb_widths_nonneg _ prm) sum_firstn_limb_widths_nonneg. - Lemma eqb_correct : forall u v x y, u ~= x -> v ~= y -> - (x = y <-> eqb u v = true). + Definition freeze_input_bounds n := + (2 ^ B - + (if PeanoNat.Nat.eq_dec n 0 + then 0 + else (2 ^ B) >> (nth_default 0 limb_widths (Init.Nat.pred n)))). + Definition bounded_by u bounds := + (forall n : nat, + 0 <= nth_default 0 (to_list (length limb_widths) u) n < bounds n). + + Lemma eqb_true_iff : forall u v x y, + bounded_by u freeze_input_bounds -> bounded_by v freeze_input_bounds -> + u ~= x -> v ~= y -> (x = y <-> eqb u v = true). Proof. - cbv [eqb]. intros. + cbv [eqb freeze_input_bounds]. intros. rewrite fieldwiseb_fieldwise by (apply Z.eqb_eq). eauto using freeze_canonical. Qed. + Lemma eqb_false_iff : forall u v x y, + bounded_by u freeze_input_bounds -> bounded_by v freeze_input_bounds -> + u ~= x -> v ~= y -> (x <> y <-> eqb u v = false). + Proof. + intros. + case_eq (eqb u v). + + rewrite <-eqb_true_iff by eassumption; intros; split; congruence. + + split; intros; auto. + intro Hfalse_eq; + rewrite (eqb_true_iff u v) in Hfalse_eq by eassumption. + congruence. + Qed. + Section Sqrt3mod4. Context (modulus_3mod4 : modulus mod 4 = 3). Context {ec : ExponentiationChain (modulus / 4 + 1)}. @@ -1007,9 +1139,21 @@ Section SquareRootProofs. Context (modulus_5mod8 : modulus mod 8 = 5). Context {ec : ExponentiationChain (modulus / 8 + 1)}. Context (sqrt_m1 : digits) (sqrt_m1_correct : mul sqrt_m1 sqrt_m1 ~= F.opp 1%F). + Context (mul_ : digits -> digits -> digits) + (mul_equiv : forall x y, mul_ x y = mul x y) + {mul_input_bounds : nat -> Z} + (mul_bounded : forall x y, bounded_by x mul_input_bounds -> + bounded_by y mul_input_bounds -> + bounded_by (mul_ x y) freeze_input_bounds). + Context (pow_ : digits -> list (nat * nat) -> digits) + (pow_equiv : forall x is, pow_ x is = pow x is) + {pow_input_bounds : nat -> Z} + (pow_bounded : forall x is, bounded_by x pow_input_bounds -> + bounded_by (pow_ x is) mul_input_bounds). Lemma sqrt_5mod8_correct : forall u x, u ~= x -> - (sqrt_5mod8 chain chain_correct sqrt_m1 u) ~= F.sqrt_5mod8 (decode sqrt_m1) x. + bounded_by u pow_input_bounds -> bounded_by u freeze_input_bounds -> + (sqrt_5mod8 mul_ pow_ chain chain_correct sqrt_m1 u) ~= F.sqrt_5mod8 (decode sqrt_m1) x. Proof. repeat match goal with | |- _ => progress (cbv [sqrt_5mod8 F.sqrt_5mod8]; intros) @@ -1017,9 +1161,19 @@ Section SquareRootProofs. | |- _ => rewrite eqb_correct in * by eassumption | |- (if eqb ?a ?b then _ else _) ~= (if dec (?c = _) then _ else _) => - assert (a ~= c); repeat break_if; try apply mul_rep; - try solve [rewrite <-chain_correct; apply pow_rep; eassumption] - | |- _ => congruence + assert (a ~= c); rewrite !mul_equiv, pow_equiv in *; + repeat break_if + | |- _ => apply mul_rep; try reflexivity; + rewrite <-chain_correct; apply pow_rep; eassumption + | |- _ => rewrite <-chain_correct; apply pow_rep; eassumption + | H : eqb ?a ?b = true |- _ => + rewrite <-(eqb_true_iff a b) in Heqb + by (eassumption || rewrite <-mul_equiv, <-pow_equiv; + apply mul_bounded, pow_bounded; auto); congruence + | H : eqb ?a ?b = false |- _ => + rewrite <-(eqb_false_iff a b) in Heqb + by (eassumption || rewrite <-mul_equiv, <-pow_equiv; + apply mul_bounded, pow_bounded; auto); congruence end. Qed. End Sqrt5mod8. -- cgit v1.2.3