aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/ctc_decoder_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/ctc_decoder_ops.cc')
-rw-r--r--tensorflow/core/kernels/ctc_decoder_ops.cc34
1 files changed, 22 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/ctc_decoder_ops.cc b/tensorflow/core/kernels/ctc_decoder_ops.cc
index 96bdb6a241..8cadeac68d 100644
--- a/tensorflow/core/kernels/ctc_decoder_ops.cc
+++ b/tensorflow/core/kernels/ctc_decoder_ops.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/ctc/ctc_beam_search.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
@@ -213,20 +214,29 @@ class CTCGreedyDecoderOp : public OpKernel {
// Perform best path decoding
std::vector<std::vector<std::vector<int> > > sequences(batch_size);
- for (int b = 0; b < batch_size; ++b) {
- sequences[b].resize(1);
- auto& sequence = sequences[b][0];
- int prev_indices = -1;
- for (int t = 0; t < seq_len_t(b); ++t) {
- int max_class_indices;
- log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices);
- if (max_class_indices != blank_index &&
- !(merge_repeated_ && max_class_indices == prev_indices)) {
- sequence.push_back(max_class_indices);
+ auto decode = [&](const int64 begin, const int64 end) {
+ for (int b = begin; b < end; ++b) {
+ sequences[b].resize(1);
+ auto &sequence = sequences[b][0];
+ int prev_indices = -1;
+ for (int t = 0; t < seq_len_t(b); ++t) {
+ int max_class_indices;
+ log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices);
+ if (max_class_indices != blank_index &&
+ !(merge_repeated_ && max_class_indices == prev_indices)) {
+ sequence.push_back(max_class_indices);
+ }
+ prev_indices = max_class_indices;
}
- prev_indices = max_class_indices;
}
- }
+ };
+
+ const int64 kCostPerUnit = 50 * max_time * num_classes;
+ const int64 total = batch_size;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *ctx->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, total,
+ kCostPerUnit, decode);
OP_REQUIRES_OK(
ctx, decode_helper_.StoreAllDecodedSequences(