aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ModularBaseSystemOpt.v
diff options
context:
space:
mode:
authorGravatar Jason Gross <jasongross9@gmail.com>2016-08-08 15:13:01 -0700
committerGravatar GitHub <noreply@github.com>2016-08-08 15:13:01 -0700
commitef4656a95a449fdd857bcf2a62cea87f7457844b (patch)
tree683df6a5b74020e5bea6affa3288d13bf2fa0ae0 /src/ModularArithmetic/ModularBaseSystemOpt.v
parent4eb7a0cdf53f26cd3597e66d99389645baec45c6 (diff)
Massively speed up construct_params (#48)
By using reflection, we can speed up the overall build time by about half a minute. By fully reducing [base_from_limb_widths] once we plug in arguments, and not before, we can get about another half-minute in 8.5pl2 (and a great deal more in 8.6, where vm_compute no longer is slow; see https://coq.inria.fr/bugs/show_bug.cgi?id=5004). Times in 8.5pl2: After | File Name | Before || Change --------------------------------------------------------------------------- 0m27.80s | Total | 1m19.59s || -0m51.78s --------------------------------------------------------------------------- 0m04.71s | Experiments/SpecificCurve25519 | 0m26.78s || -0m22.07s 0m17.13s | Specific/GF25519 | 0m39.10s || -0m21.97s 0m02.27s | Specific/GF1305 | 0m09.02s || -0m06.75s 0m02.75s | ModularArithmetic/ModularBaseSystemOpt | 0m03.77s || -0m01.02s 0m00.95s | ModularArithmetic/ModularBaseSystemField | 0m00.93s || +0m00.01s
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystemOpt.v')
-rw-r--r--src/ModularArithmetic/ModularBaseSystemOpt.v86
1 files changed, 64 insertions, 22 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v
index 884c0ea72..21295687b 100644
--- a/src/ModularArithmetic/ModularBaseSystemOpt.v
+++ b/src/ModularArithmetic/ModularBaseSystemOpt.v
@@ -59,39 +59,81 @@ Ltac opt_step :=
destruct e
end.
-Ltac brute_force_indices limb_widths :=
- intros; unfold sum_firstn, limb_widths; cbv [length limb_widths] in *;
- repeat match goal with
- | _ => progress simpl in *
- | [H : (0 + _ < _)%nat |- _ ] => simpl in H
- | [H : (S _ + _ < S _)%nat |- _ ] => simpl in H
- | [H : (S _ < S _)%nat |- _ ] => apply lt_S_n in H
- | [H : (?x + _ < _)%nat |- _ ] => is_var x; destruct x
- | [H : (?x < _)%nat |- _ ] => is_var x; destruct x
- | _ => omega
- end.
-
-
-Definition limb_widths_from_len len k := Eval compute in
- (fix loop i prev :=
+Definition limb_widths_from_len_step loop len k :=
+ (fun i prev =>
match i with
| O => nil
- | S i' => let x := (if (Z.eq_dec ((k * Z.of_nat (len - i + 1)) mod (Z.of_nat len)) 0)
+ | S i' => let x := (if (Z.eqb ((k * Z.of_nat (len - i + 1)) mod (Z.of_nat len)) 0)
then (k * Z.of_nat (len - i + 1)) / Z.of_nat len
else (k * Z.of_nat (len - i + 1)) / Z.of_nat len + 1)in
x - prev:: (loop i' x)
- end) len 0.
+ end).
+Definition limb_widths_from_len len k :=
+ (fix loop i prev := limb_widths_from_len_step loop len k i prev) len 0.
+
+Definition brute_force_indices0 lw : bool
+ := List.fold_right
+ andb true
+ (List.map
+ (fun i
+ => List.fold_right
+ andb true
+ (List.map
+ (fun j
+ => sum_firstn lw (i + j) <=? sum_firstn lw i + sum_firstn lw j)
+ (seq 0 (length lw - i))))
+ (seq 0 (length lw))).
+
+Lemma brute_force_indices_correct0 lw
+ : brute_force_indices0 lw = true -> forall i j : nat,
+ (i + j < length lw)%nat -> sum_firstn lw (i + j) <= sum_firstn lw i + sum_firstn lw j.
+Proof.
+ unfold brute_force_indices0.
+ progress repeat setoid_rewrite fold_right_andb_true_map_iff.
+ setoid_rewrite in_seq.
+ setoid_rewrite Z.leb_le.
+ eauto with omega.
+Qed.
+
+Definition brute_force_indices1 lw : bool
+ := List.fold_right
+ andb true
+ (List.map
+ (fun i
+ => List.fold_right
+ andb true
+ (List.map
+ (fun j
+ => let w_sum := sum_firstn lw in
+ sum_firstn lw (length lw) + w_sum (i + j - length lw)%nat <=? w_sum i + w_sum j)
+ (seq (length lw - i) (length lw - (length lw - i)))))
+ (seq 1 (length lw - 1))).
+
+Lemma brute_force_indices_correct1 lw
+ : brute_force_indices1 lw = true -> forall i j : nat,
+ (i < length lw)%nat ->
+ (j < length lw)%nat ->
+ (i + j >= length lw)%nat ->
+ let w_sum := sum_firstn lw in
+ sum_firstn lw (length lw) + w_sum (i + j - length lw)%nat <= w_sum i + w_sum j.
+Proof.
+ unfold brute_force_indices1.
+ progress repeat setoid_rewrite fold_right_andb_true_map_iff.
+ setoid_rewrite in_seq.
+ setoid_rewrite Z.leb_le.
+ eauto with omega.
+Qed.
Ltac construct_params prime_modulus len k :=
- let lw := fresh "lw" in set (lw := limb_widths_from_len len k);
- cbv in lw;
+ let lwv := (eval cbv in (limb_widths_from_len len k)) in
+ let lw := fresh "lw" in pose lwv as lw;
eapply Build_PseudoMersenneBaseParams with (limb_widths := lw);
[ abstract (apply fold_right_and_True_forall_In_iff; simpl; repeat (split; [omega |]); auto)
| abstract (cbv; congruence)
- | abstract brute_force_indices lw
+ | abstract (refine (@brute_force_indices_correct0 lw _); vm_cast_no_check (eq_refl true))
| abstract apply prime_modulus
| abstract (cbv; congruence)
- | abstract brute_force_indices lw].
+ | abstract (refine (@brute_force_indices_correct1 lw _); vm_cast_no_check (eq_refl true))].
Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits :=
match limb_widths with
@@ -217,7 +259,7 @@ Section Carries.
eexists. intros H.
rewrite <-carry_gen_opt_correct by assumption.
cbv beta iota delta [carry_gen_opt].
- match goal with |- appcontext[?a & Z_ones_opt _] =>
+ match goal with |- appcontext[?a & Z_ones_opt _] =>
let LHS := match goal with |- ?LHS = ?RHS => LHS end in
let RHS := match goal with |- ?LHS = ?RHS => RHS end in
let RHSf := match (eval pattern (a) in RHS) with ?RHSf _ => RHSf end in