aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/export_tensorflow.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/export_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc19
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 378212cb74..8b41865985 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1940,6 +1940,21 @@ void ConvertLogicalOrOperator(const Model& model,
(*logical_or_op->mutable_attr())["T"].set_type(data_type);
}
+void ConvertCTCBeamSearchDecoderOperator(
+ const Model& model, const CTCBeamSearchDecoderOperator& src_op,
+ const char* op_name, GraphDef* tensorflow_graph) {
+ auto* op = tensorflow_graph->add_node();
+ op->set_op(op_name);
+ op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ for (int i = 0; i < 2; ++i) {
+ *op->add_input() = src_op.inputs[i];
+ }
+ (*op->mutable_attr())["beam_width"].set_i(src_op.beam_width);
+ (*op->mutable_attr())["top_paths"].set_i(src_op.top_paths);
+ (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -2194,6 +2209,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertLogicalOrOperator(model,
static_cast<const LogicalOrOperator&>(src_op),
"LogicalOr", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kCTCBeamSearchDecoder) {
+ ConvertCTCBeamSearchDecoderOperator(
+ model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op),
+ "CTCBeamSearchDecoder", tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}