aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-13 08:43:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-13 08:48:11 -0700
commit4503f464628d4ba6f01e0e5b2aa9ff829763982b (patch)
tree596eb2e7f4310a1cd70d5d1bdf5251a95fbf09cb
parent1c241e5ba7fa7068f9cf8f925638b170db57c438 (diff)
Avoid cache thrashing in CTC beam search
Change the logic that identifies topK choices with a cache friendly alternative. PiperOrigin-RevId: 172101068
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_search.h83
1 files changed, 59 insertions, 24 deletions
diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h
index f1773bcd95..372f25a143 100644
--- a/tensorflow/core/util/ctc/ctc_beam_search.h
+++ b/tensorflow/core/util/ctc/ctc_beam_search.h
@@ -102,6 +102,11 @@ class CTCBeamSearchDecoder : public CTCDecoder {
template <typename Vector>
void Step(const Vector& log_input_t);
+ template <typename Vector>
+ float GetTopK(const int K, const Vector& input,
+ std::vector<float>* top_k_logits,
+ std::vector<int>* top_k_indices);
+
// Retrieve the beam scorer instance used during decoding.
BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; }
@@ -204,29 +209,57 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
template <typename CTCBeamState, typename CTCBeamComparer>
template <typename Vector>
+float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK(
+ const int K, const Vector& input, std::vector<float>* top_k_logits,
+ std::vector<int>* top_k_indices) {
+ // Find Top K choices, complexity nk in worst case. The array input is read
+ // just once.
+ CHECK_EQ(num_classes_, input.size());
+ top_k_logits->clear();
+ top_k_indices->clear();
+ top_k_logits->resize(K, -INFINITY);
+ top_k_indices->resize(K, -1);
+ for (int j = 0; j < num_classes_ - 1; ++j) {
+ const float logit = input(j);
+ if (logit > (*top_k_logits)[K - 1]) {
+ int k = K - 1;
+ while (k > 0 && logit > (*top_k_logits)[k - 1]) {
+ (*top_k_logits)[k] = (*top_k_logits)[k - 1];
+ (*top_k_indices)[k] = (*top_k_indices)[k - 1];
+ k--;
+ }
+ (*top_k_logits)[k] = logit;
+ (*top_k_indices)[k] = j;
+ }
+ }
+ // Return max value which is in 0th index or blank character logit
+ return std::max((*top_k_logits)[0], input(num_classes_ - 1));
+}
+
+template <typename CTCBeamState, typename CTCBeamComparer>
+template <typename Vector>
void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
const Vector& raw_input) {
- Eigen::ArrayXf input = raw_input;
- // Remove the max for stability when performing log-prob calculations.
- input -= input.maxCoeff();
-
- // Minimum allowed input value for label selection:
- float label_selection_input_min = -std::numeric_limits<float>::infinity();
- if (label_selection_size_ > 0 && label_selection_size_ < input.size()) {
- std::vector<float> input_copy(input.data(), input.data() + input.size());
- std::nth_element(input_copy.begin(),
- input_copy.begin() + label_selection_size_ - 1,
- input_copy.end(), [](float a, float b) { return a > b; });
- label_selection_input_min = input_copy[label_selection_size_ - 1];
- }
- if (label_selection_margin_ >= 0) {
- // max element is 0, per normalization above
- label_selection_input_min =
- std::max(label_selection_input_min, -label_selection_margin_);
+ std::vector<float> top_k_logits;
+ std::vector<int> top_k_indices;
+ const bool top_k =
+ (label_selection_size_ > 0 && label_selection_size_ < raw_input.size());
+ // Number of character classes to consider in each step.
+ const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1);
+ // Get max coefficient and remove it from raw_input later.
+ float max_coeff;
+ if (top_k) {
+ max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits,
+ &top_k_indices);
+ } else {
+ max_coeff = raw_input.maxCoeff();
}
+ const float label_selection_input_min =
+ (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
+ : -std::numeric_limits<float>::infinity();
// Extract the beams sorted in decreasing new probability
- CHECK_EQ(num_classes_, input.size());
+ CHECK_EQ(num_classes_, raw_input.size());
std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
leaves_.Reset();
@@ -252,10 +285,10 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
beam_scorer_->GetStateExpansionScore(b->state, previous));
}
// Plabel(l=abc @ t=6) *= P(c @ 6)
- b->newp.label += input(b->label);
+ b->newp.label += raw_input(b->label) - max_coeff;
}
// Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
- b->newp.blank = b->oldp.total + input(blank_index_);
+ b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff;
// P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
@@ -285,13 +318,15 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
continue;
}
- for (int ind = 0; ind < num_classes_ - 1; ind++) {
+ for (int ind = 0; ind < max_classes; ind++) {
+ const int label = top_k ? top_k_indices[ind] : ind;
+ const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
// Perform label selection: if input for this label looks very
// unpromising, never evaluate it with a scorer.
- if (input(ind) < label_selection_input_min) {
+ if (logit < label_selection_input_min) {
continue;
}
- BeamEntry& c = b->GetChild(ind);
+ BeamEntry& c = b->GetChild(label);
if (!c.Active()) {
// Pblank(l=abcd @ t=6) = 0
c.newp.blank = kLogZero;
@@ -301,7 +336,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
// Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
- c.newp.label = input(c.label) +
+ c.newp.label = logit - max_coeff +
beam_scorer_->GetStateExpansionScore(c.state, previous);
// P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
c.newp.total = c.newp.label;