diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-11 14:04:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-11 14:07:52 -0700 |
commit | e5201672aa664cf39725f4a52b9774d2bae43ba3 (patch) | |
tree | 4ba887d26dd40b9566cfa4e26a9e4c713ce27480 /tensorflow/contrib/graph_editor | |
parent | eed6828acf19260279b38a7fbaf79141c813f795 (diff) |
Adds a nodedef_fn parameter to copy_op_handler, allowing customization by mutating
NodeDef before creating the copied operation.
PiperOrigin-RevId: 192505209
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/transform_test.py | 29 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/transform.py | 11 |
2 files changed, 39 insertions, 1 deletions
diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index 2603de6407..97f38c923f 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -18,9 +18,11 @@ from __future__ import division from __future__ import print_function import collections +import functools import numpy as np from tensorflow.contrib import graph_editor as ge from tensorflow.contrib.graph_editor.tests import match +from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -42,6 +44,7 @@ class TransformTest(test.TestCase): self.graph = ops.Graph() with self.graph.as_default(): c0 = constant_op.constant(1.0, shape=[10], name="Const") + c0.op._set_attr("_foo", attr_value_pb2.AttrValue(s=b"foo")) c1 = constant_op.constant(1.0, shape=[10], name="Const") c2 = constant_op.constant(1.0, shape=[10], name="Const") i = constant_op.constant(1.0, shape=[10], name="Input") @@ -112,6 +115,32 @@ class TransformTest(test.TestCase): top = ge.select_ops("^AddNoise_2$", graph=graph)[0] self.assertTrue(matcher2(top)) + def test_transform_nodedef_fn(self): + transformer = ge.Transformer() + + def nodedef_fn(node_def): + if "_foo" in node_def.attr: + del node_def.attr["_foo"] + node_def.attr["_bar"].s = b"bar" + return node_def + + my_copy_op_handler = functools.partial( + ge.transform.copy_op_handler, nodedef_fn=nodedef_fn) + transformer.transform_op_handler = my_copy_op_handler + + graph = ops.Graph() + transformer(self.graph, graph, "", "") + + c0_before = self.graph.get_operation_by_name("Const") + c0_after = graph.get_operation_by_name("Const") + self.assertEquals(c0_before.get_attr("_foo"), b"foo") + with self.assertRaises(ValueError): + c0_after.get_attr("_foo") + + all_ops = graph.get_operations() + for op in all_ops: + self.assertEquals(op.get_attr("_bar"), b"bar") + def test_copy_with_input_replacements(self): with self.graph.as_default(): ten = constant_op.constant(10.0, shape=[10], name="Input") diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index d8a48387a7..a320a3f232 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -129,7 +129,7 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True): return None -def copy_op_handler(info, op, new_inputs, copy_shape=True): +def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None): """Copy a `tf.Operation`. Args: @@ -137,6 +137,11 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True): op: the `tf.Operation` to be copied. new_inputs: The new inputs for this op. copy_shape: also copy the shape of the tensor + nodedef_fn: If provided, a function that will be run on the NodeDef + and should return a mutated NodeDef before a new Operation is created. + This is useful as certain features cannot be set on the Operation and + must be modified in NodeDef. + Returns: A `(op, op_outputs)` tuple containing the transformed op and its outputs. """ @@ -155,6 +160,10 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True): name_ = info.graph_.unique_name(name_) node_def_.name = name_ + # Mutate NodeDef if requested: + if nodedef_fn is not None: + node_def_ = nodedef_fn(node_def_) + # Copy the other inputs needed for initialization output_types_ = op._output_types[:] input_types_ = op._input_types[:] |