aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/function_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/function_test.py')
-rw-r--r--tensorflow/python/eager/function_test.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 57e545be69..e46bde098b 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -286,7 +286,23 @@ class FunctionTest(test.TestCase):
c = constant_op.constant([[2.]])
f_c = f(c)
g, = gradients_impl.gradients(f_c, c)
- self.assertAllEqual(sess.run(g), [[1.0]])
+ self.assertAllEqual(sess.run(g).values, [[1.0]])
+
+ def testNoSymGradNestedDefun(self):
+
+ @function.defun
+ def outer():
+
+ @function.defun
+ def f(x):
+ return array_ops.gather_nd(x, [[0]])
+
+ c = constant_op.constant([[2.]])
+ f_c = f(c)
+ g, = gradients_impl.gradients(f_c, c)
+ self.assertTrue(isinstance(g, ops.IndexedSlices))
+
+ outer()
def testNestedInputsGraphFunction(self):
matmul = function.defun(math_ops.matmul)