diff options
author | 2018-08-02 20:11:50 -0700 | |
---|---|---|
committer | 2018-08-02 20:15:21 -0700 | |
commit | eecadedbaae7b938e8a80dfb60c52679bcbf7196 (patch) | |
tree | 7f41ed6a9a3be126f5026a809593736d9bd10340 | |
parent | 2d3819668d8c3ab99cd09a769ffb7b76e453fd8f (diff) |
Implementation of ctc beam search decoder op in custom op fashion.
PiperOrigin-RevId: 207210333
22 files changed, 1865 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/experimental/kernels/BUILD b/tensorflow/contrib/lite/experimental/kernels/BUILD new file mode 100644 index 0000000000..9c06c4ebd9 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/BUILD @@ -0,0 +1,84 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +# ctc support classes imported directly from TensorFlow. +cc_library( + name = "ctc_utils", + hdrs = [ + "ctc_beam_entry.h", + "ctc_beam_scorer.h", + "ctc_beam_search.h", + "ctc_decoder.h", + "ctc_loss_util.h", + ], + deps = [ + ":top_n", + "//tensorflow/contrib/lite/kernels/internal:types", + "//third_party/eigen3", + ], +) + +# top_n support classes imported directly from TensorFlow. +cc_library( + name = "top_n", + hdrs = [ + "top_n.h", + ], + deps = [ + "//tensorflow/contrib/lite/kernels/internal:types", + ], +) + +cc_library( + name = "experimental_ops", + srcs = [ + "ctc_beam_search_decoder.cc", + ], + # Suppress warnings that are introduced by Eigen Tensor. + copts = tflite_copts() + [ + "-Wno-error=reorder", + ] + select({ + "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"], + "//conditions:default": [ + ], + }), + deps = [ + ":ctc_utils", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/kernels:op_macros", + "//tensorflow/contrib/lite/kernels/internal:kernel_utils", + "//tensorflow/contrib/lite/kernels/internal:optimized", + "//tensorflow/contrib/lite/kernels/internal:optimized_base", + "//tensorflow/contrib/lite/kernels/internal:quantization_util", + "//tensorflow/contrib/lite/kernels/internal:reference", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + "@flatbuffers", + ], +) + +tf_cc_test( + name = "ctc_beam_search_decoder_test", + size = "small", + srcs = ["ctc_beam_search_decoder_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":experimental_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h new file mode 100644 index 0000000000..a60ff2a1c5 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h @@ -0,0 +1,150 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_beam_entry.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ + +#include <algorithm> +#include <memory> +#include <unordered_map> +#include <vector> + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h" + +namespace tflite { +namespace experimental { +namespace ctc { + +// The ctc_beam_search namespace holds several classes meant to be accessed only +// in case of extending the CTCBeamSearch decoder to allow custom scoring +// functions. +// +// BeamEntry is exposed through template arguments BeamScorer and BeamComparer +// of CTCBeamSearch (ctc_beam_search.h). +namespace ctc_beam_search { + +struct EmptyBeamState {}; + +struct BeamProbability { + BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {} + void Reset() { + total = kLogZero; + blank = kLogZero; + label = kLogZero; + } + float total; + float blank; + float label; +}; + +template <class CTCBeamState> +class BeamRoot; + +template <class CTCBeamState = EmptyBeamState> +struct BeamEntry { + // BeamRoot<CTCBeamState>::AddEntry() serves as the factory method. + friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry( + BeamEntry<CTCBeamState>* p, int l); + inline bool Active() const { return newp.total != kLogZero; } + // Return the child at the given index, or construct a new one in-place if + // none was found. + BeamEntry& GetChild(int ind) { + auto entry = children.emplace(ind, nullptr); + auto& child_entry = entry.first->second; + // If this is a new child, populate the BeamEntry<CTCBeamState>*. + if (entry.second) { + child_entry = beam_root->AddEntry(this, ind); + } + return *child_entry; + } + std::vector<int> LabelSeq(bool merge_repeated) const { + std::vector<int> labels; + int prev_label = -1; + const BeamEntry* c = this; + while (c->parent != nullptr) { // Checking c->parent to skip root leaf. + if (!merge_repeated || c->label != prev_label) { + labels.push_back(c->label); + } + prev_label = c->label; + c = c->parent; + } + std::reverse(labels.begin(), labels.end()); + return labels; + } + + BeamEntry<CTCBeamState>* parent; + int label; + // All instances of child BeamEntry are owned by *beam_root. + std::unordered_map<int, BeamEntry<CTCBeamState>*> children; + BeamProbability oldp; + BeamProbability newp; + CTCBeamState state; + + private: + // Constructor giving parent, label, and the beam_root. + // The object pointed to by p cannot be copied and should not be moved, + // otherwise parent will become invalid. + // This private constructor is only called through the factory method + // BeamRoot<CTCBeamState>::AddEntry(). + BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root) + : parent(p), label(l), beam_root(beam_root) {} + BeamRoot<CTCBeamState>* beam_root; + + BeamEntry(const BeamEntry&) = delete; + void operator=(const BeamEntry&) = delete; +}; + +// This class owns all instances of BeamEntry. This is used to avoid recursive +// destructor call during destruction. +template <class CTCBeamState = EmptyBeamState> +class BeamRoot { + public: + BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); } + BeamRoot(const BeamRoot&) = delete; + BeamRoot& operator=(const BeamRoot&) = delete; + + BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) { + auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this); + beam_entries_.emplace_back(new_entry); + return new_entry; + } + BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; } + + private: + BeamEntry<CTCBeamState>* root_entry_ = nullptr; + std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_; +}; + +// BeamComparer is the default beam comparer provided in CTCBeamSearch. +template <class CTCBeamState = EmptyBeamState> +class BeamComparer { + public: + virtual ~BeamComparer() {} + virtual bool inline operator()(const BeamEntry<CTCBeamState>* a, + const BeamEntry<CTCBeamState>* b) const { + return a->newp.total > b->newp.total; + } +}; + +} // namespace ctc_beam_search + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h new file mode 100644 index 0000000000..ec60e26257 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Collection of scoring classes that can be extended and provided to the +// CTCBeamSearchDecoder to incorporate additional scoring logic (such as a +// language model). +// +// To build a custom scorer extend and implement the pure virtual methods from +// BeamScorerInterface. The default CTC decoding behavior is implemented +// through BaseBeamScorer. + +// Copied from tensorflow/core/util/ctc/ctc_beam_scorer.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ + +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h" + +namespace tflite { +namespace experimental { +namespace ctc { + +// Base implementation of a beam scorer used by default by the decoder that can +// be subclassed and provided as an argument to CTCBeamSearchDecoder, if complex +// scoring is required. Its main purpose is to provide a thin layer for +// integrating language model scoring easily. +template <typename CTCBeamState> +class BaseBeamScorer { + public: + virtual ~BaseBeamScorer() {} + // State initialization. + virtual void InitializeState(CTCBeamState* root) const {} + // ExpandState is called when expanding a beam to one of its children. + // Called at most once per child beam. In the simplest case, no state + // expansion is done. + virtual void ExpandState(const CTCBeamState& from_state, int from_label, + CTCBeamState* to_state, int to_label) const {} + // ExpandStateEnd is called after decoding has finished. Its purpose is to + // allow a final scoring of the beam in its current state, before resorting + // and retrieving the TopN requested candidates. Called at most once per beam. + virtual void ExpandStateEnd(CTCBeamState* state) const {} + // GetStateExpansionScore should be an inexpensive method to retrieve the + // (cached) expansion score computed within ExpandState. The score is + // multiplied (log-addition) with the input score at the current step from + // the network. + // + // The score returned should be a log-probability. In the simplest case, as + // there's no state expansion logic, the expansion score is zero. + virtual float GetStateExpansionScore(const CTCBeamState& state, + float previous_score) const { + return previous_score; + } + // GetStateEndExpansionScore should be an inexpensive method to retrieve the + // (cached) expansion score computed within ExpandStateEnd. The score is + // multiplied (log-addition) with the final probability of the beam. + // + // The score returned should be a log-probability. + virtual float GetStateEndExpansionScore(const CTCBeamState& state) const { + return 0; + } +}; + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h new file mode 100644 index 0000000000..c658e43092 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h @@ -0,0 +1,420 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_beam_search.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ + +#include <algorithm> +#include <cmath> +#include <limits> +#include <memory> +#include <vector> + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h" +#include "tensorflow/contrib/lite/experimental/kernels/top_n.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +namespace tflite { +namespace experimental { +namespace ctc { + +template <typename CTCBeamState = ctc_beam_search::EmptyBeamState, + typename CTCBeamComparer = + ctc_beam_search::BeamComparer<CTCBeamState>> +class CTCBeamSearchDecoder : public CTCDecoder { + // Beam Search + // + // Example (GravesTh Fig. 7.5): + // a - + // P = [ 0.3 0.7 ] t = 0 + // [ 0.4 0.6 ] t = 1 + // + // Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42 + // P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58 + // + // In this case, Best Path decoding is suboptimal. + // + // For Beam Search, we use the following main recurrence relations: + // + // Relation 1: + // ---------------------------------------------------------- Eq. 1 + // P(l=abcd @ t=7) = P(l=abc @ t=6) * P(d @ 7) + // + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7)) + // where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and + // updated recursively in the beam entry. + // + // Relation 2: + // ---------------------------------------------------------- Eq. 2 + // P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3) + // for ? in a, b, d, ..., (not including c or the blank index), + // and the recurrence starts from the beam entry for P(l=abc @ t=2). + // + // For this case, the length of the new sequence equals t+1 (t + // starts at 0). This special case can be calculated as: + // P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3) + // but we calculate it recursively for speed purposes. + typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry; + typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot; + typedef ctc_beam_search::BeamProbability BeamProbability; + + public: + typedef BaseBeamScorer<CTCBeamState> DefaultBeamScorer; + + // The beam search decoder is constructed specifying the beam_width (number of + // candidates to keep at each decoding timestep) and a beam scorer (used for + // custom scoring, for example enabling the use of a language model). + // The ownership of the scorer remains with the caller. The default + // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the + // standard beam search. + CTCBeamSearchDecoder(int num_classes, int beam_width, + BaseBeamScorer<CTCBeamState>* scorer, int batch_size = 1, + bool merge_repeated = false) + : CTCDecoder(num_classes, batch_size, merge_repeated), + beam_width_(beam_width), + leaves_(beam_width), + beam_scorer_(scorer) { + Reset(); + } + + ~CTCBeamSearchDecoder() override {} + + // Run the hibernating beam search algorithm on the given input. + bool Decode(const CTCDecoder::SequenceLength& seq_len, + const std::vector<CTCDecoder::Input>& input, + std::vector<CTCDecoder::Output>* output, + CTCDecoder::ScoreOutput* scores) override; + + // Calculate the next step of the beam search and update the internal state. + template <typename Vector> + void Step(const Vector& log_input_t); + + template <typename Vector> + float GetTopK(const int K, const Vector& input, + std::vector<float>* top_k_logits, + std::vector<int>* top_k_indices); + + // Retrieve the beam scorer instance used during decoding. + BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; } + + // Set label selection parameters for faster decoding. + // See comments for label_selection_size_ and label_selection_margin_. + void SetLabelSelectionParameters(int label_selection_size, + float label_selection_margin) { + label_selection_size_ = label_selection_size; + label_selection_margin_ = label_selection_margin; + } + + // Reset the beam search + void Reset(); + + // Extract the top n paths at current time step + bool TopPaths(int n, std::vector<std::vector<int>>* paths, + std::vector<float>* log_probs, bool merge_repeated) const; + + private: + int beam_width_; + + // Label selection is designed to avoid possibly very expensive scorer calls, + // by pruning the hypotheses based on the input alone. + // Label selection size controls how many items in each beam are passed + // through to the beam scorer. Only items with top N input scores are + // considered. + // Label selection margin controls the difference between minimal input score + // (versus the best scoring label) for an item to be passed to the beam + // scorer. This margin is expressed in terms of log-probability. + // Default is to do no label selection. + // For more detail: https://research.google.com/pubs/pub44823.html + int label_selection_size_ = 0; // zero means unlimited + float label_selection_margin_ = -1; // -1 means unlimited. + + gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_; + std::unique_ptr<BeamRoot> beam_root_; + BaseBeamScorer<CTCBeamState>* beam_scorer_; + + CTCBeamSearchDecoder(const CTCBeamSearchDecoder&) = delete; + void operator=(const CTCBeamSearchDecoder&) = delete; +}; + +template <typename CTCBeamState, typename CTCBeamComparer> +bool CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode( + const CTCDecoder::SequenceLength& seq_len, + const std::vector<CTCDecoder::Input>& input, + std::vector<CTCDecoder::Output>* output, ScoreOutput* scores) { + // Storage for top paths. + std::vector<std::vector<int>> beams; + std::vector<float> beam_log_probabilities; + int top_n = output->size(); + if (std::any_of(output->begin(), output->end(), + [this](const CTCDecoder::Output& output) -> bool { + return output.size() < this->batch_size_; + })) { + return false; + } + if (scores->rows() < batch_size_ || scores->cols() < top_n) { + return false; + } + + for (int b = 0; b < batch_size_; ++b) { + int seq_len_b = seq_len[b]; + Reset(); + + for (int t = 0; t < seq_len_b; ++t) { + // Pass log-probabilities for this example + time. + Step(input[t].row(b)); + } // for (int t... + + // O(n * log(n)) + std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract()); + leaves_.Reset(); + for (int i = 0; i < branches->size(); ++i) { + BeamEntry* entry = (*branches)[i]; + beam_scorer_->ExpandStateEnd(&entry->state); + entry->newp.total += + beam_scorer_->GetStateEndExpansionScore(entry->state); + leaves_.push(entry); + } + + bool status = + TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_); + if (!status) { + return status; + } + + TFLITE_DCHECK_EQ(top_n, beam_log_probabilities.size()); + TFLITE_DCHECK_EQ(beams.size(), beam_log_probabilities.size()); + + for (int i = 0; i < top_n; ++i) { + // Copy output to the correct beam + batch + (*output)[i][b].swap(beams[i]); + (*scores)(b, i) = -beam_log_probabilities[i]; + } + } // for (int b... + return true; +} + +template <typename CTCBeamState, typename CTCBeamComparer> +template <typename Vector> +float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK( + const int K, const Vector& input, std::vector<float>* top_k_logits, + std::vector<int>* top_k_indices) { + // Find Top K choices, complexity nk in worst case. The array input is read + // just once. + TFLITE_DCHECK_EQ(num_classes_, input.size()); + top_k_logits->clear(); + top_k_indices->clear(); + top_k_logits->resize(K, -INFINITY); + top_k_indices->resize(K, -1); + for (int j = 0; j < num_classes_ - 1; ++j) { + const float logit = input(j); + if (logit > (*top_k_logits)[K - 1]) { + int k = K - 1; + while (k > 0 && logit > (*top_k_logits)[k - 1]) { + (*top_k_logits)[k] = (*top_k_logits)[k - 1]; + (*top_k_indices)[k] = (*top_k_indices)[k - 1]; + k--; + } + (*top_k_logits)[k] = logit; + (*top_k_indices)[k] = j; + } + } + // Return max value which is in 0th index or blank character logit + return std::max((*top_k_logits)[0], input(num_classes_ - 1)); +} + +template <typename CTCBeamState, typename CTCBeamComparer> +template <typename Vector> +void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( + const Vector& raw_input) { + std::vector<float> top_k_logits; + std::vector<int> top_k_indices; + const bool top_k = + (label_selection_size_ > 0 && label_selection_size_ < raw_input.size()); + // Number of character classes to consider in each step. + const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1); + // Get max coefficient and remove it from raw_input later. + float max_coeff; + if (top_k) { + max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits, + &top_k_indices); + } else { + max_coeff = raw_input.maxCoeff(); + } + const float label_selection_input_min = + (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_) + : -std::numeric_limits<float>::infinity(); + + // Extract the beams sorted in decreasing new probability + TFLITE_DCHECK_EQ(num_classes_, raw_input.size()); + + std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract()); + leaves_.Reset(); + + for (BeamEntry* b : *branches) { + // P(.. @ t) becomes the new P(.. @ t-1) + b->oldp = b->newp; + } + + for (BeamEntry* b : *branches) { + if (b->parent != nullptr) { // if not the root + if (b->parent->Active()) { + // If last two sequence characters are identical: + // Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5) + // + Pblank(l=ac @ t=5)) + // else: + // Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5) + // + P(l=ab @ t=5)) + float previous = (b->label == b->parent->label) ? b->parent->oldp.blank + : b->parent->oldp.total; + b->newp.label = + LogSumExp(b->newp.label, + beam_scorer_->GetStateExpansionScore(b->state, previous)); + } + // Plabel(l=abc @ t=6) *= P(c @ 6) + b->newp.label += raw_input(b->label) - max_coeff; + } + // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6) + b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff; + // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6) + b->newp.total = LogSumExp(b->newp.blank, b->newp.label); + + // Push the entry back to the top paths list. + // Note, this will always fill leaves back up in sorted order. + leaves_.push(b); + } + + // we need to resort branches in descending oldp order. + + // branches is in descending oldp order because it was + // originally in descending newp order and we copied newp to oldp. + + // Grow new leaves + for (BeamEntry* b : *branches) { + // A new leaf (represented by its BeamProbability) is a candidate + // iff its total probability is nonzero and either the beam list + // isn't full, or the lowest probability entry in the beam has a + // lower probability than the leaf. + auto is_candidate = [this](const BeamProbability& prob) { + return (prob.total > kLogZero && + (leaves_.size() < beam_width_ || + prob.total > leaves_.peek_bottom()->newp.total)); + }; + + if (!is_candidate(b->oldp)) { + continue; + } + + for (int ind = 0; ind < max_classes; ind++) { + const int label = top_k ? top_k_indices[ind] : ind; + const float logit = top_k ? top_k_logits[ind] : raw_input(ind); + // Perform label selection: if input for this label looks very + // unpromising, never evaluate it with a scorer. + if (logit < label_selection_input_min) { + continue; + } + BeamEntry& c = b->GetChild(label); + if (!c.Active()) { + // Pblank(l=abcd @ t=6) = 0 + c.newp.blank = kLogZero; + // If new child label is identical to beam label: + // Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6) + // Otherwise: + // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6) + beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label); + float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total; + c.newp.label = logit - max_coeff + + beam_scorer_->GetStateExpansionScore(c.state, previous); + // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6) + c.newp.total = c.newp.label; + + if (is_candidate(c.newp)) { + // Before adding the new node to the beam, check if the beam + // is already at maximum width. + if (leaves_.size() == beam_width_) { + // Bottom is no longer in the beam search. Reset + // its probability; signal it's no longer in the beam search. + BeamEntry* bottom = leaves_.peek_bottom(); + bottom->newp.Reset(); + } + leaves_.push(&c); + } else { + // Deactivate child. + c.oldp.Reset(); + c.newp.Reset(); + } + } + } + } // for (BeamEntry* b... +} + +template <typename CTCBeamState, typename CTCBeamComparer> +void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() { + leaves_.Reset(); + + // This beam root, and all of its children, will be in memory until + // the next reset. + beam_root_.reset(new BeamRoot(nullptr, -1)); + beam_root_->RootEntry()->newp.total = 0.0; // ln(1) + beam_root_->RootEntry()->newp.blank = 0.0; // ln(1) + + // Add the root as the initial leaf. + leaves_.push(beam_root_->RootEntry()); + + // Call initialize state on the root object. + beam_scorer_->InitializeState(&beam_root_->RootEntry()->state); +} + +template <typename CTCBeamState, typename CTCBeamComparer> +bool CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths( + int n, std::vector<std::vector<int>>* paths, std::vector<float>* log_probs, + bool merge_repeated) const { + TFLITE_DCHECK(paths); + TFLITE_DCHECK(log_probs); + paths->clear(); + log_probs->clear(); + if (n > beam_width_) { + return false; + } + if (n > leaves_.size()) { + return false; + } + + gtl::TopN<BeamEntry*, CTCBeamComparer> top_branches(n); + + // O(beam_width_ * log(n)), space complexity is O(n) + for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) { + top_branches.push(*it); + } + // O(n * log(n)) + std::unique_ptr<std::vector<BeamEntry*>> branches(top_branches.Extract()); + + for (int i = 0; i < n; ++i) { + BeamEntry* e((*branches)[i]); + paths->push_back(e->LabelSeq(merge_repeated)); + log_probs->push_back(e->newp.total); + } + return true; +} + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc new file mode 100644 index 0000000000..834d1ebd66 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc @@ -0,0 +1,247 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <vector> +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace experimental { +namespace ctc_beam_search_decoder { + +constexpr int kInputsTensor = 0; +constexpr int kSequenceLengthTensor = 1; + +typedef struct { + int beam_width; + int top_paths; + bool merge_repeated; +} CTCBeamSearchDecoderParams; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_CHECK(buffer != nullptr); + const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + + CTCBeamSearchDecoderParams* option = new CTCBeamSearchDecoderParams; + option->beam_width = m["beam_width"].AsInt32(); + option->top_paths = m["top_paths"].AsInt32(); + option->merge_repeated = m["merge_repeated"].AsBool(); + + return option; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<CTCBeamSearchDecoderParams*>(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const CTCBeamSearchDecoderParams* option = + reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data); + const int top_paths = option->top_paths; + TF_LITE_ENSURE(context, option->beam_width >= top_paths); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + // The outputs should be top_paths * 3 + 1. + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3 * top_paths + 1); + + const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(inputs), 3); + // TensorFlow only supports float. + TF_LITE_ENSURE_EQ(context, inputs->type, kTfLiteFloat32); + const int batch_size = SizeOfDimension(inputs, 1); + + const TfLiteTensor* sequence_length = + GetInput(context, node, kSequenceLengthTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(sequence_length), 1); + TF_LITE_ENSURE_EQ(context, NumElements(sequence_length), batch_size); + // TensorFlow only supports int32. + TF_LITE_ENSURE_EQ(context, sequence_length->type, kTfLiteInt32); + + // Resize decoded outputs. + // Do not resize indices & values cause we don't know the values yet. + for (int i = 0; i < top_paths; ++i) { + TfLiteTensor* indices = GetOutput(context, node, i); + SetTensorToDynamic(indices); + TfLiteTensor* values = GetOutput(context, node, i + top_paths); + SetTensorToDynamic(values); + TfLiteTensor* output_shape = GetOutput(context, node, i + 2 * top_paths); + SetTensorToDynamic(output_shape); + } + + // Resize log probability outputs. + TfLiteTensor* log_probability_output = + GetOutput(context, node, top_paths * 3); + TfLiteIntArray* log_probability_output_shape_array = TfLiteIntArrayCreate(2); + log_probability_output_shape_array->data[0] = batch_size; + log_probability_output_shape_array->data[1] = top_paths; + return context->ResizeTensor(context, log_probability_output, + log_probability_output_shape_array); +} + +TfLiteStatus Resize(TfLiteContext* context, + std::initializer_list<int32_t> output_shape, + TfLiteTensor* output) { + const int dimensions = output_shape.size(); + TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(dimensions); + int i = 0; + for (const int v : output_shape) { + output_shape_array->data[i++] = v; + } + return context->ResizeTensor(context, output, output_shape_array); +} + +TfLiteStatus StoreAllDecodedSequences( + TfLiteContext* context, + const std::vector<std::vector<std::vector<int>>>& sequences, + TfLiteNode* node, int top_paths) { + const int32_t batch_size = sequences.size(); + std::vector<int32_t> num_entries(top_paths, 0); + + // Calculate num_entries per path + for (const auto& batch_s : sequences) { + TF_LITE_ENSURE_EQ(context, batch_s.size(), top_paths); + for (int p = 0; p < top_paths; ++p) { + num_entries[p] += batch_s[p].size(); + } + } + + for (int p = 0; p < top_paths; ++p) { + const int32_t p_num = num_entries[p]; + + // Resize the decoded outputs. + TfLiteTensor* indices = GetOutput(context, node, p); + TF_LITE_ENSURE_OK(context, Resize(context, {p_num, 2}, indices)); + + TfLiteTensor* values = GetOutput(context, node, p + top_paths); + TF_LITE_ENSURE_OK(context, Resize(context, {p_num}, values)); + + TfLiteTensor* decoded_shape = GetOutput(context, node, p + 2 * top_paths); + TF_LITE_ENSURE_OK(context, Resize(context, {2}, decoded_shape)); + + int32_t max_decoded = 0; + int32_t offset = 0; + + int32_t* indices_data = GetTensorData<int32_t>(indices); + int32_t* values_data = GetTensorData<int32_t>(values); + int32_t* decoded_shape_data = GetTensorData<int32_t>(decoded_shape); + for (int b = 0; b < batch_size; ++b) { + auto& p_batch = sequences[b][p]; + int32_t num_decoded = p_batch.size(); + max_decoded = std::max(max_decoded, num_decoded); + + std::copy_n(p_batch.begin(), num_decoded, values_data + offset); + for (int32_t t = 0; t < num_decoded; ++t, ++offset) { + indices_data[offset * 2] = b; + indices_data[offset * 2 + 1] = t; + } + } + + decoded_shape_data[0] = batch_size; + decoded_shape_data[1] = max_decoded; + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor); + const TfLiteTensor* sequence_length = + GetInput(context, node, kSequenceLengthTensor); + const CTCBeamSearchDecoderParams* option = + reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data); + + const int max_time = SizeOfDimension(inputs, 0); + const int batch_size = SizeOfDimension(inputs, 1); + const int num_classes = SizeOfDimension(inputs, 2); + + const int beam_width = option->beam_width; + const int top_paths = option->top_paths; + const bool merge_repeated = option->merge_repeated; + + // Validate sequence length is less or equal than max time. + for (int i = 0; i < batch_size; ++i) { + TF_LITE_ENSURE(context, + max_time >= GetTensorData<int32_t>(sequence_length)[i]); + } + + // The following logic is implemented like + // tensorflow/core/kernels/ctc_decoder_ops.cc + std::vector<optimized_ops::TTypes<float>::UnalignedConstMatrix> input_list_t; + + for (std::size_t t = 0; t < max_time; ++t) { + input_list_t.emplace_back( + GetTensorData<float>(inputs) + t * batch_size * num_classes, batch_size, + num_classes); + } + + ::tflite::experimental::ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer + beam_scorer; + ::tflite::experimental::ctc::CTCBeamSearchDecoder<> beam_search( + num_classes, beam_width, &beam_scorer, 1 /* batch_size */, + merge_repeated); + + // Allocate temporary memory for holding chip operation data. + float* input_chip_t_data = + static_cast<float*>(malloc(num_classes * sizeof(float))); + Eigen::array<Eigen::DenseIndex, 1> dims; + dims[0] = num_classes; + optimized_ops::TTypes<float>::Flat input_chip_t(input_chip_t_data, dims); + + std::vector<std::vector<std::vector<int>>> best_paths(batch_size); + std::vector<float> log_probs; + + TfLiteTensor* log_probabilities = GetOutput(context, node, 3 * top_paths); + float* log_probabilities_output = GetTensorData<float>(log_probabilities); + + // Assumption: the blank index is num_classes - 1 + for (int b = 0; b < batch_size; ++b) { + auto& best_paths_b = best_paths[b]; + best_paths_b.resize(top_paths); + for (int t = 0; t < GetTensorData<int32_t>(sequence_length)[b]; ++t) { + input_chip_t = input_list_t[t].chip(b, 0); + auto input_bi = + Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes); + beam_search.Step(input_bi); + } + TF_LITE_ENSURE(context, beam_search.TopPaths(top_paths, &best_paths_b, + &log_probs, merge_repeated)); + beam_search.Reset(); + + // Fill in log_probabilities output. + for (int bp = 0; bp < top_paths; ++bp) { + log_probabilities_output[b * top_paths + bp] = log_probs[bp]; + } + } + + free(input_chip_t_data); + return StoreAllDecodedSequences(context, best_paths, node, top_paths); +} + +} // namespace ctc_beam_search_decoder + +TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER() { + static TfLiteRegistration r = { + ctc_beam_search_decoder::Init, ctc_beam_search_decoder::Free, + ctc_beam_search_decoder::Prepare, ctc_beam_search_decoder::Eval}; + return &r; +} + +} // namespace experimental +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc new file mode 100644 index 0000000000..9d1e6a562f --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc @@ -0,0 +1,238 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <functional> +#include <memory> +#include <vector> + +#include <gtest/gtest.h> +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace experimental { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER(); + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class CTCBeamSearchDecoderOpModel : public SingleOpModel { + public: + CTCBeamSearchDecoderOpModel(std::initializer_list<int> input_shape, + std::initializer_list<int> sequence_length_shape, + int beam_width, int top_paths, + bool merge_repeated) { + inputs_ = AddInput(TensorType_FLOAT32); + sequence_length_ = AddInput(TensorType_INT32); + + for (int i = 0; i < top_paths * 3; ++i) { + outputs_.push_back(AddOutput(TensorType_INT32)); + } + outputs_.push_back(AddOutput(TensorType_FLOAT32)); + + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("beam_width", beam_width); + fbb.Int("top_paths", top_paths); + fbb.Bool("merge_repeated", merge_repeated); + }); + fbb.Finish(); + SetCustomOp("CTCBeamSearchDecoder", fbb.GetBuffer(), + Register_CTC_BEAM_SEARCH_DECODER); + BuildInterpreter({input_shape, sequence_length_shape}); + } + + int inputs() { return inputs_; } + + int sequence_length() { return sequence_length_; } + + std::vector<std::vector<int>> GetDecodedOutpus() { + std::vector<std::vector<int>> outputs; + for (int i = 0; i < outputs_.size() - 1; ++i) { + outputs.push_back(ExtractVector<int>(outputs_[i])); + } + return outputs; + } + + std::vector<float> GetLogProbabilitiesOutput() { + return ExtractVector<float>(outputs_[outputs_.size() - 1]); + } + + std::vector<std::vector<int>> GetOutputShapes() { + std::vector<std::vector<int>> output_shapes; + for (const int output : outputs_) { + output_shapes.push_back(GetTensorShape(output)); + } + return output_shapes; + } + + private: + int inputs_; + int sequence_length_; + std::vector<int> outputs_; +}; + +TEST(CTCBeamSearchTest, SimpleTest) { + CTCBeamSearchDecoderOpModel m({2, 1, 2}, {1}, 1, 1, true); + m.PopulateTensor<float>(m.inputs(), + {-0.50922557, -1.35512652, -2.55445064, -1.58419356}); + m.PopulateTensor<int>(m.sequence_length(), {2}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 4); + EXPECT_THAT(output_shapes[0], ElementsAre(1, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(1)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + EXPECT_THAT(output_shapes[3], ElementsAre(1, 1)); + + // Check decoded outputs. + const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 3); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(0)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1)); + // Check log probabilities output. + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({0.32134813}))); +} + +TEST(CTCBeamSearchTest, MultiBatchTest) { + CTCBeamSearchDecoderOpModel m({3, 3, 3}, {3}, 1, 1, true); + m.PopulateTensor<float>( + m.inputs(), + {-0.63649208, -0.00487571, -0.04249819, -0.67754697, -1.0341399, + -2.14717721, -0.77686821, -3.41973774, -0.05151402, -0.21482619, + -0.57411168, -1.45039917, -0.73769373, -2.10941739, -0.44818325, + -0.25287673, -2.80057302, -0.54748312, -0.73334867, -0.86537719, + -0.2065197, -0.18725838, -1.42770405, -0.86051965, -1.61642301, + -2.07275114, -0.9201845}); + m.PopulateTensor<int>(m.sequence_length(), {3, 3, 3}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 4); + EXPECT_THAT(output_shapes[0], ElementsAre(4, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(4)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + EXPECT_THAT(output_shapes[3], ElementsAre(3, 1)); + + // Check decoded outputs. + const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 3); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 0, 1, 1, 0, 2, 0)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2)); + // Check log probabilities output. + EXPECT_THAT( + m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({0.46403232, 0.49500442, 0.40443572}))); +} + +TEST(CTCBeamSearchTest, MultiPathsTest) { + CTCBeamSearchDecoderOpModel m({3, 2, 5}, {2}, 3, 2, true); + m.PopulateTensor<float>( + m.inputs(), + {-2.206851, -0.09542714, -0.2393415, -3.81866197, -0.27241158, + -0.20371124, -0.68236623, -1.1397166, -0.17422639, -1.85224048, + -0.9406037, -0.32544678, -0.21846784, -0.38377237, -0.33498676, + -0.10139782, -0.51886883, -0.21678554, -0.15267063, -1.91164412, + -0.31328673, -0.27462716, -0.65975336, -1.53671973, -2.76554225, + -0.23920634, -1.2370502, -4.98751576, -3.12995717, -0.43129368}); + m.PopulateTensor<int>(m.sequence_length(), {3, 3}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 7); + EXPECT_THAT(output_shapes[0], ElementsAre(4, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(3, 2)); + EXPECT_THAT(output_shapes[2], ElementsAre(4)); + EXPECT_THAT(output_shapes[3], ElementsAre(3)); + EXPECT_THAT(output_shapes[4], ElementsAre(2)); + EXPECT_THAT(output_shapes[5], ElementsAre(2)); + EXPECT_THAT(output_shapes[6], ElementsAre(2, 2)); + + // Check decoded outputs. + const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 6); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 0, 1, 1, 0, 1, 1)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(0, 0, 0, 1, 1, 0)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 2, 3, 0)); + EXPECT_THAT(decoded_outputs[3], ElementsAre(2, 1, 0)); + EXPECT_THAT(decoded_outputs[4], ElementsAre(2, 2)); + EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2)); + // Check log probabilities output. + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear( + {0.91318405, 0.9060272, 1.0780245, 0.64358956}))); +} + +TEST(CTCBeamSearchTest, NonEqualSequencesTest) { + CTCBeamSearchDecoderOpModel m({3, 3, 4}, {3}, 3, 1, true); + m.PopulateTensor<float>( + m.inputs(), + {-1.26658163, -0.25760023, -0.03917975, -0.63772235, -0.03794756, + -0.45063099, -0.27706473, -0.01569179, -0.59940385, -0.35700127, + -0.48920721, -1.42635476, -1.3462478, -0.02565498, -0.30179568, + -0.6491698, -0.55017719, -2.92291466, -0.92522973, -0.47592022, + -0.07099135, -0.31575624, -0.86345281, -0.36017021, -0.79208612, + -1.75306124, -0.65089224, -0.00912786, -0.42915003, -1.72606203, + -1.66337589, -0.70800793, -2.52272352, -0.67329562, -2.49145522, + -0.49786342}); + m.PopulateTensor<int>(m.sequence_length(), {1, 2, 3}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 4); + EXPECT_THAT(output_shapes[0], ElementsAre(3, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(3)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + EXPECT_THAT(output_shapes[3], ElementsAre(3, 1)); + + // Check decoded outputs. + const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 3); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 1, 0, 2, 0)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(2, 0, 1)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1)); + // Check log probabilities output. + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({0., 1.0347567, 0.7833005}))); +} + +} // namespace +} // namespace experimental +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h b/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h new file mode 100644 index 0000000000..596ad4a5f7 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h @@ -0,0 +1,114 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_decoder.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ + +#include <memory> +#include <vector> + +#include "third_party/eigen3/Eigen/Core" + +namespace tflite { +namespace experimental { +namespace ctc { + +// The CTCDecoder is an abstract interface to be implemented when providing a +// decoding method on the timestep output of a RNN trained with CTC loss. +// +// The two types of decoding available are: +// - greedy path, through the CTCGreedyDecoder +// - beam search, through the CTCBeamSearchDecoder +class CTCDecoder { + public: + typedef Eigen::Map<const Eigen::ArrayXi> SequenceLength; + typedef Eigen::Map<const Eigen::MatrixXf> Input; + typedef std::vector<std::vector<int>> Output; + typedef Eigen::Map<Eigen::MatrixXf> ScoreOutput; + + CTCDecoder(int num_classes, int batch_size, bool merge_repeated) + : num_classes_(num_classes), + blank_index_(num_classes - 1), + batch_size_(batch_size), + merge_repeated_(merge_repeated) {} + + virtual ~CTCDecoder() {} + + // Dimensionality of the input/output is expected to be: + // - seq_len[b] - b = 0 to batch_size_ + // - input[t].rows(b) - t = 0 to timesteps; b = 0 t batch_size_ + // - output.size() specifies the number of beams to be returned. + // - scores(b, i) - b = 0 to batch_size; i = 0 to output.size() + virtual bool Decode(const SequenceLength& seq_len, + const std::vector<Input>& input, + std::vector<Output>* output, ScoreOutput* scores) = 0; + + int batch_size() { return batch_size_; } + int num_classes() { return num_classes_; } + + protected: + int num_classes_; + int blank_index_; + int batch_size_; + bool merge_repeated_; +}; + +// CTCGreedyDecoder is an implementation of the simple best path decoding +// algorithm, selecting at each timestep the most likely class at each timestep. +class CTCGreedyDecoder : public CTCDecoder { + public: + CTCGreedyDecoder(int num_classes, int batch_size, bool merge_repeated) + : CTCDecoder(num_classes, batch_size, merge_repeated) {} + + bool Decode(const CTCDecoder::SequenceLength& seq_len, + const std::vector<CTCDecoder::Input>& input, + std::vector<CTCDecoder::Output>* output, + CTCDecoder::ScoreOutput* scores) override { + if (output->empty() || (*output)[0].size() < batch_size_) { + return false; + } + if (scores->rows() < batch_size_ || scores->cols() == 0) { + return false; + } + // For each batch entry, identify the transitions + for (int b = 0; b < batch_size_; ++b) { + int seq_len_b = seq_len[b]; + // Only writing to beam 0 + std::vector<int>& output_b = (*output)[0][b]; + + int prev_class_ix = -1; + (*scores)(b, 0) = 0; + for (int t = 0; t < seq_len_b; ++t) { + auto row = input[t].row(b); + int max_class_ix; + (*scores)(b, 0) += -row.maxCoeff(&max_class_ix); + if (max_class_ix != blank_index_ && + !(merge_repeated_ && max_class_ix == prev_class_ix)) { + output_b.push_back(max_class_ix); + } + prev_class_ix = max_class_ix; + } + } + return true; + } +}; + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h b/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h new file mode 100644 index 0000000000..0bae732533 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_loss_util.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ + +#include <cmath> +#include <limits> + +namespace tflite { +namespace experimental { +namespace ctc { + +const float kLogZero = -std::numeric_limits<float>::infinity(); + +// Add logarithmic probabilities using: +// ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a))) +// The two inputs are assumed to be log probabilities. +// (GravesTh) Eq. 7.18 +inline float LogSumExp(float log_prob_1, float log_prob_2) { + // Always have 'b' be the smaller number to avoid the exponential from + // blowing up. + if (log_prob_1 == kLogZero && log_prob_2 == kLogZero) { + return kLogZero; + } else { + return (log_prob_1 > log_prob_2) + ? log_prob_1 + log1pf(expf(log_prob_2 - log_prob_1)) + : log_prob_2 + log1pf(expf(log_prob_1 - log_prob_2)); + } +} + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/top_n.h b/tensorflow/contrib/lite/experimental/kernels/top_n.h new file mode 100644 index 0000000000..cd2a2f1c80 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/top_n.h @@ -0,0 +1,341 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This simple class finds the top n elements of an incrementally provided set +// of elements which you push one at a time. If the number of elements exceeds +// n, the lowest elements are incrementally dropped. At the end you get +// a vector of the top elements sorted in descending order (through Extract() or +// ExtractNondestructive()), or a vector of the top elements but not sorted +// (through ExtractUnsorted() or ExtractUnsortedNondestructive()). +// +// The value n is specified in the constructor. If there are p elements pushed +// altogether: +// The total storage requirements are O(min(n, p)) elements +// The running time is O(p * log(min(n, p))) comparisons +// If n is a constant, the total storage required is a constant and the running +// time is linear in p. +// +// NOTE(zhifengc): There is a way to do this in O(min(n, p)) storage and O(p) +// runtime. The basic idea is to repeatedly fill up a buffer of 2 * n elements, +// discarding the lowest n elements whenever the buffer is full using a linear- +// time median algorithm. This may have better performance when the input +// sequence is partially sorted. +// +// NOTE(zhifengc): This class should be redesigned to avoid reallocating a +// vector for each Extract. + +// Copied from tensorflow/core/lib/gtl/top_n.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ + +#include <stddef.h> +#include <algorithm> +#include <functional> +#include <string> +#include <vector> + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +namespace tflite { +namespace gtl { + +// Cmp is an stl binary predicate. Note that Cmp is the "greater" predicate, +// not the more commonly used "less" predicate. +// +// If you use a "less" predicate here, the TopN will pick out the bottom N +// elements out of the ones passed to it, and it will return them sorted in +// ascending order. +// +// TopN is rule-of-zero copyable and movable if its members are. +template <class T, class Cmp = std::greater<T> > +class TopN { + public: + // The TopN is in one of the three states: + // + // o UNORDERED: this is the state an instance is originally in, + // where the elements are completely orderless. + // + // o BOTTOM_KNOWN: in this state, we keep the invariant that there + // is at least one element in it, and the lowest element is at + // position 0. The elements in other positions remain + // unsorted. This state is reached if the state was originally + // UNORDERED and a peek_bottom() function call is invoked. + // + // o HEAP_SORTED: in this state, the array is kept as a heap and + // there are exactly (limit_+1) elements in the array. This + // state is reached when at least (limit_+1) elements are + // pushed in. + // + // The state transition graph is at follows: + // + // peek_bottom() (limit_+1) elements + // UNORDERED --------------> BOTTOM_KNOWN --------------------> HEAP_SORTED + // | ^ + // | (limit_+1) elements | + // +-----------------------------------------------------------+ + + enum State { UNORDERED, BOTTOM_KNOWN, HEAP_SORTED }; + using UnsortedIterator = typename std::vector<T>::const_iterator; + + // 'limit' is the maximum number of top results to return. + explicit TopN(size_t limit) : TopN(limit, Cmp()) {} + TopN(size_t limit, const Cmp &cmp) : limit_(limit), cmp_(cmp) {} + + size_t limit() const { return limit_; } + + // Number of elements currently held by this TopN object. This + // will be no greater than 'limit' passed to the constructor. + size_t size() const { return std::min(elements_.size(), limit_); } + + bool empty() const { return size() == 0; } + + // If you know how many elements you will push at the time you create the + // TopN object, you can call reserve to preallocate the memory that TopN + // will need to process all 'n' pushes. Calling this method is optional. + void reserve(size_t n) { elements_.reserve(std::min(n, limit_ + 1)); } + + // Push 'v'. If the maximum number of elements was exceeded, drop the + // lowest element and return it in 'dropped' (if given). If the maximum is not + // exceeded, 'dropped' will remain unchanged. 'dropped' may be omitted or + // nullptr, in which case it is not filled in. + // Requires: T is CopyAssignable, Swappable + void push(const T &v) { push(v, nullptr); } + void push(const T &v, T *dropped) { PushInternal(v, dropped); } + + // Move overloads of push. + // Requires: T is MoveAssignable, Swappable + void push(T &&v) { // NOLINT(build/c++11) + push(std::move(v), nullptr); + } + void push(T &&v, T *dropped) { // NOLINT(build/c++11) + PushInternal(std::move(v), dropped); + } + + // Peeks the bottom result without calling Extract() + const T &peek_bottom(); + + // Extract the elements as a vector sorted in descending order. The caller + // assumes ownership of the vector and must delete it when done. This is a + // destructive operation. The only method that can be called immediately + // after Extract() is Reset(). + std::vector<T> *Extract(); + + // Similar to Extract(), but makes no guarantees the elements are in sorted + // order. As with Extract(), the caller assumes ownership of the vector and + // must delete it when done. This is a destructive operation. The only + // method that can be called immediately after ExtractUnsorted() is Reset(). + std::vector<T> *ExtractUnsorted(); + + // A non-destructive version of Extract(). Copy the elements in a new vector + // sorted in descending order and return it. The caller assumes ownership of + // the new vector and must delete it when done. After calling + // ExtractNondestructive(), the caller can continue to push() new elements. + std::vector<T> *ExtractNondestructive() const; + + // A non-destructive version of Extract(). Copy the elements to a given + // vector sorted in descending order. After calling + // ExtractNondestructive(), the caller can continue to push() new elements. + // Note: + // 1. The given argument must to be allocated. + // 2. Any data contained in the vector prior to the call will be deleted + // from it. After the call the vector will contain only the elements + // from the data structure. + void ExtractNondestructive(std::vector<T> *output) const; + + // A non-destructive version of ExtractUnsorted(). Copy the elements in a new + // vector and return it, with no guarantees the elements are in sorted order. + // The caller assumes ownership of the new vector and must delete it when + // done. After calling ExtractUnsortedNondestructive(), the caller can + // continue to push() new elements. + std::vector<T> *ExtractUnsortedNondestructive() const; + + // A non-destructive version of ExtractUnsorted(). Copy the elements into + // a given vector, with no guarantees the elements are in sorted order. + // After calling ExtractUnsortedNondestructive(), the caller can continue + // to push() new elements. + // Note: + // 1. The given argument must to be allocated. + // 2. Any data contained in the vector prior to the call will be deleted + // from it. After the call the vector will contain only the elements + // from the data structure. + void ExtractUnsortedNondestructive(std::vector<T> *output) const; + + // Return an iterator to the beginning (end) of the container, + // with no guarantees about the order of iteration. These iterators are + // invalidated by mutation of the data structure. + UnsortedIterator unsorted_begin() const { return elements_.begin(); } + UnsortedIterator unsorted_end() const { return elements_.begin() + size(); } + + // Accessor for comparator template argument. + Cmp *comparator() { return &cmp_; } + + // This removes all elements. If Extract() or ExtractUnsorted() have been + // called, this will put it back in an empty but useable state. + void Reset(); + + private: + template <typename U> + void PushInternal(U &&v, T *dropped); // NOLINT(build/c++11) + + // elements_ can be in one of two states: + // elements_.size() <= limit_: elements_ is an unsorted vector of elements + // pushed so far. + // elements_.size() > limit_: The last element of elements_ is unused; + // the other elements of elements_ are an stl heap whose size is exactly + // limit_. In this case elements_.size() is exactly one greater than + // limit_, but don't use "elements_.size() == limit_ + 1" to check for + // that because you'll get a false positive if limit_ == size_t(-1). + std::vector<T> elements_; + size_t limit_; // Maximum number of elements to find + Cmp cmp_; // Greater-than comparison function + State state_ = UNORDERED; +}; + +// ---------------------------------------------------------------------- +// Implementations of non-inline functions + +template <class T, class Cmp> +template <typename U> +void TopN<T, Cmp>::PushInternal(U &&v, T *dropped) { // NOLINT(build/c++11) + if (limit_ == 0) { + if (dropped) *dropped = std::forward<U>(v); // NOLINT(build/c++11) + return; + } + if (state_ != HEAP_SORTED) { + elements_.push_back(std::forward<U>(v)); // NOLINT(build/c++11) + if (state_ == UNORDERED || cmp_(elements_.back(), elements_.front())) { + // Easy case: we just pushed the new element back + } else { + // To maintain the BOTTOM_KNOWN state, we need to make sure that + // the element at position 0 is always the smallest. So we put + // the new element at position 0 and push the original bottom + // element in the back. + // Warning: this code is subtle. + using std::swap; + swap(elements_.front(), elements_.back()); + } + if (elements_.size() == limit_ + 1) { + // Transition from unsorted vector to a heap. + std::make_heap(elements_.begin(), elements_.end(), cmp_); + if (dropped) *dropped = std::move(elements_.front()); + std::pop_heap(elements_.begin(), elements_.end(), cmp_); + state_ = HEAP_SORTED; + } + } else { + // Only insert the new element if it is greater than the least element. + if (cmp_(v, elements_.front())) { + elements_.back() = std::forward<U>(v); // NOLINT(build/c++11) + std::push_heap(elements_.begin(), elements_.end(), cmp_); + if (dropped) *dropped = std::move(elements_.front()); + std::pop_heap(elements_.begin(), elements_.end(), cmp_); + } else { + if (dropped) *dropped = std::forward<U>(v); // NOLINT(build/c++11) + } + } +} + +template <class T, class Cmp> +const T &TopN<T, Cmp>::peek_bottom() { + TFLITE_DCHECK(!empty()); + if (state_ == UNORDERED) { + // We need to do a linear scan to find out the bottom element + int min_candidate = 0; + for (size_t i = 1; i < elements_.size(); ++i) { + if (cmp_(elements_[min_candidate], elements_[i])) { + min_candidate = i; + } + } + // By swapping the element at position 0 and the minimal + // element, we transition to the BOTTOM_KNOWN state + if (min_candidate != 0) { + using std::swap; + swap(elements_[0], elements_[min_candidate]); + } + state_ = BOTTOM_KNOWN; + } + return elements_.front(); +} + +template <class T, class Cmp> +std::vector<T> *TopN<T, Cmp>::Extract() { + auto out = new std::vector<T>; + out->swap(elements_); + if (state_ != HEAP_SORTED) { + std::sort(out->begin(), out->end(), cmp_); + } else { + out->pop_back(); + std::sort_heap(out->begin(), out->end(), cmp_); + } + return out; +} + +template <class T, class Cmp> +std::vector<T> *TopN<T, Cmp>::ExtractUnsorted() { + auto out = new std::vector<T>; + out->swap(elements_); + if (state_ == HEAP_SORTED) { + // Remove the limit_+1'th element. + out->pop_back(); + } + return out; +} + +template <class T, class Cmp> +std::vector<T> *TopN<T, Cmp>::ExtractNondestructive() const { + auto out = new std::vector<T>; + ExtractNondestructive(out); + return out; +} + +template <class T, class Cmp> +void TopN<T, Cmp>::ExtractNondestructive(std::vector<T> *output) const { + TFLITE_DCHECK(output); + *output = elements_; + if (state_ != HEAP_SORTED) { + std::sort(output->begin(), output->end(), cmp_); + } else { + output->pop_back(); + std::sort_heap(output->begin(), output->end(), cmp_); + } +} + +template <class T, class Cmp> +std::vector<T> *TopN<T, Cmp>::ExtractUnsortedNondestructive() const { + auto elements = new std::vector<T>; + ExtractUnsortedNondestructive(elements); + return elements; +} + +template <class T, class Cmp> +void TopN<T, Cmp>::ExtractUnsortedNondestructive(std::vector<T> *output) const { + TFLITE_DCHECK(output); + *output = elements_; + if (state_ == HEAP_SORTED) { + // Remove the limit_+1'th element. + output->pop_back(); + } +} + +template <class T, class Cmp> +void TopN<T, Cmp>::Reset() { + elements_.clear(); + state_ = UNORDERED; +} + +} // namespace gtl +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index bb1dcdda6e..ebb2c7a8eb 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -168,6 +168,18 @@ ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data, return ArrayMap<Scalar>(data, rows, cols); } +// Copied from tensorflow/core/framework/tensor_types.h +template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex> +struct TTypes { + // Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>, + Eigen::Aligned> + Flat; + typedef Eigen::TensorMap< + Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>> + UnalignedConstMatrix; +}; + // TODO(b/62193649): this function is only needed as long // as we have the --variable_batch hack. template <typename Scalar, int N> diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 378212cb74..8b41865985 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -1940,6 +1940,21 @@ void ConvertLogicalOrOperator(const Model& model, (*logical_or_op->mutable_attr())["T"].set_type(data_type); } +void ConvertCTCBeamSearchDecoderOperator( + const Model& model, const CTCBeamSearchDecoderOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + auto* op = tensorflow_graph->add_node(); + op->set_op(op_name); + op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *op->add_input() = src_op.inputs[i]; + } + (*op->mutable_attr())["beam_width"].set_i(src_op.beam_width); + (*op->mutable_attr())["top_paths"].set_i(src_op.top_paths); + (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -2194,6 +2209,10 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertLogicalOrOperator(model, static_cast<const LogicalOrOperator&>(src_op), "LogicalOr", tensorflow_graph); + } else if (src_op.type == OperatorType::kCTCBeamSearchDecoder) { + ConvertCTCBeamSearchDecoderOperator( + model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op), + "CTCBeamSearchDecoder", tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index f033ee013e..c8310161cb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -215,6 +215,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { model->GetArray(op->outputs[0]).data_type = on_value_type; break; } + case OperatorType::kCTCBeamSearchDecoder: { + CHECK_EQ(op->inputs.size(), 2); + // All outputs (sparse tensors) are int32s (although tf uses int64s) + // except the last one (log probabilities) is float. + const int output_size = op->outputs.size(); + for (int i = 0; i < output_size - 1; ++i) { + model->GetArray(op->outputs[i]).data_type = ArrayDataType::kInt32; + } + model->GetArray(op->outputs[output_size - 1]).data_type = + ArrayDataType::kFloat; + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 9a3db5c888..9a404c2606 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1854,6 +1854,34 @@ tensorflow::Status ConvertOneHotOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertCTCBeamSearchDecoderOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "CTCBeamSearchDecoder"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + + auto* op = new CTCBeamSearchDecoderOperator; + for (const string& input : node.input()) { + op->inputs.push_back(input); + } + + op->beam_width = + HasAttr(node, "beam_width") ? GetIntAttr(node, "beam_width") : 1; + op->top_paths = + HasAttr(node, "top_paths") ? GetIntAttr(node, "top_paths") : 1; + op->merge_repeated = HasAttr(node, "merge_repeated") + ? GetBoolAttr(node, "merge_repeated") + : true; + + // There are top_paths + 1 outputs. + op->outputs.push_back(node.name()); // Implicit :0. + for (int i = 0; i < op->top_paths; ++i) { + op->outputs.push_back(node.name() + ":" + std::to_string(i + 1)); + } + model->operators.emplace_back(op); + return tensorflow::Status::OK(); +} + } // namespace namespace internal { @@ -1888,6 +1916,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Const", ConvertConstOperator}, {"Conv2D", ConvertConvOperator}, {"Conv2DBackpropInput", ConvertTransposeConvOperator}, + {"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator}, {"DepthToSpace", ConvertDepthToSpaceOperator}, {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator}, {"Div", ConvertSimpleOperator<DivOperator, 2>}, diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 7d0dbfcc05..cd263930f5 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -148,6 +148,7 @@ enum class OperatorType : uint8 { kLogicalAnd, kLogicalNot, kLogicalOr, + kCTCBeamSearchDecoder, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -438,6 +439,28 @@ struct ConvOperator : Operator { int dilation_height_factor = 1; }; +// CTCBeamSearchDecoder operator: +// +// Inputs: +// inputs[0]: required: the logits. +// inputs[1]: required: sequence length. +// inputs[2]: optional: beam width. +// inputs[3]: optional: top paths. +// inputs[4]: optional: merge repeated. +// +// Outputs: +// outputs[0]: deocoded. +// outputs[1]: log probability. +// +// TensorFlow equivalent: CTCBeamSearchDecoder +struct CTCBeamSearchDecoderOperator : Operator { + CTCBeamSearchDecoderOperator() + : Operator(OperatorType::kCTCBeamSearchDecoder) {} + int beam_width; + int top_paths; + bool merge_repeated = true; +}; + // Depthwise-separable convolution operator. // // Inputs: diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 9380168f30..b1cd74794c 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -1070,6 +1070,27 @@ class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions, int GetVersion(const Operator& op) const override { return 1; } }; +class CTCBeamSearchDecoder + : public CustomOperator<CTCBeamSearchDecoderOperator> { + public: + using CustomOperator::CustomOperator; + + void WriteOptions(const TocoOperator& op, + flexbuffers::Builder* fbb) const override { + fbb->Int("beam_width", op.beam_width); + fbb->Int("top_paths", op.top_paths); + fbb->Bool("merge_repeated", op.merge_repeated); + } + + void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + op->beam_width = m["beam_width"].AsInt32(); + op->top_paths = m["top_paths"].AsInt32(); + op->merge_repeated = m["merge_repeated"].AsBool(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -1301,6 +1322,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { // Custom Operators. ops.emplace_back( new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); + ops.emplace_back(new CTCBeamSearchDecoder( + "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder)); ops.emplace_back(new TensorFlowUnsupported("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported)); diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 384f7c118d..12fdbbf214 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -472,6 +472,20 @@ TEST_F(OperatorTest, BuiltinOneHot) { EXPECT_EQ(op.axis, output_toco_op->axis); } +TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) { + CTCBeamSearchDecoderOperator op; + op.beam_width = 3; + op.top_paths = 2; + op.merge_repeated = false; + std::unique_ptr<toco::CTCBeamSearchDecoderOperator> output_toco_op = + SerializeAndDeserialize(GetOperator("CTC_BEAM_SEARCH_DECODER", + OperatorType::kCTCBeamSearchDecoder), + op); + EXPECT_EQ(op.beam_width, output_toco_op->beam_width); + EXPECT_EQ(op.top_paths, output_toco_op->top_paths); + EXPECT_EQ(op.merge_repeated, output_toco_op->merge_repeated); +} + TEST_F(OperatorTest, TensorFlowUnsupported) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 68155c7329..80df09eb08 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -404,6 +404,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(LogicalAnd) HANDLE_OPERATORTYPENAME_CASE(LogicalNot) HANDLE_OPERATORTYPENAME_CASE(LogicalOr) + HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h index 53087821d7..973e315f09 100644 --- a/tensorflow/core/util/ctc/ctc_beam_entry.h +++ b/tensorflow/core/util/ctc/ctc_beam_entry.h @@ -1,3 +1,4 @@ +// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -145,3 +146,4 @@ class BeamComparer { } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ +// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h) diff --git a/tensorflow/core/util/ctc/ctc_beam_scorer.h b/tensorflow/core/util/ctc/ctc_beam_scorer.h index 2579198ece..1a622babe1 100644 --- a/tensorflow/core/util/ctc/ctc_beam_scorer.h +++ b/tensorflow/core/util/ctc/ctc_beam_scorer.h @@ -1,3 +1,4 @@ +// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -73,3 +74,4 @@ class BaseBeamScorer { } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SCORER_H_ +// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h) diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h index 709c65fc96..aee647a1b3 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search.h +++ b/tensorflow/core/util/ctc/ctc_beam_search.h @@ -418,3 +418,4 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths( } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ +// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h) diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h index b8bab69053..3be36822e5 100644 --- a/tensorflow/core/util/ctc/ctc_decoder.h +++ b/tensorflow/core/util/ctc/ctc_decoder.h @@ -1,3 +1,4 @@ +// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -112,3 +113,4 @@ class CTCGreedyDecoder : public CTCDecoder { } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ +// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h) diff --git a/tensorflow/core/util/ctc/ctc_loss_util.h b/tensorflow/core/util/ctc/ctc_loss_util.h index 50f8f49f1c..36be9e92ef 100644 --- a/tensorflow/core/util/ctc/ctc_loss_util.h +++ b/tensorflow/core/util/ctc/ctc_loss_util.h @@ -1,3 +1,4 @@ +// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -46,3 +47,4 @@ inline float LogSumExp(float log_prob_1, float log_prob_2) { } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_ +// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h) |