aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor
diff options
context:
space:
mode:
authorGravatar Frank Perbet <fkp@google.com>2018-03-16 09:02:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-16 09:07:58 -0700
commit4b5cb6c49934b6c62fcfc4bf710a30dcce2568d3 (patch)
tree1a7e1bab6ffe80866bf8a751c195e008ca926d3b /tensorflow/contrib/graph_editor
parent09d3d2dcc1b1b9ee7282b37bc4e0f212c577f6a2 (diff)
Make the graph_editor C-API friendly: always construct ops with their inputs.
PiperOrigin-RevId: 189346024
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r--tensorflow/contrib/graph_editor/tests/transform_test.py32
-rw-r--r--tensorflow/contrib/graph_editor/transform.py66
2 files changed, 74 insertions, 24 deletions
diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py
index ca00394388..2a1b78042d 100644
--- a/tensorflow/contrib/graph_editor/tests/transform_test.py
+++ b/tensorflow/contrib/graph_editor/tests/transform_test.py
@@ -23,6 +23,7 @@ from tensorflow.contrib import graph_editor as ge
from tensorflow.contrib.graph_editor.tests import match
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -84,9 +85,9 @@ class TransformTest(test.TestCase):
def test_transform(self):
transformer = ge.Transformer()
- def my_transform_op_handler(info, op):
+ def my_transform_op_handler(info, op, new_inputs):
add_noise = op.name.startswith("Add")
- op_, op_outputs_ = ge.transform.copy_op_handler(info, op)
+ op_, op_outputs_ = ge.transform.copy_op_handler(info, op, new_inputs)
if not add_noise:
return op_, op_outputs_
# add some noise to op
@@ -201,15 +202,36 @@ class TransformTest(test.TestCase):
get_operation_by_name("res/grad/mul1_grad/Mul_1"))
# Make sure _original_ops are as expected.
- self.assertEquals(original_mul1_grad._original_op.name, u"mul1")
- self.assertEquals(result_mul1_grad._original_op.name, u"res/mul1")
- self.assertNotEquals(res.name, g.name)
+ self.assertEqual(original_mul1_grad._original_op.name, u"mul1")
+ self.assertEqual(result_mul1_grad._original_op.name, u"res/mul1")
+ self.assertNotEqual(res.name, g.name)
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
g_val, res_val = sess.run([g, res])
self.assertNear(g_val, 0.0, ERROR_TOLERANCE)
self.assertNear(res_val, 0.0, ERROR_TOLERANCE)
+ def test_graph_while_loop(self):
+ graph = ops.Graph()
+ with graph.as_default():
+ max_index = array_ops.placeholder(dtype=dtypes.int32, shape=tuple())
+ index_start = constant_op.constant(1)
+ sum_start = constant_op.constant(0)
+ _, result = control_flow_ops.while_loop(
+ cond=lambda i, unused_s: i <= max_index,
+ body=lambda i, s: (i + 1, s + i),
+ loop_vars=[index_start, sum_start])
+ copied_graph = ops.Graph()
+ _, copy_info = ge.copy(
+ graph, dst_graph=copied_graph, dst_scope="imported")
+ copied_result = copy_info.transformed(result)
+ copied_max_index = copy_info.transformed(max_index)
+ with copied_graph.as_default():
+ with session.Session() as sess:
+ n = 10
+ sum_val = sess.run(copied_result, feed_dict={copied_max_index: n})
+ self.assertEqual(sum_val, 55)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py
index 14ac529665..03c9afe813 100644
--- a/tensorflow/contrib/graph_editor/transform.py
+++ b/tensorflow/contrib/graph_editor/transform.py
@@ -30,6 +30,7 @@ from tensorflow.contrib.graph_editor import select
from tensorflow.contrib.graph_editor import subgraph
from tensorflow.contrib.graph_editor import util
from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
@@ -129,20 +130,26 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True):
return None
-def copy_op_handler(info, op, copy_shape=True):
+def copy_op_handler(info, op, new_inputs, copy_shape=True):
"""Copy a `tf.Operation`.
Args:
info: Transform._TmpInfo instance.
op: the `tf.Operation` to be copied.
+ new_inputs: The new inputs for this op.
copy_shape: also copy the shape of the tensor
Returns:
A `(op, op_outputs)` tuple containing the transformed op and its outputs.
"""
+ # The `new_inputs` was added to this function. For compatibility reason,
+ # let's raise an error if `new_inputs` is a boolean.
+ if isinstance(new_inputs, bool):
+ raise TypeError("the `new_inputs` argument must be an iterable.")
+
# pylint: disable=protected-access
# Clone the node def:
- node_def_ = deepcopy(op._node_def)
+ node_def_ = deepcopy(op.node_def)
# Transform name:
name_ = info.new_name(op.name)
@@ -155,10 +162,10 @@ def copy_op_handler(info, op, copy_shape=True):
# Make a copy of the op_def too.
# Its unique to every _type_ of Operation.
- op_def_ = deepcopy(op._op_def)
+ op_def_ = deepcopy(op.op_def)
# Initialize a new Operation instance
- op_ = tf_ops.Operation(node_def_, info.graph_, [], output_types_,
+ op_ = tf_ops.Operation(node_def_, info.graph_, new_inputs, output_types_,
[], input_types_, None, op_def_)
# copy the shape over
@@ -170,6 +177,7 @@ def copy_op_handler(info, op, copy_shape=True):
# attribute to exist, we will create a dummy original_op first and then
# later finalise it with the actual original_op when all the ops have
# been copied.
+ # TODO(fkp): Stop worrying about _original_op and remove this code?
if op._original_op:
op_._original_op = op._original_op
@@ -328,6 +336,14 @@ class _TmpInfo(object):
for key in self.graph.get_all_collection_keys())
self.cyclic_ops = []
self.transform_original_op_handler = transform_op_if_inside_handler
+ # The graph is transformed op by op, in the same order the original ops
+ # where created. However, this is sometimes not possible due to cycles
+ # (e.g. while loops). So when the transformer creates a new op whose
+ # inputs do not exist yet, temporary placeholders are created and stored
+ # in this `tmp_cyclic_ts` container. During a second pass,
+ # those temporary tensors are replaced by the proper transformed tensors
+ # (see the function `_finalize_cycle`).
+ self.tmp_cyclic_ts = []
def new_name(self, name):
"""Compute a destination name from a source name.
@@ -428,10 +444,10 @@ class Transformer(object):
# Create temporary info used during this transform call
info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope)
- info.transform_original_op_handler = self.transform_original_op_handler
self._copy_ops(info)
- self._connect_ops(info)
+ self._finalize_cycle(info)
+ self._connect_control_inputs(info)
# Compute information about the transformation
res_info = TransformerInfo(info)
@@ -440,10 +456,12 @@ class Transformer(object):
def _copy_ops(self, info):
"""Copy ops without connecting them."""
- for op in info.sgv.ops:
+ sorted_ops = sorted(info.sgv.ops, key=lambda op: op._id) # pylint: disable=protected-access
+ for op in sorted_ops:
logging.debug("Copying op: %s", op.name)
+ new_inputs = [self._transformed_t(info, t) for t in op.inputs]
# TODO(fkp): return a subgraph?
- op_, op_outputs_ = self.transform_op_handler(info, op)
+ op_, op_outputs_ = self.transform_op_handler(info, op, new_inputs)
if op is op_:
raise ValueError("In-place transformation not allowed.")
@@ -456,27 +474,31 @@ class Transformer(object):
info.transformed_ts[op_output] = op_output_
self.assign_collections_handler(info, op_output, op_output_)
- def _connect_ops(self, info):
+ def _finalize_cycle(self, info):
+ for t, tmp_t_ in info.tmp_cyclic_ts:
+ if t not in info.transformed_ts:
+ raise ValueError("The tensor {} should be transformed by now.".format(
+ t.name))
+ op_ = tmp_t_.consumers()[0]
+ t_ = info.transformed_ts[t]
+ op_._update_input(list(op_.inputs).index(tmp_t_), t_) # pylint: disable=protected-access
+
+ def _connect_control_inputs(self, info):
"""Connect the previously copied ops."""
for op in info.sgv.ops:
- logging.debug("Finalizing op: %s", op.name)
+ logging.debug("Connecting control inputs of op: %s", op.name)
op_ = info.transformed_ops[op]
- # pylint: disable=protected-access
- if op_.inputs:
- raise ValueError("The newly transformed op should not have "
- "any inputs yet: {}".format(op_.name))
- inputs_ = [self._transformed_t(info, t) for t in op.inputs]
- for t in inputs_:
- op_._add_input(t)
-
# Finalize original op.
+ # TODO(fkp): Stop worrying about _original_op and remove this code?
+ # pylint: disable=protected-access
if op._original_op:
- original_op = info.transform_original_op_handler(info, op._original_op)
+ original_op = self.transform_original_op_handler(info, op._original_op)
if original_op is None:
logging.debug("Could not find original op for: %s", op_.name)
else:
op_._original_op = original_op
+ # pylint: enable=protected-access
# Finalize control inputs:
control_inputs_ = [self.transform_control_input_handler(info, ci)
@@ -528,6 +550,12 @@ class Transformer(object):
def _transformed_t(self, info, t):
"""Return tre transformed tensor of `t`."""
if t not in info.transformed_ts:
+ if t.op in info.ops:
+ with info.graph_.as_default():
+ tmp_t_ = array_ops.placeholder(
+ shape=t.shape, dtype=t.dtype, name="ge_tmp")
+ info.tmp_cyclic_ts.append((t, tmp_t_))
+ return tmp_t_
# If op is not in the subgraph.
if t in info.sgv_inputs_set:
# t is an input of the subgraph.