aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/batching
diff options
context:
space:
mode:
authorGravatar Vinu Rajashekhar <vinuraja@google.com>2018-06-04 14:48:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-04 14:52:29 -0700
commit06c4fb61f269e18ca2f4b9a73d1b92e48bd095bf (patch)
treedda36259be82ece8c16e0084d2a133b2fedcbaa6 /tensorflow/contrib/batching
parent6b2a088fb263af2428ca672a62088646a7f54219 (diff)
Fixes a cleanup bug in BatchFunction op.
PiperOrigin-RevId: 199198413
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops_test.py28
1 files changed, 26 insertions, 2 deletions
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
index 68e8a88ca0..ea8339334f 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
@@ -24,6 +24,7 @@ import time
from tensorflow.contrib.batching.python.ops import batch_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
+from tensorflow.python.framework.errors import InvalidArgumentError
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_batch_ops
from tensorflow.python.ops import gradients_impl
@@ -208,7 +209,7 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(main_results[0], [3])
def testBatchFunctionOp(self):
- """Tests that the batch_func works."""
+ """Tests that the batch_function op works."""
with self.test_session() as sess:
@function.Defun(dtypes.int32)
@@ -237,7 +238,7 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(main_results[0], [3])
def testBatchFunctionOpWithCapturedInput(self):
- """Tests that batch_func with timeout."""
+ """Tests that batch_function op works with captured input."""
with self.test_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
@@ -270,6 +271,29 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
+ def testBatchFunctionOpWithInputError(self):
+ """Tests that batch_function op works with error in the inputs."""
+ with self.test_session() as sess:
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+
+ @function.Defun(dtypes.int32, dtypes.int32)
+ def computation(in0, in1):
+ return in0 + in1
+
+ result = gen_batch_ops.batch_function(
+ [inp], # computation actually expects 2 inputs.
+ num_batch_threads=1,
+ max_batch_size=10,
+ batch_timeout_micros=100000, # 100ms
+ batching_queue="",
+ f=computation,
+ captured_tensors=computation.captured_inputs,
+ Tout=[o.type for o in computation.definition.signature.output_arg])
+
+ with self.assertRaisesRegexp(InvalidArgumentError,
+ ".*2 arguments.*but 1.*"):
+ sess.run([result], feed_dict={inp: [2]})
+
def testBasicUnbatchDecoratedWithReshape(self):
"""Tests that the batch_function decorator works."""
with self.test_session() as sess: