diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/export_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/export_tensorflow.cc | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 99ccfaea64..f5157149af 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -1492,6 +1492,37 @@ void ConvertPadOperator(const Model& model, const PadOperator& src_op, shape->add_dim()->set_size(2); } +void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("PadV2"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *new_op->add_input() = src_op.inputs[0]; + *new_op->add_input() = src_op.inputs[1]; + *new_op->add_input() = src_op.inputs[2]; + + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + + // Create the params tensor. + auto* params_op = tensorflow_graph->add_node(); + params_op->set_op("Const"); + params_op->set_name(src_op.inputs[1]); + (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + + CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size()); + for (int i = 0; i < src_op.left_padding.size(); ++i) { + tensor->add_int_val(src_op.left_padding[i]); + tensor->add_int_val(src_op.right_padding[i]); + } + auto* shape = tensor->mutable_tensor_shape(); + shape->add_dim()->set_size(src_op.left_padding.size()); + shape->add_dim()->set_size(2); +} + void CreateSliceInput(const string& input_name, const std::vector<int>& values, GraphDef* tensorflow_graph) { auto* params_op = tensorflow_graph->add_node(); @@ -1643,6 +1674,19 @@ void ConvertTensorFlowMaximumOperator(const Model& model, (*sub_op->mutable_attr())["T"].set_type(data_type); } +void ConvertSelectOperator(const Model& model, const SelectOperator& src_op, + GraphDef* tensorflow_graph) { + auto* sub_op = tensorflow_graph->add_node(); + sub_op->set_op("Select"); + sub_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 3); + *sub_op->add_input() = src_op.inputs[0]; + *sub_op->add_input() = src_op.inputs[1]; + *sub_op->add_input() = src_op.inputs[2]; + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[1]); + (*sub_op->mutable_attr())["T"].set_type(data_type); +} + void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op, GraphDef* tensorflow_graph) { auto* topk_op = tensorflow_graph->add_node(); @@ -1671,6 +1715,19 @@ void ConvertRandomUniformOperator(const Model& model, (*new_op->mutable_attr())["seed2"].set_i(src_op.seed2); } +void ConvertComparisonOperator(const Model& model, const Operator& src_op, + const char* op_name, + GraphDef* tensorflow_graph) { + auto* comparison_op = tensorflow_graph->add_node(); + comparison_op->set_op(op_name); + comparison_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *comparison_op->add_input() = src_op.inputs[0]; + *comparison_op->add_input() = src_op.inputs[1]; + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*comparison_op->mutable_attr())["T"].set_type(data_type); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1795,6 +1852,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kPad) { ConvertPadOperator(model, static_cast<const PadOperator&>(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kPadV2) { + ConvertPadV2Operator(model, static_cast<const PadV2Operator&>(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kStridedSlice) { ConvertStridedSliceOperator( model, static_cast<const StridedSliceOperator&>(src_op), @@ -1859,6 +1919,17 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertRandomUniformOperator( model, static_cast<const RandomUniformOperator&>(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowGreater) { + ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) { + ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowLess) { + ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowLessEqual) { + ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph); + } else if (src_op.type == OperatorType::kSelect) { + ConvertSelectOperator(model, static_cast<const SelectOperator&>(src_op), + tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } |