aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-09 13:21:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-09 13:24:25 -0700
commit7aa1f1f57fa6851529c471da78b2e91e0aaab5c3 (patch)
treea3ea20f9014a9d3c5f3664e1a708b4acbee186ff /tensorflow/contrib/graph_editor
parent32e0db3e7a085ff2473a53b9401686544b4442aa (diff)
Adds a within_ops_fn parameter to get_forward_walk_ops and get_backward_walk_ops
that allows setting a condition on ops that are within or not within. Also adds tests for these methods that were missing. PiperOrigin-RevId: 192176693
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r--tensorflow/contrib/graph_editor/select.py26
-rw-r--r--tensorflow/contrib/graph_editor/tests/select_test.py155
2 files changed, 172 insertions, 9 deletions
diff --git a/tensorflow/contrib/graph_editor/select.py b/tensorflow/contrib/graph_editor/select.py
index 3ea6ff4d61..d700e6e1a7 100644
--- a/tensorflow/contrib/graph_editor/select.py
+++ b/tensorflow/contrib/graph_editor/select.py
@@ -383,6 +383,7 @@ def get_within_boundary_ops(ops,
def get_forward_walk_ops(seed_ops,
inclusive=True,
within_ops=None,
+ within_ops_fn=None,
stop_at_ts=(),
control_outputs=None):
"""Do a forward graph walk and return all the visited ops.
@@ -395,6 +396,9 @@ def get_forward_walk_ops(seed_ops,
within_ops: an iterable of `tf.Operation` within which the search is
restricted. If `within_ops` is `None`, the search is performed within
the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
stop_at_ts: an iterable of tensors at which the graph walk stops.
control_outputs: a `util.ControlOutputs` instance or None.
If not `None`, it will be used while walking the graph forward.
@@ -423,7 +427,8 @@ def get_forward_walk_ops(seed_ops,
seed_ops &= within_ops
def is_within(op):
- return within_ops is None or op in within_ops
+ return (within_ops is None or op in within_ops) and (
+ within_ops_fn is None or within_ops_fn(op))
result = list(seed_ops)
wave = set(seed_ops)
@@ -450,6 +455,7 @@ def get_forward_walk_ops(seed_ops,
def get_backward_walk_ops(seed_ops,
inclusive=True,
within_ops=None,
+ within_ops_fn=None,
stop_at_ts=(),
control_inputs=False):
"""Do a backward graph walk and return all the visited ops.
@@ -462,6 +468,9 @@ def get_backward_walk_ops(seed_ops,
within_ops: an iterable of `tf.Operation` within which the search is
restricted. If `within_ops` is `None`, the search is performed within
the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
stop_at_ts: an iterable of tensors at which the graph walk stops.
control_inputs: if True, control inputs will be used while moving backward.
Returns:
@@ -488,7 +497,8 @@ def get_backward_walk_ops(seed_ops,
seed_ops &= within_ops
def is_within(op):
- return within_ops is None or op in within_ops
+ return (within_ops is None or op in within_ops) and (
+ within_ops_fn is None or within_ops_fn(op))
result = list(seed_ops)
wave = set(seed_ops)
@@ -516,6 +526,7 @@ def get_walks_intersection_ops(forward_seed_ops,
forward_inclusive=True,
backward_inclusive=True,
within_ops=None,
+ within_ops_fn=None,
control_inputs=False,
control_outputs=None,
control_ios=None):
@@ -535,6 +546,9 @@ def get_walks_intersection_ops(forward_seed_ops,
within_ops: an iterable of tf.Operation within which the search is
restricted. If within_ops is None, the search is performed within
the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
control_inputs: A boolean indicating whether control inputs are enabled.
control_outputs: An instance of util.ControlOutputs or None. If not None,
control outputs are enabled.
@@ -555,11 +569,13 @@ def get_walks_intersection_ops(forward_seed_ops,
forward_seed_ops,
inclusive=forward_inclusive,
within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
control_outputs=control_outputs)
backward_ops = get_backward_walk_ops(
backward_seed_ops,
inclusive=backward_inclusive,
within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
control_inputs=control_inputs)
return [op for op in forward_ops if op in backward_ops]
@@ -569,6 +585,7 @@ def get_walks_union_ops(forward_seed_ops,
forward_inclusive=True,
backward_inclusive=True,
within_ops=None,
+ within_ops_fn=None,
control_inputs=False,
control_outputs=None,
control_ios=None):
@@ -587,6 +604,9 @@ def get_walks_union_ops(forward_seed_ops,
resulting set.
within_ops: restrict the search within those operations. If within_ops is
None, the search is done within the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
control_inputs: A boolean indicating whether control inputs are enabled.
control_outputs: An instance of util.ControlOutputs or None. If not None,
control outputs are enabled.
@@ -607,11 +627,13 @@ def get_walks_union_ops(forward_seed_ops,
forward_seed_ops,
inclusive=forward_inclusive,
within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
control_outputs=control_outputs)
backward_ops = get_backward_walk_ops(
backward_seed_ops,
inclusive=backward_inclusive,
within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
control_inputs=control_inputs)
return util.concatenate_unique(forward_ops, backward_ops)
diff --git a/tensorflow/contrib/graph_editor/tests/select_test.py b/tensorflow/contrib/graph_editor/tests/select_test.py
index 82f999637d..d12c6d3cbd 100644
--- a/tensorflow/contrib/graph_editor/tests/select_test.py
+++ b/tensorflow/contrib/graph_editor/tests/select_test.py
@@ -77,12 +77,10 @@ class SelectTest(test.TestCase):
"""Test for ge.get_ops_ios."""
control_outputs = ge.util.ControlOutputs(self.graph)
self.assertEqual(
- len(ge.get_ops_ios(
- self.h.op, control_ios=control_outputs)), 3)
+ len(ge.get_ops_ios(self.h.op, control_ios=control_outputs)), 3)
self.assertEqual(len(ge.get_ops_ios(self.h.op)), 2)
self.assertEqual(
- len(ge.get_ops_ios(
- self.c.op, control_ios=control_outputs)), 6)
+ len(ge.get_ops_ios(self.c.op, control_ios=control_outputs)), 6)
self.assertEqual(len(ge.get_ops_ios(self.c.op)), 5)
def test_compute_boundary_ts_0(self):
@@ -135,16 +133,49 @@ class SelectTest(test.TestCase):
ops = ge.get_walks_intersection_ops([self.c.op], [self.g.op])
self.assertEqual(len(ops), 2)
+ ops = ge.get_walks_intersection_ops([self.a.op], [self.f.op])
+ self.assertEqual(len(ops), 3)
+ self.assertTrue(self.a.op in ops)
+ self.assertTrue(self.c.op in ops)
+ self.assertTrue(self.f.op in ops)
+
+ within_ops = [self.a.op, self.f.op]
+ ops = ge.get_walks_intersection_ops(
+ [self.a.op], [self.f.op], within_ops=within_ops)
+ self.assertEqual(len(ops), 0)
+
+ within_ops_fn = lambda op: op in [self.a.op, self.f.op]
+ ops = ge.get_walks_intersection_ops(
+ [self.a.op], [self.f.op], within_ops_fn=within_ops_fn)
+ self.assertEqual(len(ops), 0)
+
def test_get_walks_union(self):
"""Test for ge.get_walks_union_ops."""
ops = ge.get_walks_union_ops([self.f.op], [self.g.op])
self.assertEqual(len(ops), 6)
+ ops = ge.get_walks_union_ops([self.a.op], [self.f.op])
+ self.assertEqual(len(ops), 8)
+
+ within_ops = [self.a.op, self.c.op, self.d.op, self.f.op]
+ ops = ge.get_walks_union_ops([self.a.op], [self.f.op],
+ within_ops=within_ops)
+ self.assertEqual(len(ops), 4)
+ self.assertTrue(self.b.op not in ops)
+
+ within_ops_fn = lambda op: op in [self.a.op, self.c.op, self.f.op]
+ ops = ge.get_walks_union_ops([self.a.op], [self.f.op],
+ within_ops_fn=within_ops_fn)
+ self.assertEqual(len(ops), 3)
+ self.assertTrue(self.b.op not in ops)
+ self.assertTrue(self.d.op not in ops)
+
def test_select_ops(self):
parameters = (
(("^foo/",), 7),
(("^foo/bar/",), 4),
- (("^foo/bar/", "a"), 5),)
+ (("^foo/bar/", "a"), 5),
+ )
for param, length in parameters:
ops = ge.select_ops(*param, graph=self.graph)
self.assertEqual(len(ops), length)
@@ -152,7 +183,8 @@ class SelectTest(test.TestCase):
def test_select_ts(self):
parameters = (
(".*:0", 8),
- (r".*/bar/\w+:0", 4),)
+ (r".*/bar/\w+:0", 4),
+ )
for regex, length in parameters:
ts = ge.select_ts(regex, graph=self.graph)
self.assertEqual(len(ts), length)
@@ -160,12 +192,121 @@ class SelectTest(test.TestCase):
def test_select_ops_and_ts(self):
parameters = (
(("^foo/.*",), 7, 0),
- (("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4),)
+ (("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4),
+ )
for param, l0, l1 in parameters:
ops, ts = ge.select_ops_and_ts(*param, graph=self.graph)
self.assertEqual(len(ops), l0)
self.assertEqual(len(ts), l1)
+ def test_forward_walk_ops(self):
+ seed_ops = [self.a.op, self.d.op]
+ # Include all ops except for self.g.op
+ within_ops = [
+ x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h]
+ ]
+ # For the fn, exclude self.e.op.
+ within_ops_fn = lambda op: op not in (self.e.op,)
+ stop_at_ts = (self.f,)
+
+ with self.graph.as_default():
+ # No b.op since it's an independent source node.
+ # No g.op from within_ops.
+ # No e.op from within_ops_fn.
+ # No h.op from stop_at_ts and within_ops.
+ ops = ge.select.get_forward_walk_ops(
+ seed_ops,
+ inclusive=True,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(
+ set(ops), set([self.a.op, self.c.op, self.d.op, self.f.op]))
+
+ # Also no a.op and d.op when inclusive=False
+ ops = ge.select.get_forward_walk_ops(
+ seed_ops,
+ inclusive=False,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(set(ops), set([self.c.op, self.f.op]))
+
+ # Not using within_ops_fn adds e.op.
+ ops = ge.select.get_forward_walk_ops(
+ seed_ops,
+ inclusive=False,
+ within_ops=within_ops,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(set(ops), set([self.c.op, self.e.op, self.f.op]))
+
+ # Not using stop_at_ts adds back h.op.
+ ops = ge.select.get_forward_walk_ops(
+ seed_ops, inclusive=False, within_ops=within_ops)
+ self.assertEqual(
+ set(ops), set([self.c.op, self.e.op, self.f.op, self.h.op]))
+
+ # Starting just form a (the tensor, not op) omits a, b, d.
+ ops = ge.select.get_forward_walk_ops([self.a], inclusive=True)
+ self.assertEqual(
+ set(ops), set([self.c.op, self.e.op, self.f.op, self.g.op,
+ self.h.op]))
+
+ def test_backward_walk_ops(self):
+ seed_ops = [self.h.op]
+ # Include all ops except for self.g.op
+ within_ops = [
+ x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h]
+ ]
+ # For the fn, exclude self.c.op.
+ within_ops_fn = lambda op: op not in (self.c.op,)
+ stop_at_ts = (self.f,)
+
+ with self.graph.as_default():
+ # Backward walk only includes h since we stop at f and g is not within.
+ ops = ge.select.get_backward_walk_ops(
+ seed_ops,
+ inclusive=True,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(set(ops), set([self.h.op]))
+
+ # If we do inclusive=False, the result is empty.
+ ops = ge.select.get_backward_walk_ops(
+ seed_ops,
+ inclusive=False,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(set(ops), set())
+
+ # Removing stop_at_fs adds f.op, d.op.
+ ops = ge.select.get_backward_walk_ops(
+ seed_ops,
+ inclusive=True,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn)
+ self.assertEqual(set(ops), set([self.d.op, self.f.op, self.h.op]))
+
+ # Not using within_ops_fn adds back ops for a, b, c.
+ ops = ge.select.get_backward_walk_ops(
+ seed_ops, inclusive=True, within_ops=within_ops)
+ self.assertEqual(
+ set(ops),
+ set([
+ self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.h.op
+ ]))
+
+ # Vanially backward search via self.h.op includes everything excpet e.op.
+ ops = ge.select.get_backward_walk_ops(seed_ops, inclusive=True)
+ self.assertEqual(
+ set(ops),
+ set([
+ self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.g.op,
+ self.h.op
+ ]))
+
if __name__ == "__main__":
test.main()