aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-02 20:11:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 20:15:21 -0700
commiteecadedbaae7b938e8a80dfb60c52679bcbf7196 (patch)
tree7f41ed6a9a3be126f5026a809593736d9bd10340 /tensorflow/contrib/lite/toco/import_tensorflow.cc
parent2d3819668d8c3ab99cd09a769ffb7b76e453fd8f (diff)
Implementation of ctc beam search decoder op in custom op fashion.
PiperOrigin-RevId: 207210333
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc29
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 9a3db5c888..9a404c2606 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1854,6 +1854,34 @@ tensorflow::Status ConvertOneHotOperator(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "CTCBeamSearchDecoder");
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
+
+ auto* op = new CTCBeamSearchDecoderOperator;
+ for (const string& input : node.input()) {
+ op->inputs.push_back(input);
+ }
+
+ op->beam_width =
+ HasAttr(node, "beam_width") ? GetIntAttr(node, "beam_width") : 1;
+ op->top_paths =
+ HasAttr(node, "top_paths") ? GetIntAttr(node, "top_paths") : 1;
+ op->merge_repeated = HasAttr(node, "merge_repeated")
+ ? GetBoolAttr(node, "merge_repeated")
+ : true;
+
+ // There are top_paths + 1 outputs.
+ op->outputs.push_back(node.name()); // Implicit :0.
+ for (int i = 0; i < op->top_paths; ++i) {
+ op->outputs.push_back(node.name() + ":" + std::to_string(i + 1));
+ }
+ model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
+}
+
} // namespace
namespace internal {
@@ -1888,6 +1916,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Const", ConvertConstOperator},
{"Conv2D", ConvertConvOperator},
{"Conv2DBackpropInput", ConvertTransposeConvOperator},
+ {"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator},
{"DepthToSpace", ConvertDepthToSpaceOperator},
{"DepthwiseConv2dNative", ConvertDepthwiseConvOperator},
{"Div", ConvertSimpleOperator<DivOperator, 2>},