aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-05-07 10:37:52 +0200
committerGravatar Jade Philipoom <jadep@google.com>2018-05-31 13:46:48 +0200
commit0d3ae3b975eb56d3247646d0a2023f675c14d759 (patch)
tree2185d30ecaa1d19cecaccc335bdb627bd871a27e /src
parent60c83608df9ef701ad559381288931dd61749f38 (diff)
proofs for straightline pass (with admits for some depth stuff that should be unneeded soon)
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v313
1 files changed, 290 insertions, 23 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 1d1f6b20e..e7c46ee9a 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -7958,7 +7958,7 @@ Module Straightline.
Definition dummy t : expr t := Scalar (dummy_scalar t).
End with_ident.
- Definition scalar_of_uncurried_ident {s d} (idc : ident.ident s d)
+ Definition of_uncurried_scalar_ident {s d} (idc : ident.ident s d)
: scalar s -> option (scalar d) :=
match idc in ident.ident s d return scalar (ident:=ident.ident) s -> option (scalar d) with
| ident.Z.cast r => fun args => Some (Cast r args)
@@ -7973,18 +7973,18 @@ Module Straightline.
| _ => fun _ => None
end.
- Fixpoint scalar_of_uncurried {t} (e : uexpr t) : option (scalar t) :=
+ Fixpoint of_uncurried_scalar {t} (e : uexpr t) : option (scalar t) :=
match e in Uncurried.expr.expr t return option (scalar t) with
| expr.Var t v as e => Some (Var t v)
| expr.TT as e => Some TT
| expr.Pair A B a b
- => match scalar_of_uncurried a, scalar_of_uncurried b with
+ => match of_uncurried_scalar a, of_uncurried_scalar b with
| Some x, Some y => Some (Pair x y)
| _, _ => None
end
| expr.AppIdent _ _ idc args
- => match scalar_of_uncurried args with
- | Some x => scalar_of_uncurried_ident idc x
+ => match of_uncurried_scalar args with
+ | Some x => of_uncurried_scalar_ident idc x
| None => None
end
| _ => None
@@ -8027,7 +8027,7 @@ Module Straightline.
| Some (r, x, e) =>
match invert_AppIdent x with
| Some (existT s idc_x') =>
- match scalar_of_uncurried (snd idc_x') with
+ match of_uncurried_scalar (snd idc_x') with
| Some x'' =>
match invert_Abs e with
| Some k => Some (existT _ s (r, fst idc_x', x'', k))
@@ -8066,7 +8066,7 @@ Module Straightline.
fun (args : uexpr _) default =>
match invert_AppIdent args with
| Some (existT s idc_x') =>
- match scalar_of_uncurried (snd idc_x') with
+ match of_uncurried_scalar (snd idc_x') with
| Some x'' =>
@mk_LetInAppIdent s type.Z type.Z default r (fst idc_x') x'' (fun y => Scalar (Var _ y))
| None => default
@@ -8077,7 +8077,7 @@ Module Straightline.
fun (args : uexpr _) default =>
match invert_AppIdent args with
| Some (existT s idc_x') =>
- match scalar_of_uncurried (snd idc_x') with
+ match of_uncurried_scalar (snd idc_x') with
| Some x'' =>
@mk_LetInAppIdent s (type.Z*type.Z) (type.Z*type.Z) default r (fst idc_x') x'' (fun y => Scalar (Var _ y))
| None => default
@@ -8091,14 +8091,20 @@ Module Straightline.
(of_uncurried : forall t, uexpr t -> expr t)
: expr t -> expr t :=
match e in Uncurried.expr.expr t return expr t -> expr t with
- | AppIdent s d idc args => of_uncurried_ident of_uncurried idc args
+ | AppIdent s d idc args =>
+ fun default =>
+ of_uncurried_ident of_uncurried idc args
+ (match of_uncurried_scalar (AppIdent idc args) with
+ | Some s => Scalar s
+ | None => default
+ end)
| _ as e =>
(fun default =>
- match scalar_of_uncurried e with
+ match of_uncurried_scalar e with
| Some s => Scalar s
| None => default
end)
- end.
+ end.
(* TODO : uses fuel; ideally want a cleaner termination proof *)
Fixpoint of_uncurried (fuel : nat) {t} (e : uexpr t)
@@ -8107,24 +8113,285 @@ Module Straightline.
| S fuel' => of_uncurried_step e (@of_uncurried fuel') (dummy t)
| O => dummy t
end.
-
End with_var.
- End expr.
- Fixpoint depth {var t} (dummy_var : forall t, var t) (e : @Uncurried.expr.expr ident.ident var t) : nat :=
- match e with
- | Uncurried.expr.Var _ _ => O
- | Uncurried.expr.TT => O
- | Uncurried.expr.AppIdent _ _ idc args => S (depth dummy_var args)
- | Uncurried.expr.App _ _ f x => S (Nat.max (depth dummy_var f) (depth dummy_var x))
- | Uncurried.expr.Pair _ _ x y => S (Nat.max (depth dummy_var x) (depth dummy_var y))
- | Uncurried.expr.Abs _ _ f => S (depth dummy_var (f (dummy_var _)))
- end.
+ Section depth.
+ Context (var : type -> Type) (dummy_var : forall t, var t).
+ Fixpoint depth {t} (e : @Uncurried.expr.expr ident var t) : nat :=
+ match e with
+ | Uncurried.expr.Var _ _ => 1
+ | Uncurried.expr.TT => 1
+ | Uncurried.expr.AppIdent _ _ idc args => S (depth args)
+ | Uncurried.expr.App _ _ f x => S (Nat.max (depth f) (depth x))
+ | Uncurried.expr.Pair _ _ x y => S (Nat.max (depth x) (depth y))
+ | Uncurried.expr.Abs _ _ f => S (depth (f (dummy_var _)))
+ end.
+
+ Definition Expr_depth {t} (e : Expr t) : nat := depth (e _).
+ End depth.
+
+ Section interp.
+ Local Notation scalar := (@scalar type.interp default.ident).
+ Local Notation expr := (@expr type.interp default.ident).
+
+ Definition interp_cast (r : zrange) (x : Z) : Z := ident.cast ident.cast_outside_of_range r x.
+
+ Fixpoint interp_scalar {t} (s : scalar t) : type.interp t :=
+ match s with
+ | Var t v => v
+ | TT => tt
+ | Nil _ => []
+ | Pair _ _ x y => (interp_scalar x, interp_scalar y)
+ | Cast r x => interp_cast r (interp_scalar x)
+ | Cast2 r x =>
+ let '(a,b) := interp_scalar x in (interp_cast (fst r) a, interp_cast (snd r) b)
+ | Fst _ _ p => fst (interp_scalar p)
+ | Snd _ _ p => snd (interp_scalar p)
+ | Shiftr n x => Z.shiftr (interp_scalar x) n
+ | Shiftl n x => Z.shiftl (interp_scalar x) n
+ | Land n x => Z.land (interp_scalar x) n
+ | CC_m n x => Z.cc_m n (interp_scalar x)
+ | Primitive _ x => x
+ end.
+
+ Fixpoint interp {t} (e : expr t) : type.interp t :=
+ match e with
+ | Scalar _ s => interp_scalar s
+ | LetInAppIdentZ _ _ r idc x f =>
+ interp (f (interp_cast r (ident.interp idc (interp_scalar x))))
+ | LetInAppIdentZZ _ _ (r1,r2) idc x f =>
+ let '(a,b) := ident.interp idc (interp_scalar x) in
+ interp (f (interp_cast r1 a, interp_cast r2 b))
+ end.
+
+ Inductive ok_scalar_ident : forall {s d}, ident.ident s d -> Prop :=
+ | ok_si_cast : forall r, ok_scalar_ident (ident.Z.cast r)
+ | ok_si_cast2 : forall r, ok_scalar_ident (ident.Z.cast2 r)
+ | ok_si_fst : forall A B, ok_scalar_ident (@ident.fst A B)
+ | ok_si_snd : forall A B, ok_scalar_ident (@ident.snd A B)
+ | ok_si_shiftr : forall n, ok_scalar_ident (@ident.Z.shiftr n)
+ | ok_si_shiftl : forall n, ok_scalar_ident (@ident.Z.shiftl n)
+ | ok_si_land : forall n, ok_scalar_ident (@ident.Z.land n)
+ | ok_si_cc_m : forall n, ok_scalar_ident (@ident.Z.cc_m_concrete n)
+ | ok_prim : forall p x, ok_scalar_ident (@ident.primitive p x)
+ .
+
+ Inductive ok_scalar: forall {t}, @Uncurried.expr.expr ident type.interp t -> Prop :=
+ | ok_Var : forall t v, @ok_scalar t (Uncurried.expr.Var v)
+ | ok_TT : ok_scalar Uncurried.expr.TT
+ | ok_AppIdent :
+ forall s d idc args,
+ ok_scalar args ->
+ @ok_scalar_ident s d idc ->
+ ok_scalar (AppIdent idc args)
+ | ok_Pair :
+ forall A B a b,
+ @ok_scalar A a ->
+ @ok_scalar B b ->
+ ok_scalar (Uncurried.expr.Pair a b)
+ .
+
+ Inductive ok_expr : forall {t}, Uncurried.expr.expr t -> Prop :=
+ | ok_LetInAppIdentZ :
+ forall tC r s (idc : ident s type.Z) x k,
+ ok_scalar x -> (forall y, @ok_expr tC (k y)) ->
+ @ok_expr tC (AppIdent (@ident.Let_In _ tC) (Uncurried.expr.Pair (AppIdent (ident.Z.cast r) (AppIdent idc x)) (Abs k)))
+ | ok_LetInAppIdentZZ :
+ forall tC r s (idc : ident s (type.prod type.Z type.Z)) x k,
+ ok_scalar x -> (forall y, @ok_expr tC (k y)) ->
+ @ok_expr tC (AppIdent (@ident.Let_In _ tC) (Uncurried.expr.Pair (AppIdent (ident.Z.cast2 r) (AppIdent idc x)) (Abs k)))
+ | ok_scalar_cast :
+ forall r s (idc : ident s _) x,
+ ok_scalar x ->
+ ok_scalar_ident idc ->
+ @ok_expr type.Z (AppIdent (ident.Z.cast r) (AppIdent idc x))
+ | ok_scalar_cast2 :
+ forall r s (idc : ident s _) x,
+ ok_scalar x ->
+ ok_scalar_ident idc ->
+ @ok_expr (type.prod type.Z type.Z) (AppIdent (ident.Z.cast2 r) (AppIdent idc x))
+ | ok_scalar_nocast :
+ forall t x, @ok_scalar t x -> @ok_expr t x
+ .
+
+ Lemma interp_cast_correct {t} r (idc : ident t type.Z) x :
+ interp_cast r (ident.interp idc (expr.interp (@ident.interp) x))
+ = expr.interp (@ident.interp) (AppIdent (ident.Z.cast r) (AppIdent idc x)).
+ Proof. reflexivity. Qed.
+
+ Lemma interp_cast2_correct {t} r1 r2 (idc : ident t (type.prod type.Z type.Z)) x :
+ (interp_cast r1 (fst (ident.interp idc (expr.interp (@ident.interp) x))),
+ interp_cast r2 (snd (ident.interp idc (expr.interp (@ident.interp) x))))
+ = expr.interp (@ident.interp) (AppIdent (ident.Z.cast2 (r1, r2)) (AppIdent idc x)).
+ Proof. cbn; break_match; reflexivity. Qed.
+
+ Definition primitive_eq_dec (a b : type.primitive) : {a = b} + {a <> b}.
+ Proof. destruct a,b; auto; right; congruence. Defined.
+ Fixpoint type_eq_dec (a b : type) : {a = b} + {a <> b}.
+ Proof.
+ destruct a, b; try solve [right; congruence]; [ | | | ].
+ { destruct (primitive_eq_dec p p0); subst; [left | right]; congruence. }
+ { destruct (type_eq_dec a1 b1); destruct (type_eq_dec a2 b2); subst; try solve [right; congruence].
+ left; congruence. }
+ { destruct (type_eq_dec a1 b1); destruct (type_eq_dec a2 b2); subst; try solve [right; congruence].
+ left; congruence. }
+ { destruct (type_eq_dec a b); [left | right]; congruence. }
+ Defined.
+
+ Ltac invert H :=
+ inversion H; subst;
+ repeat match goal with
+ | H : existT _ _ _ = existT _ _ _ |- _ => apply (Eqdep_dec.inj_pair2_eq_dec _ type_eq_dec) in H; subst
+ end.
+
+ Ltac invert_ok_expr :=
+ match goal with H : ok_expr _ |- _ => invert H end.
+ Ltac invert_ok_scalar :=
+ match goal with H : ok_scalar _ |- _ => invert H end.
+ Ltac invert_ok_scalar_ident :=
+ match goal with H : ok_scalar_ident _ |- _ => invert H end.
+ Ltac simpl_inversions :=
+ cbn [invert_LetInAppIdent invert_LetInCast invert_Pair invert_cast invert_AppIdent invert_Abs].
+
+ Lemma invert_AppIdent_correct {d} (e : @Uncurried.expr.expr ident type.interp d) x p :
+ invert_AppIdent e = Some (existT (fun s : type => (ident s d * default.expr s)%type) x p) ->
+ e = AppIdent (fst p) (snd p).
+ Proof.
+ cbv [invert_AppIdent].
+ break_match; try discriminate.
+ intro H; invert H. reflexivity.
+ Qed.
+
+ Lemma depth_positive {var t} dummy_var (e : Uncurried.expr.expr t) : 0 < depth var dummy_var e.
+ Proof. destruct e; cbn [depth]; rewrite Nat2Z.inj_succ; omega. Qed.
+
+ Lemma of_uncurried_scalar_ident_correct {s d} (idc : ident s d) args args':
+ ok_scalar_ident idc ->
+ of_uncurried_scalar args = Some args' ->
+ interp_scalar args' = expr.interp (@ident.interp) args ->
+ exists s,
+ of_uncurried_scalar_ident idc args' = Some s
+ /\ interp_scalar s = expr.interp (@ident.interp) (AppIdent idc args).
+ Proof.
+ destruct 1; intros;
+ repeat match goal with
+ | _ => eexists; split; [ reflexivity | cbn [interp_scalar] ]
+ | H : interp_scalar _ = _ |- _ => rewrite H
+ | _ => reflexivity
+ end.
+ { cbn; break_match; reflexivity. }
+ { cbn; break_match; reflexivity. }
+ Qed.
+
+ Lemma of_uncurried_scalar_correct {t} (e : Uncurried.expr.expr t) :
+ ok_scalar e ->
+ exists s,
+ of_uncurried_scalar e = Some s
+ /\ interp_scalar s = expr.interp (@ident.interp) e.
+ Proof.
+ induction 1; cbn [of_uncurried_scalar]; intros;
+ repeat match goal with
+ | IH : exists _, _ /\ _ |- _ => destruct IH as [? [? ?] ]
+ | H : of_uncurried_scalar _ = _ |- _ => rewrite H
+ | _ => apply of_uncurried_scalar_ident_correct; solve [auto]
+ | _ => eexists; tauto
+ end; [ ].
+ eexists; split; [ reflexivity | ].
+ cbn [interp_scalar expr.interp].
+ congruence.
+ Qed.
+
+ Ltac rewrite_ok_scalar :=
+ match goal with H : ok_scalar _ |- _ =>
+ let P := fresh in destruct (of_uncurried_scalar_correct _ H) as [? [P ?] ]; rewrite P in *
+ end;
+ repeat match goal with
+ | H : Some _ = Some _ |- _ => inversion H; progress subst
+ | _ => progress break_match;
+ match goal with | H: Some _ = Some _ |- _ => inversion H; progress subst end
+ end.
+
+ Lemma of_uncurried_correct dummy_arrow fuel dummy_var :
+ forall {t} (e : @Uncurried.expr.expr ident type.interp t),
+ (depth _ dummy_var e <= fuel)%nat ->
+ ok_expr e ->
+ expr.interp (@ident.interp) e = interp (@of_uncurried _ dummy_arrow fuel _ e).
+ Proof.
+ induction fuel; intros; [ pose proof (depth_positive dummy_var e); omega | ].
+ destruct e; cbn [depth of_uncurried expr.interp interp]; intros; invert_ok_expr;
+ repeat match goal with
+ | |- context [of_uncurried_scalar _ ] => progress rewrite_ok_scalar
+ | _ => progress (cbn [of_uncurried_step of_uncurried_ident fst snd mk_LetInAppIdent expr.interp interp depth] in * )
+ | _ => progress simpl_inversions
+ | _ => congruence
+ end; [ | | | | ].
+ {
+ match goal with H : interp_scalar _ = _ |- _ => rewrite H end.
+ rewrite <-IHfuel; rewrite interp_cast_correct.
+ { reflexivity. }
+ { cbn [depth] in *.
+ (* here we have to reason about the depth calculation for arrows; this will probably be unnecessary with new compilers setup *)
+ admit. }
+ { auto. } }
+ {
+ match goal with H : interp_scalar _ = _ |- _ => rewrite H end.
+ break_match.
+ match goal with H : ?p = (?x, ?y) |- _ =>
+ replace x with (fst p) by (rewrite H; reflexivity);
+ replace y with (snd p) by (rewrite H; reflexivity)
+ end.
+ rewrite <-IHfuel. rewrite interp_cast2_correct.
+ { cbn; break_match; reflexivity. }
+ { cbn [depth] in *.
+ (* here we have to reason about the depth calculation for arrows; this will probably be unnecessary with new compilers setup *)
+ admit. }
+ { auto. } }
+ {
+ match goal with H : interp_scalar _ = _ |- _ => rewrite H end.
+ rewrite interp_cast_correct.
+ reflexivity. }
+ {
+ match goal with H : interp_scalar _ = _ |- _ => rewrite H end.
+ break_match.
+ match goal with H : ?p = (?x, ?y) |- _ =>
+ replace x with (fst p) by (rewrite H; reflexivity);
+ replace y with (snd p) by (rewrite H; reflexivity)
+ end.
+ rewrite interp_cast2_correct.
+ cbn; break_match; reflexivity. }
+ { invert_ok_scalar.
+ rewrite <-H2.
+ invert_ok_scalar_ident; try reflexivity.
+ { match goal with H : context [of_uncurried_scalar _ = Some _ ] |- _ => cbn in H end.
+ rewrite_ok_scalar.
+ cbn [of_uncurried_ident].
+ cbn [interp_scalar].
+ cbn.
+ break_match; cbn; auto.
+ match goal with H : _ |- _ => apply invert_AppIdent_correct in H end.
+ subst.
+ invert_ok_scalar.
+ rewrite_ok_scalar.
+ repeat match goal with H : interp_scalar _ = _ |- _ => rewrite H end.
+ reflexivity. }
+ { match goal with H : context [of_uncurried_scalar _ = Some _ ] |- _ => cbn in H end.
+ rewrite_ok_scalar.
+ cbn [of_uncurried_ident].
+ break_match; cbn; auto.
+ match goal with H : _ |- _ => apply invert_AppIdent_correct in H end.
+ subst.
+ invert_ok_scalar.
+ rewrite_ok_scalar.
+ repeat match goal with H : interp_scalar _ = _ |- _ => rewrite H end.
+ destruct r; reflexivity. }
+ Admitted.
+ End interp.
+ End expr.
Definition of_Expr {s d} (e : Expr (s->d)) (var : type -> Type) (x:var s) dummy_arrow: expr.expr d
:=
match invert_Abs (e var) with
- | Some f => expr.of_uncurried (dummy_arrow:=dummy_arrow) (depth (var:=fun _ => unit) (fun _ => tt) (e (fun _ => unit))) (f x)
+ | Some f => expr.of_uncurried (dummy_arrow:=dummy_arrow) (expr.depth (fun _ => unit) (fun _ => tt) (e (fun _ => unit))) (f x)
| None => expr.dummy (dummy_arrow:=dummy_arrow) d
end.