diff options
Diffstat (limited to 'tensorflow/python/eager/function_test.py')
-rw-r--r-- | tensorflow/python/eager/function_test.py | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 57e545be69..e46bde098b 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -286,7 +286,23 @@ class FunctionTest(test.TestCase): c = constant_op.constant([[2.]]) f_c = f(c) g, = gradients_impl.gradients(f_c, c) - self.assertAllEqual(sess.run(g), [[1.0]]) + self.assertAllEqual(sess.run(g).values, [[1.0]]) + + def testNoSymGradNestedDefun(self): + + @function.defun + def outer(): + + @function.defun + def f(x): + return array_ops.gather_nd(x, [[0]]) + + c = constant_op.constant([[2.]]) + f_c = f(c) + g, = gradients_impl.gradients(f_c, c) + self.assertTrue(isinstance(g, ops.IndexedSlices)) + + outer() def testNestedInputsGraphFunction(self): matmul = function.defun(math_ops.matmul) |