aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/Conversion.v
blob: 4f28d7eecbd3309bc70374bebd245e1fbc9e423b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
Require Import Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.micromega.Psatz.
Require Import Coq.Numbers.Natural.Peano.NPeano.
Require Import Coq.Lists.List.
Require Import Coq.funind.Recdef.
Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil.
Require Import Crypto.Tactics.VerdiTactics.
Require Import Crypto.Util.Tactics.
Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.ModularArithmetic.Pow2BaseProofs Crypto.BaseSystemProofs.
Require Import Crypto.Util.Notations.
Require Export Crypto.Util.FixCoqMistakes.
Require Crypto.BaseSystem.
Local Open Scope Z_scope.

Section ConversionHelper.
  Local Hint Resolve in_eq in_cons.

  (* concatenates first n bits of a with all bits of b *)
  Definition concat_bits n a b := Z.lor (Z.pow2_mod a n) (b << n).

  Lemma concat_bits_spec : forall a b n i, 0 <= n ->
                                           Z.testbit (concat_bits n a b) i =
                                           if Z_lt_dec i n then Z.testbit a i else Z.testbit b (i - n).
  Proof.
    repeat match goal with
           | |- _ => progress cbv [concat_bits]; intros
           | |- _ => progress autorewrite with Ztestbit
           | |- _ => rewrite Z.testbit_pow2_mod by omega
           | |- _ => rewrite Z.testbit_neg_r by omega
           | |- _ => break_if
           | |- appcontext [Z.testbit (?a << ?b) ?i] => destruct (Z_le_dec 0 i)
           | |- (?a || ?b)%bool = ?a => replace b with false
           | |- _ => reflexivity
           end.
  Qed.

  Definition update_by_concat_bits num_low_bits bits x := concat_bits num_low_bits x bits.

End ConversionHelper.

