diff options
author | 2018-06-04 14:48:32 -0700 | |
---|---|---|
committer | 2018-06-04 14:52:29 -0700 | |
commit | 06c4fb61f269e18ca2f4b9a73d1b92e48bd095bf (patch) | |
tree | dda36259be82ece8c16e0084d2a133b2fedcbaa6 | |
parent | 6b2a088fb263af2428ca672a62088646a7f54219 (diff) |
Fixes a cleanup bug in BatchFunction op.
PiperOrigin-RevId: 199198413
-rw-r--r-- | tensorflow/contrib/batching/python/ops/batch_ops_test.py | 28 | ||||
-rw-r--r-- | tensorflow/core/kernels/batch_kernels.cc | 37 |
2 files changed, 48 insertions, 17 deletions
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index 68e8a88ca0..ea8339334f 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -24,6 +24,7 @@ import time from tensorflow.contrib.batching.python.ops import batch_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import function +from tensorflow.python.framework.errors import InvalidArgumentError from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_batch_ops from tensorflow.python.ops import gradients_impl @@ -208,7 +209,7 @@ class BatchOpsTest(test.TestCase): self.assertEqual(main_results[0], [3]) def testBatchFunctionOp(self): - """Tests that the batch_func works.""" + """Tests that the batch_function op works.""" with self.test_session() as sess: @function.Defun(dtypes.int32) @@ -237,7 +238,7 @@ class BatchOpsTest(test.TestCase): self.assertEqual(main_results[0], [3]) def testBatchFunctionOpWithCapturedInput(self): - """Tests that batch_func with timeout.""" + """Tests that batch_function op works with captured input.""" with self.test_session() as sess: captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) @@ -270,6 +271,29 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBatchFunctionOpWithInputError(self): + """Tests that batch_function op works with error in the inputs.""" + with self.test_session() as sess: + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + + @function.Defun(dtypes.int32, dtypes.int32) + def computation(in0, in1): + return in0 + in1 + + result = gen_batch_ops.batch_function( + [inp], # computation actually expects 2 inputs. + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + batching_queue="", + f=computation, + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + with self.assertRaisesRegexp(InvalidArgumentError, + ".*2 arguments.*but 1.*"): + sess.run([result], feed_dict={inp: [2]}) + def testBasicUnbatchDecoratedWithReshape(self): """Tests that the batch_function decorator works.""" with self.test_session() as sess: diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index c0eef229ce..35ddda0ec0 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -523,21 +523,28 @@ class BatchResource : public ResourceBase { const auto& captured_inputs = batch->task(batch->num_tasks() - 1).captured_inputs; args.insert(args.end(), captured_inputs.begin(), captured_inputs.end()); - flib->Run(opts, fhandle_, args, &combined_outputs, - [&](const Status& run_status) { - if (!run_status.ok()) { - return; - } - const auto split_status = - SplitOutputTensors(combined_outputs, batch.get()); - // We do the cleanup here as an optimization, so that it runs in - // the underlying TF inter-op threadpool. Running it in the - // threadpool, let's the ensuing ops be scheduled faster, - // because the executor will add them to the front of the - // threadpool's task queue rather than the end. - cleanup_fn(split_status); - done.Notify(); - }); + + // Releases the cleanup method here, because the callback of the function + // library runtime will handle it now. + finally.release(); + flib->Run( + opts, fhandle_, args, &combined_outputs, [&](const Status& run_status) { + Status final_status; + auto run_finally = gtl::MakeCleanup([&]() { + // We do the cleanup here as an optimization, so that it runs in + // the underlying TF inter-op threadpool. Running it in the + // threadpool, let's the ensuing ops be scheduled faster, + // because the executor will add them to the front of the + // threadpool's task queue rather than the end. + cleanup_fn(final_status); + done.Notify(); + }); + final_status = run_status; + if (!final_status.ok()) { + return; + } + final_status = SplitOutputTensors(combined_outputs, batch.get()); + }); // By waiting for the notification we are ensuring that this thread isn't // used for processing other batches, which gives the batches time to // coalesce upstream. So overall the number of batches going through the |