diff options
author | Frank Perbet <fkp@google.com> | 2018-03-16 09:02:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-16 09:07:58 -0700 |
commit | 4b5cb6c49934b6c62fcfc4bf710a30dcce2568d3 (patch) | |
tree | 1a7e1bab6ffe80866bf8a751c195e008ca926d3b /tensorflow/contrib/graph_editor | |
parent | 09d3d2dcc1b1b9ee7282b37bc4e0f212c577f6a2 (diff) |
Make the graph_editor C-API friendly: always construct ops with their inputs.
PiperOrigin-RevId: 189346024
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/transform_test.py | 32 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/transform.py | 66 |
2 files changed, 74 insertions, 24 deletions
diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index ca00394388..2a1b78042d 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -23,6 +23,7 @@ from tensorflow.contrib import graph_editor as ge from tensorflow.contrib.graph_editor.tests import match from tensorflow.python.client import session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -84,9 +85,9 @@ class TransformTest(test.TestCase): def test_transform(self): transformer = ge.Transformer() - def my_transform_op_handler(info, op): + def my_transform_op_handler(info, op, new_inputs): add_noise = op.name.startswith("Add") - op_, op_outputs_ = ge.transform.copy_op_handler(info, op) + op_, op_outputs_ = ge.transform.copy_op_handler(info, op, new_inputs) if not add_noise: return op_, op_outputs_ # add some noise to op @@ -201,15 +202,36 @@ class TransformTest(test.TestCase): get_operation_by_name("res/grad/mul1_grad/Mul_1")) # Make sure _original_ops are as expected. - self.assertEquals(original_mul1_grad._original_op.name, u"mul1") - self.assertEquals(result_mul1_grad._original_op.name, u"res/mul1") - self.assertNotEquals(res.name, g.name) + self.assertEqual(original_mul1_grad._original_op.name, u"mul1") + self.assertEqual(result_mul1_grad._original_op.name, u"res/mul1") + self.assertNotEqual(res.name, g.name) with session.Session() as sess: sess.run(variables.global_variables_initializer()) g_val, res_val = sess.run([g, res]) self.assertNear(g_val, 0.0, ERROR_TOLERANCE) self.assertNear(res_val, 0.0, ERROR_TOLERANCE) + def test_graph_while_loop(self): + graph = ops.Graph() + with graph.as_default(): + max_index = array_ops.placeholder(dtype=dtypes.int32, shape=tuple()) + index_start = constant_op.constant(1) + sum_start = constant_op.constant(0) + _, result = control_flow_ops.while_loop( + cond=lambda i, unused_s: i <= max_index, + body=lambda i, s: (i + 1, s + i), + loop_vars=[index_start, sum_start]) + copied_graph = ops.Graph() + _, copy_info = ge.copy( + graph, dst_graph=copied_graph, dst_scope="imported") + copied_result = copy_info.transformed(result) + copied_max_index = copy_info.transformed(max_index) + with copied_graph.as_default(): + with session.Session() as sess: + n = 10 + sum_val = sess.run(copied_result, feed_dict={copied_max_index: n}) + self.assertEqual(sum_val, 55) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 14ac529665..03c9afe813 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -30,6 +30,7 @@ from tensorflow.contrib.graph_editor import select from tensorflow.contrib.graph_editor import subgraph from tensorflow.contrib.graph_editor import util from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging @@ -129,20 +130,26 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True): return None -def copy_op_handler(info, op, copy_shape=True): +def copy_op_handler(info, op, new_inputs, copy_shape=True): """Copy a `tf.Operation`. Args: info: Transform._TmpInfo instance. op: the `tf.Operation` to be copied. + new_inputs: The new inputs for this op. copy_shape: also copy the shape of the tensor Returns: A `(op, op_outputs)` tuple containing the transformed op and its outputs. """ + # The `new_inputs` was added to this function. For compatibility reason, + # let's raise an error if `new_inputs` is a boolean. + if isinstance(new_inputs, bool): + raise TypeError("the `new_inputs` argument must be an iterable.") + # pylint: disable=protected-access # Clone the node def: - node_def_ = deepcopy(op._node_def) + node_def_ = deepcopy(op.node_def) # Transform name: name_ = info.new_name(op.name) @@ -155,10 +162,10 @@ def copy_op_handler(info, op, copy_shape=True): # Make a copy of the op_def too. # Its unique to every _type_ of Operation. - op_def_ = deepcopy(op._op_def) + op_def_ = deepcopy(op.op_def) # Initialize a new Operation instance - op_ = tf_ops.Operation(node_def_, info.graph_, [], output_types_, + op_ = tf_ops.Operation(node_def_, info.graph_, new_inputs, output_types_, [], input_types_, None, op_def_) # copy the shape over @@ -170,6 +177,7 @@ def copy_op_handler(info, op, copy_shape=True): # attribute to exist, we will create a dummy original_op first and then # later finalise it with the actual original_op when all the ops have # been copied. + # TODO(fkp): Stop worrying about _original_op and remove this code? if op._original_op: op_._original_op = op._original_op @@ -328,6 +336,14 @@ class _TmpInfo(object): for key in self.graph.get_all_collection_keys()) self.cyclic_ops = [] self.transform_original_op_handler = transform_op_if_inside_handler + # The graph is transformed op by op, in the same order the original ops + # where created. However, this is sometimes not possible due to cycles + # (e.g. while loops). So when the transformer creates a new op whose + # inputs do not exist yet, temporary placeholders are created and stored + # in this `tmp_cyclic_ts` container. During a second pass, + # those temporary tensors are replaced by the proper transformed tensors + # (see the function `_finalize_cycle`). + self.tmp_cyclic_ts = [] def new_name(self, name): """Compute a destination name from a source name. @@ -428,10 +444,10 @@ class Transformer(object): # Create temporary info used during this transform call info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope) - info.transform_original_op_handler = self.transform_original_op_handler self._copy_ops(info) - self._connect_ops(info) + self._finalize_cycle(info) + self._connect_control_inputs(info) # Compute information about the transformation res_info = TransformerInfo(info) @@ -440,10 +456,12 @@ class Transformer(object): def _copy_ops(self, info): """Copy ops without connecting them.""" - for op in info.sgv.ops: + sorted_ops = sorted(info.sgv.ops, key=lambda op: op._id) # pylint: disable=protected-access + for op in sorted_ops: logging.debug("Copying op: %s", op.name) + new_inputs = [self._transformed_t(info, t) for t in op.inputs] # TODO(fkp): return a subgraph? - op_, op_outputs_ = self.transform_op_handler(info, op) + op_, op_outputs_ = self.transform_op_handler(info, op, new_inputs) if op is op_: raise ValueError("In-place transformation not allowed.") @@ -456,27 +474,31 @@ class Transformer(object): info.transformed_ts[op_output] = op_output_ self.assign_collections_handler(info, op_output, op_output_) - def _connect_ops(self, info): + def _finalize_cycle(self, info): + for t, tmp_t_ in info.tmp_cyclic_ts: + if t not in info.transformed_ts: + raise ValueError("The tensor {} should be transformed by now.".format( + t.name)) + op_ = tmp_t_.consumers()[0] + t_ = info.transformed_ts[t] + op_._update_input(list(op_.inputs).index(tmp_t_), t_) # pylint: disable=protected-access + + def _connect_control_inputs(self, info): """Connect the previously copied ops.""" for op in info.sgv.ops: - logging.debug("Finalizing op: %s", op.name) + logging.debug("Connecting control inputs of op: %s", op.name) op_ = info.transformed_ops[op] - # pylint: disable=protected-access - if op_.inputs: - raise ValueError("The newly transformed op should not have " - "any inputs yet: {}".format(op_.name)) - inputs_ = [self._transformed_t(info, t) for t in op.inputs] - for t in inputs_: - op_._add_input(t) - # Finalize original op. + # TODO(fkp): Stop worrying about _original_op and remove this code? + # pylint: disable=protected-access if op._original_op: - original_op = info.transform_original_op_handler(info, op._original_op) + original_op = self.transform_original_op_handler(info, op._original_op) if original_op is None: logging.debug("Could not find original op for: %s", op_.name) else: op_._original_op = original_op + # pylint: enable=protected-access # Finalize control inputs: control_inputs_ = [self.transform_control_input_handler(info, ci) @@ -528,6 +550,12 @@ class Transformer(object): def _transformed_t(self, info, t): """Return tre transformed tensor of `t`.""" if t not in info.transformed_ts: + if t.op in info.ops: + with info.graph_.as_default(): + tmp_t_ = array_ops.placeholder( + shape=t.shape, dtype=t.dtype, name="ge_tmp") + info.tmp_cyclic_ts.append((t, tmp_t_)) + return tmp_t_ # If op is not in the subgraph. if t in info.sgv_inputs_set: # t is an input of the subgraph. |