aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-11 14:04:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-11 14:07:52 -0700
commite5201672aa664cf39725f4a52b9774d2bae43ba3 (patch)
tree4ba887d26dd40b9566cfa4e26a9e4c713ce27480 /tensorflow/contrib/graph_editor
parenteed6828acf19260279b38a7fbaf79141c813f795 (diff)
Adds a nodedef_fn parameter to copy_op_handler, allowing customization by mutating
NodeDef before creating the copied operation. PiperOrigin-RevId: 192505209
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r--tensorflow/contrib/graph_editor/tests/transform_test.py29
-rw-r--r--tensorflow/contrib/graph_editor/transform.py11
2 files changed, 39 insertions, 1 deletions
diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py
index 2603de6407..97f38c923f 100644
--- a/tensorflow/contrib/graph_editor/tests/transform_test.py
+++ b/tensorflow/contrib/graph_editor/tests/transform_test.py
@@ -18,9 +18,11 @@ from __future__ import division
from __future__ import print_function
import collections
+import functools
import numpy as np
from tensorflow.contrib import graph_editor as ge
from tensorflow.contrib.graph_editor.tests import match
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -42,6 +44,7 @@ class TransformTest(test.TestCase):
self.graph = ops.Graph()
with self.graph.as_default():
c0 = constant_op.constant(1.0, shape=[10], name="Const")
+ c0.op._set_attr("_foo", attr_value_pb2.AttrValue(s=b"foo"))
c1 = constant_op.constant(1.0, shape=[10], name="Const")
c2 = constant_op.constant(1.0, shape=[10], name="Const")
i = constant_op.constant(1.0, shape=[10], name="Input")
@@ -112,6 +115,32 @@ class TransformTest(test.TestCase):
top = ge.select_ops("^AddNoise_2$", graph=graph)[0]
self.assertTrue(matcher2(top))
+ def test_transform_nodedef_fn(self):
+ transformer = ge.Transformer()
+
+ def nodedef_fn(node_def):
+ if "_foo" in node_def.attr:
+ del node_def.attr["_foo"]
+ node_def.attr["_bar"].s = b"bar"
+ return node_def
+
+ my_copy_op_handler = functools.partial(
+ ge.transform.copy_op_handler, nodedef_fn=nodedef_fn)
+ transformer.transform_op_handler = my_copy_op_handler
+
+ graph = ops.Graph()
+ transformer(self.graph, graph, "", "")
+
+ c0_before = self.graph.get_operation_by_name("Const")
+ c0_after = graph.get_operation_by_name("Const")
+ self.assertEquals(c0_before.get_attr("_foo"), b"foo")
+ with self.assertRaises(ValueError):
+ c0_after.get_attr("_foo")
+
+ all_ops = graph.get_operations()
+ for op in all_ops:
+ self.assertEquals(op.get_attr("_bar"), b"bar")
+
def test_copy_with_input_replacements(self):
with self.graph.as_default():
ten = constant_op.constant(10.0, shape=[10], name="Input")
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py
index d8a48387a7..a320a3f232 100644
--- a/tensorflow/contrib/graph_editor/transform.py
+++ b/tensorflow/contrib/graph_editor/transform.py
@@ -129,7 +129,7 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True):
return None
-def copy_op_handler(info, op, new_inputs, copy_shape=True):
+def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None):
"""Copy a `tf.Operation`.
Args:
@@ -137,6 +137,11 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True):
op: the `tf.Operation` to be copied.
new_inputs: The new inputs for this op.
copy_shape: also copy the shape of the tensor
+ nodedef_fn: If provided, a function that will be run on the NodeDef
+ and should return a mutated NodeDef before a new Operation is created.
+ This is useful as certain features cannot be set on the Operation and
+ must be modified in NodeDef.
+
Returns:
A `(op, op_outputs)` tuple containing the transformed op and its outputs.
"""
@@ -155,6 +160,10 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True):
name_ = info.graph_.unique_name(name_)
node_def_.name = name_
+ # Mutate NodeDef if requested:
+ if nodedef_fn is not None:
+ node_def_ = nodedef_fn(node_def_)
+
# Copy the other inputs needed for initialization
output_types_ = op._output_types[:]
input_types_ = op._input_types[:]