diff options
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystemOpt.v')
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemOpt.v | 274 |
1 files changed, 241 insertions, 33 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index 981680b4a..116fe10e5 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -27,7 +27,12 @@ Definition Z_shiftl_by_opt := Eval compute in Z_shiftl_by. Definition nth_default_opt {A} := Eval compute in @nth_default A. Definition set_nth_opt {A} := Eval compute in @set_nth A. Definition map_opt {A B} := Eval compute in @map A B. -Definition base_from_limb_widths_opt := Eval compute in base_from_limb_widths. +Definition full_carry_chain_opt := Eval compute in @full_carry_chain. +Definition length_opt := Eval compute in length. +Definition base_opt := Eval compute in @base. +Definition max_ones_opt := Eval compute in @max_ones. +Definition max_bound_opt := Eval compute in @max_bound. +Definition minus_opt := Eval compute in minus. Definition Let_In {A P} (x : A) (f : forall y : A, P y) := let y := x in f y. @@ -71,18 +76,22 @@ Ltac construct_params prime_modulus len k := | abstract apply prime_modulus | abstract brute_force_indices lw]. -Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits := +Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits := match limb_widths with | nil => nil | x :: tail => 2 ^ (x + 1) - (2 * c) :: map (fun w => 2 ^ (w + 1) - 2) tail end. +Ltac compute_preconditions := + cbv; intros; repeat match goal with H : _ \/ _ |- _ => + destruct H; subst; [ congruence | ] end; (congruence || omega). + Ltac subst_precondition := match goal with | [H : ?P, H' : ?P -> _ |- _] => specialize (H' H); clear H end. -Ltac kill_precondition H := +Ltac kill_precondition H := forward H; [abstract (try exact eq_refl; clear; cbv; intros; repeat break_or_hyp; intuition)|]; subst_precondition. @@ -95,8 +104,7 @@ Ltac compute_formula := let p := fresh "p" in set (p := P) in H at 1; change P with p at 1; let r := fresh "r" in set (r := result) in H |- *; cbv -[m p r PseudoMersenneBaseRep.rep] in H; - repeat rewrite ?Z.mul_1_r, ?Z.add_0_l, ?Z.add_assoc, ?Z.mul_assoc in H; - exact H + repeat rewrite ?Z.mul_1_r, ?Z.add_0_l, ?Z.add_assoc, ?Z.mul_assoc in H end. Section Carries. @@ -113,8 +121,9 @@ Section Carries. rewrite <- pull_app_if_sumbool. cbv beta delta [carry carry_and_reduce carry_simple add_to_nth log_cap - pow2_mod Z.ones Z.pred base + pow2_mod Z.ones Z.pred PseudoMersenneBaseParams.limb_widths]. + change @base with @base_opt. change @nth_default with @nth_default_opt in *. change @set_nth with @set_nth_opt in *. lazymatch goal with @@ -129,7 +138,6 @@ Section Carries. change @set_nth with @set_nth_opt. change @map with @map_opt. rewrite <- @beq_nat_eq_nat_dec. - change base_from_limb_widths with base_from_limb_widths_opt. reflexivity. Defined. @@ -179,7 +187,7 @@ Section Carries. change (LHS = Let_In (nth_default_opt 0%Z b i) RHSf). change Z.shiftl with Z_shiftl_opt. change (-1) with (Z_opp_opt 1). - change Z.add with Z_add_opt at 8 12 20 24. + change Z.add with Z_add_opt at 5 9 17 21. reflexivity. Defined. @@ -191,6 +199,39 @@ Section Carries. @carry_opt_cps T i f b = f (carry i b) := proj2_sig (carry_opt_cps_sig i f b). + Definition carry_sequence_opt_cps2_sig {T} (is : list nat) (us : digits) + (f : digits -> T) + : { b : T | (forall i, In i is -> i < length base)%nat -> b = f (carry_sequence is us) }. + Proof. + eexists. + cbv [carry_sequence]. + transitivity (fold_right carry_opt_cps f (List.rev is) us). + Focus 2. + { + assert (forall i, In i (rev is) -> i < length base)%nat as Hr. { + subst. intros. rewrite <- in_rev in *. auto. } + remember (rev is) as ris eqn:Heq. + rewrite <- (rev_involutive is), <- Heq. + clear H Heq is. + rewrite fold_left_rev_right. + revert us; induction ris; [ reflexivity | ]; intros. + { simpl. + rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption]. + rewrite carry_opt_cps_correct; [reflexivity|]. + apply Hr; left; reflexivity. + } } + Unfocus. + reflexivity. + Defined. + + Definition carry_sequence_opt_cps2 {T} is us (f : digits -> T) := + Eval cbv [proj1_sig carry_sequence_opt_cps2_sig] in + proj1_sig (carry_sequence_opt_cps2_sig is us f). + + Definition carry_sequence_opt_cps2_correct {T} is us (f : digits -> T) + : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps2 is us f = f (carry_sequence is us) + := proj2_sig (carry_sequence_opt_cps2_sig is us f). + Definition carry_sequence_opt_cps_sig (is : list nat) (us : digits) : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }. Proof. @@ -198,7 +239,7 @@ Section Carries. cbv [carry_sequence]. transitivity (fold_right carry_opt_cps id (List.rev is) us). Focus 2. - { + { assert (forall i, In i (rev is) -> i < length base)%nat as Hr. { subst. intros. rewrite <- in_rev in *. auto. } remember (rev is) as ris eqn:Heq. @@ -226,14 +267,55 @@ Section Carries. Lemma carry_sequence_opt_cps_rep : forall (is : list nat) (us : list Z) (x : F modulus), (forall i : nat, In i is -> i < length base)%nat -> - length us = length base -> rep us x -> rep (carry_sequence_opt_cps is us) x. Proof. intros. rewrite carry_sequence_opt_cps_correct by assumption. - apply carry_sequence_rep; assumption. + apply carry_sequence_rep; eauto using rep_length. Qed. + Lemma full_carry_chain_bounds : forall i, In i full_carry_chain -> (i < length base)%nat. + Proof. + unfold full_carry_chain; rewrite <-base_length; intros. + apply make_chain_lt; auto. + Qed. + + Definition carry_full_opt_sig (us : digits) : { b : digits | b = carry_full us }. + Proof. + eexists. + cbv [carry_full]. + change @full_carry_chain with full_carry_chain_opt. + rewrite <-carry_sequence_opt_cps_correct by (auto; apply full_carry_chain_bounds). + reflexivity. + Defined. + + Definition carry_full_opt (us : digits) : digits + := Eval cbv [proj1_sig carry_full_opt_sig] in proj1_sig (carry_full_opt_sig us). + + Definition carry_full_opt_correct us : carry_full_opt us = carry_full us := + proj2_sig (carry_full_opt_sig us). + + Definition carry_full_opt_cps_sig + {T} + (f : digits -> T) + (us : digits) + : { d : T | d = f (carry_full us) }. + Proof. + eexists. + rewrite <- carry_full_opt_correct. + cbv beta iota delta [carry_full_opt]. + rewrite carry_sequence_opt_cps_correct by apply full_carry_chain_bounds. + rewrite <-carry_sequence_opt_cps2_correct by apply full_carry_chain_bounds. + reflexivity. + Defined. + + Definition carry_full_opt_cps {T} (f : digits -> T) (us : digits) : T + := Eval cbv [proj1_sig carry_full_opt_cps_sig] in proj1_sig (carry_full_opt_cps_sig f us). + + Definition carry_full_opt_cps_correct {T} us (f : digits -> T) : + carry_full_opt_cps f us = f (carry_full us) := + proj2_sig (carry_full_opt_cps_sig f us). + End Carries. Section Addition. @@ -416,12 +498,11 @@ Section Multiplication. eexists. cbv [BaseSystem.mul mul mul_each mul_bi mul_bi' zeros ext_base reduce]. rewrite <- mul'_opt_correct. - cbv [base PseudoMersenneBaseParams.limb_widths]. + change @base with base_opt. rewrite map_shiftl by apply k_nonneg. rewrite c_subst. rewrite k_subst. change @map with @map_opt. - change base_from_limb_widths with base_from_limb_widths_opt. change @Z_shiftl_by with @Z_shiftl_by_opt. reflexivity. Defined. @@ -433,31 +514,158 @@ Section Multiplication. : mul_opt us vs = mul us vs := proj2_sig (mul_opt_sig us vs). - Lemma mul_opt_rep: + Definition carry_mul_opt_sig (us vs : T) : { b : digits | b = carry_mul us vs }. + Proof. + eexists. + cbv [carry_mul]. + erewrite <-carry_full_opt_correct by eauto. + erewrite <-mul_opt_correct. + reflexivity. + Defined. + + Definition carry_mul_opt (us vs : T) : digits + := Eval cbv [proj1_sig carry_mul_opt_sig] in proj1_sig (carry_mul_opt_sig us vs). + + Definition carry_mul_opt_correct us vs + : carry_mul_opt us vs = carry_mul us vs + := proj2_sig (carry_mul_opt_sig us vs). + + Lemma carry_mul_opt_rep: forall (u v : T) (x y : F modulus), PseudoMersenneBaseRep.rep u x -> PseudoMersenneBaseRep.rep v y -> - PseudoMersenneBaseRep.rep (mul_opt u v) (x * y)%F. + PseudoMersenneBaseRep.rep (carry_mul_opt u v) (x * y)%F. Proof. intros. - rewrite mul_opt_correct. - change mul with PseudoMersenneBaseRep.mul. + rewrite carry_mul_opt_correct. + change carry_mul with PseudoMersenneBaseRep.mul. auto using PseudoMersenneBaseRep.mul_rep. Qed. - Definition carry_mul_opt - (is : list nat) - (us vs : list Z) - : list Z - := carry_sequence_opt_cps c_ is (mul_opt us vs). - - Lemma carry_mul_opt_correct - : forall (is : list nat) (us vs : list Z) (x y: F modulus), - PseudoMersenneBaseRep.rep us x -> PseudoMersenneBaseRep.rep vs y -> - (forall i : nat, In i is -> i < length base)%nat -> - length (mul_opt us vs) = length base -> - PseudoMersenneBaseRep.rep (carry_mul_opt is us vs) (x*y)%F. +End Multiplication. + +Record freezePreconditions {modulus} (prm : PseudoMersenneBaseParams modulus) int_width := +mkFreezePreconditions { + lt_1_length_base : (1 < length base)%nat; + int_width_pos : 0 < int_width; + int_width_compat : forall w, In w limb_widths -> w <= int_width; + c_pos : 0 < c; + c_reduce1 : c * (Z.ones (int_width - log_cap (pred (length base)))) < max_bound 0 + 1; + c_reduce2 : c <= max_bound 0 - c; + two_pow_k_le_2modulus : 2 ^ k <= 2 * modulus +}. +Local Hint Resolve lt_1_length_base int_width_pos int_width_compat c_pos + c_reduce1 c_reduce2 two_pow_k_le_2modulus. + +Section Canonicalization. + Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm} + (* allows caller to precompute k and c *) + (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_) + {int_width} (preconditions : freezePreconditions prm int_width). + + Definition modulus_digits_opt_sig : + { b : digits | b = modulus_digits }. + Proof. + eexists. + cbv beta iota delta [modulus_digits modulus_digits' app]. + change @max_bound with max_bound_opt. + rewrite c_subst. + change length with length_opt. + change minus with minus_opt. + change Z.add with Z_add_opt. + change Z.sub with Z_sub_opt. + change @base with base_opt. + reflexivity. + Defined. + + Definition modulus_digits_opt : digits + := Eval cbv [proj1_sig modulus_digits_opt_sig] in proj1_sig (modulus_digits_opt_sig). + + Definition modulus_digits_opt_correct + : modulus_digits_opt = modulus_digits + := proj2_sig (modulus_digits_opt_sig). + + + Definition carry_full_3_opt_cps_sig + {T} (f : digits -> T) + (us : digits) + : { d : T | d = f (carry_full (carry_full (carry_full us))) }. + Proof. + eexists. + transitivity (carry_full_opt_cps c_ (carry_full_opt_cps c_ (carry_full_opt_cps c_ f)) us). + Focus 2. { + rewrite !carry_full_opt_cps_correct by assumption; reflexivity. + } + Unfocus. + reflexivity. + Defined. + + Definition carry_full_3_opt_cps {T} (f : digits -> T) (us : digits) : T + := Eval cbv [proj1_sig carry_full_3_opt_cps_sig] in proj1_sig (carry_full_3_opt_cps_sig f us). + + Definition carry_full_3_opt_cps_correct {T} (f : digits -> T) us : + carry_full_3_opt_cps f us = f (carry_full (carry_full (carry_full us))) := + proj2_sig (carry_full_3_opt_cps_sig f us). + + Definition freeze_opt_sig (us : T) : + { b : digits | b = freeze us }. Proof. - intros is us vs x y; intros. - change (carry_mul_opt _ _ _) with (carry_sequence_opt_cps c_ is (mul_opt us vs)). - apply carry_sequence_opt_cps_rep, mul_opt_rep; auto. + eexists. + cbv [freeze]. + cbv [and_term]. + let LHS := match goal with |- ?LHS = ?RHS => LHS end in + let RHS := match goal with |- ?LHS = ?RHS => RHS end in + let RHSf := match (eval pattern (isFull (carry_full (carry_full (carry_full us)))) in RHS) with ?RHSf _ => RHSf end in + change (LHS = Let_In (isFull(carry_full (carry_full (carry_full us)))) RHSf). + let LHS := match goal with |- ?LHS = ?RHS => LHS end in + let RHS := match goal with |- ?LHS = ?RHS => RHS end in + let RHSf := match (eval pattern (carry_full (carry_full (carry_full us))) in RHS) with ?RHSf _ => RHSf end in + rewrite <-carry_full_3_opt_cps_correct with (f := RHSf). + cbv beta iota delta [and_term isFull isFull']. + change length with length_opt. + change @max_bound with max_bound_opt. + rewrite c_subst. + change @max_ones with max_ones_opt. + change @base with base_opt. + change minus with minus_opt. + change @map with @map_opt. + change Z.sub with Z_sub_opt at 1. + rewrite <-modulus_digits_opt_correct. + reflexivity. + Defined. + + Definition freeze_opt (us : T) : digits + := Eval cbv beta iota delta [proj1_sig freeze_opt_sig] in proj1_sig (freeze_opt_sig us). + + Definition freeze_opt_correct us + : freeze_opt us = freeze us + := proj2_sig (freeze_opt_sig us). + + Lemma freeze_opt_canonical: forall us vs x, + @pre_carry_bounds _ _ int_width us -> PseudoMersenneBaseRep.rep us x -> + @pre_carry_bounds _ _ int_width vs -> PseudoMersenneBaseRep.rep vs x -> + freeze_opt us = freeze_opt vs. + Proof. + intros. + rewrite !freeze_opt_correct. + change PseudoMersenneBaseRep.rep with rep in *. + eapply freeze_canonical with (B := int_width); eauto. Qed. -End Multiplication.
\ No newline at end of file + + Lemma freeze_opt_preserves_rep : forall us x, PseudoMersenneBaseRep.rep us x -> + PseudoMersenneBaseRep.rep (freeze_opt us) x. + Proof. + intros. + rewrite freeze_opt_correct. + change PseudoMersenneBaseRep.rep with rep in *. + eapply freeze_preserves_rep; eauto. + Qed. + + Lemma freeze_opt_spec : forall us vs x, rep us x -> rep vs x -> + @pre_carry_bounds _ _ int_width us -> + @pre_carry_bounds _ _ int_width vs -> + (PseudoMersenneBaseRep.rep (freeze_opt us) x /\ freeze_opt us = freeze_opt vs). + Proof. + split; eauto using freeze_opt_canonical. + auto using freeze_opt_preserves_rep. + Qed. + +End Canonicalization.
\ No newline at end of file |