summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Adam Chlipala <adamc@hcoop.net>2008-07-03 11:04:25 -0400
committerGravatar Adam Chlipala <adamc@hcoop.net>2008-07-03 11:04:25 -0400
commitb2eb9f45b9b14e5c7f53d0ad7ca8e84aa7858b59 (patch)
treecd4847d16103c7bdbfba1ece0416497bb28d05d8
parente8002363e5d7764edf9a06ec0717f212ebbee26f (diff)
Fancier head normalization pushed inside of Disjoint
-rw-r--r--src/disjoint.sig9
-rw-r--r--src/disjoint.sml127
-rw-r--r--src/elaborate.sml94
-rw-r--r--tests/cfold_disj.lac5
4 files changed, 140 insertions, 95 deletions
diff --git a/src/disjoint.sig b/src/disjoint.sig
index 16afa885..025269cf 100644
--- a/src/disjoint.sig
+++ b/src/disjoint.sig
@@ -30,9 +30,14 @@ signature DISJOINT = sig
type env
val empty : env
- val assert : ElabEnv.env -> env -> Elab.con * Elab.con -> env
val enter : env -> env
- val prove : ElabEnv.env -> env -> Elab.con * Elab.con * ErrorMsg.span -> (Elab.con * Elab.con) list
+ type goal = ErrorMsg.span * ElabEnv.env * env * Elab.con * Elab.con
+
+ val assert : ElabEnv.env -> env -> Elab.con * Elab.con -> env * goal list
+
+ val prove : ElabEnv.env -> env -> Elab.con * Elab.con * ErrorMsg.span -> goal list
+
+ val hnormCon : ElabEnv.env * env -> Elab.con -> Elab.con * goal list
end
diff --git a/src/disjoint.sml b/src/disjoint.sml
index 6c66fdd8..6bd7e0c9 100644
--- a/src/disjoint.sml
+++ b/src/disjoint.sml
@@ -109,6 +109,8 @@ structure PM = BinaryMapFn(PK)
type env = PS.set PM.map
+type goal = ErrorMsg.span * ElabEnv.env * env * Elab.con * Elab.con
+
val empty = PM.empty
fun nameToRow (c, loc) =
@@ -128,32 +130,62 @@ datatype piece' =
Piece of piece
| Unknown of con
-fun decomposeRow env c =
+fun pieceEnter p =
+ case p of
+ NameR n => NameR (n + 1)
+ | RowR n => RowR (n + 1)
+ | _ => p
+
+fun enter denv =
+ PM.foldli (fn (p, pset, denv') =>
+ PM.insert (denv', pieceEnter p, PS.map pieceEnter pset))
+ PM.empty denv
+
+fun prove1 denv (p1, p2) =
+ case (p1, p2) of
+ (NameC s1, NameC s2) => s1 <> s2
+ | _ =>
+ case PM.find (denv, p1) of
+ NONE => false
+ | SOME pset => PS.member (pset, p2)
+
+fun decomposeRow (env, denv) c =
let
- fun decomposeName (c, acc) =
- case #1 (hnormCon env c) of
- CName s => Piece (NameC s) :: acc
- | CRel n => Piece (NameR n) :: acc
- | CNamed n => Piece (NameN n) :: acc
- | CModProj (m1, ms, x) => Piece (NameM (m1, ms, x)) :: acc
- | _ => Unknown c :: acc
-
- fun decomposeRow (c, acc) =
- case #1 (hnormCon env c) of
- CRecord (_, xcs) => foldl (fn ((x, _), acc) => decomposeName (x, acc)) acc xcs
- | CConcat (c1, c2) => decomposeRow (c1, decomposeRow (c2, acc))
- | CRel n => Piece (RowR n) :: acc
- | CNamed n => Piece (RowN n) :: acc
- | CModProj (m1, ms, x) => Piece (RowM (m1, ms, x)) :: acc
- | _ => Unknown c :: acc
+ fun decomposeName (c, (acc, gs)) =
+ let
+ val (cAll as (c, _), gs') = hnormCon (env, denv) c
+
+ val acc = case c of
+ CName s => Piece (NameC s) :: acc
+ | CRel n => Piece (NameR n) :: acc
+ | CNamed n => Piece (NameN n) :: acc
+ | CModProj (m1, ms, x) => Piece (NameM (m1, ms, x)) :: acc
+ | _ => Unknown cAll :: acc
+ in
+ (acc, gs' @ gs)
+ end
+
+ fun decomposeRow (c, (acc, gs)) =
+ let
+ val (cAll as (c, _), gs') = hnormCon (env, denv) c
+ val gs = gs' @ gs
+ in
+ case c of
+ CRecord (_, xcs) => foldl (fn ((x, _), acc_gs) => decomposeName (x, acc_gs)) (acc, gs) xcs
+ | CConcat (c1, c2) => decomposeRow (c1, decomposeRow (c2, (acc, gs)))
+ | CRel n => (Piece (RowR n) :: acc, gs)
+ | CNamed n => (Piece (RowN n) :: acc, gs)
+ | CModProj (m1, ms, x) => (Piece (RowM (m1, ms, x)) :: acc, gs)
+ | _ => (Unknown cAll :: acc, gs)
+ end
in
- decomposeRow (c, [])
+ decomposeRow (c, ([], []))
end
-fun assert env denv (c1, c2) =
+and assert env denv (c1, c2) =
let
- val ps1 = decomposeRow env c1
- val ps2 = decomposeRow env c2
+ val (ps1, gs1) = decomposeRow (env, denv) c1
+ val (ps2, gs2) = decomposeRow (env, denv) c2
val unUnknown = List.mapPartial (fn Unknown _ => NONE | Piece p => SOME p)
val ps1 = unUnknown ps1
@@ -167,6 +199,9 @@ fun assert env denv (c1, c2) =
fun assertPiece ps (p, denv) =
let
val pset = Option.getOpt (PM.find (denv, p), PS.empty)
+ val ps = case p of
+ NameC _ => List.filter (fn NameC _ => false | _ => true) ps
+ | _ => ps
val pset = PS.addList (pset, ps)
in
PM.insert (denv, p, pset)
@@ -174,38 +209,19 @@ fun assert env denv (c1, c2) =
val denv = foldl (assertPiece ps2) denv ps1
in
- foldl (assertPiece ps1) denv ps2
+ (foldl (assertPiece ps1) denv ps2, gs1 @ gs2)
end
-fun pieceEnter p =
- case p of
- NameR n => NameR (n + 1)
- | RowR n => RowR (n + 1)
- | _ => p
-
-fun enter denv =
- PM.foldli (fn (p, pset, denv') =>
- PM.insert (denv', pieceEnter p, PS.map pieceEnter pset))
- PM.empty denv
-
-fun prove1 denv (p1, p2) =
- case (p1, p2) of
- (NameC s1, NameC s2) => s1 <> s2
- | _ =>
- case PM.find (denv, p1) of
- NONE => false
- | SOME pset => PS.member (pset, p2)
-
-fun prove env denv (c1, c2, loc) =
+and prove env denv (c1, c2, loc) =
let
- val ps1 = decomposeRow env c1
- val ps2 = decomposeRow env c2
+ val (ps1, gs1) = decomposeRow (env, denv) c1
+ val (ps2, gs2) = decomposeRow (env, denv) c2
val hasUnknown = List.exists (fn Unknown _ => true | _ => false)
val unUnknown = List.mapPartial (fn Unknown _ => NONE | Piece p => SOME p)
in
if hasUnknown ps1 orelse hasUnknown ps2 then
- [(c1, c2)]
+ [(loc, env, denv, c1, c2)]
else
let
val ps1 = unUnknown ps1
@@ -222,9 +238,26 @@ fun prove env denv (c1, c2, loc) =
if prove1 denv (p1, p2) then
rem
else
- (pieceToRow (p1, loc), pieceToRow (p2, loc)) :: rem) rem ps2)
- [] ps1
+ (loc, env, denv, pieceToRow (p1, loc), pieceToRow (p2, loc)) :: rem) rem ps2)
+ (gs1 @ gs2) ps1
end
end
+and hnormCon (env, denv) c =
+ let
+ val cAll as (c, loc) = ElabOps.hnormCon env c
+
+ fun doDisj (c1, c2, c) =
+ let
+ val (c, gs) = hnormCon (env, denv) c
+ in
+ (c, prove env denv (c1, c2, loc) @ gs)
+ end
+ in
+ case c of
+ CDisjoint cs => doDisj cs
+ | TDisjoint cs => doDisj cs
+ | _ => (cAll, [])
+ end
+
end
diff --git a/src/elaborate.sml b/src/elaborate.sml
index 216d483f..7a8c06a8 100644
--- a/src/elaborate.sml
+++ b/src/elaborate.sml
@@ -251,13 +251,13 @@ fun elabCon (env, denv) (c, loc) =
val ku1 = kunif loc
val ku2 = kunif loc
- val denv' = D.assert env denv (c1', c2')
- val (c', k, gs3) = elabCon (env, denv') c
+ val (denv', gs3) = D.assert env denv (c1', c2')
+ val (c', k, gs4) = elabCon (env, denv') c
in
checkKind env c1' k1 (L'.KRecord ku1, loc);
checkKind env c2' k2 (L'.KRecord ku2, loc);
- ((L'.TDisjoint (c1', c2', c'), loc), k, gs1 @ gs2 @ gs3)
+ ((L'.TDisjoint (c1', c2', c'), loc), k, gs1 @ gs2 @ gs3 @ gs4)
end
| L.TRecord c =>
let
@@ -330,13 +330,13 @@ fun elabCon (env, denv) (c, loc) =
val ku1 = kunif loc
val ku2 = kunif loc
- val denv' = D.assert env denv (c1', c2')
- val (c', k, gs3) = elabCon (env, denv') c
+ val (denv', gs3) = D.assert env denv (c1', c2')
+ val (c', k, gs4) = elabCon (env, denv') c
in
checkKind env c1' k1 (L'.KRecord ku1, loc);
checkKind env c2' k2 (L'.KRecord ku2, loc);
- ((L'.CDisjoint (c1', c2', c'), loc), k, gs1 @ gs2 @ gs3)
+ ((L'.CDisjoint (c1', c2', c'), loc), k, gs1 @ gs2 @ gs3 @ gs4)
end
| L.CName s =>
@@ -369,8 +369,7 @@ fun elabCon (env, denv) (c, loc) =
let
val r2 = (L'.CRecord (k, [xc']), loc)
in
- map (fn cs => (loc, env, denv, cs)) (D.prove env denv (r1, r2, loc))
- @ ds
+ D.prove env denv (r1, r2, loc) @ ds
end)
ds rest
in
@@ -389,7 +388,7 @@ fun elabCon (env, denv) (c, loc) =
checkKind env c1' k1 k;
checkKind env c2' k2 k;
((L'.CConcat (c1', c2'), loc), k,
- map (fn cs => (loc, env, denv, cs)) (D.prove env denv (c1', c2', loc)) @ gs1 @ gs2)
+ D.prove env denv (c1', c2', loc) @ gs1 @ gs2)
end
| L.CFold =>
let
@@ -545,23 +544,7 @@ fun kindof env (c, loc) =
| L'.CError => kerror
| L'.CUnif (_, k, _, _) => k
-fun hnormCon (env, denv) c =
- let
- val cAll as (c, loc) = ElabOps.hnormCon env c
-
- fun doDisj (c1, c2, c) =
- let
- val (c, gs) = hnormCon (env, denv) c
- in
- (c,
- map (fn cs => (loc, env, denv, cs)) (D.prove env denv (c1, c2, loc)) @ gs)
- end
- in
- case c of
- L'.CDisjoint cs => doDisj cs
- | L'.TDisjoint cs => doDisj cs
- | _ => (cAll, [])
- end
+val hnormCon = D.hnormCon
fun unifyRecordCons (env, denv) (c1, c2) =
let
@@ -703,9 +686,9 @@ and unifyCons' (env, denv) c1 c2 =
let
val (c1, gs1) = hnormCon (env, denv) c1
val (c2, gs2) = hnormCon (env, denv) c2
+ val gs3 = unifyCons'' (env, denv) c1 c2
in
- unifyCons'' (env, denv) c1 c2;
- gs1 @ gs2
+ gs1 @ gs2 @ gs3
end
and unifyCons'' (env, denv) (c1All as (c1, _)) (c2All as (c2, _)) =
@@ -1040,13 +1023,13 @@ fun elabExp (env, denv) (e, loc) =
val ku1 = kunif loc
val ku2 = kunif loc
- val denv' = D.assert env denv (c1', c2')
- val (e', t, gs3) = elabExp (env, denv') e
+ val (denv', gs3) = D.assert env denv (c1', c2')
+ val (e', t, gs4) = elabExp (env, denv') e
in
checkKind env c1' k1 (L'.KRecord ku1, loc);
checkKind env c2' k2 (L'.KRecord ku2, loc);
- (e', (L'.TDisjoint (c1', c2', t), loc), gs1 @ gs2 @ gs3)
+ (e', (L'.TDisjoint (c1', c2', t), loc), gs1 @ gs2 @ gs3 @ gs4)
end
| L.ERecord xes =>
@@ -1075,8 +1058,7 @@ fun elabExp (env, denv) (e, loc) =
val xc' = (x', t')
val r2 = (L'.CRecord (k, [xc']), loc)
in
- map (fn cs => (loc, env, denv, cs)) (D.prove env denv (r1, r2, loc))
- @ gs
+ D.prove env denv (r1, r2, loc) @ gs
end)
gs rest
in
@@ -1100,9 +1082,7 @@ fun elabExp (env, denv) (e, loc) =
val gs3 =
checkCon (env, denv) e' et
(L'.TRecord (L'.CConcat (first, rest), loc), loc)
- val gs4 =
- map (fn cs => (loc, env, denv, cs))
- (D.prove env denv (first, rest, loc))
+ val gs4 = D.prove env denv (first, rest, loc)
in
((L'.EField (e', c', {field = ft, rest = rest}), loc), ft, gs1 @ gs2 @ gs3 @ gs4)
end
@@ -1287,12 +1267,12 @@ fun elabSgn_item ((sgi, loc), (env, denv, gs)) =
val (c1', k1, gs1) = elabCon (env, denv) c1
val (c2', k2, gs2) = elabCon (env, denv) c2
- val denv = D.assert env denv (c1', c2')
+ val (denv, gs3) = D.assert env denv (c1', c2')
in
checkKind env c1' k1 (L'.KRecord (kunif loc), loc);
checkKind env c2' k2 (L'.KRecord (kunif loc), loc);
- ([(L'.SgiConstraint (c1', c2'), loc)], (env, denv, gs1 @ gs2))
+ ([(L'.SgiConstraint (c1', c2'), loc)], (env, denv, gs1 @ gs2 @ gs3))
end
and elabSgn (env, denv) (sgn, loc) =
@@ -1484,7 +1464,16 @@ fun dopenConstraints (loc, env, denv) {str, strs} =
val denv = case cso of
NONE => (strError env (UnboundStr (loc, str));
denv)
- | SOME cs => foldl (fn ((c1, c2), denv) => D.assert env denv (c1, c2)) denv cs
+ | SOME cs => foldl (fn ((c1, c2), denv) =>
+ let
+ val (denv, gs) = D.assert env denv (c1, c2)
+ in
+ case gs of
+ [] => ()
+ | _ => raise Fail "dopenConstraints: Sub-constraints remain";
+
+ denv
+ end) denv cs
in
denv
end
@@ -1500,7 +1489,10 @@ fun sgiOfDecl (d, loc) =
fun sgiBindsD (env, denv) (sgi, _) =
case sgi of
- L'.SgiConstraint (c1, c2) => D.assert env denv (c1, c2)
+ L'.SgiConstraint (c1, c2) =>
+ (case D.assert env denv (c1, c2) of
+ (denv, []) => denv
+ | _ => raise Fail "sgiBindsD: Sub-constraints remain")
| _ => denv
fun subSgn (env, denv) sgn1 (sgn2 as (_, loc2)) =
@@ -1634,7 +1626,15 @@ fun subSgn (env, denv) sgn1 (sgn2 as (_, loc2)) =
case sgi1 of
L'.SgiConstraint (c1, d1) =>
if consEq (env, denv) (c1, c2) andalso consEq (env, denv) (d1, d2) then
- SOME (env, D.assert env denv (c2, d2))
+ let
+ val (denv, gs) = D.assert env denv (c2, d2)
+ in
+ case gs of
+ [] => ()
+ | _ => raise Fail "subSgn: Sub-constraints remain";
+
+ SOME (env, denv)
+ end
else
NONE
| _ => NONE)
@@ -1793,14 +1793,14 @@ fun elabDecl ((d, loc), (env, denv, gs)) =
let
val (c1', k1, gs1) = elabCon (env, denv) c1
val (c2', k2, gs2) = elabCon (env, denv) c2
- val gs3 = map (fn cs => (loc, env, denv, cs)) (D.prove env denv (c1', c2', loc))
+ val gs3 = D.prove env denv (c1', c2', loc)
- val denv' = D.assert env denv (c1', c2')
+ val (denv', gs4) = D.assert env denv (c1', c2')
in
checkKind env c1' k1 (L'.KRecord (kunif loc), loc);
checkKind env c2' k2 (L'.KRecord (kunif loc), loc);
- ([(L'.DConstraint (c1', c2'), loc)], (env, denv', gs1 @ gs2 @ gs3))
+ ([(L'.DConstraint (c1', c2'), loc)], (env, denv', gs1 @ gs2 @ gs3 @ gs4))
end
| L.DOpenConstraints (m, ms) =>
@@ -1982,13 +1982,15 @@ fun elabFile basis env file =
if ErrorMsg.anyErrors () then
()
else
- app (fn (loc, env, denv, (c1, c2)) =>
+ app (fn (loc, env, denv, c1, c2) =>
case D.prove env denv (c1, c2, loc) of
[] => ()
| _ =>
(ErrorMsg.errorAt loc "Couldn't prove field name disjointness";
eprefaces' [("Con 1", p_con env c1),
- ("Con 2", p_con env c2)])) gs;
+ ("Con 2", p_con env c2),
+ ("Hnormed 1", p_con env (ElabOps.hnormCon env c1)),
+ ("Hnormed 2", p_con env (ElabOps.hnormCon env c2))])) gs;
(L'.DFfiStr ("Basis", basis_n, sgn), ErrorMsg.dummySpan) :: ds @ file
end
diff --git a/tests/cfold_disj.lac b/tests/cfold_disj.lac
new file mode 100644
index 00000000..e0a19484
--- /dev/null
+++ b/tests/cfold_disj.lac
@@ -0,0 +1,5 @@
+con id = fold (fn nm => fn t :: Type => fn acc => [nm] ~ acc => [nm = t] ++ acc) []
+
+con idT = id [D = int, E = float]
+
+val idV = fn x : $idT => x.E