aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor
diff options
context:
space:
mode:
authorGravatar Frank Perbet <fkp@google.com>2018-03-16 10:48:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-16 10:53:00 -0700
commit7bc1f803ea53f06677bc1f96cba59d1c751fc09a (patch)
tree330162ba29a4fa980cd4fa6f36256468dfc84841 /tensorflow/contrib/graph_editor
parent4091e498ba8dedc8f4ad5952dfe1262e735e7f42 (diff)
Automated g4 rollback of changelist 189346024
PiperOrigin-RevId: 189361083
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r--tensorflow/contrib/graph_editor/tests/transform_test.py32
-rw-r--r--tensorflow/contrib/graph_editor/transform.py66
2 files changed, 24 insertions, 74 deletions
diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py
index 2a1b78042d..ca00394388 100644
--- a/tensorflow/contrib/graph_editor/tests/transform_test.py
+++ b/tensorflow/contrib/graph_editor/tests/transform_test.py
@@ -23,7 +23,6 @@ from tensorflow.contrib import graph_editor as ge
from tensorflow.contrib.graph_editor.tests import match
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -85,9 +84,9 @@ class TransformTest(test.TestCase):
def test_transform(self):
transformer = ge.Transformer()
- def my_transform_op_handler(info, op, new_inputs):
+ def my_transform_op_handler(info, op):
add_noise = op.name.startswith("Add")
- op_, op_outputs_ = ge.transform.copy_op_handler(info, op, new_inputs)
+ op_, op_outputs_ = ge.transform.copy_op_handler(info, op)
if not add_noise:
return op_, op_outputs_
# add some noise to op
@@ -202,36 +201,15 @@ class TransformTest(test.TestCase):
get_operation_by_name("res/grad/mul1_grad/Mul_1"))
# Make sure _original_ops are as expected.
- self.assertEqual(original_mul1_grad._original_op.name, u"mul1")
- self.assertEqual(result_mul1_grad._original_op.name, u"res/mul1")
- self.assertNotEqual(res.name, g.name)
+ self.assertEquals(original_mul1_grad._original_op.name, u"mul1")
+ self.assertEquals(result_mul1_grad._original_op.name, u"res/mul1")
+ self.assertNotEquals(res.name, g.name)
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
g_val, res_val = sess.run([g, res])
self.assertNear(g_val, 0.0, ERROR_TOLERANCE)
self.assertNear(res_val, 0.0, ERROR_TOLERANCE)
- def test_graph_while_loop(self):
- graph = ops.Graph()
- with graph.as_default():
- max_index = array_ops.placeholder(dtype=dtypes.int32, shape=tuple())
- index_start = constant_op.constant(1)
- sum_start = constant_op.constant(0)
- _, result = control_flow_ops.while_loop(
- cond=lambda i, unused_s: i <= max_index,
- body=lambda i, s: (i + 1, s + i),
- loop_vars=[index_start, sum_start])
- copied_graph = ops.Graph()
- _, copy_info = ge.copy(
- graph, dst_graph=copied_graph, dst_scope="imported")
- copied_result = copy_info.transformed(result)
- copied_max_index = copy_info.transformed(max_index)
- with copied_graph.as_default():
- with session.Session() as sess:
- n = 10
- sum_val = sess.run(copied_result, feed_dict={copied_max_index: n})
- self.assertEqual(sum_val, 55)
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py
index 03c9afe813..14ac529665 100644
--- a/tensorflow/contrib/graph_editor/transform.py
+++ b/tensorflow/contrib/graph_editor/transform.py
@@ -30,7 +30,6 @@ from tensorflow.contrib.graph_editor import select
from tensorflow.contrib.graph_editor import subgraph
from tensorflow.contrib.graph_editor import util
from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
@@ -130,26 +129,20 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True):
return None
-def copy_op_handler(info, op, new_inputs, copy_shape=True):
+def copy_op_handler(info, op, copy_shape=True):
"""Copy a `tf.Operation`.
Args:
info: Transform._TmpInfo instance.
op: the `tf.Operation` to be copied.
- new_inputs: The new inputs for this op.
copy_shape: also copy the shape of the tensor
Returns:
A `(op, op_outputs)` tuple containing the transformed op and its outputs.
"""
- # The `new_inputs` was added to this function. For compatibility reason,
- # let's raise an error if `new_inputs` is a boolean.
- if isinstance(new_inputs, bool):
- raise TypeError("the `new_inputs` argument must be an iterable.")
-
# pylint: disable=protected-access
# Clone the node def:
- node_def_ = deepcopy(op.node_def)
+ node_def_ = deepcopy(op._node_def)
# Transform name:
name_ = info.new_name(op.name)
@@ -162,10 +155,10 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True):
# Make a copy of the op_def too.
# Its unique to every _type_ of Operation.
- op_def_ = deepcopy(op.op_def)
+ op_def_ = deepcopy(op._op_def)
# Initialize a new Operation instance
- op_ = tf_ops.Operation(node_def_, info.graph_, new_inputs, output_types_,
+ op_ = tf_ops.Operation(node_def_, info.graph_, [], output_types_,
[], input_types_, None, op_def_)
# copy the shape over
@@ -177,7 +170,6 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True):
# attribute to exist, we will create a dummy original_op first and then
# later finalise it with the actual original_op when all the ops have
# been copied.
- # TODO(fkp): Stop worrying about _original_op and remove this code?
if op._original_op:
op_._original_op = op._original_op
@@ -336,14 +328,6 @@ class _TmpInfo(object):
for key in self.graph.get_all_collection_keys())
self.cyclic_ops = []
self.transform_original_op_handler = transform_op_if_inside_handler
- # The graph is transformed op by op, in the same order the original ops
- # where created. However, this is sometimes not possible due to cycles
- # (e.g. while loops). So when the transformer creates a new op whose
- # inputs do not exist yet, temporary placeholders are created and stored
- # in this `tmp_cyclic_ts` container. During a second pass,
- # those temporary tensors are replaced by the proper transformed tensors
- # (see the function `_finalize_cycle`).
- self.tmp_cyclic_ts = []
def new_name(self, name):
"""Compute a destination name from a source name.
@@ -444,10 +428,10 @@ class Transformer(object):
# Create temporary info used during this transform call
info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope)
+ info.transform_original_op_handler = self.transform_original_op_handler
self._copy_ops(info)
- self._finalize_cycle(info)
- self._connect_control_inputs(info)
+ self._connect_ops(info)
# Compute information about the transformation
res_info = TransformerInfo(info)
@@ -456,12 +440,10 @@ class Transformer(object):
def _copy_ops(self, info):
"""Copy ops without connecting them."""
- sorted_ops = sorted(info.sgv.ops, key=lambda op: op._id) # pylint: disable=protected-access
- for op in sorted_ops:
+ for op in info.sgv.ops:
logging.debug("Copying op: %s", op.name)
- new_inputs = [self._transformed_t(info, t) for t in op.inputs]
# TODO(fkp): return a subgraph?
- op_, op_outputs_ = self.transform_op_handler(info, op, new_inputs)
+ op_, op_outputs_ = self.transform_op_handler(info, op)
if op is op_:
raise ValueError("In-place transformation not allowed.")
@@ -474,31 +456,27 @@ class Transformer(object):
info.transformed_ts[op_output] = op_output_
self.assign_collections_handler(info, op_output, op_output_)
- def _finalize_cycle(self, info):
- for t, tmp_t_ in info.tmp_cyclic_ts:
- if t not in info.transformed_ts:
- raise ValueError("The tensor {} should be transformed by now.".format(
- t.name))
- op_ = tmp_t_.consumers()[0]
- t_ = info.transformed_ts[t]
- op_._update_input(list(op_.inputs).index(tmp_t_), t_) # pylint: disable=protected-access
-
- def _connect_control_inputs(self, info):
+ def _connect_ops(self, info):
"""Connect the previously copied ops."""
for op in info.sgv.ops:
- logging.debug("Connecting control inputs of op: %s", op.name)
+ logging.debug("Finalizing op: %s", op.name)
op_ = info.transformed_ops[op]
- # Finalize original op.
- # TODO(fkp): Stop worrying about _original_op and remove this code?
# pylint: disable=protected-access
+ if op_.inputs:
+ raise ValueError("The newly transformed op should not have "
+ "any inputs yet: {}".format(op_.name))
+ inputs_ = [self._transformed_t(info, t) for t in op.inputs]
+ for t in inputs_:
+ op_._add_input(t)
+
+ # Finalize original op.
if op._original_op:
- original_op = self.transform_original_op_handler(info, op._original_op)
+ original_op = info.transform_original_op_handler(info, op._original_op)
if original_op is None:
logging.debug("Could not find original op for: %s", op_.name)
else:
op_._original_op = original_op
- # pylint: enable=protected-access
# Finalize control inputs:
control_inputs_ = [self.transform_control_input_handler(info, ci)
@@ -550,12 +528,6 @@ class Transformer(object):
def _transformed_t(self, info, t):
"""Return tre transformed tensor of `t`."""
if t not in info.transformed_ts:
- if t.op in info.ops:
- with info.graph_.as_default():
- tmp_t_ = array_ops.placeholder(
- shape=t.shape, dtype=t.dtype, name="ge_tmp")
- info.tmp_cyclic_ts.append((t, tmp_t_))
- return tmp_t_
# If op is not in the subgraph.
if t in info.sgv_inputs_set:
# t is an input of the subgraph.