aboutsummaryrefslogtreecommitdiff
path: root/src/Util/ForLoop/Unrolling.v
blob: e0518f39adb19b60c418e2d47b5da347091d40b7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
(** * Proving properties of for-loops via loop-unrolling *)
Require Import Coq.micromega.Psatz.
Require Import Coq.ZArith.ZArith.
Require Import Crypto.Util.ForLoop.
Require Import Crypto.Util.ForLoop.Instances.
Require Import Crypto.Util.ZUtil.
Require Import Crypto.Util.Bool.
Require Import Crypto.Util.Tactics.RewriteHyp.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Tactics.DestructHead.
Require Import Crypto.Util.Notations.

Section with_body.
  Context {stateT : Type}
          (body : nat -> stateT -> stateT).

  Lemma unfold_repeat_function (count : nat) (st : stateT)
    : repeat_function body count st
      = match count with
        | O => st
        | S count' => repeat_function body count' (body count st)
        end.
  Proof using Type. destruct count; reflexivity. Qed.

  Lemma repeat_function_unroll1_start (count : nat) (st : stateT)
    : repeat_function body (S count) st
      = repeat_function body count (body (S count) st).
  Proof using Type. rewrite unfold_repeat_function; reflexivity. Qed.

  Lemma repeat_function_unroll1_end (count : nat) (st : stateT)
    : repeat_function body (S count) st
      = body 1 (repeat_function (fun count => body (S count)) count st).
  Proof using Type.
    revert st; induction count; [ reflexivity | ].
    intros; simpl in *; rewrite <- IHcount; reflexivity.
  Qed.

  Lemma repeat_function_unroll1_start_match (count : nat) (st : stateT)
    : repeat_function body count st
      = match count with
        | 0 => st
        | S count' => repeat_function body count' (body count st)
        end.
  Proof using Type. destruct count; [ reflexivity | apply repeat_function_unroll1_start ]. Qed.

  Lemma repeat_function_unroll1_end_match (count : nat) (st : stateT)
    : repeat_function body count st
      = match count with
        | 0 => st
        | S count' => body 1 (repeat_function (fun count => body (S count)) count' st)
        end.
  Proof using Type. destruct count; [ reflexivity | apply repeat_function_unroll1_end ]. Qed.
End with_body.

Local Open Scope bool_scope.
Local Open Scope Z_scope.

Section for_loop.
  Context (i0 finish : Z) (step : Z) {stateT} (initial : stateT) (body : Z -> stateT -> stateT)
          (Hgood : Z.sgn step = Z.sgn (finish - i0)).

  Let countZ := (Z.quot (finish - i0 + step - Z.sgn step) step).
  Let count := Z.to_nat countZ.
  Let of_nat_count c := (i0 + step * Z.of_nat (count - c)).
  Let nat_body := (fun c => body (of_nat_count c)).

  Lemma for_loop_empty
        (Heq : finish = i0)
    : for_loop i0 finish step initial body = initial.
  Proof.
    subst; unfold for_loop.
    rewrite Z.sub_diag, Z.quot_sub_sgn; autorewrite with zsimplify_const.
    reflexivity.
  Qed.

  Lemma for_loop_unroll1
    : for_loop i0 finish step initial body
      = if finish =? i0
        then initial
        else let initial' := body i0 initial in
             if Z.abs (finish - i0) <=? Z.abs step
             then initial'
             else for_loop (i0 + step) finish step initial' body.
  Proof.
    break_innermost_match_step; Z.ltb_to_lt.
    { apply for_loop_empty; assumption. }
    { unfold for_loop.
      rewrite repeat_function_unroll1_start_match.
      destruct (Z_zerop step);
        repeat first [ progress break_innermost_match
                     | congruence
                     | lia
                     | progress Z.ltb_to_lt
                     | progress subst
                     | progress rewrite Nat.sub_diag
                     | progress autorewrite with zsimplify_const in *
                     | progress rewrite Z.quot_small_iff in * by omega
                     | progress rewrite Z.quot_small_abs in * by lia
                     | rewrite Nat.sub_succ_l by omega
                     | progress destruct_head' and
                     | rewrite !Z.sub_add_distr
                     | match goal with
                       | [ H : ?x = Z.of_nat _ |- context[?x] ] => rewrite H
                       | [ H : Z.abs ?x <= 0 |- _ ] => assert (x = 0) by lia; clear H
                       | [ H : 0 = Z.sgn ?x |- _ ] => assert (x = 0) by lia; clear H
                       | [ H : ?x - ?y = 0 |- _ ] => is_var x; assert (x = y) by omega; subst x
                       | [ H : Z.to_nat _ = _ |- _ ] => apply Nat2Z.inj_iff in H
                       | [ H : context[Z.abs ?x] |- _ ] => rewrite (Z.abs_eq x) in H by omega
                       | [ H : context[Z.abs ?x] |- _ ] => rewrite (Z.abs_neq x) in H by omega
                       | [ H : Z.of_nat (Z.to_nat _) = _ |- _ ]
                         => rewrite Z2Nat.id in H by (apply Z.quot_nonneg_same_sgn; lia)
                       | [ H : _ = Z.of_nat (S ?x) |- _ ]
                         => is_var x; destruct x; [ reflexivity | ]
                       | [ H : ?x + 1 = Z.of_nat (S ?y) |- _ ]
                         => assert (x = Z.of_nat y) by lia; clear H
                       | [ |- repeat_function _ ?x ?y = repeat_function _ ?x ?y ]
                         => apply repeat_function_Proper_le; intros
                       | [ |- ?f _ ?x = ?f _ ?x ]
                         => is_var f; apply f_equal2; [ | reflexivity ]
                       end
                     | progress rewrite Z.quot_add_sub_sgn_small in * |- by lia
                     | progress autorewrite with zsimplify ]. }
  Qed.
End for_loop.

Lemma for_loop_notation_empty {stateT}
      {i0 : Z} {step : Z} {finish : Z} {initial : stateT}
      {cmp : Z -> Z -> bool}
      {step_expr finish_expr} (body : Z -> stateT -> stateT)
      {Hstep : class_eq (fun i => i = step) step_expr}
      {Hfinish : class_eq (fun i => cmp i finish) finish_expr}
      {Hgood : for_loop_is_good i0 step finish cmp}
      (Heq : i0 = finish)
  : @for_loop_notation i0 step finish _ initial cmp step_expr finish_expr body Hstep Hfinish Hgood = initial.
Proof.
  unfold for_loop_notation, for_loop_is_good in *; split_andb; Z.ltb_to_lt.
  apply for_loop_empty; auto.
Qed.

Local Notation adjust_bool b p
  := (match b as b' return b' = true -> b' = true with
      | true => fun _ => eq_refl
      | false => fun x => x
      end p).

