From 6d27149299a6aaaca3d82480c1b0e90a98cc18a7 Mon Sep 17 00:00:00 2001 From: jadep Date: Thu, 6 Oct 2016 11:59:06 -0400 Subject: Moved conversion logic out of Pow2BaseProofs into its own file --- src/ModularArithmetic/Conversion.v | 292 +++++++++++++++++++++ src/ModularArithmetic/ModularBaseSystemList.v | 13 +- .../ModularBaseSystemListProofs.v | 5 +- src/ModularArithmetic/ModularBaseSystemOpt.v | 1 + src/ModularArithmetic/Pow2BaseProofs.v | 279 -------------------- 5 files changed, 303 insertions(+), 287 deletions(-) create mode 100644 src/ModularArithmetic/Conversion.v (limited to 'src/ModularArithmetic') diff --git a/src/ModularArithmetic/Conversion.v b/src/ModularArithmetic/Conversion.v new file mode 100644 index 000000000..8ad19c4c6 --- /dev/null +++ b/src/ModularArithmetic/Conversion.v @@ -0,0 +1,292 @@ +Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Coq.Numbers.Natural.Peano.NPeano. +Require Import Coq.Lists.List. +Require Import Coq.funind.Recdef. +Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.Tactics. +Require Import Crypto.ModularArithmetic.Pow2Base. +Require Import Crypto.ModularArithmetic.Pow2BaseProofs Crypto.BaseSystemProofs. +Require Import Crypto.Util.Notations. +Require Export Crypto.Util.FixCoqMistakes. +Require Crypto.BaseSystem. +Local Open Scope Z_scope. + +Section ConversionHelper. + Local Hint Resolve in_eq in_cons. + + (* concatenates first n bits of a with all bits of b *) + Definition concat_bits n a b := Z.lor (Z.pow2_mod a n) (b << n). + + Lemma concat_bits_spec : forall a b n i, 0 <= n -> + Z.testbit (concat_bits n a b) i = + if Z_lt_dec i n then Z.testbit a i else Z.testbit b (i - n). + Proof. + repeat match goal with + | |- _ => progress cbv [concat_bits]; intros + | |- _ => progress autorewrite with Ztestbit + | |- _ => rewrite Z.testbit_pow2_mod by omega + | |- _ => rewrite Z.testbit_neg_r by omega + | |- _ => break_if + | |- appcontext [Z.testbit (?a << ?b) ?i] => destruct (Z_le_dec 0 i) + | |- (?a || ?b)%bool = ?a => replace b with false + | |- _ => reflexivity + end. + Qed. + + Definition update_by_concat_bits num_low_bits bits x := concat_bits num_low_bits x bits. + +End ConversionHelper. + +Section Conversion. + Context {limb_widthsA} (limb_widthsA_nonneg : forall w, In w limb_widthsA -> 0 <= w) + {limb_widthsB} (limb_widthsB_nonneg : forall w, In w limb_widthsB -> 0 <= w). + Local Notation bitsIn lw := (sum_firstn lw (length lw)). + Context (bits_fit : bitsIn limb_widthsA <= bitsIn limb_widthsB). + Local Notation decodeA := (BaseSystem.decode (base_from_limb_widths limb_widthsA)). + Local Notation decodeB := (BaseSystem.decode (base_from_limb_widths limb_widthsB)). + Local Notation "u # i" := (nth_default 0 u i). + Local Hint Resolve in_eq in_cons nth_default_limb_widths_nonneg sum_firstn_limb_widths_nonneg Nat2Z.is_nonneg. + Local Opaque bounded. + + Function convert' inp i out + {measure (fun x => Z.to_nat ((bitsIn limb_widthsA) - Z.of_nat x)) i}:= + if Z_le_dec (bitsIn limb_widthsA) (Z.of_nat i) + then out + else + let digitA := digit_index limb_widthsA (Z.of_nat i) in + let digitB := digit_index limb_widthsB (Z.of_nat i) in + let indexA := bit_index limb_widthsA (Z.of_nat i) in + let indexB := bit_index limb_widthsB (Z.of_nat i) in + let dist := Z.min ((limb_widthsA # digitA) - indexA) ((limb_widthsB # digitB) - indexB) in + let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in + convert' inp (i + Z.to_nat dist)%nat (update_nth digitB (update_by_concat_bits indexB bitsA) out). + Proof. + generalize limb_widthsA_nonneg; intros _. (* don't drop this from the proof in 8.4 *) + generalize limb_widthsB_nonneg; intros _. (* don't drop this from the proof in 8.4 *) + repeat match goal with + | |- _ => progress intros + | |- appcontext [bit_index (Z.of_nat ?i)] => + unique pose proof (Nat2Z.is_nonneg i) + | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] => + unique pose proof (bit_index_not_done lw i) + | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] => + unique assert (0 <= i < bitsIn lw -> i + ((lw # digit_index lw i) - bit_index lw i) <= bitsIn lw) by auto using rem_bits_in_digit_le_rem_bits + | |- _ => rewrite Z2Nat.id + | |- _ => rewrite Nat2Z.inj_add + | |- (Z.to_nat _ < Z.to_nat _)%nat => apply Z2Nat.inj_lt + | |- (?a - _ < ?a - _) => apply Z.sub_lt_mono_l + | |- appcontext [Z.min ?a ?b] => unique assert (0 < Z.min a b) by (specialize_by lia; lia) + | |- _ => lia + end. + Defined. + + Definition convert'_invariant inp i out := + length out = length limb_widthsB + /\ bounded limb_widthsB out + /\ Z.of_nat i <= bitsIn limb_widthsA + /\ forall n, Z.testbit (decodeB out) n = if Z_lt_dec n (Z.of_nat i) then Z.testbit (decodeA inp) n else false. + + Ltac subst_lia := subst_let; subst; lia. + + Lemma convert'_bounded_step : forall inp i out, + bounded limb_widthsB out -> + let digitA := digit_index limb_widthsA (Z.of_nat i) in + let digitB := digit_index limb_widthsB (Z.of_nat i) in + let indexA := bit_index limb_widthsA (Z.of_nat i) in + let indexB := bit_index limb_widthsB (Z.of_nat i) in + let dist := Z.min ((limb_widthsA # digitA) - indexA) + ((limb_widthsB # digitB) - indexB) in + let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in + 0 < dist -> + bounded limb_widthsB (update_nth digitB (update_by_concat_bits indexB bitsA) out). + Proof. + repeat match goal with + | |- _ => progress intros + | |- _ => progress autorewrite with Ztestbit + | |- _ => rewrite update_nth_nth_default_full + | |- _ => rewrite Z.testbit_pow2_mod + | |- _ => break_if + | |- _ => progress cbv [update_by_concat_bits]; + rewrite concat_bits_spec by (apply bit_index_nonneg; auto using Nat2Z.is_nonneg) + | |- bounded _ _ => apply pow2_mod_bounded_iff + | |- Z.pow2_mod _ _ = _ => apply Z.bits_inj' + | |- false = Z.testbit _ _ => symmetry + | x := _ |- Z.testbit ?x _ = _ => subst x + | |- Z.testbit _ _ = false => eapply testbit_bounded_high; eauto; lia + | |- _ => solve [auto] + | |- _ => subst_lia + end. + Qed. + + Lemma convert'_index_step : forall inp i out, + bounded limb_widthsB out -> + let digitA := digit_index limb_widthsA (Z.of_nat i) in + let digitB := digit_index limb_widthsB (Z.of_nat i) in + let indexA := bit_index limb_widthsA (Z.of_nat i) in + let indexB := bit_index limb_widthsB (Z.of_nat i) in + let dist := Z.min ((limb_widthsA # digitA) - indexA) + ((limb_widthsB # digitB) - indexB) in + let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in + 0 < dist -> + Z.of_nat i < bitsIn limb_widthsA -> + Z.of_nat i + dist <= bitsIn limb_widthsA. + Proof. + pose proof (rem_bits_in_digit_le_rem_bits limb_widthsA). + pose proof (rem_bits_in_digit_le_rem_bits limb_widthsA). + repeat match goal with + | |- _ => progress intros + | H : forall x : Z, In x ?lw -> x = ?y, H0 : 0 < ?y |- _ => + unique pose proof (uniform_limb_widths_nonneg H0 lw H) + | |- _ => progress specialize_by assumption + | H : _ /\ _ |- _ => destruct H + | |- _ => break_if + | |- _ => split + | a := digit_index _ ?i, H : forall x, 0 <= x < bitsIn _ -> _ |- _ => specialize (H i); forward H + | |- _ => subst_lia + | |- _ => apply bit_index_pos_iff; auto + | |- _ => apply Nat2Z.is_nonneg + end. + Qed. + + Lemma convert'_invariant_step : forall inp i out, + length inp = length limb_widthsA -> + bounded limb_widthsA inp -> + convert'_invariant inp i out -> + let digitA := digit_index limb_widthsA (Z.of_nat i) in + let digitB := digit_index limb_widthsB (Z.of_nat i) in + let indexA := bit_index limb_widthsA (Z.of_nat i) in + let indexB := bit_index limb_widthsB (Z.of_nat i) in + let dist := Z.min ((limb_widthsA # digitA) - indexA) + ((limb_widthsB # digitB) - indexB) in + let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in + 0 < dist -> + Z.of_nat i < bitsIn limb_widthsA -> + convert'_invariant inp (i + Z.to_nat dist)%nat + (update_nth digitB (update_by_concat_bits indexB bitsA) out). + Proof. + Time + repeat match goal with + | |- _ => progress intros; cbv [convert'_invariant] in * + | |- _ => progress autorewrite with Ztestbit + | H : forall x, In x ?lw -> 0 <= x |- appcontext[digit_index ?lw ?i] => + unique pose proof (digit_index_lt_length lw H i) + | |- _ => rewrite Nat2Z.inj_add + | |- _ => rewrite Z2Nat.id in * + | H : forall n, Z.testbit (decodeB _) n = _ |- Z.testbit (decodeB _) ?n = _ => + specialize (H n) + | H0 : ?n < ?i, H1 : ?n < ?i + ?d, + H : Z.testbit (decodeB _) ?n = Z.testbit (decodeA _) ?n |- _ = Z.testbit (decodeA _) ?n => + rewrite <-H + | H : _ /\ _ |- _ => destruct H + | |- _ => break_if + | |- _ => split + | |- _ => rewrite testbit_decode_full + | |- _ => rewrite update_nth_nth_default_full + | |- _ => rewrite nth_default_out_of_bounds by omega + | H : ~ (0 <= ?n ) |- appcontext[Z.testbit ?a ?n] => rewrite (Z.testbit_neg_r a n) by omega + | |- _ => progress cbv [update_by_concat_bits]; + rewrite concat_bits_spec by (apply bit_index_nonneg; auto using Nat2Z.is_nonneg) + | |- _ => solve [distr_length] + | |- _ => eapply convert'_bounded_step; solve [auto] + | |- _ => etransitivity; [ | eapply convert'_index_step]; subst_let; eauto; lia + | H : digit_index limb_widthsB ?i = digit_index limb_widthsB ?j |- _ => + unique assert (digit_index limb_widthsA i = digit_index limb_widthsA j) by + (symmetry; apply same_digit; assumption || lia); + pose proof (same_digit_bit_index_sub limb_widthsA j i) as X; + forward X; [ | lia | lia | lia ] + | d := digit_index ?lw ?j, + H : digit_index ?lw ?i <> ?d |- _ => + exfalso; apply H; symmetry; apply same_digit; assumption || subst_lia + | d := digit_index ?lw ?j, + H : digit_index ?lw ?i = ?d |- _ => + let X := fresh "H" in + ((pose proof (same_digit_bit_index_sub lw i j) as X; + forward X; [ subst_let | subst_lia | lia | lia ]) || + (pose proof (same_digit_bit_index_sub lw j i) as X; + forward X; [ subst_let | subst_lia | lia | lia ])) + | |- Z.testbit _ (bit_index ?lw _ - bit_index ?lw ?i + _) = false => + apply (@testbit_bounded_high limb_widthsA); auto; + rewrite (same_digit_bit_index_sub) by subst_lia; + rewrite <-(split_index_eqn limb_widthsA i) at 2 by lia + | |- ?lw # ?b <= ?a - ((sum_firstn ?lw ?b) + ?c) + ?c => replace (a - (sum_firstn lw b + c) + c) with (a - sum_firstn lw b) by ring; apply Z.le_add_le_sub_r + | |- (?lw # ?n) + sum_firstn ?lw ?n <= _ => + rewrite <-sum_firstn_succ_default; transitivity (bitsIn lw); [ | lia]; + apply sum_firstn_prefix_le; auto; lia + | |- _ => lia + | |- _ => assumption + | |- _ => solve [auto] + | |- _ => rewrite <-testbit_decode by (assumption || lia || auto); assumption + | |- _ => repeat (f_equal; try congruence); lia + end. + Qed. + + Lemma convert'_invariant_holds : forall inp i out, + length inp = length limb_widthsA -> + bounded limb_widthsA inp -> + convert'_invariant inp i out -> + convert'_invariant inp (Z.to_nat (bitsIn limb_widthsA)) (convert' inp i out). + Proof. + intros until 2; functional induction (convert' inp i out); + repeat match goal with + | |- _ => progress intros + | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] => + unique pose proof (bit_index_not_done lw i) + | H : convert'_invariant _ _ _ |- convert'_invariant _ _ (convert' _ _ _) => + eapply convert'_invariant_step in H; solve [auto; specialize_by lia; lia] + | H : convert'_invariant _ _ ?out |- convert'_invariant _ _ ?out => progress cbv [convert'_invariant] in * + | H : _ /\ _ |- _ => destruct H + | |- _ => rewrite Z2Nat.id + | |- _ => split + | |- _ => assumption + | |- _ => lia + | |- _ => solve [eauto] + | |- _ => replace (bitsIn limb_widthsA) with (Z.of_nat i) by (apply Z.le_antisymm; assumption) + end. + Qed. + + Definition convert us := convert' us 0 (BaseSystem.zeros (length limb_widthsB)). + + Lemma convert_correct : forall us, length us = length limb_widthsA -> + bounded limb_widthsA us -> + decodeA us = decodeB (convert us). + Proof. + repeat match goal with + | |- _ => progress intros + | |- _ => progress cbv [convert convert'_invariant] in * + | |- _ => progress change (Z.of_nat 0) with 0 in * + | |- _ => progress rewrite ?length_zeros, ?zeros_rep, ?Z.testbit_0_l + | H : length _ = length limb_widthsA |- _ => rewrite H + | |- _ => rewrite Z.testbit_neg_r by omega + | |- _ => rewrite nth_default_zeros + | |- _ => break_if + | |- _ => split + | H : _ /\ _ |- _ => destruct H + | H : forall n, Z.testbit ?x n = _ |- _ = ?x => apply Z.bits_inj'; intros; rewrite H + | |- _ = decodeB (convert' ?a ?b ?c) => edestruct (convert'_invariant_holds a b c) + | |- _ => apply testbit_decode_high + | |- _ => assumption + | |- _ => reflexivity + | |- _ => lia + | |- _ => solve [auto using sum_firstn_limb_widths_nonneg] + | |- _ => solve [apply nth_default_preserves_properties; auto; lia] + | |- _ => rewrite Z2Nat.id in * + | |- bounded _ _ => apply bounded_iff + | |- 0 < 2 ^ _ => zero_bounds + end. + Qed. + + (* This is part of convert'_invariant, but proving it separately strips preconditions *) + Lemma length_convert' : forall inp i out, + length (convert' inp i out) = length out. + Proof. + intros; functional induction (convert' inp i out); distr_length. + Qed. + + Lemma length_convert : forall us, length (convert us) = length limb_widthsB. + Proof. + cbv [convert]; intros. + rewrite length_convert', length_zeros. + reflexivity. + Qed. +End Conversion. \ No newline at end of file diff --git a/src/ModularArithmetic/ModularBaseSystemList.v b/src/ModularArithmetic/ModularBaseSystemList.v index a472c3534..6d0848151 100644 --- a/src/ModularArithmetic/ModularBaseSystemList.v +++ b/src/ModularArithmetic/ModularBaseSystemList.v @@ -10,6 +10,7 @@ Require Import Crypto.ModularArithmetic.ExtendedBaseVector. Require Import Crypto.Tactics.VerdiTactics. Require Import Crypto.Util.Notations. Require Import Crypto.ModularArithmetic.Pow2Base. +Require Import Crypto.ModularArithmetic.Conversion. Local Open Scope Z_scope. Section Defs. @@ -77,12 +78,12 @@ Section Defs. (bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn target_widths (length target_widths)). - Definition pack := @Pow2BaseProofs.convert limb_widths limb_widths_nonneg - target_widths target_widths_nonneg - (Z.eq_le_incl _ _ bits_eq). + Definition pack := @convert limb_widths limb_widths_nonneg + target_widths target_widths_nonneg + (Z.eq_le_incl _ _ bits_eq). - Definition unpack := @Pow2BaseProofs.convert target_widths target_widths_nonneg - limb_widths limb_widths_nonneg - (Z.eq_le_incl _ _ (Z.eq_sym bits_eq)). + Definition unpack := @convert target_widths target_widths_nonneg + limb_widths limb_widths_nonneg + (Z.eq_le_incl _ _ (Z.eq_sym bits_eq)). End Defs. diff --git a/src/ModularArithmetic/ModularBaseSystemListProofs.v b/src/ModularArithmetic/ModularBaseSystemListProofs.v index 11b28769b..93b39e89a 100644 --- a/src/ModularArithmetic/ModularBaseSystemListProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemListProofs.v @@ -4,6 +4,7 @@ Require Import Coq.Lists.List. Require Import Crypto.Tactics.VerdiTactics. Require Import Crypto.BaseSystem. Require Import Crypto.BaseSystemProofs. +Require Import Crypto.ModularArithmetic.Conversion. Require Import Crypto.ModularArithmetic.Pow2Base. Require Import Crypto.ModularArithmetic.Pow2BaseProofs. Require Import Crypto.ModularArithmetic.ExtendedBaseVector. @@ -133,7 +134,7 @@ Section LengthProofs. length (pack target_widths_nonneg pf us) = length target_widths. Proof. cbv [pack]; intros. - apply Pow2BaseProofs.length_convert. + apply length_convert. Qed. Lemma length_unpack : forall {target_widths} @@ -142,7 +143,7 @@ Section LengthProofs. length (unpack target_widths_nonneg pf us) = length limb_widths. Proof. cbv [pack]; intros. - apply Pow2BaseProofs.length_convert. + apply length_convert. Qed. End LengthProofs. diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index dba1afd29..3eef0901e 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -2,6 +2,7 @@ Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. Require Import Crypto.ModularArithmetic.ExtendedBaseVector. +Require Import Crypto.ModularArithmetic.Conversion. Require Import Crypto.ModularArithmetic.Pow2Base. Require Import Crypto.ModularArithmetic.Pow2BaseProofs. Require Import Crypto.BaseSystem. diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v index ebf14a00e..c28ee2bc7 100644 --- a/src/ModularArithmetic/Pow2BaseProofs.v +++ b/src/ModularArithmetic/Pow2BaseProofs.v @@ -1209,285 +1209,6 @@ Section SplitIndex. End SplitIndex. -Section ConversionHelper. - Local Hint Resolve in_eq in_cons. - - (* concatenates first n bits of a with all bits of b *) - Definition concat_bits n a b := Z.lor (Z.pow2_mod a n) (b << n). - - Lemma concat_bits_spec : forall a b n i, 0 <= n -> - Z.testbit (concat_bits n a b) i = - if Z_lt_dec i n then Z.testbit a i else Z.testbit b (i - n). - Proof. - repeat match goal with - | |- _ => progress cbv [concat_bits]; intros - | |- _ => progress autorewrite with Ztestbit - | |- _ => rewrite Z.testbit_pow2_mod by omega - | |- _ => rewrite Z.testbit_neg_r by omega - | |- _ => break_if - | |- appcontext [Z.testbit (?a << ?b) ?i] => destruct (Z_le_dec 0 i) - | |- (?a || ?b)%bool = ?a => replace b with false - | |- _ => reflexivity - end. - Qed. - - Definition update_by_concat_bits num_low_bits bits x := concat_bits num_low_bits x bits. - -End ConversionHelper. - -Section Conversion. - Context {limb_widthsA} (limb_widthsA_nonneg : forall w, In w limb_widthsA -> 0 <= w) - {limb_widthsB} (limb_widthsB_nonneg : forall w, In w limb_widthsB -> 0 <= w). - Local Notation bitsIn lw := (sum_firstn lw (length lw)). - Context (bits_fit : bitsIn limb_widthsA <= bitsIn limb_widthsB). - Local Notation decodeA := (BaseSystem.decode (base_from_limb_widths limb_widthsA)). - Local Notation decodeB := (BaseSystem.decode (base_from_limb_widths limb_widthsB)). - Local Notation "u # i" := (nth_default 0 u i). - Local Hint Resolve in_eq in_cons nth_default_limb_widths_nonneg sum_firstn_limb_widths_nonneg Nat2Z.is_nonneg. - Local Opaque bounded. - - Function convert' inp i out - {measure (fun x => Z.to_nat ((bitsIn limb_widthsA) - Z.of_nat x)) i}:= - if Z_le_dec (bitsIn limb_widthsA) (Z.of_nat i) - then out - else - let digitA := digit_index limb_widthsA (Z.of_nat i) in - let digitB := digit_index limb_widthsB (Z.of_nat i) in - let indexA := bit_index limb_widthsA (Z.of_nat i) in - let indexB := bit_index limb_widthsB (Z.of_nat i) in - let dist := Z.min ((limb_widthsA # digitA) - indexA) ((limb_widthsB # digitB) - indexB) in - let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in - convert' inp (i + Z.to_nat dist)%nat (update_nth digitB (update_by_concat_bits indexB bitsA) out). - Proof. - generalize limb_widthsA_nonneg; intros _. (* don't drop this from the proof in 8.4 *) - generalize limb_widthsB_nonneg; intros _. (* don't drop this from the proof in 8.4 *) - repeat match goal with - | |- _ => progress intros - | |- appcontext [bit_index (Z.of_nat ?i)] => - unique pose proof (Nat2Z.is_nonneg i) - | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] => - unique pose proof (bit_index_not_done lw i) - | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] => - unique assert (0 <= i < bitsIn lw -> i + ((lw # digit_index lw i) - bit_index lw i) <= bitsIn lw) by auto using rem_bits_in_digit_le_rem_bits - | |- _ => rewrite Z2Nat.id - | |- _ => rewrite Nat2Z.inj_add - | |- (Z.to_nat _ < Z.to_nat _)%nat => apply Z2Nat.inj_lt - | |- (?a - _ < ?a - _) => apply Z.sub_lt_mono_l - | |- appcontext [Z.min ?a ?b] => unique assert (0 < Z.min a b) by (specialize_by lia; lia) - | |- _ => lia - end. - Defined. - - Definition convert'_invariant inp i out := - length out = length limb_widthsB - /\ bounded limb_widthsB out - /\ Z.of_nat i <= bitsIn limb_widthsA - /\ forall n, Z.testbit (decodeB out) n = if Z_lt_dec n (Z.of_nat i) then Z.testbit (decodeA inp) n else false. - - Ltac subst_lia := subst_let; subst; lia. - - Lemma convert'_bounded_step : forall inp i out, - bounded limb_widthsB out -> - let digitA := digit_index limb_widthsA (Z.of_nat i) in - let digitB := digit_index limb_widthsB (Z.of_nat i) in - let indexA := bit_index limb_widthsA (Z.of_nat i) in - let indexB := bit_index limb_widthsB (Z.of_nat i) in - let dist := Z.min ((limb_widthsA # digitA) - indexA) - ((limb_widthsB # digitB) - indexB) in - let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in - 0 < dist -> - bounded limb_widthsB (update_nth digitB (update_by_concat_bits indexB bitsA) out). - Proof. - repeat match goal with - | |- _ => progress intros - | |- _ => progress autorewrite with Ztestbit - | |- _ => rewrite update_nth_nth_default_full - | |- _ => rewrite Z.testbit_pow2_mod - | |- _ => break_if - | |- _ => progress cbv [update_by_concat_bits]; - rewrite concat_bits_spec by (apply bit_index_nonneg; auto using Nat2Z.is_nonneg) - | |- bounded _ _ => apply pow2_mod_bounded_iff - | |- Z.pow2_mod _ _ = _ => apply Z.bits_inj' - | |- false = Z.testbit _ _ => symmetry - | x := _ |- Z.testbit ?x _ = _ => subst x - | |- Z.testbit _ _ = false => eapply testbit_bounded_high; eauto; lia - | |- _ => solve [auto] - | |- _ => subst_lia - end. - Qed. - - Lemma convert'_index_step : forall inp i out, - bounded limb_widthsB out -> - let digitA := digit_index limb_widthsA (Z.of_nat i) in - let digitB := digit_index limb_widthsB (Z.of_nat i) in - let indexA := bit_index limb_widthsA (Z.of_nat i) in - let indexB := bit_index limb_widthsB (Z.of_nat i) in - let dist := Z.min ((limb_widthsA # digitA) - indexA) - ((limb_widthsB # digitB) - indexB) in - let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in - 0 < dist -> - Z.of_nat i < bitsIn limb_widthsA -> - Z.of_nat i + dist <= bitsIn limb_widthsA. - Proof. - pose proof (rem_bits_in_digit_le_rem_bits limb_widthsA). - pose proof (rem_bits_in_digit_le_rem_bits limb_widthsA). - repeat match goal with - | |- _ => progress intros - | H : forall x : Z, In x ?lw -> x = ?y, H0 : 0 < ?y |- _ => - unique pose proof (uniform_limb_widths_nonneg H0 lw H) - | |- _ => progress specialize_by assumption - | H : _ /\ _ |- _ => destruct H - | |- _ => break_if - | |- _ => split - | a := digit_index _ ?i, H : forall x, 0 <= x < bitsIn _ -> _ |- _ => specialize (H i); forward H - | |- _ => subst_lia - | |- _ => apply bit_index_pos_iff; auto - | |- _ => apply Nat2Z.is_nonneg - end. - Qed. - - Lemma convert'_invariant_step : forall inp i out, - length inp = length limb_widthsA -> - bounded limb_widthsA inp -> - convert'_invariant inp i out -> - let digitA := digit_index limb_widthsA (Z.of_nat i) in - let digitB := digit_index limb_widthsB (Z.of_nat i) in - let indexA := bit_index limb_widthsA (Z.of_nat i) in - let indexB := bit_index limb_widthsB (Z.of_nat i) in - let dist := Z.min ((limb_widthsA # digitA) - indexA) - ((limb_widthsB # digitB) - indexB) in - let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in - 0 < dist -> - Z.of_nat i < bitsIn limb_widthsA -> - convert'_invariant inp (i + Z.to_nat dist)%nat - (update_nth digitB (update_by_concat_bits indexB bitsA) out). - Proof. - Time - repeat match goal with - | |- _ => progress intros; cbv [convert'_invariant] in * - | |- _ => progress autorewrite with Ztestbit - | H : forall x, In x ?lw -> 0 <= x |- appcontext[digit_index ?lw ?i] => - unique pose proof (digit_index_lt_length lw H i) - | |- _ => rewrite Nat2Z.inj_add - | |- _ => rewrite Z2Nat.id in * - | H : forall n, Z.testbit (decodeB _) n = _ |- Z.testbit (decodeB _) ?n = _ => - specialize (H n) - | H0 : ?n < ?i, H1 : ?n < ?i + ?d, - H : Z.testbit (decodeB _) ?n = Z.testbit (decodeA _) ?n |- _ = Z.testbit (decodeA _) ?n => - rewrite <-H - | H : _ /\ _ |- _ => destruct H - | |- _ => break_if - | |- _ => split - | |- _ => rewrite testbit_decode_full - | |- _ => rewrite update_nth_nth_default_full - | |- _ => rewrite nth_default_out_of_bounds by omega - | H : ~ (0 <= ?n ) |- appcontext[Z.testbit ?a ?n] => rewrite (Z.testbit_neg_r a n) by omega - | |- _ => progress cbv [update_by_concat_bits]; - rewrite concat_bits_spec by (apply bit_index_nonneg; auto using Nat2Z.is_nonneg) - | |- _ => solve [distr_length] - | |- _ => eapply convert'_bounded_step; solve [auto] - | |- _ => etransitivity; [ | eapply convert'_index_step]; subst_let; eauto; lia - | H : digit_index limb_widthsB ?i = digit_index limb_widthsB ?j |- _ => - unique assert (digit_index limb_widthsA i = digit_index limb_widthsA j) by - (symmetry; apply same_digit; assumption || lia); - pose proof (same_digit_bit_index_sub limb_widthsA j i) as X; - forward X; [ | lia | lia | lia ] - | d := digit_index ?lw ?j, - H : digit_index ?lw ?i <> ?d |- _ => - exfalso; apply H; symmetry; apply same_digit; assumption || subst_lia - | d := digit_index ?lw ?j, - H : digit_index ?lw ?i = ?d |- _ => - let X := fresh "H" in - ((pose proof (same_digit_bit_index_sub lw i j) as X; - forward X; [ subst_let | subst_lia | lia | lia ]) || - (pose proof (same_digit_bit_index_sub lw j i) as X; - forward X; [ subst_let | subst_lia | lia | lia ])) - | |- Z.testbit _ (bit_index ?lw _ - bit_index ?lw ?i + _) = false => - apply (@testbit_bounded_high limb_widthsA); auto; - rewrite (same_digit_bit_index_sub) by subst_lia; - rewrite <-(split_index_eqn limb_widthsA i) at 2 by lia - | |- ?lw # ?b <= ?a - ((sum_firstn ?lw ?b) + ?c) + ?c => replace (a - (sum_firstn lw b + c) + c) with (a - sum_firstn lw b) by ring; apply Z.le_add_le_sub_r - | |- (?lw # ?n) + sum_firstn ?lw ?n <= _ => - rewrite <-sum_firstn_succ_default; transitivity (bitsIn lw); [ | lia]; - apply sum_firstn_prefix_le; auto; lia - | |- _ => lia - | |- _ => assumption - | |- _ => solve [auto] - | |- _ => rewrite <-testbit_decode by (assumption || lia || auto); assumption - | |- _ => repeat (f_equal; try congruence); lia - end. - Qed. - - Lemma convert'_invariant_holds : forall inp i out, - length inp = length limb_widthsA -> - bounded limb_widthsA inp -> - convert'_invariant inp i out -> - convert'_invariant inp (Z.to_nat (bitsIn limb_widthsA)) (convert' inp i out). - Proof. - intros until 2; functional induction (convert' inp i out); - repeat match goal with - | |- _ => progress intros - | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] => - unique pose proof (bit_index_not_done lw i) - | H : convert'_invariant _ _ _ |- convert'_invariant _ _ (convert' _ _ _) => - eapply convert'_invariant_step in H; solve [auto; specialize_by lia; lia] - | H : convert'_invariant _ _ ?out |- convert'_invariant _ _ ?out => progress cbv [convert'_invariant] in * - | H : _ /\ _ |- _ => destruct H - | |- _ => rewrite Z2Nat.id - | |- _ => split - | |- _ => assumption - | |- _ => lia - | |- _ => solve [eauto] - | |- _ => replace (bitsIn limb_widthsA) with (Z.of_nat i) by (apply Z.le_antisymm; assumption) - end. - Qed. - - Definition convert us := convert' us 0 (BaseSystem.zeros (length limb_widthsB)). - - Lemma convert_correct : forall us, length us = length limb_widthsA -> - bounded limb_widthsA us -> - decodeA us = decodeB (convert us). - Proof. - repeat match goal with - | |- _ => progress intros - | |- _ => progress cbv [convert convert'_invariant] in * - | |- _ => progress change (Z.of_nat 0) with 0 in * - | |- _ => progress rewrite ?length_zeros, ?zeros_rep, ?Z.testbit_0_l - | H : length _ = length limb_widthsA |- _ => rewrite H - | |- _ => rewrite Z.testbit_neg_r by omega - | |- _ => rewrite nth_default_zeros - | |- _ => break_if - | |- _ => split - | H : _ /\ _ |- _ => destruct H - | H : forall n, Z.testbit ?x n = _ |- _ = ?x => apply Z.bits_inj'; intros; rewrite H - | |- _ = decodeB (convert' ?a ?b ?c) => edestruct (convert'_invariant_holds a b c) - | |- _ => apply testbit_decode_high - | |- _ => assumption - | |- _ => reflexivity - | |- _ => lia - | |- _ => solve [auto using sum_firstn_limb_widths_nonneg] - | |- _ => solve [apply nth_default_preserves_properties; auto; lia] - | |- _ => rewrite Z2Nat.id in * - | |- bounded _ _ => apply bounded_iff - | |- 0 < 2 ^ _ => zero_bounds - end. - Qed. - - (* This is part of convert'_invariant, but proving it separately strips preconditions *) - Lemma length_convert' : forall inp i out, - length (convert' inp i out) = length out. - Proof. - intros; functional induction (convert' inp i out); distr_length. - Qed. - - Lemma length_convert : forall us, length (convert us) = length limb_widthsB. - Proof. - cbv [convert]; intros. - rewrite length_convert', length_zeros. - reflexivity. - Qed. -End Conversion. - Section carrying_helper. Context {limb_widths} (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w). Local Notation base := (base_from_limb_widths limb_widths). -- cgit v1.2.3