aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/gradients_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/gradients_test.py')
-rw-r--r--tensorflow/python/ops/gradients_test.py90
1 files changed, 90 insertions, 0 deletions
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index d70cd088c9..d02fcf4ee2 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -437,6 +437,96 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
grad_func=grad_func, python_grad_func=self._PythonGradient)
f.add_to_graph(ops.Graph())
+ def testGradientWrtCaptured(self):
+ with ops.Graph().as_default():
+ x = constant_op.constant(1.0, name="x")
+
+ @function.Defun()
+ def Foo():
+ y = math_ops.multiply(x, 2.0, name="y")
+ g = gradients_impl.gradients(y, x)
+ return g[0]
+
+ f = Foo()
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(f), 2.0)
+
+ def testGradientOfCaptured(self):
+ with ops.Graph().as_default():
+ x = constant_op.constant(1.0, name="x")
+ y = math_ops.multiply(x, 2.0, name="y")
+
+ @function.Defun()
+ def Foo():
+ g = gradients_impl.gradients(y, x)
+ return g[0]
+
+ f = Foo()
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(f), 2.0)
+
+ def testCapturedResourceVariable(self):
+ with ops.Graph().as_default():
+ var = resource_variable_ops.ResourceVariable(1.0, name="var")
+
+ @function.Defun()
+ def Foo():
+ y = math_ops.multiply(var, 2.0, name="y")
+ g = gradients_impl.gradients(y, var)
+ return g[0]
+
+ f = Foo()
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ self.assertEqual(sess.run(f), 2.0)
+
+ def testCapturedNested(self):
+ with ops.Graph().as_default():
+ x1 = constant_op.constant(1.0, name="x1")
+ x2 = constant_op.constant(2.0, name="x2")
+ x3 = math_ops.multiply(x1, x2, name="x3")
+
+ @function.Defun()
+ def Outer():
+ outer1 = array_ops.identity(x1, name="outer1")
+
+ @function.Defun()
+ def Inner():
+ inner1 = array_ops.identity(outer1, name="inner1")
+ inner2 = array_ops.identity(x2, name="inner2")
+ inner3 = array_ops.identity(x3, name="inner3")
+ return gradients_impl.gradients([inner1, inner2, inner3, x1],
+ [x1, x2])
+
+ return Inner()
+
+ x1_grad, x2_grad = Outer()
+ with self.test_session() as sess:
+ # 1.0 + None + 2.0 + 1.0 = 4.0
+ self.assertEqual(sess.run(x1_grad), 4.0)
+ # None + 1.0 + 1.0 + None = 2.0
+ self.assertEqual(sess.run(x2_grad), 2.0)
+
+ def testCapturedFromFunction(self):
+ with ops.Graph().as_default():
+ x = constant_op.constant(1.0, name="x")
+
+ @function.Defun()
+ def Outer():
+ y = math_ops.multiply(x, 2.0, name="y")
+
+ @function.Defun()
+ def Inner():
+ z = math_ops.multiply(y, 3.0, name="z")
+ g = gradients_impl.gradients(z, y)
+ return g[0]
+
+ return Inner()
+
+ z_grad = Outer()
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(z_grad), 3.0)
+
class StopGradientTest(test_util.TensorFlowTestCase):