aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/tensor_array_ops_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-26 12:42:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-26 12:45:22 -0700
commitf63750645826df65b05cad505546a86f0e347674 (patch)
tree8467d73780d74b0f7ef4c87f8866d3bf0a233254 /tensorflow/compiler/tests/tensor_array_ops_test.py
parent667077cbd2cc86c4a656233a2d5f579aa4caf1f1 (diff)
For tf.gradients(), do not backpropagate through integer tensors.
All integer tensors are now considered constant with respect to all `xs`. This fixes a bug in gradients through tf.while_loop. PiperOrigin-RevId: 194438529
Diffstat (limited to 'tensorflow/compiler/tests/tensor_array_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/tensor_array_ops_test.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
index 7624d6e4b2..f332aa2e9b 100644
--- a/tensorflow/compiler/tests/tensor_array_ops_test.py
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -472,7 +472,9 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(c([[-2.0, -10.0]]), grad_vals[1])
def testTensorArrayGradientWriteRead(self):
- for dtype in self.numeric_types:
+ for dtype in self.float_types:
+ self._testTensorArrayGradientWriteReadType(dtype)
+ for dtype in self.complex_types:
self._testTensorArrayGradientWriteReadType(dtype)
def _testTensorArrayGradientWritePackConcatAndRead(self):