diff options
Diffstat (limited to 'tensorflow/python/ops/gradients_test.py')
-rw-r--r-- | tensorflow/python/ops/gradients_test.py | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index d70cd088c9..d02fcf4ee2 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -437,6 +437,96 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): grad_func=grad_func, python_grad_func=self._PythonGradient) f.add_to_graph(ops.Graph()) + def testGradientWrtCaptured(self): + with ops.Graph().as_default(): + x = constant_op.constant(1.0, name="x") + + @function.Defun() + def Foo(): + y = math_ops.multiply(x, 2.0, name="y") + g = gradients_impl.gradients(y, x) + return g[0] + + f = Foo() + with self.test_session() as sess: + self.assertEqual(sess.run(f), 2.0) + + def testGradientOfCaptured(self): + with ops.Graph().as_default(): + x = constant_op.constant(1.0, name="x") + y = math_ops.multiply(x, 2.0, name="y") + + @function.Defun() + def Foo(): + g = gradients_impl.gradients(y, x) + return g[0] + + f = Foo() + with self.test_session() as sess: + self.assertEqual(sess.run(f), 2.0) + + def testCapturedResourceVariable(self): + with ops.Graph().as_default(): + var = resource_variable_ops.ResourceVariable(1.0, name="var") + + @function.Defun() + def Foo(): + y = math_ops.multiply(var, 2.0, name="y") + g = gradients_impl.gradients(y, var) + return g[0] + + f = Foo() + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertEqual(sess.run(f), 2.0) + + def testCapturedNested(self): + with ops.Graph().as_default(): + x1 = constant_op.constant(1.0, name="x1") + x2 = constant_op.constant(2.0, name="x2") + x3 = math_ops.multiply(x1, x2, name="x3") + + @function.Defun() + def Outer(): + outer1 = array_ops.identity(x1, name="outer1") + + @function.Defun() + def Inner(): + inner1 = array_ops.identity(outer1, name="inner1") + inner2 = array_ops.identity(x2, name="inner2") + inner3 = array_ops.identity(x3, name="inner3") + return gradients_impl.gradients([inner1, inner2, inner3, x1], + [x1, x2]) + + return Inner() + + x1_grad, x2_grad = Outer() + with self.test_session() as sess: + # 1.0 + None + 2.0 + 1.0 = 4.0 + self.assertEqual(sess.run(x1_grad), 4.0) + # None + 1.0 + 1.0 + None = 2.0 + self.assertEqual(sess.run(x2_grad), 2.0) + + def testCapturedFromFunction(self): + with ops.Graph().as_default(): + x = constant_op.constant(1.0, name="x") + + @function.Defun() + def Outer(): + y = math_ops.multiply(x, 2.0, name="y") + + @function.Defun() + def Inner(): + z = math_ops.multiply(y, 3.0, name="z") + g = gradients_impl.gradients(z, y) + return g[0] + + return Inner() + + z_grad = Outer() + with self.test_session() as sess: + self.assertEqual(sess.run(z_grad), 3.0) + class StopGradientTest(test_util.TensorFlowTestCase): |