diff options
author | 2018-09-04 23:53:37 -0700 | |
---|---|---|
committer | 2018-09-04 23:58:14 -0700 | |
commit | 6b89e9ffc991e0683cecd7a62e04cdf4a8c88356 (patch) | |
tree | 56f8423adb06c11a144e0adfa35bbeed00e859c2 /tensorflow/contrib/lite/experimental | |
parent | 7e6885cceb0e0117efb6ef0298770868dcf71436 (diff) |
PR #21187: Added a normalization term to ctc_beam_search_decoder for tflite
PiperOrigin-RevId: 211586062
Diffstat (limited to 'tensorflow/contrib/lite/experimental')
-rw-r--r-- | tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h | 18 | ||||
-rw-r--r-- | tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc | 13 |
2 files changed, 21 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h index c658e43092..7c5099235a 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h @@ -257,6 +257,16 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( } else { max_coeff = raw_input.maxCoeff(); } + + // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))). + float logsumexp = 0.0; + for (int j = 0; j < raw_input.size(); ++j) { + logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff); + } + logsumexp = Eigen::numext::log(logsumexp); + // Final normalization offset to get correct log probabilities. + float norm_offset = max_coeff + logsumexp; + const float label_selection_input_min = (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_) : -std::numeric_limits<float>::infinity(); @@ -288,10 +298,10 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( beam_scorer_->GetStateExpansionScore(b->state, previous)); } // Plabel(l=abc @ t=6) *= P(c @ 6) - b->newp.label += raw_input(b->label) - max_coeff; + b->newp.label += raw_input(b->label) - norm_offset; } // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6) - b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff; + b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset; // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6) b->newp.total = LogSumExp(b->newp.blank, b->newp.label); @@ -326,6 +336,8 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( const float logit = top_k ? top_k_logits[ind] : raw_input(ind); // Perform label selection: if input for this label looks very // unpromising, never evaluate it with a scorer. + // We may compare logits instead of log probabilities, + // since the difference is the same in both cases. if (logit < label_selection_input_min) { continue; } @@ -339,7 +351,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6) beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label); float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total; - c.newp.label = logit - max_coeff + + c.newp.label = logit - norm_offset + beam_scorer_->GetStateExpansionScore(c.state, previous); // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6) c.newp.total = c.newp.label; diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc index 32458305c4..aa42b495bd 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc @@ -117,7 +117,7 @@ TEST(CTCBeamSearchTest, SimpleTest) { EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1)); // Check log probabilities output. EXPECT_THAT(m.GetLogProbabilitiesOutput(), - ElementsAreArray(ArrayFloatNear({0.32134813}))); + ElementsAreArray(ArrayFloatNear({-0.357094}))); } TEST(CTCBeamSearchTest, MultiBatchTest) { @@ -148,9 +148,8 @@ TEST(CTCBeamSearchTest, MultiBatchTest) { EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0)); EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2)); // Check log probabilities output. - EXPECT_THAT( - m.GetLogProbabilitiesOutput(), - ElementsAreArray(ArrayFloatNear({0.46403232, 0.49500442, 0.40443572}))); + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({-1.88343, -1.41188, -1.20958}))); } TEST(CTCBeamSearchTest, MultiPathsTest) { @@ -188,8 +187,8 @@ TEST(CTCBeamSearchTest, MultiPathsTest) { EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2)); // Check log probabilities output. EXPECT_THAT(m.GetLogProbabilitiesOutput(), - ElementsAreArray(ArrayFloatNear( - {0.91318405, 0.9060272, 1.0780245, 0.64358956}))); + ElementsAreArray( + ArrayFloatNear({-2.65148, -2.65864, -2.17914, -2.61357}))); } TEST(CTCBeamSearchTest, NonEqualSequencesTest) { @@ -223,7 +222,7 @@ TEST(CTCBeamSearchTest, NonEqualSequencesTest) { EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1)); // Check log probabilities output. EXPECT_THAT(m.GetLogProbabilitiesOutput(), - ElementsAreArray(ArrayFloatNear({0., 1.0347567, 0.7833005}))); + ElementsAreArray(ArrayFloatNear({-0.97322, -1.16334, -2.15553}))); } } // namespace |