aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/experimental
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-04 23:53:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 23:58:14 -0700
commit6b89e9ffc991e0683cecd7a62e04cdf4a8c88356 (patch)
tree56f8423adb06c11a144e0adfa35bbeed00e859c2 /tensorflow/contrib/lite/experimental
parent7e6885cceb0e0117efb6ef0298770868dcf71436 (diff)
PR #21187: Added a normalization term to ctc_beam_search_decoder for tflite
PiperOrigin-RevId: 211586062
Diffstat (limited to 'tensorflow/contrib/lite/experimental')
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h18
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc13
2 files changed, 21 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
index c658e43092..7c5099235a 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
@@ -257,6 +257,16 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
} else {
max_coeff = raw_input.maxCoeff();
}
+
+ // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
+ float logsumexp = 0.0;
+ for (int j = 0; j < raw_input.size(); ++j) {
+ logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
+ }
+ logsumexp = Eigen::numext::log(logsumexp);
+ // Final normalization offset to get correct log probabilities.
+ float norm_offset = max_coeff + logsumexp;
+
const float label_selection_input_min =
(label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
: -std::numeric_limits<float>::infinity();
@@ -288,10 +298,10 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
beam_scorer_->GetStateExpansionScore(b->state, previous));
}
// Plabel(l=abc @ t=6) *= P(c @ 6)
- b->newp.label += raw_input(b->label) - max_coeff;
+ b->newp.label += raw_input(b->label) - norm_offset;
}
// Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
- b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff;
+ b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
// P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
@@ -326,6 +336,8 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
// Perform label selection: if input for this label looks very
// unpromising, never evaluate it with a scorer.
+ // We may compare logits instead of log probabilities,
+ // since the difference is the same in both cases.
if (logit < label_selection_input_min) {
continue;
}
@@ -339,7 +351,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
// Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
- c.newp.label = logit - max_coeff +
+ c.newp.label = logit - norm_offset +
beam_scorer_->GetStateExpansionScore(c.state, previous);
// P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
c.newp.total = c.newp.label;
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
index 32458305c4..aa42b495bd 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
@@ -117,7 +117,7 @@ TEST(CTCBeamSearchTest, SimpleTest) {
EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1));
// Check log probabilities output.
EXPECT_THAT(m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear({0.32134813})));
+ ElementsAreArray(ArrayFloatNear({-0.357094})));
}
TEST(CTCBeamSearchTest, MultiBatchTest) {
@@ -148,9 +148,8 @@ TEST(CTCBeamSearchTest, MultiBatchTest) {
EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0));
EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2));
// Check log probabilities output.
- EXPECT_THAT(
- m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear({0.46403232, 0.49500442, 0.40443572})));
+ EXPECT_THAT(m.GetLogProbabilitiesOutput(),
+ ElementsAreArray(ArrayFloatNear({-1.88343, -1.41188, -1.20958})));
}
TEST(CTCBeamSearchTest, MultiPathsTest) {
@@ -188,8 +187,8 @@ TEST(CTCBeamSearchTest, MultiPathsTest) {
EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2));
// Check log probabilities output.
EXPECT_THAT(m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear(
- {0.91318405, 0.9060272, 1.0780245, 0.64358956})));
+ ElementsAreArray(
+ ArrayFloatNear({-2.65148, -2.65864, -2.17914, -2.61357})));
}
TEST(CTCBeamSearchTest, NonEqualSequencesTest) {
@@ -223,7 +222,7 @@ TEST(CTCBeamSearchTest, NonEqualSequencesTest) {
EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1));
// Check log probabilities output.
EXPECT_THAT(m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear({0., 1.0347567, 0.7833005})));
+ ElementsAreArray(ArrayFloatNear({-0.97322, -1.16334, -2.15553})));
}
} // namespace