diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/export_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/export_tensorflow.cc | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 378212cb74..8b41865985 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -1940,6 +1940,21 @@ void ConvertLogicalOrOperator(const Model& model, (*logical_or_op->mutable_attr())["T"].set_type(data_type); } +void ConvertCTCBeamSearchDecoderOperator( + const Model& model, const CTCBeamSearchDecoderOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + auto* op = tensorflow_graph->add_node(); + op->set_op(op_name); + op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *op->add_input() = src_op.inputs[i]; + } + (*op->mutable_attr())["beam_width"].set_i(src_op.beam_width); + (*op->mutable_attr())["top_paths"].set_i(src_op.top_paths); + (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -2194,6 +2209,10 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertLogicalOrOperator(model, static_cast<const LogicalOrOperator&>(src_op), "LogicalOr", tensorflow_graph); + } else if (src_op.type == OperatorType::kCTCBeamSearchDecoder) { + ConvertCTCBeamSearchDecoderOperator( + model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op), + "CTCBeamSearchDecoder", tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } |