aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2016-06-28 21:22:39 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2016-06-28 21:22:39 -0400
commitce8862f3df1d7a9a961e8d0823fff2353e3ac7c2 (patch)
tree458d3421e5e340a808387a250687790c778adbda /src
parentb1b4493de3522c71ac3f40081eb95aeba5361dd0 (diff)
BaseSystem encode function is no longer naive; it does a mod/div loop rather than sticking the value of the Z input in the first digit. The condition that c is positive has been added to PseudoMersenneBaseParams--it is necessary for this encode and for canonicalization, for which it was previously a section variable.
Diffstat (limited to 'src')
-rw-r--r--src/BaseSystem.v13
-rw-r--r--src/BaseSystemProofs.v69
-rw-r--r--src/ModularArithmetic/ModularBaseSystem.v2
-rw-r--r--src/ModularArithmetic/ModularBaseSystemOpt.v3
-rw-r--r--src/ModularArithmetic/ModularBaseSystemProofs.v63
-rw-r--r--src/ModularArithmetic/PseudoMersenneBaseParams.v1
6 files changed, 125 insertions, 26 deletions
diff --git a/src/BaseSystem.v b/src/BaseSystem.v
index 1985520f0..a37932de0 100644
--- a/src/BaseSystem.v
+++ b/src/BaseSystem.v
@@ -34,8 +34,17 @@ Section BaseSystem.
Definition accumulate p acc := fst p * snd p + acc.
Definition decode' bs u := fold_right accumulate 0 (combine u bs).
Definition decode := decode' base.
- (* Does not carry; z becomes the lowest and only digit. *)
- Definition encode (z : Z) := z :: nil.
+
+ (* i is current index, counts down *)
+ Fixpoint encode' z max i : digits :=
+ match i with
+ | O => nil
+ | S i' => let b := nth_default max base in
+ encode' z max i' ++ ((z mod (b i)) / (b i')) :: nil
+ end.
+
+ (* max must be greater than input; this is used to truncate last digit *)
+ Definition encode z max := encode' z max (length base).
Lemma decode'_truncate : forall bs us, decode' bs us = decode' bs (firstn (length bs) us).
Proof.
diff --git a/src/BaseSystemProofs.v b/src/BaseSystemProofs.v
index 85835aabe..a0372c60b 100644
--- a/src/BaseSystemProofs.v
+++ b/src/BaseSystemProofs.v
@@ -78,10 +78,24 @@ Section BaseSystemProofs.
induction bs; destruct us; destruct vs; boring; ring.
Qed.
- Lemma encode_rep : forall z, decode base (encode z) = z.
+ Lemma nth_default_base_nonzero : forall d, d <> 0 ->
+ forall i, nth_default d base i <> 0.
Proof.
- pose proof base_eq_1cons.
- unfold decode, encode; destruct z; boring.
+ intros.
+ rewrite nth_default_eq.
+ destruct (nth_in_or_default i base d).
+ + auto using positive_is_nonzero, base_positive.
+ + congruence.
+ Qed.
+
+ Lemma nth_default_base_pos : forall d, 0 < d ->
+ forall i, 0 < nth_default d base i.
+ Proof.
+ intros.
+ rewrite nth_default_eq.
+ destruct (nth_in_or_default i base d).
+ + rewrite <-gt_lt_symmetry; auto using base_positive.
+ + congruence.
Qed.
Lemma mul_each_base : forall us bs c,
@@ -544,5 +558,52 @@ Section BaseSystemProofs.
apply length0_nil; rewrite <-rev_length, rev_nil.
reflexivity.
Qed.
+ Definition encode'_zero z max : encode' base z max 0%nat = nil := eq_refl.
+ Definition encode'_succ z max i : encode' base z max (S i) =
+ encode' base z max i ++ ((z mod (nth_default max base (S i))) / (nth_default max base i)) :: nil := eq_refl.
+ Opaque encode'.
+ Hint Resolve encode'_zero encode'_succ.
+
+ Lemma encode'_length : forall z max i, length (encode' base z max i) = i.
+ Proof.
+ induction i; auto.
+ rewrite encode'_succ, app_length, IHi.
+ cbv [length].
+ omega.
+ Qed.
+
+ (* States that each element of the base is a positive integer multiple of the previous
+ element, and that max is a positive integer multiple of the last element. Ideally this
+ would have a better name. *)
+ Definition base_max_succ_divide max := forall i, (S i <= length base)%nat ->
+ Z.divide (nth_default max base i) (nth_default max base (S i)).
+
+ Lemma encode'_spec : forall z max, 0 < max ->
+ base_max_succ_divide max -> forall i, (i <= length base)%nat ->
+ decode' base (encode' base z max i) = z mod (nth_default max base i).
+ Proof.
+ induction i; intros.
+ + rewrite encode'_zero, b0_1, Z.mod_1_r.
+ apply decode_nil.
+ + rewrite encode'_succ, set_higher.
+ rewrite IHi by omega.
+ rewrite encode'_length, (Z.add_comm (z mod nth_default max base i)).
+ replace (nth_default 0 base i) with (nth_default max base i) by
+ (rewrite !nth_default_eq; apply nth_indep; omega).
+ match goal with H1 : base_max_succ_divide _, H2 : (S i <= length base)%nat, H3 : 0 < max |- _ =>
+ specialize (H1 i H2);
+ rewrite (Znumtheory.Zmod_div_mod _ _ _ (nth_default_base_pos _ H _)
+ (nth_default_base_pos _ H _) H0) end.
+ rewrite <-Z.div_mod by (apply positive_is_nonzero, gt_lt_symmetry; auto using nth_default_base_pos).
+ reflexivity.
+ Qed.
+
+ Lemma encode_rep : forall z max, 0 <= z < max ->
+ base_max_succ_divide max -> decode base (encode base z max) = z.
+ Proof.
+ unfold encode; intros.
+ rewrite encode'_spec, nth_default_out_of_bounds by (omega || auto).
+ apply Z.mod_small; omega.
+ Qed.
-End BaseSystemProofs.
+End BaseSystemProofs. \ No newline at end of file
diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v
index 08545bdb4..f637cfbba 100644
--- a/src/ModularArithmetic/ModularBaseSystem.v
+++ b/src/ModularArithmetic/ModularBaseSystem.v
@@ -19,7 +19,7 @@ Section PseudoMersenneBase.
Local Notation "u ~= x" := (rep u x).
Local Hint Unfold rep.
- Definition encode (x : F modulus) := encode x ++ BaseSystem.zeros (length base - 1)%nat.
+ Definition encode (x : F modulus) := encode base x (2 ^ k).
(* Converts from length of extended base to length of base by reduction modulo M.*)
Definition reduce (us : digits) : digits :=
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v
index 116fe10e5..4f918e147 100644
--- a/src/ModularArithmetic/ModularBaseSystemOpt.v
+++ b/src/ModularArithmetic/ModularBaseSystemOpt.v
@@ -71,9 +71,10 @@ Ltac construct_params prime_modulus len k :=
cbv in lw;
eapply Build_PseudoMersenneBaseParams with (limb_widths := lw);
[ abstract (apply fold_right_and_True_forall_In_iff; simpl; repeat (split; [omega |]); auto)
- | abstract (unfold limb_widths; cbv; congruence)
+ | abstract (cbv; congruence)
| abstract brute_force_indices lw
| abstract apply prime_modulus
+ | abstract (cbv; congruence)
| abstract brute_force_indices lw].
Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits :=
diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v
index 6f82a8950..6b6ebc136 100644
--- a/src/ModularArithmetic/ModularBaseSystemProofs.v
+++ b/src/ModularArithmetic/ModularBaseSystemProofs.v
@@ -28,19 +28,47 @@ Section PseudoMersenneProofs.
autounfold; intuition.
Qed.
+ Lemma lt_modulus_2k : modulus < 2 ^ k.
+ Proof.
+ replace (2 ^ k) with (modulus + c) by (unfold c; ring).
+ pose proof c_pos; omega.
+ Qed. Hint Resolve lt_modulus_2k.
+
+ Lemma modulus_pos : 0 < modulus.
+ Proof.
+ pose proof (NumTheoryUtil.lt_1_p _ prime_modulus); omega.
+ Qed. Hint Resolve modulus_pos.
+
Lemma encode_rep : forall x : F modulus, encode x ~= x.
Proof.
intros. unfold encode, rep.
split. {
- unfold encode; simpl.
- rewrite length_zeros.
- pose proof base_length_nonzero; omega.
+ unfold BaseSystem.encode.
+ auto using encode'_length.
} {
unfold decode.
- rewrite decode_highzeros.
rewrite encode_rep.
- apply ZToField_FieldToZ.
- apply bv.
+ + apply ZToField_FieldToZ.
+ + apply bv.
+ + split; [ | etransitivity]; try (apply FieldToZ_range; auto using modulus_pos); auto.
+ + unfold base_max_succ_divide; intros.
+ match goal with H : (_ <= length base)%nat |- _ =>
+ apply le_lt_or_eq in H; destruct H end.
+ - apply Z.mod_divide.
+ * apply nth_default_base_nonzero; auto using bv, two_k_nonzero.
+ * rewrite !nth_default_eq.
+ do 2 (erewrite nth_indep with (d := 2 ^ k) (d' := 0) by omega).
+ rewrite <-!nth_default_eq.
+ apply base_succ; omega.
+ - rewrite nth_default_out_of_bounds with (n := S i) by omega.
+ rewrite nth_default_base by omega.
+ unfold k.
+ match goal with H : S _ = length base |- _ =>
+ rewrite base_length in H; rewrite <-H end.
+ erewrite sum_firstn_succ by (apply nth_error_Some_nth_default with (x0 := 0); omega).
+ rewrite Z.pow_add_r by (auto using sum_firstn_limb_widths_nonneg;
+ apply limb_widths_nonneg; rewrite nth_default_eq; apply nth_In; omega).
+ apply Z.divide_factor_r.
}
Qed.
@@ -421,7 +449,6 @@ End CarryProofs.
Section CanonicalizationProofs.
Context `{prm : PseudoMersenneBaseParams} (lt_1_length_base : (1 < length base)%nat)
{B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B)
- (c_pos : 0 < c)
(* on the first reduce step, we add at most one bit of width to the first digit *)
(c_reduce1 : c * (Z.ones (B - log_cap (pred (length base)))) < max_bound 0 + 1)
(* on the second reduce step, we add at most one bit of width to the first digit,
@@ -780,7 +807,7 @@ Section CanonicalizationProofs.
do 2 match goal with H : appcontext[S (pred (length base))] |- _ =>
erewrite <-(S_pred (length base)) in H by eauto end.
unfold carry; break_if; [ unfold carry_and_reduce | omega ].
- clear_obvious.
+ clear_obvious. pose proof c_pos.
add_set_nth; [ zero_bounds | ]; apply IHj; auto; omega.
Qed.
@@ -869,11 +896,11 @@ Section CanonicalizationProofs.
simpl.
unfold carry, carry_and_reduce; break_if; try omega.
clear_obvious; add_set_nth.
- split; [zero_bounds; carry_seq_lower_bound | ].
+ split; [pose proof c_pos; zero_bounds; carry_seq_lower_bound | ].
rewrite Z.add_comm.
apply Z.add_le_mono.
+ apply carry_bounds_0_upper; auto; omega.
- + apply Z.mul_le_mono_pos_l; auto.
+ + apply Z.mul_le_mono_pos_l; auto using c_pos.
apply Z_shiftr_ones; auto;
[ | pose proof (B_compat_log_cap (pred (length base))); omega ].
split.
@@ -938,13 +965,13 @@ Section CanonicalizationProofs.
unfold carry, carry_and_reduce; break_if; try omega.
clear_obvious; add_set_nth.
split.
- + zero_bounds; [ | carry_seq_lower_bound].
+ + pose proof c_pos; zero_bounds; [ | carry_seq_lower_bound].
apply carry_sequence_carry_full_bounds_same; auto; omega.
+ rewrite Z.add_comm.
apply Z.add_le_mono.
- apply carry_bounds_0_upper; carry_length_conditions.
- etransitivity; [ | replace c with (c * 1) by ring; reflexivity ].
- apply Z.mul_le_mono_pos_l; try omega.
+ apply Z.mul_le_mono_pos_l; try (pose proof c_pos; omega).
rewrite Z.shiftr_div_pow2 by auto.
apply Z.div_le_upper_bound; auto.
ring_simplify.
@@ -1000,7 +1027,7 @@ Section CanonicalizationProofs.
eapply Z.le_lt_trans; [ eapply carry_full_2_bounds_0; eauto | ].
replace (Z.succ 1) with (2 ^ 1) by ring.
rewrite <-max_bound_log_cap.
- ring_simplify. omega.
+ ring_simplify. pose proof c_pos; omega.
- apply carry_full_bounds; carry_length_conditions; carry_seq_lower_bound.
Qed.
@@ -1096,7 +1123,7 @@ Section CanonicalizationProofs.
pose proof carry_full_2_bounds_0.
apply Z.le_lt_trans with (m := (max_bound 0 + c) - (1 + max_bound 0));
[ apply Z.sub_le_mono_r; subst x; apply carry_full_2_bounds_0; auto;
- ring_simplify | ]; omega.
+ ring_simplify | ]; pose proof c_pos; omega.
+ rewrite carry_unaffected_low by carry_length_conditions.
assert (0 < S i < length base)%nat by omega.
intuition; right.
@@ -1122,7 +1149,7 @@ Section CanonicalizationProofs.
replace (length l) with (pred (length limb_widths)) by (rewrite limb_widths_eq; auto).
rewrite <- base_length.
unfold carry, carry_and_reduce; break_if; try omega; intros.
- add_set_nth.
+ add_set_nth. pose proof c_pos.
split.
+ zero_bounds.
- eapply carry_full_2_bounds_same; eauto; omega.
@@ -1495,7 +1522,7 @@ Section CanonicalizationProofs.
intros.
rewrite nth_default_modulus_digits.
break_if; [ | split; auto; omega].
- break_if; subst; split; auto; try rewrite <- max_bound_log_cap; omega.
+ break_if; subst; split; auto; try rewrite <- max_bound_log_cap; pose proof c_pos; omega.
Qed.
Local Hint Resolve carry_done_modulus_digits.
@@ -1595,7 +1622,7 @@ Section CanonicalizationProofs.
f_equal.
apply land_max_ones_noop with (i := 0%nat).
rewrite <-max_bound_log_cap.
- omega.
+ pose proof c_pos; omega.
+ unfold modulus_digits'; fold modulus_digits'.
rewrite map_app.
f_equal; [ apply IHi; omega | ].
@@ -2007,7 +2034,7 @@ Section CanonicalizationProofs.
pose proof (carry_full_3_done us PCB lengths_eq) as cf3_done.
rewrite carry_done_bounds in cf3_done by simpl_lengths.
specialize (cf3_done 0%nat).
- omega.
+ pose proof c_pos; omega.
- assert ((0 < i <= length base - 1)%nat) as i_range by
(simpl_lengths; apply lt_min_l in l; omega).
specialize (high_digits i i_range).
diff --git a/src/ModularArithmetic/PseudoMersenneBaseParams.v b/src/ModularArithmetic/PseudoMersenneBaseParams.v
index e20a7ed09..02d409b68 100644
--- a/src/ModularArithmetic/PseudoMersenneBaseParams.v
+++ b/src/ModularArithmetic/PseudoMersenneBaseParams.v
@@ -15,6 +15,7 @@ Class PseudoMersenneBaseParams (modulus : Z) := {
prime_modulus : Znumtheory.prime modulus;
k := sum_firstn limb_widths (length limb_widths);
c := 2 ^ k - modulus;
+ c_pos : 0 < c;
limb_widths_match_modulus : forall i j,
(i < length limb_widths)%nat ->
(j < length limb_widths)%nat ->