aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-10 15:45:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-10 15:50:14 -0700
commit5397130e5fdff8e783cf677a0fdaf2c8f4b5b142 (patch)
tree745a1b0f93155f0229326371d61d967a979ed3f3 /tensorflow/contrib/graph_editor
parent0344d9c9cf2f15f64eaabf72a102acd759541850 (diff)
Fix bug with graph_replace on graphs with gradients. Added a test to verify the fix.
PiperOrigin-RevId: 161452438
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r--tensorflow/contrib/graph_editor/tests/transform_test.py29
-rw-r--r--tensorflow/contrib/graph_editor/transform.py19
2 files changed, 42 insertions, 6 deletions
diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py
index a4105645c6..ab5776b9dd 100644
--- a/tensorflow/contrib/graph_editor/tests/transform_test.py
+++ b/tensorflow/contrib/graph_editor/tests/transform_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -181,6 +182,34 @@ class TransformTest(test.TestCase):
self.assertEqual(res[0].name, "b:0")
self.assertEqual(res[1].name, "add_1:0")
+ def test_graph_replace_gradients(self):
+ ops.reset_default_graph()
+ w = variables.Variable(0.0, name="w")
+ y = math_ops.multiply(math_ops.multiply(w, w, name="mul1"), w, name="mul2")
+ g = gradients_impl.gradients(y, w, name="grad")[0]
+
+ # Extract the operations.
+ replacement_ts = {w.value(): g}
+ original_mul1_grad = (ops.get_default_graph().
+ get_operation_by_name("grad/mul1_grad/mul_1"))
+
+ # Should not raise exception.
+ res = ge.graph_replace(g, replacement_ts, dst_scope="res")
+
+ # Extract the operations after graph_replace.
+ result_mul1_grad = (ops.get_default_graph().
+ get_operation_by_name("res/grad/mul1_grad/mul_1"))
+
+ # Make sure _original_ops are as expected.
+ self.assertEquals(original_mul1_grad._original_op.name, u"mul1")
+ self.assertEquals(result_mul1_grad._original_op.name, u"res/mul1")
+ self.assertNotEquals(res.name, g.name)
+ with session.Session() as sess:
+ sess.run(variables.global_variables_initializer())
+ g_val, res_val = sess.run([g, res])
+ self.assertNear(g_val, 0.0, ERROR_TOLERANCE)
+ self.assertNear(res_val, 0.0, ERROR_TOLERANCE)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py
index 2234400fdc..14ac529665 100644
--- a/tensorflow/contrib/graph_editor/transform.py
+++ b/tensorflow/contrib/graph_editor/transform.py
@@ -166,13 +166,12 @@ def copy_op_handler(info, op, copy_shape=True):
for t, t_ in zip(op.outputs, op_.outputs):
t_.set_shape(t.get_shape())
- # Finalize original op.
+ # Original op cannot be finalised here yet. Because some ops require this
+ # attribute to exist, we will create a dummy original_op first and then
+ # later finalise it with the actual original_op when all the ops have
+ # been copied.
if op._original_op:
- original_op = info.transform_original_op_handler(info, op._original_op)
- if original_op is None:
- logging.debug("Could not find original op of: %s", op_.name)
- else:
- op_._original_op = original_op
+ op_._original_op = op._original_op
# Add op to the graph
info.graph_._add_op(op_)
@@ -471,6 +470,14 @@ class Transformer(object):
for t in inputs_:
op_._add_input(t)
+ # Finalize original op.
+ if op._original_op:
+ original_op = info.transform_original_op_handler(info, op._original_op)
+ if original_op is None:
+ logging.debug("Could not find original op for: %s", op_.name)
+ else:
+ op_._original_op = original_op
+
# Finalize control inputs:
control_inputs_ = [self.transform_control_input_handler(info, ci)
for ci in op.control_inputs]