diff options
author | 2018-07-10 01:53:40 -0700 | |
---|---|---|
committer | 2018-07-10 01:56:38 -0700 | |
commit | 7f85c95f71b01f711c366942a7cd911b0743b72c (patch) | |
tree | ad92abbdf8a4c5736e323a75c3c153dadf3ce5c6 /tensorflow | |
parent | c59bb780ebd1674ab34dd96d193c71698682ed4d (diff) |
Implementation of arg_min.
PiperOrigin-RevId: 203908601
Diffstat (limited to 'tensorflow')
21 files changed, 369 insertions, 25 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 90be7c5b57..b735d08b4b 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -195,7 +195,7 @@ def json_to_tflite(name, src, out): def generated_test_models(): return [ "add", - "arg_max", + "arg_min_max", "avg_pool", "batch_to_space_nd", "concat", diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index cda889bf50..31c6cdfedb 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -250,6 +250,10 @@ typedef struct { } TfLiteArgMaxParams; typedef struct { + TfLiteType output_type; +} TfLiteArgMinParams; + +typedef struct { TfLitePadding padding; int stride_width; int stride_height; diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index a44e918230..30f1be2991 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -104,6 +104,7 @@ typedef enum { kTfLiteBuiltinRsqrt = 76, kTfLiteBuiltinShape = 77, kTfLiteBuiltinPow = 78, + kTfLiteBuiltinArgMin = 79, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index dcd17bbeab..6ef669a040 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -790,6 +790,30 @@ Outputs { } ``` +**ARG_MAX** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: A tensor of indices of maximum values. +} +``` + +**ARG_MIN** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: A tensor of indices of minium values. +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/contrib/lite/kernels/arg_min_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc index 2e2ec94fab..4f30d09030 100644 --- a/tensorflow/contrib/lite/kernels/arg_min_max.cc +++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc @@ -177,6 +177,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) { return kTfLiteOk; } +TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, false); +} + TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) { return Eval(context, node, true); } @@ -189,6 +193,12 @@ TfLiteRegistration* Register_ARG_MAX() { return &r; } +TfLiteRegistration* Register_ARG_MIN() { + static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare, + arg_min_max::ArgMinEval}; + return &r; +} + } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/arg_min_max_test.cc b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc index 31b15fe19a..90e5fdc532 100644 --- a/tensorflow/contrib/lite/kernels/arg_min_max_test.cc +++ b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc @@ -24,16 +24,13 @@ namespace { using ::testing::ElementsAreArray; template <typename T> -class ArgMaxOpModel : public SingleOpModel { +class ArgBaseOpModel : public SingleOpModel { public: - ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type, - TensorType output_type, TensorType index_output_type) { + ArgBaseOpModel(std::initializer_list<int> input_shape, TensorType input_type, + TensorType output_type, TensorType index_output_type) { input_ = AddInput(input_type); axis_ = AddInput(TensorType_INT32); output_ = AddOutput(output_type); - SetBuiltinOp(BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions, - CreateArgMaxOptions(builder_, index_output_type).Union()); - BuildInterpreter({input_shape, {1, 1, 1, 1}}); } int input() { return input_; } @@ -42,12 +39,42 @@ class ArgMaxOpModel : public SingleOpModel { std::vector<T> GetOutput() { return ExtractVector<T>(output_); } std::vector<int> GetOutputShape() { return GetTensorShape(output_); } - private: + protected: int input_; int axis_; int output_; }; +template <typename T> +class ArgMaxOpModel : public ArgBaseOpModel<T> { + public: + ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type, + TensorType output_type, TensorType index_output_type) + : ArgBaseOpModel<T>(input_shape, input_type, output_type, + index_output_type) { + ArgBaseOpModel<T>::SetBuiltinOp( + BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions, + CreateArgMaxOptions(ArgBaseOpModel<T>::builder_, index_output_type) + .Union()); + ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}}); + } +}; + +template <typename T> +class ArgMinOpModel : public ArgBaseOpModel<T> { + public: + ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type, + TensorType output_type, TensorType index_output_type) + : ArgBaseOpModel<T>(input_shape, input_type, output_type, + index_output_type) { + ArgBaseOpModel<T>::SetBuiltinOp( + BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions, + CreateArgMinOptions(ArgBaseOpModel<T>::builder_, index_output_type) + .Union()); + ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}}); + } +}; + TEST(ArgMaxOpTest, GetMaxArgFloat) { ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32, TensorType_INT32, TensorType_INT32); @@ -96,6 +123,54 @@ TEST(ArgMaxOpTest, GetMaxArgOutput64) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1})); } +TEST(ArgMinOpTest, GetMinArgFloat) { + ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32, + TensorType_INT32, TensorType_INT32); + model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3}); + // Currently only support the last dimension. + model.PopulateTensor<int>(model.axis(), {3}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1})); +} + +TEST(ArgMinOpTest, GetMinArgInt) { + ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32, + TensorType_INT32); + model.PopulateTensor<int>(model.input(), {1, 9, 7, 3}); + // Currently only support the last dimension. + model.PopulateTensor<int>(model.axis(), {3}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1})); +} + +TEST(ArgMinOpTest, GetMinArgMulDimensions) { + ArgMinOpModel<int32_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32, + TensorType_INT32); + model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3}); + // Currently only support the last dimension. + model.PopulateTensor<int>(model.axis(), {3}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1})); +} + +TEST(ArgMinOpTest, GetMinArgOutput64) { + ArgMinOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT64, + TensorType_INT64); + model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3}); + // Currently only support the last dimension. + model.PopulateTensor<int>(model.axis(), {3}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 0ca08cd8f3..df88a66d19 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -82,6 +82,7 @@ TfLiteRegistration* Register_PRELU(); TfLiteRegistration* Register_MAXIMUM(); TfLiteRegistration* Register_MINIMUM(); TfLiteRegistration* Register_ARG_MAX(); +TfLiteRegistration* Register_ARG_MIN(); TfLiteRegistration* Register_GREATER(); TfLiteRegistration* Register_GREATER_EQUAL(); TfLiteRegistration* Register_LESS(); @@ -167,6 +168,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM()); AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM()); AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX()); + AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN()); AddBuiltin(BuiltinOperator_GREATER, Register_GREATER()); AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL()); AddBuiltin(BuiltinOperator_LESS, Register_LESS()); diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 0bc717bfca..dbbc505906 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -663,6 +663,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } + case BuiltinOperator_ARG_MIN: { + auto* params = MallocPOD<TfLiteArgMinParams>(); + if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) { + ConvertTensorType(schema_params->output_type(), ¶ms->output_type, + error_reporter); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } case BuiltinOperator_TRANSPOSE_CONV: { TfLiteTransposeConvParams* params = MallocPOD<TfLiteTransposeConvParams>(); diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 85279ba48e..b58466702c 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -590,6 +590,7 @@ TfLiteStatus AddOpsAndParams( case tflite::BuiltinOperator_MAXIMUM: case tflite::BuiltinOperator_MINIMUM: case tflite::BuiltinOperator_ARG_MAX: + case tflite::BuiltinOperator_ARG_MIN: case tflite::BuiltinOperator_GREATER: case tflite::BuiltinOperator_GREATER_EQUAL: case tflite::BuiltinOperator_LESS: diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 5e6467f676..097b88b1c1 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -160,6 +160,7 @@ enum BuiltinOperator : byte { RSQRT = 76, SHAPE = 77, POW = 78, + ARG_MIN = 79, } // Options for the builtin operators. @@ -220,6 +221,7 @@ union BuiltinOptions { NotEqualOptions, ShapeOptions, PowOptions, + ArgMinOptions, } enum Padding : byte { SAME, VALID } @@ -469,6 +471,10 @@ table ArgMaxOptions { output_type : TensorType; } +table ArgMinOptions { + output_type : TensorType; +} + table GreaterOptions { } diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index fe0ff9a7a5..af923bae11 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -157,6 +157,9 @@ struct TileOptionsT; struct ArgMaxOptions; struct ArgMaxOptionsT; +struct ArgMinOptions; +struct ArgMinOptionsT; + struct GreaterOptions; struct GreaterOptionsT; @@ -343,11 +346,12 @@ enum BuiltinOperator { BuiltinOperator_RSQRT = 76, BuiltinOperator_SHAPE = 77, BuiltinOperator_POW = 78, + BuiltinOperator_ARG_MIN = 79, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_POW + BuiltinOperator_MAX = BuiltinOperator_ARG_MIN }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[78] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[79] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -426,7 +430,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[78] { BuiltinOperator_SQRT, BuiltinOperator_RSQRT, BuiltinOperator_SHAPE, - BuiltinOperator_POW + BuiltinOperator_POW, + BuiltinOperator_ARG_MIN }; return values; } @@ -512,6 +517,7 @@ inline const char **EnumNamesBuiltinOperator() { "RSQRT", "SHAPE", "POW", + "ARG_MIN", nullptr }; return names; @@ -580,11 +586,12 @@ enum BuiltinOptions { BuiltinOptions_NotEqualOptions = 54, BuiltinOptions_ShapeOptions = 55, BuiltinOptions_PowOptions = 56, + BuiltinOptions_ArgMinOptions = 57, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_PowOptions + BuiltinOptions_MAX = BuiltinOptions_ArgMinOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[57] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[58] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -642,7 +649,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[57] { BuiltinOptions_EqualOptions, BuiltinOptions_NotEqualOptions, BuiltinOptions_ShapeOptions, - BuiltinOptions_PowOptions + BuiltinOptions_PowOptions, + BuiltinOptions_ArgMinOptions }; return values; } @@ -706,6 +714,7 @@ inline const char **EnumNamesBuiltinOptions() { "NotEqualOptions", "ShapeOptions", "PowOptions", + "ArgMinOptions", nullptr }; return names; @@ -944,6 +953,10 @@ template<> struct BuiltinOptionsTraits<PowOptions> { static const BuiltinOptions enum_value = BuiltinOptions_PowOptions; }; +template<> struct BuiltinOptionsTraits<ArgMinOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_ArgMinOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1423,6 +1436,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_PowOptions ? reinterpret_cast<const PowOptionsT *>(value) : nullptr; } + ArgMinOptionsT *AsArgMinOptions() { + return type == BuiltinOptions_ArgMinOptions ? + reinterpret_cast<ArgMinOptionsT *>(value) : nullptr; + } + const ArgMinOptionsT *AsArgMinOptions() const { + return type == BuiltinOptions_ArgMinOptions ? + reinterpret_cast<const ArgMinOptionsT *>(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -4486,6 +4507,60 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions( flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct ArgMinOptionsT : public flatbuffers::NativeTable { + typedef ArgMinOptions TableType; + TensorType output_type; + ArgMinOptionsT() + : output_type(TensorType_FLOAT32) { + } +}; + +struct ArgMinOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ArgMinOptionsT NativeTableType; + enum { + VT_OUTPUT_TYPE = 4 + }; + TensorType output_type() const { + return static_cast<TensorType>(GetField<int8_t>(VT_OUTPUT_TYPE, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_OUTPUT_TYPE) && + verifier.EndTable(); + } + ArgMinOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<ArgMinOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ArgMinOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_output_type(TensorType output_type) { + fbb_.AddElement<int8_t>(ArgMinOptions::VT_OUTPUT_TYPE, static_cast<int8_t>(output_type), 0); + } + explicit ArgMinOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ArgMinOptionsBuilder &operator=(const ArgMinOptionsBuilder &); + flatbuffers::Offset<ArgMinOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<ArgMinOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions( + flatbuffers::FlatBufferBuilder &_fbb, + TensorType output_type = TensorType_FLOAT32) { + ArgMinOptionsBuilder builder_(_fbb); + builder_.add_output_type(output_type); + return builder_.Finish(); +} + +flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct GreaterOptionsT : public flatbuffers::NativeTable { typedef GreaterOptions TableType; GreaterOptionsT() { @@ -5413,6 +5488,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const PowOptions *builtin_options_as_PowOptions() const { return builtin_options_type() == BuiltinOptions_PowOptions ? static_cast<const PowOptions *>(builtin_options()) : nullptr; } + const ArgMinOptions *builtin_options_as_ArgMinOptions() const { + return builtin_options_type() == BuiltinOptions_ArgMinOptions ? static_cast<const ArgMinOptions *>(builtin_options()) : nullptr; + } const flatbuffers::Vector<uint8_t> *custom_options() const { return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); } @@ -5668,6 +5746,10 @@ template<> inline const PowOptions *Operator::builtin_options_as<PowOptions>() c return builtin_options_as_PowOptions(); } +template<> inline const ArgMinOptions *Operator::builtin_options_as<ArgMinOptions>() const { + return builtin_options_as_ArgMinOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -7333,6 +7415,32 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatB _output_type); } +inline ArgMinOptionsT *ArgMinOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ArgMinOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ArgMinOptions::UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = output_type(); _o->output_type = _e; }; +} + +inline flatbuffers::Offset<ArgMinOptions> ArgMinOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateArgMinOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ArgMinOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _output_type = _o->output_type; + return tflite::CreateArgMinOptions( + _fbb, + _output_type); +} + inline GreaterOptionsT *GreaterOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new GreaterOptionsT(); UnPackTo(_o, _resolver); @@ -8083,6 +8191,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast<const PowOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast<const ArgMinOptions *>(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -8325,6 +8437,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast<const PowOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast<const ArgMinOptions *>(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -8555,6 +8671,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast<const PowOptionsT *>(value); return CreatePowOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast<const ArgMinOptionsT *>(value); + return CreateArgMinOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -8785,6 +8905,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new PowOptionsT(*reinterpret_cast<PowOptionsT *>(u.value)); break; } + case BuiltinOptions_ArgMinOptions: { + value = new ArgMinOptionsT(*reinterpret_cast<ArgMinOptionsT *>(u.value)); + break; + } default: break; } @@ -9072,6 +9196,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast<ArgMinOptionsT *>(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 35e0328e9b..1093bd2cbe 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -2224,7 +2224,7 @@ def make_topk_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_arg_max_tests(zip_path): +def make_arg_min_max_tests(zip_path): """Make a set of tests to do arg_max.""" test_parameters = [{ @@ -2232,6 +2232,7 @@ def make_arg_max_tests(zip_path): "input_shape": [[1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]], "output_type": [tf.int32, tf.int64], "axis_is_last_dim": [True, False], + "is_arg_max": [True], }] def build_graph(parameters): @@ -2244,7 +2245,10 @@ def make_arg_max_tests(zip_path): axis = len(parameters["input_shape"]) - 1 else: axis = random.randint(0, max(len(parameters["input_shape"]) - 2, 0)) - out = tf.arg_max(input_value, axis, output_type=parameters["output_type"]) + if parameters["is_arg_max"]: + out = tf.arg_max(input_value, axis, output_type=parameters["output_type"]) + else: + out = tf.arg_min(input_value, axis, output_type=parameters["output_type"]) return [input_value], [out] def build_inputs(parameters, sess, inputs, outputs): diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index c4e20312d8..5bc6b53416 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -97,11 +97,12 @@ std::map<string, string> kBrokenTests = { {R"(^\/gather.*axis=1)", "76910444"}, // No support for arbitrary dimensions in ArgMax. - {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])", + {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])", "77546240"}, - {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])", + {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])", + "77546240"}, + {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.\])", "77546240"}, - {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.\])", "77546240"}, }; // Allows test data to be unzipped into a temporary directory and makes diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 6be6b25f93..a08cdbfba6 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -1135,6 +1135,22 @@ void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op, GetTensorFlowDataType(model, src_op.outputs[0])); } +void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node(); + argmin_op->set_op("ArgMin"); + argmin_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *argmin_op->add_input() = src_op.inputs[0]; + *argmin_op->add_input() = src_op.inputs[1]; + (*argmin_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); + (*argmin_op->mutable_attr())["Tidx"].set_type( + GetTensorFlowDataType(model, src_op.inputs[1])); + (*argmin_op->mutable_attr())["output_type"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); +} + void ConvertTransposeOperator(const Model& model, const TransposeOperator& src_op, GraphDef* tensorflow_graph) { @@ -1964,6 +1980,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kArgMax) { ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kArgMin) { + ConvertArgMinOperator(model, static_cast<const ArgMinOperator&>(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kTopK_V2) { ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op), tensorflow_graph); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 00ab7cbaa9..670bcf64e7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -100,6 +100,13 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type; break; } + case OperatorType::kArgMin: { + // Data type of the ArgMin op is specified. + CHECK_EQ(op->outputs.size(), 1); + auto* argmin_op = static_cast<ArgMinOperator*>(op); + model->GetArray(op->outputs[0]).data_type = argmin_op->output_data_type; + break; + } case OperatorType::kRange: { auto* range_op = static_cast<RangeOperator*>(op); // Output type of the Range op can be set via an attribute diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 8eb0423283..4f95c57451 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1404,7 +1404,8 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) { } } -void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { +template <typename Op> +void ProcessArgMinMaxOperator(Model* model, Op* op) { CHECK_EQ(op->inputs.size(), 2); const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. @@ -1696,7 +1697,12 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { static_cast<StridedSliceOperator*>(op)); break; case OperatorType::kArgMax: - ProcessArgMaxOperator(model, static_cast<ArgMaxOperator*>(op)); + ProcessArgMinMaxOperator<ArgMaxOperator>( + model, static_cast<ArgMaxOperator*>(op)); + break; + case OperatorType::kArgMin: + ProcessArgMinMaxOperator<ArgMinOperator>( + model, static_cast<ArgMinOperator*>(op)); break; case OperatorType::kUnsupported: break; diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 5c32a39035..bc439a2feb 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1230,10 +1230,11 @@ tensorflow::Status ConvertGatherOperator( return tensorflow::Status::OK(); } -tensorflow::Status ConvertArgMaxOperator( +template <typename Op, const char* op_name> +tensorflow::Status ConvertArgMinMaxOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK_EQ(node.op(), "ArgMax"); + CHECK_EQ(node.op(), op_name); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto axis_data_type = HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; @@ -1242,7 +1243,7 @@ tensorflow::Status ConvertArgMaxOperator( : DT_INT64; CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32); CHECK(output_type == DT_INT64 || output_type == DT_INT32); - auto* op = new ArgMaxOperator; + auto* op = new Op; op->output_data_type = ConvertDataType(output_type); op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1833,12 +1834,16 @@ using ConverterType = tensorflow::Status (*)( Model* model); using ConverterMapType = std::unordered_map<std::string, ConverterType>; +constexpr char kArgMax[] = "ArgMax"; +constexpr char kArgMin[] = "ArgMin"; + ConverterMapType GetTensorFlowNodeConverterMap() { return std::unordered_map<std::string, ConverterType>({ {"Add", ConvertSimpleOperator<AddOperator, 2>}, {"AddN", ConvertSimpleOperator<AddNOperator>}, {"All", ConvertSimpleOperator<TensorFlowAllOperator>}, - {"ArgMax", ConvertArgMaxOperator}, + {"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>}, + {"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>}, {"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>}, {"AvgPool", ConvertAvgPoolOperator}, {"BatchMatMul", ConvertBatchMatMulOperator}, diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 3a1d243f87..8660464fdb 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -140,6 +140,7 @@ enum class OperatorType : uint8 { kEqual, kNotEqual, kPow, + kArgMin, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -1528,6 +1529,17 @@ struct ArgMaxOperator : Operator { ArrayDataType output_data_type = ArrayDataType::kInt64; }; +// ArgMin operator. It returns the index of the minimum value along axis. +// +// Inputs: +// inputs[0]: required: the input tensor +// +// TensorFlow equivalent: ArgMin +struct ArgMinOperator : Operator { + ArgMinOperator() : Operator(OperatorType::kArgMin) {} + ArrayDataType output_data_type = ArrayDataType::kInt64; +}; + // ResizeBilinear operator. It resizes input images with bilinear interpolation. // It does not support align_corners at the moment. // diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 7e55ae92bd..19e2fb0fa1 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -885,6 +885,25 @@ class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions, int GetVersion(const Operator& op) const override { return 1; } }; +class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions, + ::tflite::BuiltinOptions_ArgMinOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateArgMinOptions( + *builder, DataType::Serialize(op.output_data_type)); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->output_data_type = DataType::Deserialize(options.output_type()); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TransposeConv : public BuiltinOperator<TransposeConvOperator, ::tflite::TransposeConvOptions, @@ -1175,6 +1194,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { ops.emplace_back( new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax)); ops.emplace_back( + new ArgMin(::tflite::BuiltinOperator_ARG_MIN, OperatorType::kArgMin)); + ops.emplace_back( new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTile)); ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS, OperatorType::kExpandDims)); diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 8b6808d3c7..ff2d35b1f5 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -416,6 +416,13 @@ TEST_F(OperatorTest, BuiltinArgMax) { EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type); } +TEST_F(OperatorTest, BuiltinArgMin) { + ArgMinOperator op; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("ARG_MIN", OperatorType::kArgMin), op); + EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type); +} + TEST_F(OperatorTest, BuiltinTransposeConv) { TransposeConvOperator op; op.stride_width = 123; diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 8abdb014e4..4ec74e351f 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -387,6 +387,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Mean) HANDLE_OPERATORTYPENAME_CASE(Svdf) HANDLE_OPERATORTYPENAME_CASE(ArgMax) + HANDLE_OPERATORTYPENAME_CASE(ArgMin) HANDLE_OPERATORTYPENAME_CASE(TopK_V2) HANDLE_OPERATORTYPENAME_CASE(Unsupported) HANDLE_OPERATORTYPENAME_CASE(Exp) |