aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2016-08-12 19:00:54 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2016-08-12 19:00:54 -0400
commit079b0f4b019d9bd6773c9f6d07256aa92fe01146 (patch)
treed8c603d6b033548f87baa736a4b5b61e188379f3 /src/ModularArithmetic
parent8c106350250c61b06afeb64d580212abd6c63ab2 (diff)
New and improved conversion proofs (final conditions proven, invariant step unproven)
Diffstat (limited to 'src/ModularArithmetic')
-rw-r--r--src/ModularArithmetic/Pow2BaseProofs.v303
1 files changed, 294 insertions, 9 deletions
diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v
index c863c2e8f..b05b2f78d 100644
--- a/src/ModularArithmetic/Pow2BaseProofs.v
+++ b/src/ModularArithmetic/Pow2BaseProofs.v
@@ -35,16 +35,17 @@ Hint Rewrite Bool.andb_true_r Bool.andb_false_r Bool.orb_true_r Bool.orb_false_r
Bool.andb_true_l Bool.andb_false_l Bool.orb_true_l Bool.orb_false_l : Ztestbit.
(* TODO : move *)
-Lemma testbit_pow2_mod : forall a n i, 0 <= i -> 0 <= n ->
+Lemma testbit_pow2_mod : forall a n i, 0 <= n ->
Z.testbit (Z.pow2_mod a n) i = if Z_lt_dec i n then Z.testbit a i else false.
Proof.
-cbv [Z.pow2_mod]; intros.
-repeat match goal with
+ cbv [Z.pow2_mod]; intros; destruct (Z_le_dec 0 i);
+ repeat match goal with
+ | |- _ => rewrite Z.testbit_neg_r by omega
| |- _ => break_if
| |- _ => omega
| |- _ => reflexivity
| |- _ => progress autorewrite with Ztestbit
- end.
+ end.
Qed.
Hint Rewrite testbit_pow2_mod using omega : Ztestbit.
@@ -783,10 +784,16 @@ Section UniformBase.
End UniformBase.
-Section TestbitDecode.
+Section SplitIndex.
+ (* This section defines [split_index], which for a list of bounded digits
+ splits a bit index in the decoded value into a digit index and a bit
+ index within the digit. Examples:
+ limb_widths [4;4] : split_index 6 = (1,2)
+ limb_widths [26,25,26] : split_index 30 = (1,4)
+ limb_widths [26,25,26] : split_index 51 = (2,0)
+ *)
Local Notation "u # i" := (nth_default 0 u i) (at level 30).
- (* splits a bit index into a digit index and an index within the digit*)
Function split_index' i index lw :=
match lw with
| nil => (index, i)
@@ -803,7 +810,32 @@ Section TestbitDecode.
end.
Qed.
- Lemma snd_split_index'_nonneg : forall i index lw, (0 <= i) ->
+ Lemma split_index'_done_case : forall i index lw, 0 <= i ->
+ (forall x, In x lw -> 0 <= x) ->
+ if Z_lt_dec i (sum_firstn lw (length lw))
+ then (fst (split_index' i index lw) - index < length lw)%nat
+ else (fst (split_index' i index lw) - index = length lw)%nat.
+ Proof.
+ intros; functional induction (split_index' i index lw);
+ repeat match goal with
+ | |- _ => break_if
+ | |- _ => rewrite sum_firstn_nil in *
+ | |- _ => rewrite sum_firstn_succ_cons in *
+ | |- _ => progress distr_length
+ | |- _ => progress (simpl fst; simpl snd)
+ | H : appcontext [split_index' ?a ?b ?c] |- _ =>
+ unique pose proof (split_index'_ge_index a b c)
+ | H : appcontext [sum_firstn ?l ?i] |- _ =>
+ let H0 := fresh "H" in
+ assert (forall x, In x l -> 0 <= x) by auto using in_cons;
+ unique pose proof (sum_firstn_limb_widths_nonneg H0 i)
+ | |- _ => progress specialize_by assumption
+ | |- _ => progress specialize_by omega
+ | |- _ => omega
+ end.
+ Qed.
+
+ Lemma snd_split_index'_nonneg : forall index lw i, (0 <= i) ->
(0 <= snd (split_index' i index lw)).
Proof.
intros; functional induction (split_index' i index lw);
@@ -865,7 +897,7 @@ Section TestbitDecode.
Proof.
cbv [digit_index bit_index split_index]; intros.
pose proof (split_index'_correct n 0 limb_widths).
- pose proof (snd_split_index'_nonneg n 0 limb_widths).
+ pose proof (snd_split_index'_nonneg 0 limb_widths n).
specialize_by assumption.
repeat match goal with
| |- _ => progress autorewrite with Ztestbit natsimplify in *
@@ -887,7 +919,260 @@ Section TestbitDecode.
}
Qed.
-End TestbitDecode.
+ Lemma split_index_eqn : forall i, 0 <= i ->
+ sum_firstn limb_widths (digit_index i) + bit_index i = i.
+ Proof.
+ cbv [digit_index bit_index split_index]; intros.
+ erewrite <-split_index'_correct.
+ repeat f_equal; omega.
+ Qed.
+
+ Lemma bit_index_nonneg : forall i, 0 <= i -> 0 <= bit_index i.
+ Proof.
+ cbv [bit_index split_index].
+ exact (snd_split_index'_nonneg _ _).
+ Qed.
+
+ Lemma digit_index_done_case : forall i, 0 <= i ->
+ if Z_lt_dec i (sum_firstn limb_widths (length limb_widths))
+ then (digit_index i < length limb_widths)%nat
+ else (digit_index i = length limb_widths).
+ Admitted.
+
+ Lemma digit_index_not_done : forall i, 0 <= i ->
+ i < (sum_firstn limb_widths (length limb_widths)) ->
+ (digit_index i < length limb_widths)%nat.
+ Admitted.
+
+ Lemma bit_index_pos_iff : forall i, 0 <= i ->
+ 0 < limb_widths # (digit_index i) - bit_index i <->
+ i < sum_firstn limb_widths (length limb_widths).
+
+ Admitted.
+
+ Lemma le_remaining_bits : forall i, 0 <= i < sum_firstn limb_widths (length limb_widths) ->
+ 0 <= sum_firstn limb_widths (length limb_widths)
+ - (i + (limb_widths # (digit_index i) - bit_index i)).
+ Admitted.
+
+ Lemma same_digit_bit_index_sub : forall i j, 0 <= i <= j ->
+ digit_index i = digit_index j ->
+ bit_index j - bit_index i = j - i.
+ Admitted.
+
+End SplitIndex.
+
+Section ConversionHelper.
+ Local Hint Resolve in_eq in_cons.
+
+ (* TODO : ZUtil? *)
+ (* concatenates first n bits of a with all bits of b *)
+ Definition concat_bits n a b := Z.lor (Z.pow2_mod a n) (b << n).
+
+ Lemma concat_bits_spec : forall a b n i, 0 <= n ->
+ Z.testbit (concat_bits n a b) i =
+ if Z_lt_dec i n then Z.testbit a i else Z.testbit b (i - n).
+ Proof.
+ repeat match goal with
+ | |- _ => progress cbv [concat_bits]; intros
+ | |- _ => progress autorewrite with Ztestbit
+ | |- _ => rewrite testbit_pow2_mod by omega
+ | |- _ => rewrite Z.testbit_neg_r by omega
+ | |- _ => break_if
+ | |- appcontext [Z.testbit (?a << ?b) ?i] => destruct (Z_le_dec 0 i)
+ | |- (?a || ?b)%bool = ?a => replace b with false
+ | |- _ => reflexivity
+ end.
+ Qed.
+
+ Definition update_by_concat_bits num_low_bits bits x := concat_bits num_low_bits x bits.
+
+ Ltac pair_destruct :=
+ match goal with H : ?t = (?f,?s) |- _ =>
+ replace t with (fst t, snd t) in H by (destruct t; reflexivity);
+ inversion H; subst; clear H
+ end.
+
+End ConversionHelper.
+
+Section Conversion.
+ Context {widthB : Z} (widthB_pos : 0 < widthB).
+ Context {limb_widthsA} (limb_widthsA_nonneg : forall w, In w limb_widthsA -> 0 <= w)
+ {limb_widthsB} (limb_widthsB_uniform : forall w, In w limb_widthsB -> w = widthB).
+ Local Notation bitsIn lw := (sum_firstn lw (length lw)).
+ Context (bits_fit : bitsIn limb_widthsA <= bitsIn limb_widthsB).
+ Local Notation decodeA := (BaseSystem.decode (base_from_limb_widths limb_widthsA)).
+ Local Notation decodeB := (BaseSystem.decode (base_from_limb_widths limb_widthsB)).
+ Local Notation "u # i" := (nth_default 0 u i) (at level 30).
+ Local Hint Resolve in_eq in_cons.
+ Local Opaque bounded.
+ Check digit_index.
+ Check bit_index.
+
+ Function convert' inp i out
+ {measure (fun x => Z.to_nat ((bitsIn limb_widthsA) - Z.of_nat x)) i}
+ :=
+ let digitA := digit_index limb_widthsA (Z.of_nat i) in
+ let digitB := digit_index limb_widthsB (Z.of_nat i) in
+ let indexA := bit_index limb_widthsA (Z.of_nat i) in
+ let indexB := bit_index limb_widthsB (Z.of_nat i) in
+ let dist := Z.min (limb_widthsA # digitA - indexA) (limb_widthsB # digitB - indexB) in
+ let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in
+ if Z_le_dec dist 0 then out
+ else convert' inp (i + Z.to_nat dist)%nat (update_nth digitB (update_by_concat_bits indexB bitsA) out).
+ Proof.
+ intros.
+ assert (0 <= Z.of_nat i < bitsIn limb_widthsA) as range_i. {
+ split; try apply bit_index_pos_iff; auto using Nat2Z.is_nonneg.
+ lia.
+ }
+ repeat match goal with
+ | |- _ => progress intros
+ | |- _ => rewrite Z2Nat.id
+ | |- _ => rewrite Nat2Z.inj_add
+ | |- (Z.to_nat _ < Z.to_nat _)%nat => apply Z2Nat.inj_lt
+ | |- (?a - _ < ?a - _) => apply Z.sub_lt_mono_l
+ | |- _ => lia
+ end.
+ apply Z.min_case_strong; intros;
+ (etransitivity;
+ [ apply le_remaining_bits with (limb_widths := limb_widthsA) (i := Z.of_nat i); auto | ]);
+ lia.
+ Defined.
+
+ Definition convert'_invariant inp i out :=
+ length out = length limb_widthsB
+ /\ bounded limb_widthsB out
+ /\ forall n, Z.testbit (decodeB out) n = if Z_lt_dec n (Z.of_nat i) then Z.testbit (decodeA inp) n else false.
+
+ Lemma convert'_bounded_step : forall inp i out,
+ bounded limb_widthsB out ->
+ let digitA := digit_index limb_widthsA (Z.of_nat i) in
+ let digitB := digit_index limb_widthsB (Z.of_nat i) in
+ let indexA := bit_index limb_widthsA (Z.of_nat i) in
+ let indexB := bit_index limb_widthsB (Z.of_nat i) in
+ let dist := Z.min (limb_widthsA # digitA - indexA)
+ (limb_widthsB # digitB - indexB) in
+ let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in
+ bounded limb_widthsB (update_nth digitB (update_by_concat_bits indexB bitsA) out).
+ Proof.
+ Admitted.
+
+ Lemma convert'_invariant_step : forall inp i out,
+ bounded limb_widthsA inp ->
+ convert'_invariant inp i out ->
+ let digitA := digit_index limb_widthsA (Z.of_nat i) in
+ let digitB := digit_index limb_widthsB (Z.of_nat i) in
+ let indexA := bit_index limb_widthsA (Z.of_nat i) in
+ let indexB := bit_index limb_widthsB (Z.of_nat i) in
+ let dist := Z.min (limb_widthsA # digitA - indexA)
+ (limb_widthsB # digitB - indexB) in
+ let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in
+ 0 <= dist ->
+ convert'_invariant inp (i + Z.to_nat dist)%nat
+ (update_nth digitB (update_by_concat_bits indexB bitsA) out).
+ Proof.
+ Check testbit_decode.
+ SearchAbout update_nth.
+ Time repeat match goal with
+ | |- _ => progress intros; cbv [convert'_invariant] in *
+ | H : length _ = length limb_widthsB |- _ => rewrite H
+ | |- _ => rewrite Z.testbit_neg_r by omega
+ | |- _ => rewrite Nat2Z.inj_add
+ | |- _ => rewrite Z2Nat.id in *
+ | |- _ => rewrite update_nth_nth_default_full
+ | |- _ => rewrite nth_default_out_of_bounds by omega
+ | |- _ => rewrite testbit_decode by eauto using uniform_limb_widths_nonneg
+ | |- _ => progress cbv [update_by_concat_bits];
+ rewrite concat_bits_spec by (apply bit_index_nonneg, Nat2Z.is_nonneg)
+ | H : _ /\ _ |- _ => destruct H
+ | |- _ => break_if
+ | |- _ => split
+ | H : forall n, Z.testbit (decodeB _) n = _ |- Z.testbit (decodeB _) ?n = _ =>
+ specialize (H n)
+ | H : _ = Z.testbit (decodeA _) ?n |- Z.testbit (decodeB _) ?n = Z.testbit (decodeA _) ?n =>
+ rewrite <-H
+ | H : 0 <= ?n |- appcontext[Z.testbit (BaseSystem.decode _ _) ?n] =>
+ rewrite testbit_decode by
+ (distr_length; eauto using uniform_limb_widths_nonneg, convert'_bounded_step)
+ | |- Z.testbit (decodeB _) ?n = Z.testbit _ ?n =>
+ destruct (Z_le_dec 0 n)
+ | |- _ => solve [distr_length]
+ | |- _ => eapply convert'_bounded_step; solve [eauto]
+ | |- _ => lia
+ end.
+ match goal with
+ | d := digit_index ?lw ?j,
+ b := bit_index ?lw ?j,
+ H : digit_index ?lw ?i = ?d |- _ =>
+ let A := fresh "H" in
+ let B := fresh "H" in
+ ( (assert (0 <= i <= j) as A by omega)
+ || (assert (0 <= j <= i) as A by omega; symmetry in H));
+ assert (forall w, In w lw -> 0 <= w) as B by eauto using uniform_limb_widths_nonneg;
+ pose proof (same_digit_bit_index_sub lw B _ _ A H); subst b; lia
+ end.
+ Qed.
+
+ Lemma convert'_termination_condition : forall i, 0 <= i ->
+ let digitA := digit_index limb_widthsA i in
+ let digitB := digit_index limb_widthsB i in
+ let indexA := bit_index limb_widthsA i in
+ let indexB := bit_index limb_widthsB i in
+ let dist := Z.min (limb_widthsA # digitA - indexA)
+ (limb_widthsB # digitB - indexB) in
+ dist <= 0 -> bitsIn limb_widthsA = i.
+ Proof.
+ Admitted.
+
+ Lemma convert'_invariant_holds : forall inp i out,
+ bounded limb_widthsA inp ->
+ convert'_invariant inp i out ->
+ convert'_invariant inp (Z.to_nat (bitsIn limb_widthsA)) (convert' inp i out).
+ Proof.
+ intros until 1; functional induction (convert' inp i out);
+ repeat match goal with
+ | |- _ => progress intros
+ | H : convert'_invariant _ _ _ |- convert'_invariant _ _ (convert' _ _ _) =>
+ eapply convert'_invariant_step in H; solve [eauto]
+ | H : convert'_invariant _ _ ?out |- convert'_invariant _ _ ?out => progress cbv [convert'_invariant] in *
+ | |- _ => rewrite Z2Nat.id
+ | H : _ /\ _ |- _ => destruct H
+ | |- _ => split
+ | |- _ => erewrite convert'_termination_condition by (eassumption || eauto using Nat2Z.is_nonneg)
+ | |- _ => assumption
+ | |- _ => lia
+ | |- _ => solve [eauto]
+ end.
+ Qed.
+
+ Definition convert us := convert' us 0 (BaseSystem.zeros (length limb_widthsB)).
+
+ Lemma convert_correct : forall us, length us = length limb_widthsA ->
+ bounded limb_widthsA us ->
+ decodeA us = decodeB (convert us).
+ Proof.
+ repeat match goal with
+ | |- _ => progress intros
+ | |- _ => progress cbv [convert convert'_invariant] in *
+ | |- _ => progress change (Z.of_nat 0) with 0 in *
+ | |- _ => progress rewrite ?length_zeros, ?zeros_rep, ?Z.testbit_0_l
+ | H : length _ = length limb_widthsA |- _ => rewrite H
+ | |- _ => rewrite Z.testbit_neg_r by omega
+ | |- _ => break_if
+ | |- _ => split
+ | H : _ /\ _ |- _ => destruct H
+ | H : forall n, Z.testbit ?x n = _ |- _ = ?x => apply Z.bits_inj'; intros; rewrite H
+ | |- _ = decodeB (convert' ?a ?b ?c) => edestruct (convert'_invariant_holds a b c)
+ | |- _ => apply testbit_decode_high
+ | |- _ => assumption
+ | |- _ => reflexivity
+ | |- _ => lia
+ | |- _ => solve [auto using sum_firstn_limb_widths_nonneg]
+ | |- _ => rewrite Z2Nat.id in *
+ end.
+ Qed.
+End Conversion.
Section carrying_helper.
Context {limb_widths} (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w).