diff options
Diffstat (limited to 'tensorflow/core/kernels/ctc_decoder_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/ctc_decoder_ops.cc | 34 |
1 files changed, 22 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/ctc_decoder_ops.cc b/tensorflow/core/kernels/ctc_decoder_ops.cc index 96bdb6a241..8cadeac68d 100644 --- a/tensorflow/core/kernels/ctc_decoder_ops.cc +++ b/tensorflow/core/kernels/ctc_decoder_ops.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/ctc/ctc_beam_search.h" #include "tensorflow/core/util/sparse/sparse_tensor.h" +#include "tensorflow/core/util/work_sharder.h" namespace tensorflow { @@ -213,20 +214,29 @@ class CTCGreedyDecoderOp : public OpKernel { // Perform best path decoding std::vector<std::vector<std::vector<int> > > sequences(batch_size); - for (int b = 0; b < batch_size; ++b) { - sequences[b].resize(1); - auto& sequence = sequences[b][0]; - int prev_indices = -1; - for (int t = 0; t < seq_len_t(b); ++t) { - int max_class_indices; - log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices); - if (max_class_indices != blank_index && - !(merge_repeated_ && max_class_indices == prev_indices)) { - sequence.push_back(max_class_indices); + auto decode = [&](const int64 begin, const int64 end) { + for (int b = begin; b < end; ++b) { + sequences[b].resize(1); + auto &sequence = sequences[b][0]; + int prev_indices = -1; + for (int t = 0; t < seq_len_t(b); ++t) { + int max_class_indices; + log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices); + if (max_class_indices != blank_index && + !(merge_repeated_ && max_class_indices == prev_indices)) { + sequence.push_back(max_class_indices); + } + prev_indices = max_class_indices; } - prev_indices = max_class_indices; } - } + }; + + const int64 kCostPerUnit = 50 * max_time * num_classes; + const int64 total = batch_size; + const DeviceBase::CpuWorkerThreads& worker_threads = + *ctx->device()->tensorflow_cpu_worker_threads(); + Shard(worker_threads.num_threads, worker_threads.workers, total, + kCostPerUnit, decode); OP_REQUIRES_OK( ctx, decode_helper_.StoreAllDecodedSequences( |