diff options
author | 2016-12-21 02:51:43 -0800 | |
---|---|---|
committer | 2016-12-21 03:06:58 -0800 | |
commit | 2145b44339642796dc382153d26b434c2cc18559 (patch) | |
tree | c1c09e2e35366bfb9c5b074029f2bed9d8eb1371 | |
parent | f7048fc88b44102e745def51a8b2610c4aacb139 (diff) |
Fix two bugs int the Graph editor;
- compute_boundary_ts was sometimes adding spurious inputs
- sgv.consumers was returning op inside the subgraph
Change: 142645987
-rw-r--r-- | tensorflow/contrib/graph_editor/select.py | 30 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/subgraph.py | 13 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/edit_test.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/select_test.py | 13 |
4 files changed, 43 insertions, 17 deletions
diff --git a/tensorflow/contrib/graph_editor/select.py b/tensorflow/contrib/graph_editor/select.py index 0e2914cb0d..1125b90a9e 100644 --- a/tensorflow/contrib/graph_editor/select.py +++ b/tensorflow/contrib/graph_editor/select.py @@ -272,7 +272,7 @@ def get_ops_ios(ops, control_inputs=False, control_outputs=None, return res -def compute_boundary_ts(ops, ambiguous_ts_are_outputs=True): +def compute_boundary_ts(ops): """Compute the tensors at the boundary of a set of ops. This function looks at all the tensors connected to the given ops (in/out) @@ -281,17 +281,18 @@ def compute_boundary_ts(ops, ambiguous_ts_are_outputs=True): 2) output tensors: tensors whose consumer operations are not in ops 3) inside tensors: tensors which are neither input nor output tensors. + Note that a tensor can be both an inside tensor and an output tensor if it is + consumed by operations both outside and inside of `ops`. + Args: ops: an object convertible to a list of tf.Operation. - ambiguous_ts_are_outputs: a tensor can have consumers both inside and - outside ops. Such tensors are treated as outside tensor if - ambiguous_ts_are_outputs is True, otherwise they are treated as - inside tensor. Returns: A tuple `(outside_input_ts, outside_output_ts, inside_ts)` where: `outside_input_ts` is a Python list of input tensors; `outside_output_ts` is a python list of output tensors; `inside_ts` is a python list of inside tensors. + Since a tensor can be both an inside tensor and an output tensor, + `outside_output_ts` and `inside_ts` might intersect. Raises: TypeError: if ops cannot be converted to a list of tf.Operation. """ @@ -301,22 +302,25 @@ def compute_boundary_ts(ops, ambiguous_ts_are_outputs=True): output_ts_set = frozenset(output_ts) ops_set = frozenset(ops) - # fill in inside + # Compute inside tensors. inside_ts = [] + only_inside_ts = [] for t in input_ts: - # is also output? + # Skip if the input tensor is not also an output tensor. if t not in output_ts_set: continue - # is ambiguous_ts_are_outputs is True, don't add to inside if ambiguous - if ambiguous_ts_are_outputs: - consumers = frozenset(t.consumers()) - if consumers - ops_set: - continue + # Mark as "inside". inside_ts.append(t) + # Mark as "only inside" if the tensor is not both inside and output. + consumers = frozenset(t.consumers()) + if consumers - ops_set: + continue + only_inside_ts.append(t) inside_ts_set = frozenset(inside_ts) + only_inside_ts_set = frozenset(only_inside_ts) + outside_output_ts = [t for t in output_ts if t not in only_inside_ts_set] outside_input_ts = [t for t in input_ts if t not in inside_ts_set] - outside_output_ts = [t for t in output_ts if t not in inside_ts_set] return outside_input_ts, outside_output_ts, inside_ts diff --git a/tensorflow/contrib/graph_editor/subgraph.py b/tensorflow/contrib/graph_editor/subgraph.py index 00a755c79f..bfeb3ae23a 100644 --- a/tensorflow/contrib/graph_editor/subgraph.py +++ b/tensorflow/contrib/graph_editor/subgraph.py @@ -561,10 +561,19 @@ class SubGraphView(object): return subgraph_id def consumers(self): - """Return a Python set of all the consumers of this subgraph view.""" + """Return a Python set of all the consumers of this subgraph view. + + A consumer of a subgraph view is a tf.Operation which is a consumer + of one of the output tensors and is not in the subgraph. + + Returns: + A list of `tf.Operation` which are the consumers of this subgraph view. + """ + ops_set = frozenset(self._ops) res = [] for output in self._output_ts: - util.concatenate_unique(res, output.consumers()) + consumers = [op for op in output.consumers() if op not in ops_set] + util.concatenate_unique(res, consumers) return res diff --git a/tensorflow/contrib/graph_editor/tests/edit_test.py b/tensorflow/contrib/graph_editor/tests/edit_test.py index 968a73c812..a3330beee8 100644 --- a/tensorflow/contrib/graph_editor/tests/edit_test.py +++ b/tensorflow/contrib/graph_editor/tests/edit_test.py @@ -49,10 +49,10 @@ class EditTest(tf.test.TestCase): """Test for ge.detach.""" sgv = ge.sgv(self.c.op, self.a.op) control_outputs = ge.util.ControlOutputs(self.graph) - ge.detach(sgv, control_inputs=control_outputs) + ge.detach(sgv, control_ios=control_outputs) # make sure the detached graph is as expected. self.assertTrue(ge.matcher("^foo/c$") - .input_ops("geph__a_0", "geph__b_0")(self.c.op)) + .input_ops("a", "geph__b_0")(self.c.op)) def test_connect(self): """Test for ge.connect.""" diff --git a/tensorflow/contrib/graph_editor/tests/select_test.py b/tensorflow/contrib/graph_editor/tests/select_test.py index 7eece5a548..67f3c00896 100644 --- a/tensorflow/contrib/graph_editor/tests/select_test.py +++ b/tensorflow/contrib/graph_editor/tests/select_test.py @@ -101,6 +101,19 @@ class SelectTest(tf.test.TestCase): self.assertEqual(list(output_ts), [self.h]) self.assertEqual(list(inside_ts), [self.g]) + def test_compute_boundary_ts_2(self): + """Test for ge.select.compute_boundary_ts.""" + graph = tf.Graph() + with graph.as_default(): + a = tf.constant(1, name="a") + b = tf.constant(1, name="b") + c = tf.add(a, b, name="c") + _ = a + c + input_ts, output_ts, inside_ts = ge.select.compute_boundary_ts([a.op, c.op]) + self.assertEqual(list(input_ts), [b]) + self.assertEqual(list(output_ts), [a, c]) + self.assertEqual(list(inside_ts), [a]) + def test_get_within_boundary_ops_0(self): """Test for test_get_within_boundary_ops.""" control_outputs = ge.util.ControlOutputs(self.graph) |