aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/export_tensorflow.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/export_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc71
1 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 99ccfaea64..f5157149af 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1492,6 +1492,37 @@ void ConvertPadOperator(const Model& model, const PadOperator& src_op,
shape->add_dim()->set_size(2);
}
+void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("PadV2");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+ *new_op->add_input() = src_op.inputs[2];
+
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+
+ // Create the params tensor.
+ auto* params_op = tensorflow_graph->add_node();
+ params_op->set_op("Const");
+ params_op->set_name(src_op.inputs[1]);
+ (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+
+ CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
+ for (int i = 0; i < src_op.left_padding.size(); ++i) {
+ tensor->add_int_val(src_op.left_padding[i]);
+ tensor->add_int_val(src_op.right_padding[i]);
+ }
+ auto* shape = tensor->mutable_tensor_shape();
+ shape->add_dim()->set_size(src_op.left_padding.size());
+ shape->add_dim()->set_size(2);
+}
+
void CreateSliceInput(const string& input_name, const std::vector<int>& values,
GraphDef* tensorflow_graph) {
auto* params_op = tensorflow_graph->add_node();
@@ -1643,6 +1674,19 @@ void ConvertTensorFlowMaximumOperator(const Model& model,
(*sub_op->mutable_attr())["T"].set_type(data_type);
}
+void ConvertSelectOperator(const Model& model, const SelectOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* sub_op = tensorflow_graph->add_node();
+ sub_op->set_op("Select");
+ sub_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 3);
+ *sub_op->add_input() = src_op.inputs[0];
+ *sub_op->add_input() = src_op.inputs[1];
+ *sub_op->add_input() = src_op.inputs[2];
+ const auto data_type = GetTensorFlowDataType(model, src_op.inputs[1]);
+ (*sub_op->mutable_attr())["T"].set_type(data_type);
+}
+
void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
GraphDef* tensorflow_graph) {
auto* topk_op = tensorflow_graph->add_node();
@@ -1671,6 +1715,19 @@ void ConvertRandomUniformOperator(const Model& model,
(*new_op->mutable_attr())["seed2"].set_i(src_op.seed2);
}
+void ConvertComparisonOperator(const Model& model, const Operator& src_op,
+ const char* op_name,
+ GraphDef* tensorflow_graph) {
+ auto* comparison_op = tensorflow_graph->add_node();
+ comparison_op->set_op(op_name);
+ comparison_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *comparison_op->add_input() = src_op.inputs[0];
+ *comparison_op->add_input() = src_op.inputs[1];
+ const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*comparison_op->mutable_attr())["T"].set_type(data_type);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -1795,6 +1852,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kPad) {
ConvertPadOperator(model, static_cast<const PadOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kPadV2) {
+ ConvertPadV2Operator(model, static_cast<const PadV2Operator&>(src_op),
+ tensorflow_graph);
} else if (src_op.type == OperatorType::kStridedSlice) {
ConvertStridedSliceOperator(
model, static_cast<const StridedSliceOperator&>(src_op),
@@ -1859,6 +1919,17 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertRandomUniformOperator(
model, static_cast<const RandomUniformOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowGreater) {
+ ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) {
+ ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowLess) {
+ ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowLessEqual) {
+ ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSelect) {
+ ConvertSelectOperator(model, static_cast<const SelectOperator&>(src_op),
+ tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}