summaryrefslogtreecommitdiff
path: root/src/union_find_fn.sml
blob: e6f8d9bf17377ff2f6df42ad6c77b8e0d9d225dd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
functor UnionFindFn(K : ORD_KEY) :> sig
    type unionFind
    val empty : unionFind
    val union : unionFind * K.ord_key * K.ord_key -> unionFind
    val union' : (K.ord_key * K.ord_key) * unionFind -> unionFind
    val classes : unionFind -> K.ord_key list list
end = struct

structure M = BinaryMapFn(K)
structure S = BinarySetFn(K)

datatype entry =
         Set of S.set
       | Pointer of K.ord_key

(* First map is the union-find tree, second stores equivalence classes. *)
type unionFind = entry M.map ref * S.set M.map

val empty : unionFind = (ref M.empty, M.empty)

fun findPair (uf, x) =
    case M.find (!uf, x) of
        NONE => (S.singleton x, x)
      | SOME (Set set) => (set, x)
      | SOME (Pointer parent) =>
        let
            val (set, rep) = findPair (uf, parent)
        in
            uf := M.insert (!uf, x, Pointer rep);
            (set, rep)
        end

fun find ((uf, _), x) = (S.listItems o #1 o findPair) (uf, x)

fun classes (_, cs) = (map S.listItems o M.listItems) cs

fun union ((uf, cs), x, y) =
    let
        val (xSet, xRep) = findPair (uf, x)
        val (ySet, yRep) = findPair (uf, y)
        val xySet = S.union (xSet, ySet)
    in
        (ref (M.insert (M.insert (!uf, yRep, Pointer xRep),
                        xRep, Set xySet)),
         M.insert (case M.find (cs, yRep) of
                       NONE => cs
                     | SOME _ => #1 (M.remove (cs, yRep)),
                   xRep, xySet))
    end

fun union' ((x, y), uf) = union (uf, x, y)

end