summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/elab_env.sml130
-rw-r--r--tests/type_class.ur13
2 files changed, 111 insertions, 32 deletions
diff --git a/src/elab_env.sml b/src/elab_env.sml
index 1768ce7d..9f64a8c2 100644
--- a/src/elab_env.sml
+++ b/src/elab_env.sml
@@ -233,11 +233,13 @@ end
structure KM = BinaryMapFn(KK)
-type class = ((class_name * class_key) list * exp) KM.map
-val empty_class = KM.empty
+type class = {ground : ((class_name * class_key) list * exp) KM.map,
+ inclusions : exp CM.map}
+val empty_class = {ground = KM.empty,
+ inclusions = CM.empty}
fun printClasses cs = (print "Classes:\n";
- CM.appi (fn (cn, km) =>
+ CM.appi (fn (cn, {ground = km, ...} : class) =>
(print (cn2s cn ^ ":");
KM.appi (fn (ck, _) => print (" " ^ ckn2s ck)) km;
print "\n")) cs)
@@ -361,9 +363,10 @@ fun pushCRel (env : env) x k =
constructors = #constructors env,
classes = CM.map (fn class =>
- KM.foldli (fn (ck, e, km) =>
- KM.insert (km, liftClassKey ck, e))
- KM.empty class)
+ {ground = KM.foldli (fn (ck, e, km) =>
+ KM.insert (km, liftClassKey ck, e))
+ KM.empty (#ground class),
+ inclusions = #inclusions class})
(#classes env),
renameE = SM.map (fn Rel' (n, c) => Rel' (n, lift c)
@@ -482,7 +485,7 @@ fun pushClass (env : env) n =
datatypes = #datatypes env,
constructors = #constructors env,
- classes = CM.insert (#classes env, ClNamed n, KM.empty),
+ classes = CM.insert (#classes env, ClNamed n, empty_class),
renameE = #renameE env,
relE = #relE env,
@@ -565,12 +568,36 @@ fun resolveClass (env : env) c =
| SOME class =>
let
val loc = #2 c
-
+
+ fun tryIncs () =
+ let
+ fun tryIncs fs =
+ case fs of
+ [] => NONE
+ | (f', e') :: fs =>
+ case doPair (f', x) of
+ NONE => tryIncs fs
+ | SOME e =>
+ let
+ val e' = (ECApp (e', class_key_out loc x), loc)
+ val e' = (EApp (e', e), loc)
+ in
+ SOME e'
+ end
+ in
+ tryIncs (CM.listItemsi (#inclusions class))
+ end
+
fun tryRules (k, args) =
let
val len = length args
+
+ fun tryNext () =
+ case k of
+ CkApp (k1, k2) => tryRules (k1, k2 :: args)
+ | _ => tryIncs ()
in
- case KM.find (class, (k, length args)) of
+ case KM.find (#ground class, (k, length args)) of
SOME (cs, e) =>
let
val es = map (fn (cn, ck) =>
@@ -585,7 +612,7 @@ fun resolveClass (env : env) c =
end) cs
in
if List.exists (not o Option.isSome) es then
- NONE
+ tryNext ()
else
let
val e = foldl (fn (arg, e) => (ECApp (e, class_key_out loc arg), loc))
@@ -596,10 +623,7 @@ fun resolveClass (env : env) c =
SOME e
end
end
- | NONE =>
- case k of
- CkApp (k1, k2) => tryRules (k1, k2 :: args)
- | _ => NONE
+ | NONE => tryNext ()
end
in
tryRules (x, [])
@@ -615,7 +639,9 @@ fun pushERel (env : env) x t =
val renameE = SM.map (fn Rel' (n, t) => Rel' (n+1, t)
| x => x) (#renameE env)
- val classes = CM.map (KM.map (fn (ps, e) => (ps, liftExp e))) (#classes env)
+ val classes = CM.map (fn class =>
+ {ground = KM.map (fn (ps, e) => (ps, liftExp e)) (#ground class),
+ inclusions = #inclusions class}) (#classes env)
val classes = case class_pair_in t of
NONE => classes
| SOME (f, x) =>
@@ -623,7 +649,8 @@ fun pushERel (env : env) x t =
NONE => classes
| SOME class =>
let
- val class = KM.insert (class, (x, 0), ([], (ERel 0, #2 t)))
+ val class = {ground = KM.insert (#ground class, (x, 0), ([], (ERel 0, #2 t))),
+ inclusions = #inclusions class}
in
CM.insert (classes, f, class)
end
@@ -655,6 +682,10 @@ fun lookupERel (env : env) n =
(List.nth (#relE env, n))
handle Subscript => raise UnboundRel n
+datatype rule =
+ Normal of int * (class_name * class_key) list * class_key
+ | Inclusion of class_name
+
fun rule_in c =
let
fun quantifiers (c, nvars) =
@@ -675,7 +706,7 @@ fun rule_in c =
let
fun dearg (ck, i) =
if i >= nvars then
- SOME (nvars, hyps, (cn, ck))
+ SOME (cn, Normal (nvars, hyps, ck))
else case ck of
CkApp (ck, CkRel i') =>
if i' = i then
@@ -690,7 +721,13 @@ fun rule_in c =
clauses (c, [])
end
in
- quantifiers (c, 0)
+ case #1 c of
+ TCFun (_, _, _, (TFun ((CApp (f1, (CRel 0, _)), _),
+ (CApp (f2, (CRel 0, _)), _)), _)) =>
+ (case (class_name_in f1, class_name_in f2) of
+ (SOME f1, SOME f2) => SOME (f2, Inclusion f1)
+ | _ => NONE)
+ | _ => quantifiers (c, 0)
end
fun pushENamedAs (env : env) x n t =
@@ -698,12 +735,21 @@ fun pushENamedAs (env : env) x n t =
val classes = #classes env
val classes = case rule_in t of
NONE => classes
- | SOME (nvars, hyps, (f, x)) =>
+ | SOME (f, rule) =>
case CM.find (classes, f) of
NONE => classes
| SOME class =>
let
- val class = KM.insert (class, (x, nvars), (hyps, (ENamed n, #2 t)))
+ val e = (ENamed n, #2 t)
+
+ val class =
+ case rule of
+ Normal (nvars, hyps, x) =>
+ {ground = KM.insert (#ground class, (x, nvars), (hyps, e)),
+ inclusions = #inclusions class}
+ | Inclusion f' =>
+ {ground = #ground class,
+ inclusions = CM.insert (#inclusions class, f', e)}
in
CM.insert (classes, f, class)
end
@@ -1023,12 +1069,10 @@ fun enrichClasses env classes (m1, ms) sgn =
| SgiVal (x, n, c) =>
(case rule_in c of
NONE => default ()
- | SOME (nvars, hyps, (cn, a)) =>
+ | SOME (cn, rule) =>
let
+ val globalizeN = sgnS_class_name (m1, ms, fmap)
val globalize = sgnS_class_key (m1, ms, fmap)
- val ck = globalize a
- val hyps = map (fn (n, k) => (sgnS_class_name (m1, ms, fmap) n,
- globalize k)) hyps
fun unravel c =
case c of
@@ -1055,10 +1099,22 @@ fun enrichClasses env classes (m1, ms) sgn =
NONE => classes
| SOME class =>
let
- val class = KM.insert (class, (ck, nvars),
- (hyps,
- (EModProj (m1, ms, x),
- #2 sgn)))
+ val e = (EModProj (m1, ms, x),
+ #2 sgn)
+
+ val class =
+ case rule of
+ Normal (nvars, hyps, a) =>
+ {ground =
+ KM.insert (#ground class, (globalize a, nvars),
+ (map (fn (n, k) =>
+ (globalizeN n,
+ globalize k)) hyps, e)),
+ inclusions = #inclusions class}
+ | Inclusion f' =>
+ {ground = #ground class,
+ inclusions = CM.insert (#inclusions class,
+ globalizeN f', e)}
in
CM.insert (classes, cn, class)
end
@@ -1077,9 +1133,21 @@ fun enrichClasses env classes (m1, ms) sgn =
NONE => classes
| SOME class =>
let
- val class = KM.insert (class, (ck, nvars),
- (hyps,
- (EModProj (m1, ms, x), #2 sgn)))
+ val e = (EModProj (m1, ms, x), #2 sgn)
+
+ val class =
+ case rule of
+ Normal (nvars, hyps, a) =>
+ {ground =
+ KM.insert (#ground class, (globalize a, nvars),
+ (map (fn (n, k) =>
+ (globalizeN n,
+ globalize k)) hyps, e)),
+ inclusions = #inclusions class}
+ | Inclusion f' =>
+ {ground = #ground class,
+ inclusions = CM.insert (#inclusions class,
+ globalizeN f', e)}
in
CM.insert (classes, cn, class)
end
diff --git a/tests/type_class.ur b/tests/type_class.ur
index 42cbe82f..a41ccdc8 100644
--- a/tests/type_class.ur
+++ b/tests/type_class.ur
@@ -9,6 +9,11 @@ structure M : sig
val option_default : t ::: Type -> default t -> default (option t)
val pair_default : a ::: Type -> b ::: Type -> default a -> default b -> default (pair a b)
+
+ class awesome
+ val awesome_default : t ::: Type -> awesome t -> default t
+
+ val float_awesome : awesome float
end = struct
class default t = t
fun get (t ::: Type) (x : t) = x
@@ -18,6 +23,11 @@ end = struct
fun option_default (t ::: Type) (x : t) = Some x
fun pair_default (a ::: Type) (b ::: Type) (x : a) (y : b) = Pair (x, y)
+
+ class awesome t = t
+ fun awesome_default (t ::: Type) (x : t) = x
+
+ val float_awesome = 1.23
end
open M
@@ -27,6 +37,7 @@ val hi : string = default
val zero : int = default
val some_zero : option int = default
val hi_zero : pair string int = default
+val ott : float = default
fun frob (t ::: Type) (_ : default t) : t = default
val hi_again : string = frob
@@ -44,5 +55,5 @@ fun show_pair (a ::: Type) (b ::: Type) (_ : show a) (_ : show b) : show (pair a
Pair (y, z) => "(" ^ show y ^ "," ^ show z ^ ")")
fun main () : transaction page = return <xml><body>
- {[hi_again]}, {[zero_again]}, {[some_zero]}, {[hi_zero]}
+ {[hi_again]}, {[zero_again]}, {[some_zero]}, {[hi_zero]}, {[ott]}
</body></xml>