diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-02 20:11:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-02 20:15:21 -0700 |
commit | eecadedbaae7b938e8a80dfb60c52679bcbf7196 (patch) | |
tree | 7f41ed6a9a3be126f5026a809593736d9bd10340 /tensorflow/contrib/lite/toco/import_tensorflow.cc | |
parent | 2d3819668d8c3ab99cd09a769ffb7b76e453fd8f (diff) |
Implementation of ctc beam search decoder op in custom op fashion.
PiperOrigin-RevId: 207210333
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 9a3db5c888..9a404c2606 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1854,6 +1854,34 @@ tensorflow::Status ConvertOneHotOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertCTCBeamSearchDecoderOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "CTCBeamSearchDecoder"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + + auto* op = new CTCBeamSearchDecoderOperator; + for (const string& input : node.input()) { + op->inputs.push_back(input); + } + + op->beam_width = + HasAttr(node, "beam_width") ? GetIntAttr(node, "beam_width") : 1; + op->top_paths = + HasAttr(node, "top_paths") ? GetIntAttr(node, "top_paths") : 1; + op->merge_repeated = HasAttr(node, "merge_repeated") + ? GetBoolAttr(node, "merge_repeated") + : true; + + // There are top_paths + 1 outputs. + op->outputs.push_back(node.name()); // Implicit :0. + for (int i = 0; i < op->top_paths; ++i) { + op->outputs.push_back(node.name() + ":" + std::to_string(i + 1)); + } + model->operators.emplace_back(op); + return tensorflow::Status::OK(); +} + } // namespace namespace internal { @@ -1888,6 +1916,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Const", ConvertConstOperator}, {"Conv2D", ConvertConvOperator}, {"Conv2DBackpropInput", ConvertTransposeConvOperator}, + {"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator}, {"DepthToSpace", ConvertDepthToSpaceOperator}, {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator}, {"Div", ConvertSimpleOperator<DivOperator, 2>}, |