diff options
author | Rachel Lim <rachelim@google.com> | 2018-09-05 16:12:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-05 16:16:55 -0700 |
commit | 6ce8af21574ce71f94a8a06bde876d2f7bf690e5 (patch) | |
tree | 6bbecf98c3fcba968123c95d2ebdc7f66d8decb1 /tensorflow/contrib/data | |
parent | b744cc00e1522d50463e2b681beae39cbb6f4d16 (diff) |
[tf.data] Surface errors correctly in MapDefunOp by using different CancellationManagers for each run of the function.
PiperOrigin-RevId: 211717580
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py index 73cde40305..091eb5ce37 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -130,6 +130,22 @@ class MapDefunTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(result) + def testMapDefunCancelledCorrectly(self): + + @function.Defun(dtypes.int64) + def defun(x): + # x has leading dimension 5, this will raise an error + return array_ops.gather(x, 10) + + c = array_ops.tile( + array_ops.expand_dims( + constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0), + [100, 1]) + map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0] + with self.assertRaisesRegexp(errors.InvalidArgumentError, + r"indices = 10 is not in \[0, 5\)"): + self.evaluate(map_defun_op) + if __name__ == "__main__": test.main() |