Section Conversion.
  Context {limb_widthsA} (limb_widthsA_nonneg : forall w, In w limb_widthsA -> 0 <= w)
          {limb_widthsB} (limb_widthsB_nonneg : forall w, In w limb_widthsB -> 0 <= w).
  Local Notation bitsIn lw := (sum_firstn lw (length lw)).
  Context (bits_fit : bitsIn limb_widthsA <=  bitsIn limb_widthsB).
  Local Notation decodeA := (BaseSystem.decode (base_from_limb_widths limb_widthsA)).
  Local Notation decodeB := (BaseSystem.decode (base_from_limb_widths limb_widthsB)).
  Local Notation "u # i" := (nth_default 0 u i).
  Local Hint Resolve in_eq in_cons nth_default_limb_widths_nonneg sum_firstn_limb_widths_nonneg Nat2Z.is_nonneg.
  Local Opaque bounded.

  Function convert' inp i out
           {measure (fun x => Z.to_nat ((bitsIn limb_widthsA) - Z.of_nat x)) i}:=
    if Z_le_dec (bitsIn limb_widthsA) (Z.of_nat i)
    then out
    else
      let digitA := digit_index limb_widthsA (Z.of_nat i) in
      let digitB := digit_index limb_widthsB (Z.of_nat i) in
      let indexA :=   bit_index limb_widthsA (Z.of_nat i) in
      let indexB :=   bit_index limb_widthsB (Z.of_nat i) in
      let dist := Z.min ((limb_widthsA # digitA) - indexA) ((limb_widthsB # digitB) - indexB) in
      let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in
      convert' inp (i + Z.to_nat dist)%nat (update_nth digitB (update_by_concat_bits indexB bitsA) out).
  Proof.
    generalize limb_widthsA_nonneg; intros _. (* don't drop this from the proof in 8.4 *)
    generalize limb_widthsB_nonneg; intros _. (* don't drop this from the proof in 8.4 *)
    repeat match goal with
           | |- _ => progress intros
           | |- appcontext [bit_index (Z.of_nat ?i)] =>
             unique pose proof (Nat2Z.is_nonneg i)
           | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] =>
             unique pose proof (bit_index_not_done lw i)
           | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] =>
             unique assert (0 <= i < bitsIn lw -> i + ((lw # digit_index lw i) - bit_index lw i) <= bitsIn lw) by auto using rem_bits_in_digit_le_rem_bits
           | |- _ => rewrite Z2Nat.id
           | |- _ => rewrite Nat2Z.inj_add
           | |- (Z.to_nat _ < Z.to_nat _)%nat => apply Z2Nat.inj_lt
           | |- (?a - _ < ?a - _) => apply Z.sub_lt_mono_l
           | |- appcontext [Z.min ?a ?b] => unique assert (0 < Z.min a b) by (specialize_by lia; lia)
           | |- _ => lia
           end.
  Defined.

  Definition convert'_invariant inp i out :=
    length out = length limb_widthsB
    /\ bounded limb_widthsB out
    /\ Z.of_nat i <= bitsIn limb_widthsA
    /\ forall n, Z.testbit (decodeB out) n = if Z_lt_dec n (Z.of_nat i) then Z.testbit (decodeA inp) n else false.

  Ltac subst_lia := subst_let; subst; lia.

  Lemma convert'_bounded_step : forall inp i out,
    bounded limb_widthsB out ->
    let digitA := digit_index limb_widthsA (Z.of_nat i) in
    let digitB := digit_index limb_widthsB (Z.of_nat i) in
    let indexA :=   bit_index limb_widthsA (Z.of_nat i) in
    let indexB :=   bit_index limb_widthsB (Z.of_nat i) in
    let dist := Z.min ((limb_widthsA # digitA) - indexA)
                      ((limb_widthsB # digitB) - indexB) in
    let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in
    0 < dist ->
    bounded limb_widthsB (update_nth digitB (update_by_concat_bits indexB bitsA) out).
  Proof.
    repeat match goal with
           | |- _ => progress intros
           | |- _ => progress autorewrite with Ztestbit
           | |- _ => rewrite update_nth_nth_default_full
           | |- _ => rewrite Z.testbit_pow2_mod
           | |- _ => break_if
           | |- _ =>  progress cbv [update_by_concat_bits];
                        rewrite concat_bits_spec by (apply bit_index_nonneg; auto using Nat2Z.is_nonneg)
           | |- bounded _ _ => apply pow2_mod_bounded_iff
           | |- Z.pow2_mod _ _ = _ => apply Z.bits_inj'
           | |- false = Z.testbit _ _ => symmetry
           | x := _ |- Z.testbit ?x _ = _ => subst x
           | |- Z.testbit _ _ = false => eapply testbit_bounded_high; eauto; lia
           | |- _ => solve [auto]
           | |- _ => subst_lia
    end.
  Qed.

  Lemma convert'_index_step : forall inp i out,
    bounded limb_widthsB out ->
    let digitA := digit_index limb_widthsA (Z.of_nat i) in
    let digitB := digit_index limb_widthsB (Z.of_nat i) in
    let indexA :=   bit_index limb_widthsA (Z.of_nat i) in
    let indexB :=   bit_index limb_widthsB (Z.of_nat i) in
    let dist := Z.min ((limb_widthsA # digitA) - indexA)
                      ((limb_widthsB # digitB) - indexB) in
    let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in
    0 < dist ->
    Z.of_nat i < bitsIn limb_widthsA ->
    Z.of_nat i + dist <= bitsIn limb_widthsA.
  Proof.
    pose proof (rem_bits_in_digit_le_rem_bits limb_widthsA).
    pose proof (rem_bits_in_digit_le_rem_bits limb_widthsA).
    repeat match goal with
           | |- _ => progress intros
           | H : forall x : Z, In x ?lw -> x = ?y, H0 : 0 < ?y |- _ =>
              unique pose proof (uniform_limb_widths_nonneg H0 lw H)
           | |- _ => progress specialize_by assumption
           | H : _ /\ _ |- _ => destruct H
           | |- _ => break_if
           | |- _ => split
           | a := digit_index _ ?i, H : forall x, 0 <= x < bitsIn _ -> _ |- _ => specialize (H i); forward H
           | |- _ => subst_lia
           | |- _ => apply bit_index_pos_iff; auto
           | |- _ => apply Nat2Z.is_nonneg
    end.
  Qed.

  Lemma convert'_invariant_step : forall inp i out,
    length inp = length limb_widthsA ->
    bounded limb_widthsA inp ->
    convert'_invariant inp i out ->
    let digitA := digit_index limb_widthsA (Z.of_nat i) in
    let digitB := digit_index limb_widthsB (Z.of_nat i) in
    let indexA :=   bit_index limb_widthsA (Z.of_nat i) in
    let indexB :=   bit_index limb_widthsB (Z.of_nat i) in
    let dist := Z.min ((limb_widthsA # digitA) - indexA)
                      ((limb_widthsB # digitB) - indexB) in
    let bitsA := Z.pow2_mod ((inp # digitA) >> indexA) dist in
    0 < dist ->
    Z.of_nat i < bitsIn limb_widthsA ->
    convert'_invariant inp (i + Z.to_nat dist)%nat
                       (update_nth digitB (update_by_concat_bits indexB bitsA) out).
  Proof.
    Time
    repeat match goal with
           | |- _ => progress intros; cbv [convert'_invariant] in *
           | |- _ => progress autorewrite with Ztestbit
           | H : forall x, In x ?lw -> 0 <= x |- appcontext[digit_index ?lw ?i]  =>
              unique pose proof (digit_index_lt_length lw H i)
           | |- _ => rewrite Nat2Z.inj_add
           | |- _ => rewrite Z2Nat.id in *
           | H : forall n, Z.testbit (decodeB _) n = _ |- Z.testbit (decodeB _) ?n = _ =>
             specialize (H n)
           | H0 : ?n < ?i, H1 : ?n < ?i + ?d,
             H : Z.testbit (decodeB _) ?n = Z.testbit (decodeA _) ?n |- _ = Z.testbit (decodeA _) ?n =>
             rewrite <-H
           | H : _ /\ _ |- _ => destruct H
           | |- _ => break_if
           | |- _ => split
           | |- _ => rewrite testbit_decode_full
           | |- _ => rewrite update_nth_nth_default_full
           | |- _ => rewrite nth_default_out_of_bounds by omega
           | H : ~ (0 <= ?n ) |- appcontext[Z.testbit ?a ?n] => rewrite (Z.testbit_neg_r a n) by omega
           | |- _ =>  progress cbv [update_by_concat_bits];
                        rewrite concat_bits_spec by (apply bit_index_nonneg; auto using Nat2Z.is_nonneg)
           | |- _ => solve [distr_length]
           | |- _ => eapply convert'_bounded_step; solve [auto]
           | |- _ => etransitivity; [ | eapply convert'_index_step]; subst_let; eauto; lia
           | H : digit_index limb_widthsB ?i = digit_index limb_widthsB ?j |- _ =>
             unique assert (digit_index limb_widthsA i = digit_index limb_widthsA j) by
               (symmetry; apply same_digit; assumption || lia);
                  pose proof (same_digit_bit_index_sub limb_widthsA j i) as X;
                     forward X; [  | lia | lia | lia ]
           | d := digit_index ?lw ?j,
               H : digit_index ?lw ?i <> ?d |- _ =>
               exfalso; apply H; symmetry; apply same_digit; assumption || subst_lia
           | d := digit_index ?lw ?j,
              H : digit_index ?lw ?i = ?d |- _ =>
                let X := fresh "H" in
                  ((pose proof (same_digit_bit_index_sub lw i j) as X;
                      forward X; [ subst_let | subst_lia | lia | lia ]) ||
                   (pose proof (same_digit_bit_index_sub lw j i) as X;
                      forward X; [ subst_let | subst_lia | lia | lia ]))
           | |- Z.testbit _ (bit_index ?lw _ - bit_index ?lw ?i + _) = false =>
           apply (@testbit_bounded_high limb_widthsA); auto;
               rewrite (same_digit_bit_index_sub) by subst_lia;
               rewrite <-(split_index_eqn limb_widthsA i) at 2 by lia
           | |- ?lw # ?b <= ?a - ((sum_firstn ?lw ?b) + ?c) + ?c => replace (a - (sum_firstn lw b + c) + c) with (a - sum_firstn lw b) by ring; apply Z.le_add_le_sub_r
           | |- (?lw # ?n) + sum_firstn ?lw ?n <= _ =>
            rewrite <-sum_firstn_succ_default; transitivity (bitsIn lw); [ | lia];
            apply sum_firstn_prefix_le; auto; lia
           | |- _ => lia
           | |- _ => assumption
           | |- _ => solve [auto]
           | |- _ => rewrite <-testbit_decode by (assumption || lia || auto); assumption
           | |- _ => repeat (f_equal; try congruence); lia
           end.
  Qed.

  Lemma convert'_invariant_holds : forall inp i out,
    length inp = length limb_widthsA ->
    bounded limb_widthsA inp ->
    convert'_invariant inp i out ->
    convert'_invariant inp (Z.to_nat (bitsIn limb_widthsA)) (convert' inp i out).
  Proof.
    intros until 2; functional induction (convert' inp i out);
      repeat match goal with
           | |- _ => progress intros
           | H : forall x : Z, In x ?lw -> 0 <= x |- appcontext [bit_index ?lw ?i] =>
              unique pose proof (bit_index_not_done lw i)
           | H : convert'_invariant _ _ _ |- convert'_invariant _ _ (convert' _ _ _) =>
             eapply convert'_invariant_step in H; solve [auto; specialize_by lia; lia]
           | H : convert'_invariant _ _ ?out |- convert'_invariant _ _ ?out => progress cbv [convert'_invariant] in *
           | H : _ /\ _ |- _ => destruct H
           | |- _ => rewrite Z2Nat.id
           | |- _ => split
           | |- _ => assumption
           | |- _ => lia
           | |- _ => solve [eauto]
           | |- _ => replace (bitsIn limb_widthsA) with (Z.of_nat i) by (apply Z.le_antisymm; assumption)
             end.
  Qed.

  Definition convert us := convert' us 0 (BaseSystem.zeros (length limb_widthsB)).

  Lemma convert_correct : forall us, length us = length limb_widthsA ->
                                     bounded limb_widthsA us ->
                                     decodeA us = decodeB (convert us).
  Proof.
    repeat match goal with
           | |- _ => progress intros
           | |- _ => progress cbv [convert convert'_invariant] in *
           | |- _ => progress change (Z.of_nat 0) with 0 in *
           | |- _ => progress rewrite ?length_zeros, ?zeros_rep, ?Z.testbit_0_l
           | H : length _ = length limb_widthsA |- _ => rewrite H
           | |- _ => rewrite Z.testbit_neg_r by omega
           | |- _ => rewrite nth_default_zeros
           | |- _ => break_if
           | |- _ => split
           | H : _ /\ _ |- _ => destruct H
           | H : forall n, Z.testbit ?x n = _ |- _ = ?x => apply Z.bits_inj'; intros; rewrite H
           | |- _ = decodeB (convert' ?a ?b ?c) => edestruct (convert'_invariant_holds a b c)
           | |- _ => apply testbit_decode_high
           | |- _ => assumption
           | |- _ => reflexivity
           | |- _ => lia
           | |- _ => solve [auto using sum_firstn_limb_widths_nonneg]
           | |- _ => solve [apply nth_default_preserves_properties; auto; lia]
           | |- _ => rewrite Z2Nat.id in *
           | |- bounded _ _ => apply bounded_iff
           | |- 0 < 2 ^ _ => zero_bounds
           end.
  Qed.

  (* This is part of convert'_invariant, but proving it separately strips preconditions *)
  Lemma convert'_bounded : forall inp i out,
    bounded limb_widthsB out ->
    bounded limb_widthsB (convert' inp i out).
  Proof.
    intros; functional induction (convert' inp i out); auto.
    apply IHl.
    apply convert'_bounded_step; auto.
    clear IHl.
    pose proof (bit_index_not_done limb_widthsA (Z.of_nat i)).
    pose proof (bit_index_not_done limb_widthsB (Z.of_nat i)).
    specialize_by lia.
    lia.
  Qed.

  Lemma convert_bounded : forall us, bounded limb_widthsB (convert us).
  Proof.
    intros; apply convert'_bounded.
    apply bounded_iff; intros.
    rewrite nth_default_zeros.
    split; zero_bounds.
  Qed.
    
  (* This is part of convert'_invariant, but proving it separately strips preconditions *)
  Lemma length_convert' : forall inp i out,
    length (convert' inp i out) = length out.
  Proof.
    intros; functional induction (convert' inp i out); distr_length.
  Qed.

  Lemma length_convert : forall us, length (convert us) = length limb_widthsB.
  Proof.
    cbv [convert]; intros.
    rewrite length_convert', length_zeros.
    reflexivity.
  Qed.
End Conversion.