aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-05 16:12:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 16:16:55 -0700
commit6ce8af21574ce71f94a8a06bde876d2f7bf690e5 (patch)
tree6bbecf98c3fcba968123c95d2ebdc7f66d8decb1 /tensorflow/contrib/data
parentb744cc00e1522d50463e2b681beae39cbb6f4d16 (diff)
[tf.data] Surface errors correctly in MapDefunOp by using different CancellationManagers for each run of the function.
PiperOrigin-RevId: 211717580
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 73cde40305..091eb5ce37 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -130,6 +130,22 @@ class MapDefunTest(test.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(result)
+ def testMapDefunCancelledCorrectly(self):
+
+ @function.Defun(dtypes.int64)
+ def defun(x):
+ # x has leading dimension 5, this will raise an error
+ return array_ops.gather(x, 10)
+
+ c = array_ops.tile(
+ array_ops.expand_dims(
+ constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0),
+ [100, 1])
+ map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0]
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ r"indices = 10 is not in \[0, 5\)"):
+ self.evaluate(map_defun_op)
+
if __name__ == "__main__":
test.main()