diff options
Diffstat (limited to 'src/ModularArithmetic/Conversion.v')
-rw-r--r-- | src/ModularArithmetic/Conversion.v | 292 |
1 files changed, 292 insertions, 0 deletions
diff --git a/src/ModularArithmetic/Conversion.v b/src/ModularArithmetic/Conversion.v new file mode 100644 index 000000000..8ad19c4c6 --- /dev/null +++ b/src/ModularArithmetic/Conversion.v @@ -0,0 +1,292 @@ +Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Coq.Numbers.Natural.Peano.NPeano. +Require Import Coq.Lists.List. +Require Import Coq.funind.Recdef. +Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.Tactics. +Require Import Crypto.ModularArithmetic.Pow2Base. +Require Import Crypto.ModularArithmetic.Pow2BaseProofs Crypto.BaseSystemProofs. +Require Import Crypto.Util.Notations. +Require Export Crypto.Util.FixCoqMistakes. +Require Crypto.BaseSystem. +Local Open Scope Z_scope. + +Section ConversionHelper. + Local Hint Resolve in_eq in_cons. + + (* concatenates first n bits of a with all bits of b *) + Definition concat_bits n a b := Z.lor (Z.pow2_mod a n) (b << n). + + Lemma concat_bits_spec : forall a b n i, 0 <= n -> + Z.testbit (concat_bits n a b) i = + if Z_lt_dec i n then Z.testbit a i else Z.testbit b (i - n). + Proof. + repeat match goal with + | |- _ => progress cbv [concat_bits]; intros + | |- _ => progress autorewrite with Ztestbit + | |- _ => rewrite Z.testbit_pow2_mod by omega + | |- _ => rewrite Z.testbit_neg_r by omega + | |- _ => break_if + | |- appcontext [Z.testbit (?a << ?b) ?i] => destruct (Z_le_dec 0 i) + | |- (?a || ?b)%bool = ?a => replace b with false + | |- _ => reflexivity + end. + Qed. + + Definition update_by_concat_bits num_low_bits bits x := concat_bits num_low_bits x bits. + +End ConversionHelper. + +Section Conversion. + Context {limb_widthsA} (limb_widthsA_nonneg : forall w, In w limb_widthsA -> 0 <= w) + {limb_widthsB} (limb_widthsB_nonneg : forall w, In w limb_widthsB -> 0 <= w). + Local Notation bitsIn lw := (sum_firstn lw (length lw)). + Context (bits_fit : bitsIn limb_widthsA <= bitsIn limb_widthsB). + Local Notation decodeA := (BaseSystem.decode (base_from_limb_widths limb_widthsA)). + Local Notation decodeB := (BaseSystem.decode (base_from_limb_widths limb_widthsB)). + Local Notation "u # i" := (nth_default 0 u i). + Local Hint Resolve in_eq in_cons nth_default_limb_widths_nonneg sum_firstn_limb_widths_nonneg Nat2Z.is_nonneg. + Local Opaque bounded. + + Function convert' inp i out + {measure (fun x => Z.to_nat ((bitsIn limb_widthsA) - Z.of_nat x)) i}:= + if Z_le_dec (bitsIn limb_widthsA) (Z.of_nat i) + then out + else + let digitA := digit_index limb_widthsA (Z.of_nat i) in + let digitB := digit_index limb_widthsB (Z.of_nat i) in + let indexA := bit_index limb_widthsA (Z.of_nat i) in + let indexB := bit_index limb_widthsB (Z.of_nat i) in + let dist := Z.min ((limb_widthsA # digitA) - indexA) ((limb_widthsB # digitB) - indexB) in + let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in + convert' inp (i + Z.to_nat dist)%nat (update_nth digitB (update_by_concat_bits indexB bitsA) out). + Proof. + generalize limb_widthsA_nonneg; intros _. (* don't drop this from the proof in 8.4 *) + generalize limb_widthsB_nonneg; intros _. (* don't drop this from the proof in 8.4 *) + repeat match goal with + | |- _ => progress intros + | |- appcontext [bit_index (Z.of_nat ?i)] => + unique pose proof (Nat2Z.is_nonneg i) + | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] => + unique pose proof (bit_index_not_done lw i) + | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] => + unique assert (0 <= i < bitsIn lw -> i + ((lw # digit_index lw i) - bit_index lw i) <= bitsIn lw) by auto using rem_bits_in_digit_le_rem_bits + | |- _ => rewrite Z2Nat.id + | |- _ => rewrite Nat2Z.inj_add + | |- (Z.to_nat _ < Z.to_nat _)%nat => apply Z2Nat.inj_lt + | |- (?a - _ < ?a - _) => apply Z.sub_lt_mono_l + | |- appcontext [Z.min ?a ?b] => unique assert (0 < Z.min a b) by (specialize_by lia; lia) + | |- _ => lia + end. + Defined. + + Definition convert'_invariant inp i out := + length out = length limb_widthsB + /\ bounded limb_widthsB out + /\ Z.of_nat i <= bitsIn limb_widthsA + /\ forall n, Z.testbit (decodeB out) n = if Z_lt_dec n (Z.of_nat i) then Z.testbit (decodeA inp) n else false. + + Ltac subst_lia := subst_let; subst; lia. + + Lemma convert'_bounded_step : forall inp i out, + bounded limb_widthsB out -> + let digitA := digit_index limb_widthsA (Z.of_nat i) in + let digitB := digit_index limb_widthsB (Z.of_nat i) in + let indexA := bit_index limb_widthsA (Z.of_nat i) in + let indexB := bit_index limb_widthsB (Z.of_nat i) in + let dist := Z.min ((limb_widthsA # digitA) - indexA) + ((limb_widthsB # digitB) - indexB) in + let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in + 0 < dist -> + bounded limb_widthsB (update_nth digitB (update_by_concat_bits indexB bitsA) out). + Proof. + repeat match goal with + | |- _ => progress intros + | |- _ => progress autorewrite with Ztestbit + | |- _ => rewrite update_nth_nth_default_full + | |- _ => rewrite Z.testbit_pow2_mod + | |- _ => break_if + | |- _ => progress cbv [update_by_concat_bits]; + rewrite concat_bits_spec by (apply bit_index_nonneg; auto using Nat2Z.is_nonneg) + | |- bounded _ _ => apply pow2_mod_bounded_iff + | |- Z.pow2_mod _ _ = _ => apply Z.bits_inj' + | |- false = Z.testbit _ _ => symmetry + | x := _ |- Z.testbit ?x _ = _ => subst x + | |- Z.testbit _ _ = false => eapply testbit_bounded_high; eauto; lia + | |- _ => solve [auto] + | |- _ => subst_lia + end. + Qed. + + Lemma convert'_index_step : forall inp i out, + bounded limb_widthsB out -> + let digitA := digit_index limb_widthsA (Z.of_nat i) in + let digitB := digit_index limb_widthsB (Z.of_nat i) in + let indexA := bit_index limb_widthsA (Z.of_nat i) in + let indexB := bit_index limb_widthsB (Z.of_nat i) in + let dist := Z.min ((limb_widthsA # digitA) - indexA) + ((limb_widthsB # digitB) - indexB) in + let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in + 0 < dist -> + Z.of_nat i < bitsIn limb_widthsA -> + Z.of_nat i + dist <= bitsIn limb_widthsA. + Proof. + pose proof (rem_bits_in_digit_le_rem_bits limb_widthsA). + pose proof (rem_bits_in_digit_le_rem_bits limb_widthsA). + repeat match goal with + | |- _ => progress intros + | H : forall x : Z, In x ?lw -> x = ?y, H0 : 0 < ?y |- _ => + unique pose proof (uniform_limb_widths_nonneg H0 lw H) + | |- _ => progress specialize_by assumption + | H : _ /\ _ |- _ => destruct H + | |- _ => break_if + | |- _ => split + | a := digit_index _ ?i, H : forall x, 0 <= x < bitsIn _ -> _ |- _ => specialize (H i); forward H + | |- _ => subst_lia + | |- _ => apply bit_index_pos_iff; auto + | |- _ => apply Nat2Z.is_nonneg + end. + Qed. + + Lemma convert'_invariant_step : forall inp i out, + length inp = length limb_widthsA -> + bounded limb_widthsA inp -> + convert'_invariant inp i out -> + let digitA := digit_index limb_widthsA (Z.of_nat i) in + let digitB := digit_index limb_widthsB (Z.of_nat i) in + let indexA := bit_index limb_widthsA (Z.of_nat i) in + let indexB := bit_index limb_widthsB (Z.of_nat i) in + let dist := Z.min ((limb_widthsA # digitA) - indexA) + ((limb_widthsB # digitB) - indexB) in + let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in + 0 < dist -> + Z.of_nat i < bitsIn limb_widthsA -> + convert'_invariant inp (i + Z.to_nat dist)%nat + (update_nth digitB (update_by_concat_bits indexB bitsA) out). + Proof. + Time + repeat match goal with + | |- _ => progress intros; cbv [convert'_invariant] in * + | |- _ => progress autorewrite with Ztestbit + | H : forall x, In x ?lw -> 0 <= x |- appcontext[digit_index ?lw ?i] => + unique pose proof (digit_index_lt_length lw H i) + | |- _ => rewrite Nat2Z.inj_add + | |- _ => rewrite Z2Nat.id in * + | H : forall n, Z.testbit (decodeB _) n = _ |- Z.testbit (decodeB _) ?n = _ => + specialize (H n) + | H0 : ?n < ?i, H1 : ?n < ?i + ?d, + H : Z.testbit (decodeB _) ?n = Z.testbit (decodeA _) ?n |- _ = Z.testbit (decodeA _) ?n => + rewrite <-H + | H : _ /\ _ |- _ => destruct H + | |- _ => break_if + | |- _ => split + | |- _ => rewrite testbit_decode_full + | |- _ => rewrite update_nth_nth_default_full + | |- _ => rewrite nth_default_out_of_bounds by omega + | H : ~ (0 <= ?n ) |- appcontext[Z.testbit ?a ?n] => rewrite (Z.testbit_neg_r a n) by omega + | |- _ => progress cbv [update_by_concat_bits]; + rewrite concat_bits_spec by (apply bit_index_nonneg; auto using Nat2Z.is_nonneg) + | |- _ => solve [distr_length] + | |- _ => eapply convert'_bounded_step; solve [auto] + | |- _ => etransitivity; [ | eapply convert'_index_step]; subst_let; eauto; lia + | H : digit_index limb_widthsB ?i = digit_index limb_widthsB ?j |- _ => + unique assert (digit_index limb_widthsA i = digit_index limb_widthsA j) by + (symmetry; apply same_digit; assumption || lia); + pose proof (same_digit_bit_index_sub limb_widthsA j i) as X; + forward X; [ | lia | lia | lia ] + | d := digit_index ?lw ?j, + H : digit_index ?lw ?i <> ?d |- _ => + exfalso; apply H; symmetry; apply same_digit; assumption || subst_lia + | d := digit_index ?lw ?j, + H : digit_index ?lw ?i = ?d |- _ => + let X := fresh "H" in + ((pose proof (same_digit_bit_index_sub lw i j) as X; + forward X; [ subst_let | subst_lia | lia | lia ]) || + (pose proof (same_digit_bit_index_sub lw j i) as X; + forward X; [ subst_let | subst_lia | lia | lia ])) + | |- Z.testbit _ (bit_index ?lw _ - bit_index ?lw ?i + _) = false => + apply (@testbit_bounded_high limb_widthsA); auto; + rewrite (same_digit_bit_index_sub) by subst_lia; + rewrite <-(split_index_eqn limb_widthsA i) at 2 by lia + | |- ?lw # ?b <= ?a - ((sum_firstn ?lw ?b) + ?c) + ?c => replace (a - (sum_firstn lw b + c) + c) with (a - sum_firstn lw b) by ring; apply Z.le_add_le_sub_r + | |- (?lw # ?n) + sum_firstn ?lw ?n <= _ => + rewrite <-sum_firstn_succ_default; transitivity (bitsIn lw); [ | lia]; + apply sum_firstn_prefix_le; auto; lia + | |- _ => lia + | |- _ => assumption + | |- _ => solve [auto] + | |- _ => rewrite <-testbit_decode by (assumption || lia || auto); assumption + | |- _ => repeat (f_equal; try congruence); lia + end. + Qed. + + Lemma convert'_invariant_holds : forall inp i out, + length inp = length limb_widthsA -> + bounded limb_widthsA inp -> + convert'_invariant inp i out -> + convert'_invariant inp (Z.to_nat (bitsIn limb_widthsA)) (convert' inp i out). + Proof. + intros until 2; functional induction (convert' inp i out); + repeat match goal with + | |- _ => progress intros + | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] => + unique pose proof (bit_index_not_done lw i) + | H : convert'_invariant _ _ _ |- convert'_invariant _ _ (convert' _ _ _) => + eapply convert'_invariant_step in H; solve [auto; specialize_by lia; lia] + | H : convert'_invariant _ _ ?out |- convert'_invariant _ _ ?out => progress cbv [convert'_invariant] in * + | H : _ /\ _ |- _ => destruct H + | |- _ => rewrite Z2Nat.id + | |- _ => split + | |- _ => assumption + | |- _ => lia + | |- _ => solve [eauto] + | |- _ => replace (bitsIn limb_widthsA) with (Z.of_nat i) by (apply Z.le_antisymm; assumption) + end. + Qed. + + Definition convert us := convert' us 0 (BaseSystem.zeros (length limb_widthsB)). + + Lemma convert_correct : forall us, length us = length limb_widthsA -> + bounded limb_widthsA us -> + decodeA us = decodeB (convert us). + Proof. + repeat match goal with + | |- _ => progress intros + | |- _ => progress cbv [convert convert'_invariant] in * + | |- _ => progress change (Z.of_nat 0) with 0 in * + | |- _ => progress rewrite ?length_zeros, ?zeros_rep, ?Z.testbit_0_l + | H : length _ = length limb_widthsA |- _ => rewrite H + | |- _ => rewrite Z.testbit_neg_r by omega + | |- _ => rewrite nth_default_zeros + | |- _ => break_if + | |- _ => split + | H : _ /\ _ |- _ => destruct H + | H : forall n, Z.testbit ?x n = _ |- _ = ?x => apply Z.bits_inj'; intros; rewrite H + | |- _ = decodeB (convert' ?a ?b ?c) => edestruct (convert'_invariant_holds a b c) + | |- _ => apply testbit_decode_high + | |- _ => assumption + | |- _ => reflexivity + | |- _ => lia + | |- _ => solve [auto using sum_firstn_limb_widths_nonneg] + | |- _ => solve [apply nth_default_preserves_properties; auto; lia] + | |- _ => rewrite Z2Nat.id in * + | |- bounded _ _ => apply bounded_iff + | |- 0 < 2 ^ _ => zero_bounds + end. + Qed. + + (* This is part of convert'_invariant, but proving it separately strips preconditions *) + Lemma length_convert' : forall inp i out, + length (convert' inp i out) = length out. + Proof. + intros; functional induction (convert' inp i out); distr_length. + Qed. + + Lemma length_convert : forall us, length (convert us) = length limb_widthsB. + Proof. + cbv [convert]; intros. + rewrite length_convert', length_zeros. + reflexivity. + Qed. +End Conversion.
\ No newline at end of file |