From 91c0f39a5c7236489e268de0e5fa97b055698e4c Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Sun, 22 Oct 2017 15:34:25 -0400 Subject: Factor out fold_right_cps2_specialized_step, add mapi_with'_cps2 --- src/Util/CPSUtil.v | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) (limited to 'src/Util/CPSUtil.v') diff --git a/src/Util/CPSUtil.v b/src/Util/CPSUtil.v index 9b249b460..6a97946ae 100644 --- a/src/Util/CPSUtil.v +++ b/src/Util/CPSUtil.v @@ -202,11 +202,16 @@ Qed. Hint Rewrite @combine_cps_correct: uncps. (* differs from fold_right_cps in that the functional argument `g` is also a CPS function *) -Fixpoint fold_right_cps2_specialized {T A B} (g : B -> A -> (A->T)->T) (a0 : A) (l : list B) (f : A -> T) := +Definition fold_right_cps2_specialized_step + (fold_right_cps2_specialized + : forall {T A B} (g : B -> A -> (A->T)->T) (a0 : A) (l : list B) (f : A -> T), _) + {T A B} (g : B -> A -> (A->T)->T) (a0 : A) (l : list B) (f : A -> T) := match l with | nil => f a0 | b :: tl => fold_right_cps2_specialized g a0 tl (fun r => g b r f) end. +Fixpoint fold_right_cps2_specialized {T A B} (g : B -> A -> (A->T)->T) (a0 : A) (l : list B) (f : A -> T) := + @fold_right_cps2_specialized_step (@fold_right_cps2_specialized) T A B g a0 l f. Definition fold_right_cps2 {A B} (g : B -> A -> forall {T}, (A->T)->T) (a0 : A) (l : list B) {T} (f : A -> T) := @fold_right_cps2_specialized T A B (fun b a => @g b a T) a0 l f. Lemma unfold_fold_right_cps2 {A B} (g : B -> A -> forall {T}, (A->T)->T) (a0 : A) (l : list B) {T} (f : A -> T) @@ -355,6 +360,59 @@ Module Tuple. Hint Rewrite @mapi_with_cps_correct @mapi_with'_cps_correct using (intros; autorewrite with uncps; auto): uncps. + Section internal_mapi_with_cps2. + (* We define fixpoints with fewer parameters to the internal [fix] to allow unfolding to partially specialize them *) + Context {R T A B : Type} + (f: nat->T->A->(T*B->R)->R). + + Fixpoint mapi_with'_cps2_specialized {n} i + (start:T) + : Tuple.tuple' A n -> (T * tuple' B n -> R) -> R := + match n as n0 return (tuple' A n0 -> (T * tuple' B n0->R)->R) with + | O => fun ys ret => f i start ys ret + | S n' => fun ys ret => + f i start (hd ys) (fun sb => + mapi_with'_cps2_specialized (S i) (fst sb) (tl ys) + (fun r => ret (fst r, (snd r, snd sb)))) + end. + End internal_mapi_with_cps2. + + Definition mapi_with'_cps2 {T A B n} i + (f: nat->T->A->forall {R}, (T*B->R)->R) (start:T) + : Tuple.tuple' A n -> forall {R}, (T * tuple' B n -> R) -> R + := fun ts R => @mapi_with'_cps2_specialized R T A B (fun n t a => @f n t a R) n i start ts. + + Definition mapi_with_cps2 {S A B n} + (f: nat->S->A->forall {T}, (S*B->T)->T) (start:S) (ys:tuple A n) {T} + : (S * tuple B n->T)->T := + match n as n0 return (tuple A n0 -> (S * tuple B n0->T)->T) with + | O => fun ys ret => ret (start, tt) + | S n' => fun ys ret => mapi_with'_cps2 0%nat f start ys ret + end ys. + + Lemma unfold_mapi_with'_cps2 {T A B n} i + (f: nat->T->A->forall {R}, (T*B->R)->R) (start:T) + : @mapi_with'_cps2 T A B n i f start + = match n as n0 return (tuple' A n0 -> forall {R}, (T * tuple' B n0->R)->R) with + | O => fun ys {T} ret => f i start ys ret + | S n' => fun ys {T} ret => + f i start (hd ys) (fun sb => + mapi_with'_cps2 (S i) f (fst sb) (tl ys) + (fun r => ret (fst r, (snd r, snd sb)))) + end. + Proof. destruct n; reflexivity. Qed. + + Lemma mapi_with'_cps2_correct {S A B n} : forall i f start xs T ret, + (forall i s a R (ret:_->R), f i s a R ret = ret (f i s a _ id)) -> + @mapi_with'_cps2 S A B n i f start xs T ret = ret (mapi_with' i (fun i s a => f i s a _ id) start xs). + Proof. induction n as [|n IHn]; intros i f start xs T ret H; simpl; rewrite H, ?IHn by assumption; reflexivity. Qed. + Lemma mapi_with_cps2_correct {S A B n} f start xs T ret + (H:forall i s a R (ret:_->R), f i s a R ret = ret (f i s a _ id)) + : @mapi_with_cps2 S A B n f start xs T ret = ret (mapi_with (fun i s a => f i s a _ id) start xs). + Proof. destruct n; simpl; rewrite ?mapi_with'_cps2_correct by assumption; reflexivity. Qed. + Hint Rewrite @mapi_with_cps2_correct @mapi_with'_cps2_correct + using (intros; autorewrite with uncps; auto): uncps. + Fixpoint left_append_cps {A n} (x:A) (xs:tuple A n) {R} : (tuple A (S n) -> R) -> R := match @@ -467,5 +525,5 @@ Module Tuple. Qed. End Tuple. Hint Rewrite @Tuple.map_cps_correct @Tuple.left_append_cps_correct @Tuple.left_tl_cps_correct @Tuple.left_hd_cps_correct @Tuple.tl_cps_correct @Tuple.hd_cps_correct : uncps. -Hint Rewrite @Tuple.mapi_with_cps_correct @Tuple.mapi_with'_cps_correct +Hint Rewrite @Tuple.mapi_with_cps_correct @Tuple.mapi_with'_cps_correct @Tuple.mapi_with_cps2_correct @Tuple.mapi_with'_cps2_correct using (intros; autorewrite with uncps; auto): uncps. -- cgit v1.2.3