diff options
author | Jason Gross <jasongross9@gmail.com> | 2016-08-08 15:13:01 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-08 15:13:01 -0700 |
commit | ef4656a95a449fdd857bcf2a62cea87f7457844b (patch) | |
tree | 683df6a5b74020e5bea6affa3288d13bf2fa0ae0 /src/ModularArithmetic/ModularBaseSystemOpt.v | |
parent | 4eb7a0cdf53f26cd3597e66d99389645baec45c6 (diff) |
Massively speed up construct_params (#48)
By using reflection, we can speed up the overall build time by about
half a minute. By fully reducing [base_from_limb_widths] once we plug
in arguments, and not before, we can get about another half-minute in
8.5pl2 (and a great deal more in 8.6, where vm_compute no longer is
slow; see https://coq.inria.fr/bugs/show_bug.cgi?id=5004).
Times in 8.5pl2:
After | File Name | Before || Change
---------------------------------------------------------------------------
0m27.80s | Total | 1m19.59s || -0m51.78s
---------------------------------------------------------------------------
0m04.71s | Experiments/SpecificCurve25519 | 0m26.78s || -0m22.07s
0m17.13s | Specific/GF25519 | 0m39.10s || -0m21.97s
0m02.27s | Specific/GF1305 | 0m09.02s || -0m06.75s
0m02.75s | ModularArithmetic/ModularBaseSystemOpt | 0m03.77s || -0m01.02s
0m00.95s | ModularArithmetic/ModularBaseSystemField | 0m00.93s || +0m00.01s
Diffstat (limited to 'src/ModularArithmetic/ModularBaseSystemOpt.v')
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemOpt.v | 86 |
1 files changed, 64 insertions, 22 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index 884c0ea72..21295687b 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -59,39 +59,81 @@ Ltac opt_step := destruct e end. -Ltac brute_force_indices limb_widths := - intros; unfold sum_firstn, limb_widths; cbv [length limb_widths] in *; - repeat match goal with - | _ => progress simpl in * - | [H : (0 + _ < _)%nat |- _ ] => simpl in H - | [H : (S _ + _ < S _)%nat |- _ ] => simpl in H - | [H : (S _ < S _)%nat |- _ ] => apply lt_S_n in H - | [H : (?x + _ < _)%nat |- _ ] => is_var x; destruct x - | [H : (?x < _)%nat |- _ ] => is_var x; destruct x - | _ => omega - end. - - -Definition limb_widths_from_len len k := Eval compute in - (fix loop i prev := +Definition limb_widths_from_len_step loop len k := + (fun i prev => match i with | O => nil - | S i' => let x := (if (Z.eq_dec ((k * Z.of_nat (len - i + 1)) mod (Z.of_nat len)) 0) + | S i' => let x := (if (Z.eqb ((k * Z.of_nat (len - i + 1)) mod (Z.of_nat len)) 0) then (k * Z.of_nat (len - i + 1)) / Z.of_nat len else (k * Z.of_nat (len - i + 1)) / Z.of_nat len + 1)in x - prev:: (loop i' x) - end) len 0. + end). +Definition limb_widths_from_len len k := + (fix loop i prev := limb_widths_from_len_step loop len k i prev) len 0. + +Definition brute_force_indices0 lw : bool + := List.fold_right + andb true + (List.map + (fun i + => List.fold_right + andb true + (List.map + (fun j + => sum_firstn lw (i + j) <=? sum_firstn lw i + sum_firstn lw j) + (seq 0 (length lw - i)))) + (seq 0 (length lw))). + +Lemma brute_force_indices_correct0 lw + : brute_force_indices0 lw = true -> forall i j : nat, + (i + j < length lw)%nat -> sum_firstn lw (i + j) <= sum_firstn lw i + sum_firstn lw j. +Proof. + unfold brute_force_indices0. + progress repeat setoid_rewrite fold_right_andb_true_map_iff. + setoid_rewrite in_seq. + setoid_rewrite Z.leb_le. + eauto with omega. +Qed. + +Definition brute_force_indices1 lw : bool + := List.fold_right + andb true + (List.map + (fun i + => List.fold_right + andb true + (List.map + (fun j + => let w_sum := sum_firstn lw in + sum_firstn lw (length lw) + w_sum (i + j - length lw)%nat <=? w_sum i + w_sum j) + (seq (length lw - i) (length lw - (length lw - i))))) + (seq 1 (length lw - 1))). + +Lemma brute_force_indices_correct1 lw + : brute_force_indices1 lw = true -> forall i j : nat, + (i < length lw)%nat -> + (j < length lw)%nat -> + (i + j >= length lw)%nat -> + let w_sum := sum_firstn lw in + sum_firstn lw (length lw) + w_sum (i + j - length lw)%nat <= w_sum i + w_sum j. +Proof. + unfold brute_force_indices1. + progress repeat setoid_rewrite fold_right_andb_true_map_iff. + setoid_rewrite in_seq. + setoid_rewrite Z.leb_le. + eauto with omega. +Qed. Ltac construct_params prime_modulus len k := - let lw := fresh "lw" in set (lw := limb_widths_from_len len k); - cbv in lw; + let lwv := (eval cbv in (limb_widths_from_len len k)) in + let lw := fresh "lw" in pose lwv as lw; eapply Build_PseudoMersenneBaseParams with (limb_widths := lw); [ abstract (apply fold_right_and_True_forall_In_iff; simpl; repeat (split; [omega |]); auto) | abstract (cbv; congruence) - | abstract brute_force_indices lw + | abstract (refine (@brute_force_indices_correct0 lw _); vm_cast_no_check (eq_refl true)) | abstract apply prime_modulus | abstract (cbv; congruence) - | abstract brute_force_indices lw]. + | abstract (refine (@brute_force_indices_correct1 lw _); vm_cast_no_check (eq_refl true))]. Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits := match limb_widths with @@ -217,7 +259,7 @@ Section Carries. eexists. intros H. rewrite <-carry_gen_opt_correct by assumption. cbv beta iota delta [carry_gen_opt]. - match goal with |- appcontext[?a & Z_ones_opt _] => + match goal with |- appcontext[?a & Z_ones_opt _] => let LHS := match goal with |- ?LHS = ?RHS => LHS end in let RHS := match goal with |- ?LHS = ?RHS => RHS end in let RHSf := match (eval pattern (a) in RHS) with ?RHSf _ => RHSf end in |