aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ModularBaseSystemOpt.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystemOpt.v')
-rw-r--r--src/ModularArithmetic/ModularBaseSystemOpt.v274
1 files changed, 241 insertions, 33 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v
index 981680b4a..116fe10e5 100644
--- a/src/ModularArithmetic/ModularBaseSystemOpt.v
+++ b/src/ModularArithmetic/ModularBaseSystemOpt.v
@@ -27,7 +27,12 @@ Definition Z_shiftl_by_opt := Eval compute in Z_shiftl_by.
Definition nth_default_opt {A} := Eval compute in @nth_default A.
Definition set_nth_opt {A} := Eval compute in @set_nth A.
Definition map_opt {A B} := Eval compute in @map A B.
-Definition base_from_limb_widths_opt := Eval compute in base_from_limb_widths.
+Definition full_carry_chain_opt := Eval compute in @full_carry_chain.
+Definition length_opt := Eval compute in length.
+Definition base_opt := Eval compute in @base.
+Definition max_ones_opt := Eval compute in @max_ones.
+Definition max_bound_opt := Eval compute in @max_bound.
+Definition minus_opt := Eval compute in minus.
Definition Let_In {A P} (x : A) (f : forall y : A, P y)
:= let y := x in f y.
@@ -71,18 +76,22 @@ Ltac construct_params prime_modulus len k :=
| abstract apply prime_modulus
| abstract brute_force_indices lw].
-Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits :=
+Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits :=
match limb_widths with
| nil => nil
| x :: tail =>
2 ^ (x + 1) - (2 * c) :: map (fun w => 2 ^ (w + 1) - 2) tail
end.
+Ltac compute_preconditions :=
+ cbv; intros; repeat match goal with H : _ \/ _ |- _ =>
+ destruct H; subst; [ congruence | ] end; (congruence || omega).
+
Ltac subst_precondition := match goal with
| [H : ?P, H' : ?P -> _ |- _] => specialize (H' H); clear H
end.
-Ltac kill_precondition H :=
+Ltac kill_precondition H :=
forward H; [abstract (try exact eq_refl; clear; cbv; intros; repeat break_or_hyp; intuition)|];
subst_precondition.
@@ -95,8 +104,7 @@ Ltac compute_formula :=
let p := fresh "p" in set (p := P) in H at 1; change P with p at 1;
let r := fresh "r" in set (r := result) in H |- *;
cbv -[m p r PseudoMersenneBaseRep.rep] in H;
- repeat rewrite ?Z.mul_1_r, ?Z.add_0_l, ?Z.add_assoc, ?Z.mul_assoc in H;
- exact H
+ repeat rewrite ?Z.mul_1_r, ?Z.add_0_l, ?Z.add_assoc, ?Z.mul_assoc in H
end.
Section Carries.
@@ -113,8 +121,9 @@ Section Carries.
rewrite <- pull_app_if_sumbool.
cbv beta delta
[carry carry_and_reduce carry_simple add_to_nth log_cap
- pow2_mod Z.ones Z.pred base
+ pow2_mod Z.ones Z.pred
PseudoMersenneBaseParams.limb_widths].
+ change @base with @base_opt.
change @nth_default with @nth_default_opt in *.
change @set_nth with @set_nth_opt in *.
lazymatch goal with
@@ -129,7 +138,6 @@ Section Carries.
change @set_nth with @set_nth_opt.
change @map with @map_opt.
rewrite <- @beq_nat_eq_nat_dec.
- change base_from_limb_widths with base_from_limb_widths_opt.
reflexivity.
Defined.
@@ -179,7 +187,7 @@ Section Carries.
change (LHS = Let_In (nth_default_opt 0%Z b i) RHSf).
change Z.shiftl with Z_shiftl_opt.
change (-1) with (Z_opp_opt 1).
- change Z.add with Z_add_opt at 8 12 20 24.
+ change Z.add with Z_add_opt at 5 9 17 21.
reflexivity.
Defined.
@@ -191,6 +199,39 @@ Section Carries.
@carry_opt_cps T i f b = f (carry i b)
:= proj2_sig (carry_opt_cps_sig i f b).
+ Definition carry_sequence_opt_cps2_sig {T} (is : list nat) (us : digits)
+ (f : digits -> T)
+ : { b : T | (forall i, In i is -> i < length base)%nat -> b = f (carry_sequence is us) }.
+ Proof.
+ eexists.
+ cbv [carry_sequence].
+ transitivity (fold_right carry_opt_cps f (List.rev is) us).
+ Focus 2.
+ {
+ assert (forall i, In i (rev is) -> i < length base)%nat as Hr. {
+ subst. intros. rewrite <- in_rev in *. auto. }
+ remember (rev is) as ris eqn:Heq.
+ rewrite <- (rev_involutive is), <- Heq.
+ clear H Heq is.
+ rewrite fold_left_rev_right.
+ revert us; induction ris; [ reflexivity | ]; intros.
+ { simpl.
+ rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption].
+ rewrite carry_opt_cps_correct; [reflexivity|].
+ apply Hr; left; reflexivity.
+ } }
+ Unfocus.
+ reflexivity.
+ Defined.
+
+ Definition carry_sequence_opt_cps2 {T} is us (f : digits -> T) :=
+ Eval cbv [proj1_sig carry_sequence_opt_cps2_sig] in
+ proj1_sig (carry_sequence_opt_cps2_sig is us f).
+
+ Definition carry_sequence_opt_cps2_correct {T} is us (f : digits -> T)
+ : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps2 is us f = f (carry_sequence is us)
+ := proj2_sig (carry_sequence_opt_cps2_sig is us f).
+
Definition carry_sequence_opt_cps_sig (is : list nat) (us : digits)
: { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }.
Proof.
@@ -198,7 +239,7 @@ Section Carries.
cbv [carry_sequence].
transitivity (fold_right carry_opt_cps id (List.rev is) us).
Focus 2.
- {
+ {
assert (forall i, In i (rev is) -> i < length base)%nat as Hr. {
subst. intros. rewrite <- in_rev in *. auto. }
remember (rev is) as ris eqn:Heq.
@@ -226,14 +267,55 @@ Section Carries.
Lemma carry_sequence_opt_cps_rep
: forall (is : list nat) (us : list Z) (x : F modulus),
(forall i : nat, In i is -> i < length base)%nat ->
- length us = length base ->
rep us x -> rep (carry_sequence_opt_cps is us) x.
Proof.
intros.
rewrite carry_sequence_opt_cps_correct by assumption.
- apply carry_sequence_rep; assumption.
+ apply carry_sequence_rep; eauto using rep_length.
Qed.
+ Lemma full_carry_chain_bounds : forall i, In i full_carry_chain -> (i < length base)%nat.
+ Proof.
+ unfold full_carry_chain; rewrite <-base_length; intros.
+ apply make_chain_lt; auto.
+ Qed.
+
+ Definition carry_full_opt_sig (us : digits) : { b : digits | b = carry_full us }.
+ Proof.
+ eexists.
+ cbv [carry_full].
+ change @full_carry_chain with full_carry_chain_opt.
+ rewrite <-carry_sequence_opt_cps_correct by (auto; apply full_carry_chain_bounds).
+ reflexivity.
+ Defined.
+
+ Definition carry_full_opt (us : digits) : digits
+ := Eval cbv [proj1_sig carry_full_opt_sig] in proj1_sig (carry_full_opt_sig us).
+
+ Definition carry_full_opt_correct us : carry_full_opt us = carry_full us :=
+ proj2_sig (carry_full_opt_sig us).
+
+ Definition carry_full_opt_cps_sig
+ {T}
+ (f : digits -> T)
+ (us : digits)
+ : { d : T | d = f (carry_full us) }.
+ Proof.
+ eexists.
+ rewrite <- carry_full_opt_correct.
+ cbv beta iota delta [carry_full_opt].
+ rewrite carry_sequence_opt_cps_correct by apply full_carry_chain_bounds.
+ rewrite <-carry_sequence_opt_cps2_correct by apply full_carry_chain_bounds.
+ reflexivity.
+ Defined.
+
+ Definition carry_full_opt_cps {T} (f : digits -> T) (us : digits) : T
+ := Eval cbv [proj1_sig carry_full_opt_cps_sig] in proj1_sig (carry_full_opt_cps_sig f us).
+
+ Definition carry_full_opt_cps_correct {T} us (f : digits -> T) :
+ carry_full_opt_cps f us = f (carry_full us) :=
+ proj2_sig (carry_full_opt_cps_sig f us).
+
End Carries.
Section Addition.
@@ -416,12 +498,11 @@ Section Multiplication.
eexists.
cbv [BaseSystem.mul mul mul_each mul_bi mul_bi' zeros ext_base reduce].
rewrite <- mul'_opt_correct.
- cbv [base PseudoMersenneBaseParams.limb_widths].
+ change @base with base_opt.
rewrite map_shiftl by apply k_nonneg.
rewrite c_subst.
rewrite k_subst.
change @map with @map_opt.
- change base_from_limb_widths with base_from_limb_widths_opt.
change @Z_shiftl_by with @Z_shiftl_by_opt.
reflexivity.
Defined.
@@ -433,31 +514,158 @@ Section Multiplication.
: mul_opt us vs = mul us vs
:= proj2_sig (mul_opt_sig us vs).
- Lemma mul_opt_rep:
+ Definition carry_mul_opt_sig (us vs : T) : { b : digits | b = carry_mul us vs }.
+ Proof.
+ eexists.
+ cbv [carry_mul].
+ erewrite <-carry_full_opt_correct by eauto.
+ erewrite <-mul_opt_correct.
+ reflexivity.
+ Defined.
+
+ Definition carry_mul_opt (us vs : T) : digits
+ := Eval cbv [proj1_sig carry_mul_opt_sig] in proj1_sig (carry_mul_opt_sig us vs).
+
+ Definition carry_mul_opt_correct us vs
+ : carry_mul_opt us vs = carry_mul us vs
+ := proj2_sig (carry_mul_opt_sig us vs).
+
+ Lemma carry_mul_opt_rep:
forall (u v : T) (x y : F modulus), PseudoMersenneBaseRep.rep u x -> PseudoMersenneBaseRep.rep v y ->
- PseudoMersenneBaseRep.rep (mul_opt u v) (x * y)%F.
+ PseudoMersenneBaseRep.rep (carry_mul_opt u v) (x * y)%F.
Proof.
intros.
- rewrite mul_opt_correct.
- change mul with PseudoMersenneBaseRep.mul.
+ rewrite carry_mul_opt_correct.
+ change carry_mul with PseudoMersenneBaseRep.mul.
auto using PseudoMersenneBaseRep.mul_rep.
Qed.
- Definition carry_mul_opt
- (is : list nat)
- (us vs : list Z)
- : list Z
- := carry_sequence_opt_cps c_ is (mul_opt us vs).
-
- Lemma carry_mul_opt_correct
- : forall (is : list nat) (us vs : list Z) (x y: F modulus),
- PseudoMersenneBaseRep.rep us x -> PseudoMersenneBaseRep.rep vs y ->
- (forall i : nat, In i is -> i < length base)%nat ->
- length (mul_opt us vs) = length base ->
- PseudoMersenneBaseRep.rep (carry_mul_opt is us vs) (x*y)%F.
+End Multiplication.
+
+Record freezePreconditions {modulus} (prm : PseudoMersenneBaseParams modulus) int_width :=
+mkFreezePreconditions {
+ lt_1_length_base : (1 < length base)%nat;
+ int_width_pos : 0 < int_width;
+ int_width_compat : forall w, In w limb_widths -> w <= int_width;
+ c_pos : 0 < c;
+ c_reduce1 : c * (Z.ones (int_width - log_cap (pred (length base)))) < max_bound 0 + 1;
+ c_reduce2 : c <= max_bound 0 - c;
+ two_pow_k_le_2modulus : 2 ^ k <= 2 * modulus
+}.
+Local Hint Resolve lt_1_length_base int_width_pos int_width_compat c_pos
+ c_reduce1 c_reduce2 two_pow_k_le_2modulus.
+
+Section Canonicalization.
+ Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}
+ (* allows caller to precompute k and c *)
+ (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_)
+ {int_width} (preconditions : freezePreconditions prm int_width).
+
+ Definition modulus_digits_opt_sig :
+ { b : digits | b = modulus_digits }.
+ Proof.
+ eexists.
+ cbv beta iota delta [modulus_digits modulus_digits' app].
+ change @max_bound with max_bound_opt.
+ rewrite c_subst.
+ change length with length_opt.
+ change minus with minus_opt.
+ change Z.add with Z_add_opt.
+ change Z.sub with Z_sub_opt.
+ change @base with base_opt.
+ reflexivity.
+ Defined.
+
+ Definition modulus_digits_opt : digits
+ := Eval cbv [proj1_sig modulus_digits_opt_sig] in proj1_sig (modulus_digits_opt_sig).
+
+ Definition modulus_digits_opt_correct
+ : modulus_digits_opt = modulus_digits
+ := proj2_sig (modulus_digits_opt_sig).
+
+
+ Definition carry_full_3_opt_cps_sig
+ {T} (f : digits -> T)
+ (us : digits)
+ : { d : T | d = f (carry_full (carry_full (carry_full us))) }.
+ Proof.
+ eexists.
+ transitivity (carry_full_opt_cps c_ (carry_full_opt_cps c_ (carry_full_opt_cps c_ f)) us).
+ Focus 2. {
+ rewrite !carry_full_opt_cps_correct by assumption; reflexivity.
+ }
+ Unfocus.
+ reflexivity.
+ Defined.
+
+ Definition carry_full_3_opt_cps {T} (f : digits -> T) (us : digits) : T
+ := Eval cbv [proj1_sig carry_full_3_opt_cps_sig] in proj1_sig (carry_full_3_opt_cps_sig f us).
+
+ Definition carry_full_3_opt_cps_correct {T} (f : digits -> T) us :
+ carry_full_3_opt_cps f us = f (carry_full (carry_full (carry_full us))) :=
+ proj2_sig (carry_full_3_opt_cps_sig f us).
+
+ Definition freeze_opt_sig (us : T) :
+ { b : digits | b = freeze us }.
Proof.
- intros is us vs x y; intros.
- change (carry_mul_opt _ _ _) with (carry_sequence_opt_cps c_ is (mul_opt us vs)).
- apply carry_sequence_opt_cps_rep, mul_opt_rep; auto.
+ eexists.
+ cbv [freeze].
+ cbv [and_term].
+ let LHS := match goal with |- ?LHS = ?RHS => LHS end in
+ let RHS := match goal with |- ?LHS = ?RHS => RHS end in
+ let RHSf := match (eval pattern (isFull (carry_full (carry_full (carry_full us)))) in RHS) with ?RHSf _ => RHSf end in
+ change (LHS = Let_In (isFull(carry_full (carry_full (carry_full us)))) RHSf).
+ let LHS := match goal with |- ?LHS = ?RHS => LHS end in
+ let RHS := match goal with |- ?LHS = ?RHS => RHS end in
+ let RHSf := match (eval pattern (carry_full (carry_full (carry_full us))) in RHS) with ?RHSf _ => RHSf end in
+ rewrite <-carry_full_3_opt_cps_correct with (f := RHSf).
+ cbv beta iota delta [and_term isFull isFull'].
+ change length with length_opt.
+ change @max_bound with max_bound_opt.
+ rewrite c_subst.
+ change @max_ones with max_ones_opt.
+ change @base with base_opt.
+ change minus with minus_opt.
+ change @map with @map_opt.
+ change Z.sub with Z_sub_opt at 1.
+ rewrite <-modulus_digits_opt_correct.
+ reflexivity.
+ Defined.
+
+ Definition freeze_opt (us : T) : digits
+ := Eval cbv beta iota delta [proj1_sig freeze_opt_sig] in proj1_sig (freeze_opt_sig us).
+
+ Definition freeze_opt_correct us
+ : freeze_opt us = freeze us
+ := proj2_sig (freeze_opt_sig us).
+
+ Lemma freeze_opt_canonical: forall us vs x,
+ @pre_carry_bounds _ _ int_width us -> PseudoMersenneBaseRep.rep us x ->
+ @pre_carry_bounds _ _ int_width vs -> PseudoMersenneBaseRep.rep vs x ->
+ freeze_opt us = freeze_opt vs.
+ Proof.
+ intros.
+ rewrite !freeze_opt_correct.
+ change PseudoMersenneBaseRep.rep with rep in *.
+ eapply freeze_canonical with (B := int_width); eauto.
Qed.
-End Multiplication. \ No newline at end of file
+
+ Lemma freeze_opt_preserves_rep : forall us x, PseudoMersenneBaseRep.rep us x ->
+ PseudoMersenneBaseRep.rep (freeze_opt us) x.
+ Proof.
+ intros.
+ rewrite freeze_opt_correct.
+ change PseudoMersenneBaseRep.rep with rep in *.
+ eapply freeze_preserves_rep; eauto.
+ Qed.
+
+ Lemma freeze_opt_spec : forall us vs x, rep us x -> rep vs x ->
+ @pre_carry_bounds _ _ int_width us ->
+ @pre_carry_bounds _ _ int_width vs ->
+ (PseudoMersenneBaseRep.rep (freeze_opt us) x /\ freeze_opt us = freeze_opt vs).
+ Proof.
+ split; eauto using freeze_opt_canonical.
+ auto using freeze_opt_preserves_rep.
+ Qed.
+
+End Canonicalization. \ No newline at end of file