diff options
author | Vinu Rajashekhar <vinuraja@google.com> | 2018-06-04 14:48:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-04 14:52:29 -0700 |
commit | 06c4fb61f269e18ca2f4b9a73d1b92e48bd095bf (patch) | |
tree | dda36259be82ece8c16e0084d2a133b2fedcbaa6 /tensorflow/contrib/batching | |
parent | 6b2a088fb263af2428ca672a62088646a7f54219 (diff) |
Fixes a cleanup bug in BatchFunction op.
PiperOrigin-RevId: 199198413
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r-- | tensorflow/contrib/batching/python/ops/batch_ops_test.py | 28 |
1 files changed, 26 insertions, 2 deletions
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index 68e8a88ca0..ea8339334f 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -24,6 +24,7 @@ import time from tensorflow.contrib.batching.python.ops import batch_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import function +from tensorflow.python.framework.errors import InvalidArgumentError from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_batch_ops from tensorflow.python.ops import gradients_impl @@ -208,7 +209,7 @@ class BatchOpsTest(test.TestCase): self.assertEqual(main_results[0], [3]) def testBatchFunctionOp(self): - """Tests that the batch_func works.""" + """Tests that the batch_function op works.""" with self.test_session() as sess: @function.Defun(dtypes.int32) @@ -237,7 +238,7 @@ class BatchOpsTest(test.TestCase): self.assertEqual(main_results[0], [3]) def testBatchFunctionOpWithCapturedInput(self): - """Tests that batch_func with timeout.""" + """Tests that batch_function op works with captured input.""" with self.test_session() as sess: captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) @@ -270,6 +271,29 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBatchFunctionOpWithInputError(self): + """Tests that batch_function op works with error in the inputs.""" + with self.test_session() as sess: + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + + @function.Defun(dtypes.int32, dtypes.int32) + def computation(in0, in1): + return in0 + in1 + + result = gen_batch_ops.batch_function( + [inp], # computation actually expects 2 inputs. + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + batching_queue="", + f=computation, + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + with self.assertRaisesRegexp(InvalidArgumentError, + ".*2 arguments.*but 1.*"): + sess.run([result], feed_dict={inp: [2]}) + def testBasicUnbatchDecoratedWithReshape(self): """Tests that the batch_function decorator works.""" with self.test_session() as sess: |