aboutsummaryrefslogtreecommitdiff
path: root/src/Util/CPSUtil.v
blob: 15782fa61018be832f77ecd5493ad49a7c8704af (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
Require Import Coq.Lists.List. Import ListNotations.
Require Import Coq.ZArith.ZArith Coq.omega.Omega.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Tactics.DestructHead.
Require Import Crypto.Util.ListUtil.
Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple.
Local Open Scope Z_scope.

Lemma push_id {A} (a:A) : id a = a. reflexivity. Qed.
Create HintDb push_id discriminated. Hint Rewrite @push_id : push_id.

Lemma update_nth_id {T} i (xs:list T) : ListUtil.update_nth i id xs = xs.
Proof.
  revert xs; induction i; destruct xs; simpl; solve [ trivial | congruence ].
Qed.

Lemma map_fst_combine {A B} (xs:list A) (ys:list B) : List.map fst (List.combine xs ys) = List.firstn (length ys) xs.
Proof.
  revert xs; induction ys; destruct xs; simpl; solve [ trivial | congruence ].
Qed.

Lemma map_snd_combine {A B} (xs:list A) (ys:list B) : List.map snd (List.combine xs ys) = List.firstn (length xs) ys.
Proof.
  revert xs; induction ys; destruct xs; simpl; solve [ trivial | congruence ].
Qed.

Lemma nth_default_seq_inbouns d s n i (H:(i < n)%nat) :
  List.nth_default d (List.seq s n) i = (s+i)%nat.
Proof.
  progress cbv [List.nth_default].
  rewrite ListUtil.nth_error_seq.
  break_innermost_match; solve [ trivial | omega ].
Qed.

Lemma mod_add_mul_full a b c k m : m <> 0 -> c mod m = k mod m ->
                                   (a + b * c) mod m = (a + b * k) mod m.
Proof.
  intros; rewrite Z.add_mod, Z.mul_mod by auto.
  match goal with H : _ mod _ = _ mod _ |- _ => rewrite H end.
  rewrite <-Z.mul_mod, <-Z.add_mod by auto; reflexivity.
Qed.

Fixpoint map_cps {A B} (g : A->B) ls
         {T} (f:list B->T):=
  match ls with
  | nil => f nil
  | a :: t => map_cps g t (fun r => f (g a :: r))
  end.
Lemma map_cps_correct {A B} g ls: forall {T} f,
    @map_cps A B g ls T f = f (map g ls).
Proof. induction ls as [|?? IHls]; simpl; intros; rewrite ?IHls; reflexivity. Qed.
Create HintDb uncps discriminated. Hint Rewrite @map_cps_correct : uncps.

Fixpoint flat_map_cps {A B} (g:A->forall {T}, (list B->T)->T) (ls : list A) {T} (f:list B->T)  :=
  match ls with
  | nil => f nil
  | (x::tl)%list => g x (fun r => flat_map_cps g tl (fun rr => f (r ++ rr))%list)
  end.
Lemma flat_map_cps_correct {A B} (g:A->forall {T}, (list B->T)->T) ls :
  forall {T} (f:list B->T),
    (forall x T h, @g x T h = h (g x id)) ->
    @flat_map_cps A B g ls T f = f (List.flat_map (fun x => g x id) ls).
Proof.
  induction ls as [|?? IHls]; intros T f H; [reflexivity|].
  simpl flat_map_cps. simpl flat_map.
  rewrite H; erewrite IHls by eassumption.
  reflexivity.
Qed.
Hint Rewrite @flat_map_cps_correct using (intros; autorewrite with uncps; auto): uncps.

Fixpoint from_list_default'_cps {A} (d y:A) n xs:
  forall {T}, (Tuple.tuple' A n -> T) -> T:=
  match n as n0 return (forall {T}, (Tuple.tuple' A n0 ->T) ->T) with
  | O => fun T f => f y
  | S n' => fun T f =>
              match xs with
              | nil => from_list_default'_cps d d n' nil (fun r => f (r, y))
              | x :: xs' => from_list_default'_cps d x n' xs' (fun r => f (r, y))
              end
  end.
Lemma from_list_default'_cps_correct {A} n : forall d y l {T} f,
    @from_list_default'_cps A d y n l T f = f (Tuple.from_list_default' d y n l).
Proof.
  induction n as [|? IHn]; intros; simpl; [reflexivity|].
  break_match; subst; apply IHn.
Qed.
Definition from_list_default_cps {A} (d:A) n (xs:list A) :
  forall {T}, (Tuple.tuple A n -> T) -> T:=
  match n as n0 return (forall {T}, (Tuple.tuple A n0 ->T) ->T) with
  | O => fun T f => f tt
  | S n' => fun T f =>
              match xs with
              | nil => from_list_default'_cps d d n' nil f
              | x :: xs' => from_list_default'_cps d x n' xs' f
              end
  end.
Lemma from_list_default_cps_correct {A} n : forall d l {T} f,
    @from_list_default_cps A d n l T f = f (Tuple.from_list_default d n l).
Proof.
  destruct n; intros; simpl; [reflexivity|].
  break_match; auto using from_list_default'_cps_correct.
Qed.
Hint Rewrite @from_list_default_cps_correct : uncps.
Fixpoint to_list'_cps {A} n
         {T} (f:list A -> T) : Tuple.tuple' A n -> T :=
  match n as n0 return (Tuple.tuple' A n0 -> T) with
  | O => fun x => f [x]
  | S n' => fun (xs: Tuple.tuple' A (S n')) =>
              let (xs', x) := xs in
              to_list'_cps n' (fun r => f (x::r)) xs'
  end.
Lemma to_list'_cps_correct {A} n: forall t {T} f,
    @to_list'_cps A n T f t = f (Tuple.to_list' n t).
Proof.
  induction n; simpl; intros; [reflexivity|].
  destruct_head prod. apply IHn.
Qed.
Definition to_list_cps' {A} n {T} (f:list A->T)
  : Tuple.tuple A n -> T :=
  match n as n0 return (Tuple.tuple A n0 ->T) with
  | O => fun _ => f nil
  | S n' => to_list'_cps n' f
  end.
Definition to_list_cps {A} n t {T} f :=
  @to_list_cps' A n T f t.
Lemma to_list_cps_correct {A} n t {T} f :
  @to_list_cps A n t T f = f (Tuple.to_list n t).
Proof. cbv [to_list_cps to_list_cps' Tuple.to_list]; break_match; auto using to_list'_cps_correct. Qed.
Hint Rewrite @to_list_cps_correct : uncps.

Definition on_tuple_cps {A B} (d:B) (g:list A ->forall {T},(list B->T)->T) {n m}
           (xs : Tuple.tuple A n) {T} (f:tuple B m ->T) :=
  to_list_cps n xs (fun r => g r (fun rr => from_list_default_cps d m rr f)).
Lemma on_tuple_cps_correct {A B} d (g:list A -> forall {T}, (list B->T)->T)
      {n m} xs {T} f
      (Hg : forall x {T} h, @g x T h = h (g x id)) : forall H,
    @on_tuple_cps A B d g n m xs T f = f (@Tuple.on_tuple A B (fun x => g x id) n m H xs).
Proof.
  cbv [on_tuple_cps Tuple.on_tuple]; intros H.
  rewrite to_list_cps_correct, Hg, from_list_default_cps_correct.
  rewrite (Tuple.from_list_default_eq _ _ _ (H _ (Tuple.length_to_list _))).
  reflexivity.
Qed.  Hint Rewrite @on_tuple_cps_correct using (intros; autorewrite with uncps; auto): uncps.

Fixpoint update_nth_cps {A} n (g:A->A) xs {T} (f:list A->T) :=
  match n with
  | O =>
    match xs with
    | [] => f []
    | x' :: xs' => f (g x' :: xs')
    end
  | S n' =>
    match xs with
    | [] => f []
    | x' :: xs' => update_nth_cps n' g xs' (fun r => f (x' :: r))
    end
  end.
Lemma update_nth_cps_correct {A} n g: forall xs T f,
    @update_nth_cps A n g xs T f = f (update_nth n g xs).
Proof. induction n; intros; simpl; break_match; try apply IHn; reflexivity. Qed.
Hint Rewrite @update_nth_cps_correct : uncps.

Fixpoint combine_cps {A B} (la :list A) (lb : list B)
         {T} (f:list (A*B)->T) :=
  match la with
  | nil => f nil
  | a :: tla =>
    match lb with
    | nil => f nil
    | b :: tlb => combine_cps tla tlb (fun lab => f ((a,b)::lab))
    end
  end.
Lemma combine_cps_correct {A B} la: forall lb {T} f,
    @combine_cps A B la lb T f = f (combine la lb).
Proof.
  induction la; simpl combine_cps; simpl combine; intros;
    try break_match; try apply IHla; reflexivity.
Qed.
Hint Rewrite @combine_cps_correct: uncps.

(* differs from fold_right_cps in that the functional argument `g` is also a CPS function *)
Fixpoint fold_right_cps2 {A B} (g : B -> A -> forall {T}, (A->T)->T) (a0 : A) (l : list B) {T} (f : A -> T) :=
  match l with
  | nil => f a0
  | b :: tl => fold_right_cps2 g a0 tl (fun r => g b r f)
  end.
Lemma fold_right_cps2_correct {A B} g a0 l : forall {T} f,
  (forall b a T h, @g b a T h = h (@g b a A id)) ->
  @fold_right_cps2 A B g a0 l T f = f (List.fold_right (fun b a => @g b a A id) a0 l).
Proof.
  induction l as [|?? IHl]; intros T f H; [reflexivity|].
  simpl fold_right_cps2. simpl fold_right.
  rewrite H; erewrite IHl by eassumption.
  rewrite H; reflexivity.
Qed.
Hint Rewrite @fold_right_cps2_correct using (intros; autorewrite with uncps; auto): uncps.

Definition fold_right_no_starter {A} (f:A->A->A) ls : option A :=
  match ls with
  | nil => None
  | cons x tl => Some (List.fold_right f x tl)
  end.
Lemma fold_right_min ls x :
  x = List.fold_right Z.min x ls
  \/ List.In (List.fold_right Z.min x ls) ls.
Proof.
  induction ls; intros; simpl in *; try tauto.
  match goal with |- context [Z.min ?x ?y] =>
                  destruct (Z.min_spec x y) as [[? Hmin]|[? Hmin]]
  end; rewrite Hmin; tauto.
Qed.
Lemma fold_right_no_starter_min ls : forall x,
    fold_right_no_starter Z.min ls = Some x ->
    List.In x ls.
Proof.
  cbv [fold_right_no_starter]; intros; destruct ls; try discriminate.
  inversion H; subst; clear H.
  destruct (fold_right_min ls z);
    simpl List.In; tauto.
Qed.
Fixpoint fold_right_cps {A B} (g:B->A->A) (a0:A) (l:list B) {T} (f:A->T) :=
  match l with
  | nil => f a0
  | cons a tl => fold_right_cps g a0 tl (fun r => f (g a r))
  end.
Lemma fold_right_cps_correct {A B} g a0 l: forall {T} f,
    @fold_right_cps A B g a0 l T f = f (List.fold_right g a0 l).
Proof. induction l as [|? l IHl]; intros; simpl; rewrite ?IHl; auto. Qed.
Hint Rewrite @fold_right_cps_correct : uncps.

Definition fold_right_no_starter_cps {A} g ls {T} (f:option A->T) :=
  match ls with
  | nil => f None
  | cons x tl => f (Some (List.fold_right g x tl))
  end.
Lemma fold_right_no_starter_cps_correct {A} g ls {T} f :
  @fold_right_no_starter_cps A g ls T f = f (fold_right_no_starter g ls).
Proof.
  cbv [fold_right_no_starter_cps fold_right_no_starter]; break_match; reflexivity.
Qed.
Hint Rewrite @fold_right_no_starter_cps_correct : uncps.

Import Tuple.

Module Tuple.
  Fixpoint map_cps {A B n} (g : A->B) :
  tuple A n -> forall {T}, (tuple B n->T) -> T:=
  match n with
  | O => fun _ _ f => f tt
  | S n' => fun t T f => map_cps g (tl t) (fun r => f (append (g (hd t)) r))
  end.
  Lemma map_cps_correct {A B n} g t: forall {T} f,
      @map_cps A B n g t T f = f (map g t).
  Proof.
    induction n; simpl map_cps; intros; try destruct t;
    [|rewrite IHn, <-map_append,<-subst_append]; reflexivity.
  Qed. Hint Rewrite @map_cps_correct : uncps.

  Fixpoint mapi_with'_cps {T A B n} i
          (f: nat->T->A->forall {R}, (T*B->R)->R) (start:T)
  : Tuple.tuple' A n -> forall {R}, (T * tuple' B n -> R) -> R :=
  match n as n0 return (tuple' A n0 -> forall {R}, (T * tuple' B n0->R)->R) with
  | O => fun ys {T} ret => f i start ys ret
  | S n' => fun ys {T} ret =>
              f i start (hd ys) (fun sb =>
              mapi_with'_cps (S i) f (fst sb) (tl ys)
                  (fun r => ret (fst r, (snd r, snd sb))))
  end.

  Fixpoint mapi_with_cps {S A B n}
          (f: nat->S->A->forall {T}, (S*B->T)->T) (start:S)
  : tuple A n -> forall {T}, (S * tuple B n->T)->T :=
  match n as n0 return (tuple A n0 -> forall {T}, (S * tuple B n0->T)->T) with
  | O => fun ys {T} ret => ret (start, tt)
  | S n' => fun ys {T} ret => mapi_with'_cps 0%nat f start ys ret
  end.

  Lemma mapi_with'_cps_correct {S A B n} : forall i f start xs T ret,
  (forall i s a R (ret:_->R), f i s a R ret = ret (f i s a _ id)) ->
  @mapi_with'_cps S A B n i f start xs T ret = ret (mapi_with' i (fun i s a => f i s a _ id) start xs).
  Proof. induction n as [|n IHn]; intros i f start xs T ret H; simpl; rewrite H, ?IHn by assumption; reflexivity. Qed.
  Lemma mapi_with_cps_correct {S A B n} f start xs T ret
  (H:forall i s a R (ret:_->R), f i s a R ret = ret (f i s a _ id))
  : @mapi_with_cps S A B n f start xs T ret = ret (mapi_with (fun i s a => f i s a _ id) start xs).
  Proof. destruct n; simpl; rewrite ?mapi_with'_cps_correct by assumption; reflexivity. Qed.
  Hint Rewrite @mapi_with_cps_correct @mapi_with'_cps_correct
       using (intros; autorewrite with uncps; auto): uncps.
End Tuple.
Hint Rewrite @Tuple.map_cps_correct : uncps.
Hint Rewrite @Tuple.mapi_with_cps_correct @Tuple.mapi_with'_cps_correct
     using (intros; autorewrite with uncps; auto): uncps.