aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/toco/BUILD3
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc24
2 files changed, 19 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index c88079717d..7243e584e9 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -11,6 +11,7 @@ load(
"//tensorflow:tensorflow.bzl",
"tf_cc_binary",
"tf_cc_test",
+ "tf_copts",
)
tf_proto_library_cc(
@@ -305,7 +306,7 @@ cc_library(
"tensorflow_util.h",
"toco_tooling.h",
],
- copts = select({
+ copts = tf_copts() + select({
"//tensorflow:darwin": ["-DTOCO_SUPPORT_PORTABLE_PROTOS=0"],
"//conditions:default": [],
}),
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index d8d331f3d4..b7fffbce22 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1223,11 +1223,10 @@ tensorflow::Status ConvertGatherOperator(
return tensorflow::Status::OK();
}
-template <typename Op, const char* op_name>
+template <typename Op>
tensorflow::Status ConvertArgMinMaxOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
- CHECK_EQ(node.op(), op_name);
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
const auto axis_data_type =
HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
@@ -1245,6 +1244,20 @@ tensorflow::Status ConvertArgMinMaxOperator(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertArgMaxOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "ArgMax");
+ return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags, model);
+}
+
+tensorflow::Status ConvertArgMinOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "ArgMin");
+ return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags, model);
+}
+
tensorflow::Status ConvertResizeBilinearOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -1899,17 +1912,14 @@ using ConverterType = tensorflow::Status (*)(
Model* model);
using ConverterMapType = std::unordered_map<std::string, ConverterType>;
-constexpr char kArgMax[] = "ArgMax";
-constexpr char kArgMin[] = "ArgMin";
-
ConverterMapType GetTensorFlowNodeConverterMap() {
return std::unordered_map<std::string, ConverterType>({
{"Add", ConvertSimpleOperator<AddOperator, 2>},
{"AddN", ConvertSimpleOperator<AddNOperator>},
{"All", ConvertSimpleOperator<TensorFlowAllOperator>},
{"Any", ConvertAnyOperator},
- {"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>},
- {"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>},
+ {"ArgMax", ConvertArgMaxOperator},
+ {"ArgMin", ConvertArgMinOperator},
{"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>},
{"AvgPool", ConvertAvgPoolOperator},
{"BatchMatMul", ConvertBatchMatMulOperator},