diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-13 08:43:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-13 08:48:11 -0700 |
commit | 4503f464628d4ba6f01e0e5b2aa9ff829763982b (patch) | |
tree | 596eb2e7f4310a1cd70d5d1bdf5251a95fbf09cb | |
parent | 1c241e5ba7fa7068f9cf8f925638b170db57c438 (diff) |
Avoid cache thrashing in CTC beam search
Change the logic that identifies topK choices with a cache friendly alternative.
PiperOrigin-RevId: 172101068
-rw-r--r-- | tensorflow/core/util/ctc/ctc_beam_search.h | 83 |
1 files changed, 59 insertions, 24 deletions
diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h index f1773bcd95..372f25a143 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search.h +++ b/tensorflow/core/util/ctc/ctc_beam_search.h @@ -102,6 +102,11 @@ class CTCBeamSearchDecoder : public CTCDecoder { template <typename Vector> void Step(const Vector& log_input_t); + template <typename Vector> + float GetTopK(const int K, const Vector& input, + std::vector<float>* top_k_logits, + std::vector<int>* top_k_indices); + // Retrieve the beam scorer instance used during decoding. BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; } @@ -204,29 +209,57 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode( template <typename CTCBeamState, typename CTCBeamComparer> template <typename Vector> +float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK( + const int K, const Vector& input, std::vector<float>* top_k_logits, + std::vector<int>* top_k_indices) { + // Find Top K choices, complexity nk in worst case. The array input is read + // just once. + CHECK_EQ(num_classes_, input.size()); + top_k_logits->clear(); + top_k_indices->clear(); + top_k_logits->resize(K, -INFINITY); + top_k_indices->resize(K, -1); + for (int j = 0; j < num_classes_ - 1; ++j) { + const float logit = input(j); + if (logit > (*top_k_logits)[K - 1]) { + int k = K - 1; + while (k > 0 && logit > (*top_k_logits)[k - 1]) { + (*top_k_logits)[k] = (*top_k_logits)[k - 1]; + (*top_k_indices)[k] = (*top_k_indices)[k - 1]; + k--; + } + (*top_k_logits)[k] = logit; + (*top_k_indices)[k] = j; + } + } + // Return max value which is in 0th index or blank character logit + return std::max((*top_k_logits)[0], input(num_classes_ - 1)); +} + +template <typename CTCBeamState, typename CTCBeamComparer> +template <typename Vector> void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( const Vector& raw_input) { - Eigen::ArrayXf input = raw_input; - // Remove the max for stability when performing log-prob calculations. - input -= input.maxCoeff(); - - // Minimum allowed input value for label selection: - float label_selection_input_min = -std::numeric_limits<float>::infinity(); - if (label_selection_size_ > 0 && label_selection_size_ < input.size()) { - std::vector<float> input_copy(input.data(), input.data() + input.size()); - std::nth_element(input_copy.begin(), - input_copy.begin() + label_selection_size_ - 1, - input_copy.end(), [](float a, float b) { return a > b; }); - label_selection_input_min = input_copy[label_selection_size_ - 1]; - } - if (label_selection_margin_ >= 0) { - // max element is 0, per normalization above - label_selection_input_min = - std::max(label_selection_input_min, -label_selection_margin_); + std::vector<float> top_k_logits; + std::vector<int> top_k_indices; + const bool top_k = + (label_selection_size_ > 0 && label_selection_size_ < raw_input.size()); + // Number of character classes to consider in each step. + const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1); + // Get max coefficient and remove it from raw_input later. + float max_coeff; + if (top_k) { + max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits, + &top_k_indices); + } else { + max_coeff = raw_input.maxCoeff(); } + const float label_selection_input_min = + (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_) + : -std::numeric_limits<float>::infinity(); // Extract the beams sorted in decreasing new probability - CHECK_EQ(num_classes_, input.size()); + CHECK_EQ(num_classes_, raw_input.size()); std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract()); leaves_.Reset(); @@ -252,10 +285,10 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( beam_scorer_->GetStateExpansionScore(b->state, previous)); } // Plabel(l=abc @ t=6) *= P(c @ 6) - b->newp.label += input(b->label); + b->newp.label += raw_input(b->label) - max_coeff; } // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6) - b->newp.blank = b->oldp.total + input(blank_index_); + b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff; // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6) b->newp.total = LogSumExp(b->newp.blank, b->newp.label); @@ -285,13 +318,15 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( continue; } - for (int ind = 0; ind < num_classes_ - 1; ind++) { + for (int ind = 0; ind < max_classes; ind++) { + const int label = top_k ? top_k_indices[ind] : ind; + const float logit = top_k ? top_k_logits[ind] : raw_input(ind); // Perform label selection: if input for this label looks very // unpromising, never evaluate it with a scorer. - if (input(ind) < label_selection_input_min) { + if (logit < label_selection_input_min) { continue; } - BeamEntry& c = b->GetChild(ind); + BeamEntry& c = b->GetChild(label); if (!c.Active()) { // Pblank(l=abcd @ t=6) = 0 c.newp.blank = kLogZero; @@ -301,7 +336,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6) beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label); float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total; - c.newp.label = input(c.label) + + c.newp.label = logit - max_coeff + beam_scorer_->GetStateExpansionScore(c.state, previous); // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6) c.newp.total = c.newp.label; |