diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-10 01:53:40 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-10 01:56:38 -0700 |
commit | 7f85c95f71b01f711c366942a7cd911b0743b72c (patch) | |
tree | ad92abbdf8a4c5736e323a75c3c153dadf3ce5c6 /tensorflow/contrib/lite/toco/import_tensorflow.cc | |
parent | c59bb780ebd1674ab34dd96d193c71698682ed4d (diff) |
Implementation of arg_min.
PiperOrigin-RevId: 203908601
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 5c32a39035..bc439a2feb 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1230,10 +1230,11 @@ tensorflow::Status ConvertGatherOperator( return tensorflow::Status::OK(); } -tensorflow::Status ConvertArgMaxOperator( +template <typename Op, const char* op_name> +tensorflow::Status ConvertArgMinMaxOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK_EQ(node.op(), "ArgMax"); + CHECK_EQ(node.op(), op_name); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto axis_data_type = HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; @@ -1242,7 +1243,7 @@ tensorflow::Status ConvertArgMaxOperator( : DT_INT64; CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32); CHECK(output_type == DT_INT64 || output_type == DT_INT32); - auto* op = new ArgMaxOperator; + auto* op = new Op; op->output_data_type = ConvertDataType(output_type); op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1833,12 +1834,16 @@ using ConverterType = tensorflow::Status (*)( Model* model); using ConverterMapType = std::unordered_map<std::string, ConverterType>; +constexpr char kArgMax[] = "ArgMax"; +constexpr char kArgMin[] = "ArgMin"; + ConverterMapType GetTensorFlowNodeConverterMap() { return std::unordered_map<std::string, ConverterType>({ {"Add", ConvertSimpleOperator<AddOperator, 2>}, {"AddN", ConvertSimpleOperator<AddNOperator>}, {"All", ConvertSimpleOperator<TensorFlowAllOperator>}, - {"ArgMax", ConvertArgMaxOperator}, + {"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>}, + {"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>}, {"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>}, {"AvgPool", ConvertAvgPoolOperator}, {"BatchMatMul", ConvertBatchMatMulOperator}, |