diff options
author | 2016-08-12 19:00:54 -0400 | |
---|---|---|
committer | 2016-08-12 19:00:54 -0400 | |
commit | 079b0f4b019d9bd6773c9f6d07256aa92fe01146 (patch) | |
tree | d8c603d6b033548f87baa736a4b5b61e188379f3 /src/ModularArithmetic | |
parent | 8c106350250c61b06afeb64d580212abd6c63ab2 (diff) |
New and improved conversion proofs (final conditions proven, invariant step unproven)
Diffstat (limited to 'src/ModularArithmetic')
-rw-r--r-- | src/ModularArithmetic/Pow2BaseProofs.v | 303 |
1 files changed, 294 insertions, 9 deletions
diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v index c863c2e8f..b05b2f78d 100644 --- a/src/ModularArithmetic/Pow2BaseProofs.v +++ b/src/ModularArithmetic/Pow2BaseProofs.v @@ -35,16 +35,17 @@ Hint Rewrite Bool.andb_true_r Bool.andb_false_r Bool.orb_true_r Bool.orb_false_r Bool.andb_true_l Bool.andb_false_l Bool.orb_true_l Bool.orb_false_l : Ztestbit. (* TODO : move *) -Lemma testbit_pow2_mod : forall a n i, 0 <= i -> 0 <= n -> +Lemma testbit_pow2_mod : forall a n i, 0 <= n -> Z.testbit (Z.pow2_mod a n) i = if Z_lt_dec i n then Z.testbit a i else false. Proof. -cbv [Z.pow2_mod]; intros. -repeat match goal with + cbv [Z.pow2_mod]; intros; destruct (Z_le_dec 0 i); + repeat match goal with + | |- _ => rewrite Z.testbit_neg_r by omega | |- _ => break_if | |- _ => omega | |- _ => reflexivity | |- _ => progress autorewrite with Ztestbit - end. + end. Qed. Hint Rewrite testbit_pow2_mod using omega : Ztestbit. @@ -783,10 +784,16 @@ Section UniformBase. End UniformBase. -Section TestbitDecode. +Section SplitIndex. + (* This section defines [split_index], which for a list of bounded digits + splits a bit index in the decoded value into a digit index and a bit + index within the digit. Examples: + limb_widths [4;4] : split_index 6 = (1,2) + limb_widths [26,25,26] : split_index 30 = (1,4) + limb_widths [26,25,26] : split_index 51 = (2,0) + *) Local Notation "u # i" := (nth_default 0 u i) (at level 30). - (* splits a bit index into a digit index and an index within the digit*) Function split_index' i index lw := match lw with | nil => (index, i) @@ -803,7 +810,32 @@ Section TestbitDecode. end. Qed. - Lemma snd_split_index'_nonneg : forall i index lw, (0 <= i) -> + Lemma split_index'_done_case : forall i index lw, 0 <= i -> + (forall x, In x lw -> 0 <= x) -> + if Z_lt_dec i (sum_firstn lw (length lw)) + then (fst (split_index' i index lw) - index < length lw)%nat + else (fst (split_index' i index lw) - index = length lw)%nat. + Proof. + intros; functional induction (split_index' i index lw); + repeat match goal with + | |- _ => break_if + | |- _ => rewrite sum_firstn_nil in * + | |- _ => rewrite sum_firstn_succ_cons in * + | |- _ => progress distr_length + | |- _ => progress (simpl fst; simpl snd) + | H : appcontext [split_index' ?a ?b ?c] |- _ => + unique pose proof (split_index'_ge_index a b c) + | H : appcontext [sum_firstn ?l ?i] |- _ => + let H0 := fresh "H" in + assert (forall x, In x l -> 0 <= x) by auto using in_cons; + unique pose proof (sum_firstn_limb_widths_nonneg H0 i) + | |- _ => progress specialize_by assumption + | |- _ => progress specialize_by omega + | |- _ => omega + end. + Qed. + + Lemma snd_split_index'_nonneg : forall index lw i, (0 <= i) -> (0 <= snd (split_index' i index lw)). Proof. intros; functional induction (split_index' i index lw); @@ -865,7 +897,7 @@ Section TestbitDecode. Proof. cbv [digit_index bit_index split_index]; intros. pose proof (split_index'_correct n 0 limb_widths). - pose proof (snd_split_index'_nonneg n 0 limb_widths). + pose proof (snd_split_index'_nonneg 0 limb_widths n). specialize_by assumption. repeat match goal with | |- _ => progress autorewrite with Ztestbit natsimplify in * @@ -887,7 +919,260 @@ Section TestbitDecode. } Qed. -End TestbitDecode. + Lemma split_index_eqn : forall i, 0 <= i -> + sum_firstn limb_widths (digit_index i) + bit_index i = i. + Proof. + cbv [digit_index bit_index split_index]; intros. + erewrite <-split_index'_correct. + repeat f_equal; omega. + Qed. + + Lemma bit_index_nonneg : forall i, 0 <= i -> 0 <= bit_index i. + Proof. + cbv [bit_index split_index]. + exact (snd_split_index'_nonneg _ _). + Qed. + + Lemma digit_index_done_case : forall i, 0 <= i -> + if Z_lt_dec i (sum_firstn limb_widths (length limb_widths)) + then (digit_index i < length limb_widths)%nat + else (digit_index i = length limb_widths). + Admitted. + + Lemma digit_index_not_done : forall i, 0 <= i -> + i < (sum_firstn limb_widths (length limb_widths)) -> + (digit_index i < length limb_widths)%nat. + Admitted. + + Lemma bit_index_pos_iff : forall i, 0 <= i -> + 0 < limb_widths # (digit_index i) - bit_index i <-> + i < sum_firstn limb_widths (length limb_widths). + + Admitted. + + Lemma le_remaining_bits : forall i, 0 <= i < sum_firstn limb_widths (length limb_widths) -> + 0 <= sum_firstn limb_widths (length limb_widths) + - (i + (limb_widths # (digit_index i) - bit_index i)). + Admitted. + + Lemma same_digit_bit_index_sub : forall i j, 0 <= i <= j -> + digit_index i = digit_index j -> + bit_index j - bit_index i = j - i. + Admitted. + +End SplitIndex. + +Section ConversionHelper. + Local Hint Resolve in_eq in_cons. + + (* TODO : ZUtil? *) + (* concatenates first n bits of a with all bits of b *) + Definition concat_bits n a b := Z.lor (Z.pow2_mod a n) (b << n). + + Lemma concat_bits_spec : forall a b n i, 0 <= n -> + Z.testbit (concat_bits n a b) i = + if Z_lt_dec i n then Z.testbit a i else Z.testbit b (i - n). + Proof. + repeat match goal with + | |- _ => progress cbv [concat_bits]; intros + | |- _ => progress autorewrite with Ztestbit + | |- _ => rewrite testbit_pow2_mod by omega + | |- _ => rewrite Z.testbit_neg_r by omega + | |- _ => break_if + | |- appcontext [Z.testbit (?a << ?b) ?i] => destruct (Z_le_dec 0 i) + | |- (?a || ?b)%bool = ?a => replace b with false + | |- _ => reflexivity + end. + Qed. + + Definition update_by_concat_bits num_low_bits bits x := concat_bits num_low_bits x bits. + + Ltac pair_destruct := + match goal with H : ?t = (?f,?s) |- _ => + replace t with (fst t, snd t) in H by (destruct t; reflexivity); + inversion H; subst; clear H + end. + +End ConversionHelper. + +Section Conversion. + Context {widthB : Z} (widthB_pos : 0 < widthB). + Context {limb_widthsA} (limb_widthsA_nonneg : forall w, In w limb_widthsA -> 0 <= w) + {limb_widthsB} (limb_widthsB_uniform : forall w, In w limb_widthsB -> w = widthB). + Local Notation bitsIn lw := (sum_firstn lw (length lw)). + Context (bits_fit : bitsIn limb_widthsA <= bitsIn limb_widthsB). + Local Notation decodeA := (BaseSystem.decode (base_from_limb_widths limb_widthsA)). + Local Notation decodeB := (BaseSystem.decode (base_from_limb_widths limb_widthsB)). + Local Notation "u # i" := (nth_default 0 u i) (at level 30). + Local Hint Resolve in_eq in_cons. + Local Opaque bounded. + Check digit_index. + Check bit_index. + + Function convert' inp i out + {measure (fun x => Z.to_nat ((bitsIn limb_widthsA) - Z.of_nat x)) i} + := + let digitA := digit_index limb_widthsA (Z.of_nat i) in + let digitB := digit_index limb_widthsB (Z.of_nat i) in + let indexA := bit_index limb_widthsA (Z.of_nat i) in + let indexB := bit_index limb_widthsB (Z.of_nat i) in + let dist := Z.min (limb_widthsA # digitA - indexA) (limb_widthsB # digitB - indexB) in + let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in + if Z_le_dec dist 0 then out + else convert' inp (i + Z.to_nat dist)%nat (update_nth digitB (update_by_concat_bits indexB bitsA) out). + Proof. + intros. + assert (0 <= Z.of_nat i < bitsIn limb_widthsA) as range_i. { + split; try apply bit_index_pos_iff; auto using Nat2Z.is_nonneg. + lia. + } + repeat match goal with + | |- _ => progress intros + | |- _ => rewrite Z2Nat.id + | |- _ => rewrite Nat2Z.inj_add + | |- (Z.to_nat _ < Z.to_nat _)%nat => apply Z2Nat.inj_lt + | |- (?a - _ < ?a - _) => apply Z.sub_lt_mono_l + | |- _ => lia + end. + apply Z.min_case_strong; intros; + (etransitivity; + [ apply le_remaining_bits with (limb_widths := limb_widthsA) (i := Z.of_nat i); auto | ]); + lia. + Defined. + + Definition convert'_invariant inp i out := + length out = length limb_widthsB + /\ bounded limb_widthsB out + /\ forall n, Z.testbit (decodeB out) n = if Z_lt_dec n (Z.of_nat i) then Z.testbit (decodeA inp) n else false. + + Lemma convert'_bounded_step : forall inp i out, + bounded limb_widthsB out -> + let digitA := digit_index limb_widthsA (Z.of_nat i) in + let digitB := digit_index limb_widthsB (Z.of_nat i) in + let indexA := bit_index limb_widthsA (Z.of_nat i) in + let indexB := bit_index limb_widthsB (Z.of_nat i) in + let dist := Z.min (limb_widthsA # digitA - indexA) + (limb_widthsB # digitB - indexB) in + let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in + bounded limb_widthsB (update_nth digitB (update_by_concat_bits indexB bitsA) out). + Proof. + Admitted. + + Lemma convert'_invariant_step : forall inp i out, + bounded limb_widthsA inp -> + convert'_invariant inp i out -> + let digitA := digit_index limb_widthsA (Z.of_nat i) in + let digitB := digit_index limb_widthsB (Z.of_nat i) in + let indexA := bit_index limb_widthsA (Z.of_nat i) in + let indexB := bit_index limb_widthsB (Z.of_nat i) in + let dist := Z.min (limb_widthsA # digitA - indexA) + (limb_widthsB # digitB - indexB) in + let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in + 0 <= dist -> + convert'_invariant inp (i + Z.to_nat dist)%nat + (update_nth digitB (update_by_concat_bits indexB bitsA) out). + Proof. + Check testbit_decode. + SearchAbout update_nth. + Time repeat match goal with + | |- _ => progress intros; cbv [convert'_invariant] in * + | H : length _ = length limb_widthsB |- _ => rewrite H + | |- _ => rewrite Z.testbit_neg_r by omega + | |- _ => rewrite Nat2Z.inj_add + | |- _ => rewrite Z2Nat.id in * + | |- _ => rewrite update_nth_nth_default_full + | |- _ => rewrite nth_default_out_of_bounds by omega + | |- _ => rewrite testbit_decode by eauto using uniform_limb_widths_nonneg + | |- _ => progress cbv [update_by_concat_bits]; + rewrite concat_bits_spec by (apply bit_index_nonneg, Nat2Z.is_nonneg) + | H : _ /\ _ |- _ => destruct H + | |- _ => break_if + | |- _ => split + | H : forall n, Z.testbit (decodeB _) n = _ |- Z.testbit (decodeB _) ?n = _ => + specialize (H n) + | H : _ = Z.testbit (decodeA _) ?n |- Z.testbit (decodeB _) ?n = Z.testbit (decodeA _) ?n => + rewrite <-H + | H : 0 <= ?n |- appcontext[Z.testbit (BaseSystem.decode _ _) ?n] => + rewrite testbit_decode by + (distr_length; eauto using uniform_limb_widths_nonneg, convert'_bounded_step) + | |- Z.testbit (decodeB _) ?n = Z.testbit _ ?n => + destruct (Z_le_dec 0 n) + | |- _ => solve [distr_length] + | |- _ => eapply convert'_bounded_step; solve [eauto] + | |- _ => lia + end. + match goal with + | d := digit_index ?lw ?j, + b := bit_index ?lw ?j, + H : digit_index ?lw ?i = ?d |- _ => + let A := fresh "H" in + let B := fresh "H" in + ( (assert (0 <= i <= j) as A by omega) + || (assert (0 <= j <= i) as A by omega; symmetry in H)); + assert (forall w, In w lw -> 0 <= w) as B by eauto using uniform_limb_widths_nonneg; + pose proof (same_digit_bit_index_sub lw B _ _ A H); subst b; lia + end. + Qed. + + Lemma convert'_termination_condition : forall i, 0 <= i -> + let digitA := digit_index limb_widthsA i in + let digitB := digit_index limb_widthsB i in + let indexA := bit_index limb_widthsA i in + let indexB := bit_index limb_widthsB i in + let dist := Z.min (limb_widthsA # digitA - indexA) + (limb_widthsB # digitB - indexB) in + dist <= 0 -> bitsIn limb_widthsA = i. + Proof. + Admitted. + + Lemma convert'_invariant_holds : forall inp i out, + bounded limb_widthsA inp -> + convert'_invariant inp i out -> + convert'_invariant inp (Z.to_nat (bitsIn limb_widthsA)) (convert' inp i out). + Proof. + intros until 1; functional induction (convert' inp i out); + repeat match goal with + | |- _ => progress intros + | H : convert'_invariant _ _ _ |- convert'_invariant _ _ (convert' _ _ _) => + eapply convert'_invariant_step in H; solve [eauto] + | H : convert'_invariant _ _ ?out |- convert'_invariant _ _ ?out => progress cbv [convert'_invariant] in * + | |- _ => rewrite Z2Nat.id + | H : _ /\ _ |- _ => destruct H + | |- _ => split + | |- _ => erewrite convert'_termination_condition by (eassumption || eauto using Nat2Z.is_nonneg) + | |- _ => assumption + | |- _ => lia + | |- _ => solve [eauto] + end. + Qed. + + Definition convert us := convert' us 0 (BaseSystem.zeros (length limb_widthsB)). + + Lemma convert_correct : forall us, length us = length limb_widthsA -> + bounded limb_widthsA us -> + decodeA us = decodeB (convert us). + Proof. + repeat match goal with + | |- _ => progress intros + | |- _ => progress cbv [convert convert'_invariant] in * + | |- _ => progress change (Z.of_nat 0) with 0 in * + | |- _ => progress rewrite ?length_zeros, ?zeros_rep, ?Z.testbit_0_l + | H : length _ = length limb_widthsA |- _ => rewrite H + | |- _ => rewrite Z.testbit_neg_r by omega + | |- _ => break_if + | |- _ => split + | H : _ /\ _ |- _ => destruct H + | H : forall n, Z.testbit ?x n = _ |- _ = ?x => apply Z.bits_inj'; intros; rewrite H + | |- _ = decodeB (convert' ?a ?b ?c) => edestruct (convert'_invariant_holds a b c) + | |- _ => apply testbit_decode_high + | |- _ => assumption + | |- _ => reflexivity + | |- _ => lia + | |- _ => solve [auto using sum_firstn_limb_widths_nonneg] + | |- _ => rewrite Z2Nat.id in * + end. + Qed. +End Conversion. Section carrying_helper. Context {limb_widths} (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w). |