diff options
-rw-r--r-- | tensorflow/contrib/lite/toco/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 24 |
2 files changed, 19 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index c88079717d..7243e584e9 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -11,6 +11,7 @@ load( "//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", + "tf_copts", ) tf_proto_library_cc( @@ -305,7 +306,7 @@ cc_library( "tensorflow_util.h", "toco_tooling.h", ], - copts = select({ + copts = tf_copts() + select({ "//tensorflow:darwin": ["-DTOCO_SUPPORT_PORTABLE_PROTOS=0"], "//conditions:default": [], }), diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index d8d331f3d4..b7fffbce22 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1223,11 +1223,10 @@ tensorflow::Status ConvertGatherOperator( return tensorflow::Status::OK(); } -template <typename Op, const char* op_name> +template <typename Op> tensorflow::Status ConvertArgMinMaxOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK_EQ(node.op(), op_name); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto axis_data_type = HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; @@ -1245,6 +1244,20 @@ tensorflow::Status ConvertArgMinMaxOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertArgMaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "ArgMax"); + return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags, model); +} + +tensorflow::Status ConvertArgMinOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "ArgMin"); + return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags, model); +} + tensorflow::Status ConvertResizeBilinearOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1899,17 +1912,14 @@ using ConverterType = tensorflow::Status (*)( Model* model); using ConverterMapType = std::unordered_map<std::string, ConverterType>; -constexpr char kArgMax[] = "ArgMax"; -constexpr char kArgMin[] = "ArgMin"; - ConverterMapType GetTensorFlowNodeConverterMap() { return std::unordered_map<std::string, ConverterType>({ {"Add", ConvertSimpleOperator<AddOperator, 2>}, {"AddN", ConvertSimpleOperator<AddNOperator>}, {"All", ConvertSimpleOperator<TensorFlowAllOperator>}, {"Any", ConvertAnyOperator}, - {"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>}, - {"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>}, + {"ArgMax", ConvertArgMaxOperator}, + {"ArgMin", ConvertArgMinOperator}, {"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>}, {"AvgPool", ConvertAvgPoolOperator}, {"BatchMatMul", ConvertBatchMatMulOperator}, |