aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/export_tensorflow.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/export_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc34
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 99ccfaea64..9e899cf977 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1492,6 +1492,37 @@ void ConvertPadOperator(const Model& model, const PadOperator& src_op,
shape->add_dim()->set_size(2);
}
+void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* new_op = tensorflow_graph->add_node();
+ new_op->set_op("PadV2");
+ new_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *new_op->add_input() = src_op.inputs[0];
+ *new_op->add_input() = src_op.inputs[1];
+ *new_op->add_input() = src_op.inputs[2];
+
+ const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+
+ // Create the params tensor.
+ auto* params_op = tensorflow_graph->add_node();
+ params_op->set_op("Const");
+ params_op->set_name(src_op.inputs[1]);
+ (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
+ auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_INT32);
+
+ CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
+ for (int i = 0; i < src_op.left_padding.size(); ++i) {
+ tensor->add_int_val(src_op.left_padding[i]);
+ tensor->add_int_val(src_op.right_padding[i]);
+ }
+ auto* shape = tensor->mutable_tensor_shape();
+ shape->add_dim()->set_size(src_op.left_padding.size());
+ shape->add_dim()->set_size(2);
+}
+
void CreateSliceInput(const string& input_name, const std::vector<int>& values,
GraphDef* tensorflow_graph) {
auto* params_op = tensorflow_graph->add_node();
@@ -1795,6 +1826,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kPad) {
ConvertPadOperator(model, static_cast<const PadOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kPadV2) {
+ ConvertPadV2Operator(model, static_cast<const PadV2Operator&>(src_op),
+ tensorflow_graph);
} else if (src_op.type == OperatorType::kStridedSlice) {
ConvertStridedSliceOperator(
model, static_cast<const StridedSliceOperator&>(src_op),