aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-21 02:51:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-21 03:06:58 -0800
commit2145b44339642796dc382153d26b434c2cc18559 (patch)
treec1c09e2e35366bfb9c5b074029f2bed9d8eb1371
parentf7048fc88b44102e745def51a8b2610c4aacb139 (diff)
Fix two bugs int the Graph editor;
- compute_boundary_ts was sometimes adding spurious inputs - sgv.consumers was returning op inside the subgraph Change: 142645987
-rw-r--r--tensorflow/contrib/graph_editor/select.py30
-rw-r--r--tensorflow/contrib/graph_editor/subgraph.py13
-rw-r--r--tensorflow/contrib/graph_editor/tests/edit_test.py4
-rw-r--r--tensorflow/contrib/graph_editor/tests/select_test.py13
4 files changed, 43 insertions, 17 deletions
diff --git a/tensorflow/contrib/graph_editor/select.py b/tensorflow/contrib/graph_editor/select.py
index 0e2914cb0d..1125b90a9e 100644
--- a/tensorflow/contrib/graph_editor/select.py
+++ b/tensorflow/contrib/graph_editor/select.py
@@ -272,7 +272,7 @@ def get_ops_ios(ops, control_inputs=False, control_outputs=None,
return res
-def compute_boundary_ts(ops, ambiguous_ts_are_outputs=True):
+def compute_boundary_ts(ops):
"""Compute the tensors at the boundary of a set of ops.
This function looks at all the tensors connected to the given ops (in/out)
@@ -281,17 +281,18 @@ def compute_boundary_ts(ops, ambiguous_ts_are_outputs=True):
2) output tensors: tensors whose consumer operations are not in ops
3) inside tensors: tensors which are neither input nor output tensors.
+ Note that a tensor can be both an inside tensor and an output tensor if it is
+ consumed by operations both outside and inside of `ops`.
+
Args:
ops: an object convertible to a list of tf.Operation.
- ambiguous_ts_are_outputs: a tensor can have consumers both inside and
- outside ops. Such tensors are treated as outside tensor if
- ambiguous_ts_are_outputs is True, otherwise they are treated as
- inside tensor.
Returns:
A tuple `(outside_input_ts, outside_output_ts, inside_ts)` where:
`outside_input_ts` is a Python list of input tensors;
`outside_output_ts` is a python list of output tensors;
`inside_ts` is a python list of inside tensors.
+ Since a tensor can be both an inside tensor and an output tensor,
+ `outside_output_ts` and `inside_ts` might intersect.
Raises:
TypeError: if ops cannot be converted to a list of tf.Operation.
"""
@@ -301,22 +302,25 @@ def compute_boundary_ts(ops, ambiguous_ts_are_outputs=True):
output_ts_set = frozenset(output_ts)
ops_set = frozenset(ops)
- # fill in inside
+ # Compute inside tensors.
inside_ts = []
+ only_inside_ts = []
for t in input_ts:
- # is also output?
+ # Skip if the input tensor is not also an output tensor.
if t not in output_ts_set:
continue
- # is ambiguous_ts_are_outputs is True, don't add to inside if ambiguous
- if ambiguous_ts_are_outputs:
- consumers = frozenset(t.consumers())
- if consumers - ops_set:
- continue
+ # Mark as "inside".
inside_ts.append(t)
+ # Mark as "only inside" if the tensor is not both inside and output.
+ consumers = frozenset(t.consumers())
+ if consumers - ops_set:
+ continue
+ only_inside_ts.append(t)
inside_ts_set = frozenset(inside_ts)
+ only_inside_ts_set = frozenset(only_inside_ts)
+ outside_output_ts = [t for t in output_ts if t not in only_inside_ts_set]
outside_input_ts = [t for t in input_ts if t not in inside_ts_set]
- outside_output_ts = [t for t in output_ts if t not in inside_ts_set]
return outside_input_ts, outside_output_ts, inside_ts
diff --git a/tensorflow/contrib/graph_editor/subgraph.py b/tensorflow/contrib/graph_editor/subgraph.py
index 00a755c79f..bfeb3ae23a 100644
--- a/tensorflow/contrib/graph_editor/subgraph.py
+++ b/tensorflow/contrib/graph_editor/subgraph.py
@@ -561,10 +561,19 @@ class SubGraphView(object):
return subgraph_id
def consumers(self):
- """Return a Python set of all the consumers of this subgraph view."""
+ """Return a Python set of all the consumers of this subgraph view.
+
+ A consumer of a subgraph view is a tf.Operation which is a consumer
+ of one of the output tensors and is not in the subgraph.
+
+ Returns:
+ A list of `tf.Operation` which are the consumers of this subgraph view.
+ """
+ ops_set = frozenset(self._ops)
res = []
for output in self._output_ts:
- util.concatenate_unique(res, output.consumers())
+ consumers = [op for op in output.consumers() if op not in ops_set]
+ util.concatenate_unique(res, consumers)
return res
diff --git a/tensorflow/contrib/graph_editor/tests/edit_test.py b/tensorflow/contrib/graph_editor/tests/edit_test.py
index 968a73c812..a3330beee8 100644
--- a/tensorflow/contrib/graph_editor/tests/edit_test.py
+++ b/tensorflow/contrib/graph_editor/tests/edit_test.py
@@ -49,10 +49,10 @@ class EditTest(tf.test.TestCase):
"""Test for ge.detach."""
sgv = ge.sgv(self.c.op, self.a.op)
control_outputs = ge.util.ControlOutputs(self.graph)
- ge.detach(sgv, control_inputs=control_outputs)
+ ge.detach(sgv, control_ios=control_outputs)
# make sure the detached graph is as expected.
self.assertTrue(ge.matcher("^foo/c$")
- .input_ops("geph__a_0", "geph__b_0")(self.c.op))
+ .input_ops("a", "geph__b_0")(self.c.op))
def test_connect(self):
"""Test for ge.connect."""
diff --git a/tensorflow/contrib/graph_editor/tests/select_test.py b/tensorflow/contrib/graph_editor/tests/select_test.py
index 7eece5a548..67f3c00896 100644
--- a/tensorflow/contrib/graph_editor/tests/select_test.py
+++ b/tensorflow/contrib/graph_editor/tests/select_test.py
@@ -101,6 +101,19 @@ class SelectTest(tf.test.TestCase):
self.assertEqual(list(output_ts), [self.h])
self.assertEqual(list(inside_ts), [self.g])
+ def test_compute_boundary_ts_2(self):
+ """Test for ge.select.compute_boundary_ts."""
+ graph = tf.Graph()
+ with graph.as_default():
+ a = tf.constant(1, name="a")
+ b = tf.constant(1, name="b")
+ c = tf.add(a, b, name="c")
+ _ = a + c
+ input_ts, output_ts, inside_ts = ge.select.compute_boundary_ts([a.op, c.op])
+ self.assertEqual(list(input_ts), [b])
+ self.assertEqual(list(output_ts), [a, c])
+ self.assertEqual(list(inside_ts), [a])
+
def test_get_within_boundary_ops_0(self):
"""Test for test_get_within_boundary_ops."""
control_outputs = ge.util.ControlOutputs(self.graph)