Lemma for_loop_is_good_step_gen
      cmp
      (Hcmp : cmp = Z.ltb \/ cmp = Z.gtb)
      {i0 step finish}
      {H : for_loop_is_good i0 step finish cmp}
      (H' : cmp (i0 + step) finish = true)
  : for_loop_is_good (i0 + step) step finish cmp.
Proof.
  unfold for_loop_is_good in *.
  rewrite H', Bool.andb_true_r.
  destruct Hcmp; subst;
    split_andb; Z.ltb_to_lt;
      [ rewrite (Z.sgn_pos (finish - i0)) in * by omega
      | rewrite (Z.sgn_neg (finish - i0)) in * by omega ];
      destruct step; simpl in *; try congruence;
        symmetry;
        [ apply Z.sgn_pos_iff | apply Z.sgn_neg_iff ]
        ; omega.
Qed.

Definition for_loop_is_good_step_lt
           {i0 step finish}
           {H : for_loop_is_good i0 step finish Z.ltb}
           (H' : Z.ltb (i0 + step) finish = true)
  : for_loop_is_good (i0 + step) step finish Z.ltb
  := for_loop_is_good_step_gen Z.ltb (or_introl eq_refl) (H:=H) H'.
Definition for_loop_is_good_step_gt
           {i0 step finish}
           {H : for_loop_is_good i0 step finish Z.gtb}
           (H' : Z.gtb (i0 + step) finish = true)
  : for_loop_is_good (i0 + step) step finish Z.gtb
  := for_loop_is_good_step_gen Z.gtb (or_intror eq_refl) (H:=H) H'.
Definition for_loop_is_good_step_lt'
           {i0 finish}
           {H : for_loop_is_good i0 1 (finish + 1) Z.ltb}
           (H' : Z.ltb i0 finish = true)
  : for_loop_is_good (i0 + 1) 1 (finish + 1) Z.ltb.
Proof.
  apply for_loop_is_good_step_lt; Z.ltb_to_lt; omega.
Qed.
Definition for_loop_is_good_step_gt'
           {i0 finish}
           {H : for_loop_is_good i0 (-1) (finish - 1) Z.gtb}
           (H' : Z.gtb i0 finish = true)
  : for_loop_is_good (i0 - 1) (-1) (finish - 1) Z.gtb.
Proof.
  apply for_loop_is_good_step_gt; Z.ltb_to_lt; omega.
Qed.

Local Hint Extern 1 (for_loop_is_good (?i0 + ?step) ?step ?finish _)
=> refine (adjust_bool _ (for_loop_is_good_step_lt _)); try assumption : typeclass_instances.
Local Hint Extern 1 (for_loop_is_good (?i0 + ?step) ?step ?finish _)
=> refine (adjust_bool _ (for_loop_is_good_step_gt _)); try assumption : typeclass_instances.
Local Hint Extern 1 (for_loop_is_good (?i0 - ?step') _ ?finish _)
=> refine (adjust_bool _ (for_loop_is_good_step_gt (step:=-step') _)); try assumption : typeclass_instances.
Local Hint Extern 1 (for_loop_is_good (?i0 + 1) 1 ?finish _)
=> refine (adjust_bool _ (for_loop_is_good_step_lt' _)); try assumption : typeclass_instances.
Local Hint Extern 1 (for_loop_is_good (?i0 - 1) (-1) ?finish _)
=> refine (adjust_bool _ (for_loop_is_good_step_gt' _)); try assumption : typeclass_instances.

Local Ltac t :=
  repeat match goal with
         | _ => progress unfold for_loop_is_good, for_loop_notation in *
         | _ => progress rewrite for_loop_unroll1 by auto
         | _ => omega
         | _ => progress subst
         | _ => reflexivity
         | _ => progress split_andb
         | _ => progress Z.ltb_to_lt
         | _ => progress break_innermost_match_step
         | [ H : context[Z.abs ?x] |- _ ] => rewrite (Z.abs_eq x) in H by omega
         | [ H : context[Z.abs ?x] |- _ ] => rewrite (Z.abs_neq x) in H by omega
         | [ H : context[Z.sgn ?x] |- _ ] => rewrite (Z.sgn_pos x) in H by omega
         | [ H : context[Z.sgn ?x] |- _ ] => rewrite (Z.sgn_neg x) in H by omega
         | [ H : Z.sgn _ = 1 |- _ ] => apply Z.sgn_pos_iff in H
         | [ H : Z.sgn _ = -1 |- _ ] => apply Z.sgn_neg_iff in H
         end.

Lemma for_loop_lt_unroll1 {stateT}
      {i0 : Z} {step : Z} {finish : Z} {initial : stateT}
      {step_expr finish_expr} (body : Z -> stateT -> stateT)
      {Hstep : class_eq (fun i => i = step) step_expr}
      {Hfinish : class_eq (fun i => Z.ltb i finish) finish_expr}
      {Hgood : for_loop_is_good i0 step finish Z.ltb}
  : (@for_loop_notation i0 step finish _ initial Z.ltb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish Hgood)
    = let initial' := body i0 initial in
      if Sumbool.sumbool_of_bool (Z.ltb (i0 + step) finish)
      then @for_loop_notation (i0 + step) step finish _ initial' Z.ltb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish _
      else initial'.
Proof. t. Qed.

Lemma for_loop_gt_unroll1 {stateT}
      {i0 : Z} {step : Z} {finish : Z} {initial : stateT}
      {step_expr finish_expr} (body : Z -> stateT -> stateT)
      {Hstep : class_eq (fun i => i = step) step_expr}
      {Hfinish : class_eq (fun i => Z.gtb i finish) finish_expr}
      {Hgood : for_loop_is_good i0 step finish Z.gtb}
  : (@for_loop_notation i0 step finish _ initial Z.gtb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish Hgood)
    = let initial' := body i0 initial in
      if Sumbool.sumbool_of_bool (Z.gtb (i0 + step) finish)
      then @for_loop_notation (i0 + step) step finish _ initial' Z.gtb (fun i => step_expr i) (fun i => finish_expr i) (fun i st => body i st) Hstep Hfinish _
      else initial'.
Proof. t. Qed.

Lemma for_loop_lt1_unroll1 {stateT}
      {i0 : Z} {finish : Z} {initial : stateT}
      {body : Z -> stateT -> stateT}
      {Hgood : for_loop_is_good i0 1 finish _}
  : for (int i = i0; i < finish; i++) updating (st = initial) {{
      body i st
    }}
    = let initial' := body i0 initial in
      if Sumbool.sumbool_of_bool (Z.ltb (i0 + 1) finish)
      then for (int i = i0+1; i < finish; i++) updating (st = initial') {{
             body i st
           }}
      else initial'.
Proof. apply for_loop_lt_unroll1. Qed.

Lemma for_loop_gt1_unroll1 {stateT}
      {i0 : Z} {finish : Z} {initial : stateT}
      {body : Z -> stateT -> stateT}
      {Hgood : for_loop_is_good i0 (-1) finish _}
  : for (int i = i0; i > finish; i--) updating (st = initial) {{
      body i st
    }}
    = let initial' := body i0 initial in
      if Sumbool.sumbool_of_bool (Z.gtb (i0 - 1) finish)
      then for (int i = i0-1; i > finish; i--) updating (st = initial') {{
             body i st
           }}
      else initial'.
Proof. apply for_loop_gt_unroll1. Qed.

Lemma for_loop_le1_unroll1 {stateT}
      {i0 : Z} {finish : Z} {initial : stateT}
      {body : Z -> stateT -> stateT}
      {Hgood : for_loop_is_good i0 1 (finish+1) _}
  : for (int i = i0; i <= finish; i++) updating (st = initial) {{
      body i st
    }}
    = let initial' := body i0 initial in
      if Sumbool.sumbool_of_bool (Z.ltb i0 finish)
      then for (int i = i0+1; i <= finish; i++) updating (st = initial') {{
             body i st
           }}
      else initial'.
Proof.
  rewrite for_loop_lt_unroll1; unfold for_loop_notation.
  break_innermost_match; Z.ltb_to_lt; try omega; try reflexivity.
Qed.

Lemma for_loop_ge1_unroll1 {stateT}
      {i0 : Z} {finish : Z} {initial : stateT}
      {body : Z -> stateT -> stateT}
      {Hgood : for_loop_is_good i0 (-1) (finish-1) _}
  : for (int i = i0; i >= finish; i--) updating (st = initial) {{
      body i st
    }}
    = let initial' := body i0 initial in
      if Sumbool.sumbool_of_bool (Z.gtb i0 finish)
      then for (int i = i0-1; i >= finish; i--) updating (st = initial') {{
             body i st
           }}
      else initial'.
Proof.
  rewrite for_loop_gt_unroll1; unfold for_loop_notation.
  break_innermost_match; Z.ltb_to_lt; try omega; try reflexivity.
Qed.