diff options
author | Frank Perbet <fkp@google.com> | 2018-03-16 10:48:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-16 10:53:00 -0700 |
commit | 7bc1f803ea53f06677bc1f96cba59d1c751fc09a (patch) | |
tree | 330162ba29a4fa980cd4fa6f36256468dfc84841 /tensorflow/contrib/graph_editor | |
parent | 4091e498ba8dedc8f4ad5952dfe1262e735e7f42 (diff) |
Automated g4 rollback of changelist 189346024
PiperOrigin-RevId: 189361083
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/transform_test.py | 32 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/transform.py | 66 |
2 files changed, 24 insertions, 74 deletions
diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index 2a1b78042d..ca00394388 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -23,7 +23,6 @@ from tensorflow.contrib import graph_editor as ge from tensorflow.contrib.graph_editor.tests import match from tensorflow.python.client import session from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -85,9 +84,9 @@ class TransformTest(test.TestCase): def test_transform(self): transformer = ge.Transformer() - def my_transform_op_handler(info, op, new_inputs): + def my_transform_op_handler(info, op): add_noise = op.name.startswith("Add") - op_, op_outputs_ = ge.transform.copy_op_handler(info, op, new_inputs) + op_, op_outputs_ = ge.transform.copy_op_handler(info, op) if not add_noise: return op_, op_outputs_ # add some noise to op @@ -202,36 +201,15 @@ class TransformTest(test.TestCase): get_operation_by_name("res/grad/mul1_grad/Mul_1")) # Make sure _original_ops are as expected. - self.assertEqual(original_mul1_grad._original_op.name, u"mul1") - self.assertEqual(result_mul1_grad._original_op.name, u"res/mul1") - self.assertNotEqual(res.name, g.name) + self.assertEquals(original_mul1_grad._original_op.name, u"mul1") + self.assertEquals(result_mul1_grad._original_op.name, u"res/mul1") + self.assertNotEquals(res.name, g.name) with session.Session() as sess: sess.run(variables.global_variables_initializer()) g_val, res_val = sess.run([g, res]) self.assertNear(g_val, 0.0, ERROR_TOLERANCE) self.assertNear(res_val, 0.0, ERROR_TOLERANCE) - def test_graph_while_loop(self): - graph = ops.Graph() - with graph.as_default(): - max_index = array_ops.placeholder(dtype=dtypes.int32, shape=tuple()) - index_start = constant_op.constant(1) - sum_start = constant_op.constant(0) - _, result = control_flow_ops.while_loop( - cond=lambda i, unused_s: i <= max_index, - body=lambda i, s: (i + 1, s + i), - loop_vars=[index_start, sum_start]) - copied_graph = ops.Graph() - _, copy_info = ge.copy( - graph, dst_graph=copied_graph, dst_scope="imported") - copied_result = copy_info.transformed(result) - copied_max_index = copy_info.transformed(max_index) - with copied_graph.as_default(): - with session.Session() as sess: - n = 10 - sum_val = sess.run(copied_result, feed_dict={copied_max_index: n}) - self.assertEqual(sum_val, 55) - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 03c9afe813..14ac529665 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -30,7 +30,6 @@ from tensorflow.contrib.graph_editor import select from tensorflow.contrib.graph_editor import subgraph from tensorflow.contrib.graph_editor import util from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging @@ -130,26 +129,20 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True): return None -def copy_op_handler(info, op, new_inputs, copy_shape=True): +def copy_op_handler(info, op, copy_shape=True): """Copy a `tf.Operation`. Args: info: Transform._TmpInfo instance. op: the `tf.Operation` to be copied. - new_inputs: The new inputs for this op. copy_shape: also copy the shape of the tensor Returns: A `(op, op_outputs)` tuple containing the transformed op and its outputs. """ - # The `new_inputs` was added to this function. For compatibility reason, - # let's raise an error if `new_inputs` is a boolean. - if isinstance(new_inputs, bool): - raise TypeError("the `new_inputs` argument must be an iterable.") - # pylint: disable=protected-access # Clone the node def: - node_def_ = deepcopy(op.node_def) + node_def_ = deepcopy(op._node_def) # Transform name: name_ = info.new_name(op.name) @@ -162,10 +155,10 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True): # Make a copy of the op_def too. # Its unique to every _type_ of Operation. - op_def_ = deepcopy(op.op_def) + op_def_ = deepcopy(op._op_def) # Initialize a new Operation instance - op_ = tf_ops.Operation(node_def_, info.graph_, new_inputs, output_types_, + op_ = tf_ops.Operation(node_def_, info.graph_, [], output_types_, [], input_types_, None, op_def_) # copy the shape over @@ -177,7 +170,6 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True): # attribute to exist, we will create a dummy original_op first and then # later finalise it with the actual original_op when all the ops have # been copied. - # TODO(fkp): Stop worrying about _original_op and remove this code? if op._original_op: op_._original_op = op._original_op @@ -336,14 +328,6 @@ class _TmpInfo(object): for key in self.graph.get_all_collection_keys()) self.cyclic_ops = [] self.transform_original_op_handler = transform_op_if_inside_handler - # The graph is transformed op by op, in the same order the original ops - # where created. However, this is sometimes not possible due to cycles - # (e.g. while loops). So when the transformer creates a new op whose - # inputs do not exist yet, temporary placeholders are created and stored - # in this `tmp_cyclic_ts` container. During a second pass, - # those temporary tensors are replaced by the proper transformed tensors - # (see the function `_finalize_cycle`). - self.tmp_cyclic_ts = [] def new_name(self, name): """Compute a destination name from a source name. @@ -444,10 +428,10 @@ class Transformer(object): # Create temporary info used during this transform call info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope) + info.transform_original_op_handler = self.transform_original_op_handler self._copy_ops(info) - self._finalize_cycle(info) - self._connect_control_inputs(info) + self._connect_ops(info) # Compute information about the transformation res_info = TransformerInfo(info) @@ -456,12 +440,10 @@ class Transformer(object): def _copy_ops(self, info): """Copy ops without connecting them.""" - sorted_ops = sorted(info.sgv.ops, key=lambda op: op._id) # pylint: disable=protected-access - for op in sorted_ops: + for op in info.sgv.ops: logging.debug("Copying op: %s", op.name) - new_inputs = [self._transformed_t(info, t) for t in op.inputs] # TODO(fkp): return a subgraph? - op_, op_outputs_ = self.transform_op_handler(info, op, new_inputs) + op_, op_outputs_ = self.transform_op_handler(info, op) if op is op_: raise ValueError("In-place transformation not allowed.") @@ -474,31 +456,27 @@ class Transformer(object): info.transformed_ts[op_output] = op_output_ self.assign_collections_handler(info, op_output, op_output_) - def _finalize_cycle(self, info): - for t, tmp_t_ in info.tmp_cyclic_ts: - if t not in info.transformed_ts: - raise ValueError("The tensor {} should be transformed by now.".format( - t.name)) - op_ = tmp_t_.consumers()[0] - t_ = info.transformed_ts[t] - op_._update_input(list(op_.inputs).index(tmp_t_), t_) # pylint: disable=protected-access - - def _connect_control_inputs(self, info): + def _connect_ops(self, info): """Connect the previously copied ops.""" for op in info.sgv.ops: - logging.debug("Connecting control inputs of op: %s", op.name) + logging.debug("Finalizing op: %s", op.name) op_ = info.transformed_ops[op] - # Finalize original op. - # TODO(fkp): Stop worrying about _original_op and remove this code? # pylint: disable=protected-access + if op_.inputs: + raise ValueError("The newly transformed op should not have " + "any inputs yet: {}".format(op_.name)) + inputs_ = [self._transformed_t(info, t) for t in op.inputs] + for t in inputs_: + op_._add_input(t) + + # Finalize original op. if op._original_op: - original_op = self.transform_original_op_handler(info, op._original_op) + original_op = info.transform_original_op_handler(info, op._original_op) if original_op is None: logging.debug("Could not find original op for: %s", op_.name) else: op_._original_op = original_op - # pylint: enable=protected-access # Finalize control inputs: control_inputs_ = [self.transform_control_input_handler(info, ci) @@ -550,12 +528,6 @@ class Transformer(object): def _transformed_t(self, info, t): """Return tre transformed tensor of `t`.""" if t not in info.transformed_ts: - if t.op in info.ops: - with info.graph_.as_default(): - tmp_t_ = array_ops.placeholder( - shape=t.shape, dtype=t.dtype, name="ge_tmp") - info.tmp_cyclic_ts.append((t, tmp_t_)) - return tmp_t_ # If op is not in the subgraph. if t in info.sgv_inputs_set: # t is an input of the subgraph. |