aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vinu Rajashekhar <vinuraja@google.com>2018-06-04 14:48:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-04 14:52:29 -0700
commit06c4fb61f269e18ca2f4b9a73d1b92e48bd095bf (patch)
treedda36259be82ece8c16e0084d2a133b2fedcbaa6
parent6b2a088fb263af2428ca672a62088646a7f54219 (diff)
Fixes a cleanup bug in BatchFunction op.
PiperOrigin-RevId: 199198413
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops_test.py28
-rw-r--r--tensorflow/core/kernels/batch_kernels.cc37
2 files changed, 48 insertions, 17 deletions
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
index 68e8a88ca0..ea8339334f 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
@@ -24,6 +24,7 @@ import time
from tensorflow.contrib.batching.python.ops import batch_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
+from tensorflow.python.framework.errors import InvalidArgumentError
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_batch_ops
from tensorflow.python.ops import gradients_impl
@@ -208,7 +209,7 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(main_results[0], [3])
def testBatchFunctionOp(self):
- """Tests that the batch_func works."""
+ """Tests that the batch_function op works."""
with self.test_session() as sess:
@function.Defun(dtypes.int32)
@@ -237,7 +238,7 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(main_results[0], [3])
def testBatchFunctionOpWithCapturedInput(self):
- """Tests that batch_func with timeout."""
+ """Tests that batch_function op works with captured input."""
with self.test_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
@@ -270,6 +271,29 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
+ def testBatchFunctionOpWithInputError(self):
+ """Tests that batch_function op works with error in the inputs."""
+ with self.test_session() as sess:
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+
+ @function.Defun(dtypes.int32, dtypes.int32)
+ def computation(in0, in1):
+ return in0 + in1
+
+ result = gen_batch_ops.batch_function(
+ [inp], # computation actually expects 2 inputs.
+ num_batch_threads=1,
+ max_batch_size=10,
+ batch_timeout_micros=100000, # 100ms
+ batching_queue="",
+ f=computation,
+ captured_tensors=computation.captured_inputs,
+ Tout=[o.type for o in computation.definition.signature.output_arg])
+
+ with self.assertRaisesRegexp(InvalidArgumentError,
+ ".*2 arguments.*but 1.*"):
+ sess.run([result], feed_dict={inp: [2]})
+
def testBasicUnbatchDecoratedWithReshape(self):
"""Tests that the batch_function decorator works."""
with self.test_session() as sess:
diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc
index c0eef229ce..35ddda0ec0 100644
--- a/tensorflow/core/kernels/batch_kernels.cc
+++ b/tensorflow/core/kernels/batch_kernels.cc
@@ -523,21 +523,28 @@ class BatchResource : public ResourceBase {
const auto& captured_inputs =
batch->task(batch->num_tasks() - 1).captured_inputs;
args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
- flib->Run(opts, fhandle_, args, &combined_outputs,
- [&](const Status& run_status) {
- if (!run_status.ok()) {
- return;
- }
- const auto split_status =
- SplitOutputTensors(combined_outputs, batch.get());
- // We do the cleanup here as an optimization, so that it runs in
- // the underlying TF inter-op threadpool. Running it in the
- // threadpool, let's the ensuing ops be scheduled faster,
- // because the executor will add them to the front of the
- // threadpool's task queue rather than the end.
- cleanup_fn(split_status);
- done.Notify();
- });
+
+ // Releases the cleanup method here, because the callback of the function
+ // library runtime will handle it now.
+ finally.release();
+ flib->Run(
+ opts, fhandle_, args, &combined_outputs, [&](const Status& run_status) {
+ Status final_status;
+ auto run_finally = gtl::MakeCleanup([&]() {
+ // We do the cleanup here as an optimization, so that it runs in
+ // the underlying TF inter-op threadpool. Running it in the
+ // threadpool, let's the ensuing ops be scheduled faster,
+ // because the executor will add them to the front of the
+ // threadpool's task queue rather than the end.
+ cleanup_fn(final_status);
+ done.Notify();
+ });
+ final_status = run_status;
+ if (!final_status.ok()) {
+ return;
+ }
+ final_status = SplitOutputTensors(combined_outputs, batch.get());
+ });
// By waiting for the notification we are ensuring that this thread isn't
// used for processing other batches, which gives the batches time to
// coalesce upstream. So overall the number of batches going through the