diff options
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/transform_test.py | 29 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/transform.py | 19 |
2 files changed, 42 insertions, 6 deletions
diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index a4105645c6..ab5776b9dd 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -181,6 +182,34 @@ class TransformTest(test.TestCase): self.assertEqual(res[0].name, "b:0") self.assertEqual(res[1].name, "add_1:0") + def test_graph_replace_gradients(self): + ops.reset_default_graph() + w = variables.Variable(0.0, name="w") + y = math_ops.multiply(math_ops.multiply(w, w, name="mul1"), w, name="mul2") + g = gradients_impl.gradients(y, w, name="grad")[0] + + # Extract the operations. + replacement_ts = {w.value(): g} + original_mul1_grad = (ops.get_default_graph(). + get_operation_by_name("grad/mul1_grad/mul_1")) + + # Should not raise exception. + res = ge.graph_replace(g, replacement_ts, dst_scope="res") + + # Extract the operations after graph_replace. + result_mul1_grad = (ops.get_default_graph(). + get_operation_by_name("res/grad/mul1_grad/mul_1")) + + # Make sure _original_ops are as expected. + self.assertEquals(original_mul1_grad._original_op.name, u"mul1") + self.assertEquals(result_mul1_grad._original_op.name, u"res/mul1") + self.assertNotEquals(res.name, g.name) + with session.Session() as sess: + sess.run(variables.global_variables_initializer()) + g_val, res_val = sess.run([g, res]) + self.assertNear(g_val, 0.0, ERROR_TOLERANCE) + self.assertNear(res_val, 0.0, ERROR_TOLERANCE) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 2234400fdc..14ac529665 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -166,13 +166,12 @@ def copy_op_handler(info, op, copy_shape=True): for t, t_ in zip(op.outputs, op_.outputs): t_.set_shape(t.get_shape()) - # Finalize original op. + # Original op cannot be finalised here yet. Because some ops require this + # attribute to exist, we will create a dummy original_op first and then + # later finalise it with the actual original_op when all the ops have + # been copied. if op._original_op: - original_op = info.transform_original_op_handler(info, op._original_op) - if original_op is None: - logging.debug("Could not find original op of: %s", op_.name) - else: - op_._original_op = original_op + op_._original_op = op._original_op # Add op to the graph info.graph_._add_op(op_) @@ -471,6 +470,14 @@ class Transformer(object): for t in inputs_: op_._add_input(t) + # Finalize original op. + if op._original_op: + original_op = info.transform_original_op_handler(info, op._original_op) + if original_op is None: + logging.debug("Could not find original op for: %s", op_.name) + else: + op_._original_op = original_op + # Finalize control inputs: control_inputs_ = [self.transform_control_input_handler(info, ci) for ci in op.control_inputs] |