aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-10 01:53:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 01:56:38 -0700
commit7f85c95f71b01f711c366942a7cd911b0743b72c (patch)
treead92abbdf8a4c5736e323a75c3c153dadf3ce5c6 /tensorflow
parentc59bb780ebd1674ab34dd96d193c71698682ed4d (diff)
Implementation of arg_min.
PiperOrigin-RevId: 203908601
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/lite/build_def.bzl2
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h4
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h1
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md24
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max.cc10
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max_test.cc89
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
-rw-r--r--tensorflow/contrib/lite/model.cc9
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc1
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs6
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h141
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py8
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc7
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc19
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc10
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc13
-rw-r--r--tensorflow/contrib/lite/toco/model.h12
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc21
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc7
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1
21 files changed, 369 insertions, 25 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 90be7c5b57..b735d08b4b 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -195,7 +195,7 @@ def json_to_tflite(name, src, out):
def generated_test_models():
return [
"add",
- "arg_max",
+ "arg_min_max",
"avg_pool",
"batch_to_space_nd",
"concat",
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index cda889bf50..31c6cdfedb 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -250,6 +250,10 @@ typedef struct {
} TfLiteArgMaxParams;
typedef struct {
+ TfLiteType output_type;
+} TfLiteArgMinParams;
+
+typedef struct {
TfLitePadding padding;
int stride_width;
int stride_height;
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index a44e918230..30f1be2991 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -104,6 +104,7 @@ typedef enum {
kTfLiteBuiltinRsqrt = 76,
kTfLiteBuiltinShape = 77,
kTfLiteBuiltinPow = 78,
+ kTfLiteBuiltinArgMin = 79,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index dcd17bbeab..6ef669a040 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -790,6 +790,30 @@ Outputs {
}
```
+**ARG_MAX**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: A tensor of indices of maximum values.
+}
+```
+
+**ARG_MIN**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: A tensor of indices of minium values.
+}
+```
+
And these are TensorFlow Lite operations that are present but not ready for
custom models yet:
diff --git a/tensorflow/contrib/lite/kernels/arg_min_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc
index 2e2ec94fab..4f30d09030 100644
--- a/tensorflow/contrib/lite/kernels/arg_min_max.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc
@@ -177,6 +177,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
return kTfLiteOk;
}
+TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, false);
+}
+
TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
return Eval(context, node, true);
}
@@ -189,6 +193,12 @@ TfLiteRegistration* Register_ARG_MAX() {
return &r;
}
+TfLiteRegistration* Register_ARG_MIN() {
+ static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
+ arg_min_max::ArgMinEval};
+ return &r;
+}
+
} // namespace builtin
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/arg_min_max_test.cc b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc
index 31b15fe19a..90e5fdc532 100644
--- a/tensorflow/contrib/lite/kernels/arg_min_max_test.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc
@@ -24,16 +24,13 @@ namespace {
using ::testing::ElementsAreArray;
template <typename T>
-class ArgMaxOpModel : public SingleOpModel {
+class ArgBaseOpModel : public SingleOpModel {
public:
- ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
- TensorType output_type, TensorType index_output_type) {
+ ArgBaseOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type) {
input_ = AddInput(input_type);
axis_ = AddInput(TensorType_INT32);
output_ = AddOutput(output_type);
- SetBuiltinOp(BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
- CreateArgMaxOptions(builder_, index_output_type).Union());
- BuildInterpreter({input_shape, {1, 1, 1, 1}});
}
int input() { return input_; }
@@ -42,12 +39,42 @@ class ArgMaxOpModel : public SingleOpModel {
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
- private:
+ protected:
int input_;
int axis_;
int output_;
};
+template <typename T>
+class ArgMaxOpModel : public ArgBaseOpModel<T> {
+ public:
+ ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type)
+ : ArgBaseOpModel<T>(input_shape, input_type, output_type,
+ index_output_type) {
+ ArgBaseOpModel<T>::SetBuiltinOp(
+ BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
+ CreateArgMaxOptions(ArgBaseOpModel<T>::builder_, index_output_type)
+ .Union());
+ ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
+ }
+};
+
+template <typename T>
+class ArgMinOpModel : public ArgBaseOpModel<T> {
+ public:
+ ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type)
+ : ArgBaseOpModel<T>(input_shape, input_type, output_type,
+ index_output_type) {
+ ArgBaseOpModel<T>::SetBuiltinOp(
+ BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
+ CreateArgMinOptions(ArgBaseOpModel<T>::builder_, index_output_type)
+ .Union());
+ ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
+ }
+};
+
TEST(ArgMaxOpTest, GetMaxArgFloat) {
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
TensorType_INT32, TensorType_INT32);
@@ -96,6 +123,54 @@ TEST(ArgMaxOpTest, GetMaxArgOutput64) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
}
+TEST(ArgMinOpTest, GetMinArgFloat) {
+ ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
+ TensorType_INT32, TensorType_INT32);
+ model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgInt) {
+ ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32,
+ TensorType_INT32);
+ model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgMulDimensions) {
+ ArgMinOpModel<int32_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
+ TensorType_INT32);
+ model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgOutput64) {
+ ArgMinOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT64,
+ TensorType_INT64);
+ model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 0ca08cd8f3..df88a66d19 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -82,6 +82,7 @@ TfLiteRegistration* Register_PRELU();
TfLiteRegistration* Register_MAXIMUM();
TfLiteRegistration* Register_MINIMUM();
TfLiteRegistration* Register_ARG_MAX();
+TfLiteRegistration* Register_ARG_MIN();
TfLiteRegistration* Register_GREATER();
TfLiteRegistration* Register_GREATER_EQUAL();
TfLiteRegistration* Register_LESS();
@@ -167,6 +168,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
+ AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN());
AddBuiltin(BuiltinOperator_GREATER, Register_GREATER());
AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL());
AddBuiltin(BuiltinOperator_LESS, Register_LESS());
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 0bc717bfca..dbbc505906 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -663,6 +663,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_ARG_MIN: {
+ auto* params = MallocPOD<TfLiteArgMinParams>();
+ if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
+ ConvertTensorType(schema_params->output_type(), &params->output_type,
+ error_reporter);
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_TRANSPOSE_CONV: {
TfLiteTransposeConvParams* params =
MallocPOD<TfLiteTransposeConvParams>();
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 85279ba48e..b58466702c 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -590,6 +590,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_MAXIMUM:
case tflite::BuiltinOperator_MINIMUM:
case tflite::BuiltinOperator_ARG_MAX:
+ case tflite::BuiltinOperator_ARG_MIN:
case tflite::BuiltinOperator_GREATER:
case tflite::BuiltinOperator_GREATER_EQUAL:
case tflite::BuiltinOperator_LESS:
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 5e6467f676..097b88b1c1 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -160,6 +160,7 @@ enum BuiltinOperator : byte {
RSQRT = 76,
SHAPE = 77,
POW = 78,
+ ARG_MIN = 79,
}
// Options for the builtin operators.
@@ -220,6 +221,7 @@ union BuiltinOptions {
NotEqualOptions,
ShapeOptions,
PowOptions,
+ ArgMinOptions,
}
enum Padding : byte { SAME, VALID }
@@ -469,6 +471,10 @@ table ArgMaxOptions {
output_type : TensorType;
}
+table ArgMinOptions {
+ output_type : TensorType;
+}
+
table GreaterOptions {
}
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index fe0ff9a7a5..af923bae11 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -157,6 +157,9 @@ struct TileOptionsT;
struct ArgMaxOptions;
struct ArgMaxOptionsT;
+struct ArgMinOptions;
+struct ArgMinOptionsT;
+
struct GreaterOptions;
struct GreaterOptionsT;
@@ -343,11 +346,12 @@ enum BuiltinOperator {
BuiltinOperator_RSQRT = 76,
BuiltinOperator_SHAPE = 77,
BuiltinOperator_POW = 78,
+ BuiltinOperator_ARG_MIN = 79,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_POW
+ BuiltinOperator_MAX = BuiltinOperator_ARG_MIN
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[78] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[79] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -426,7 +430,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[78] {
BuiltinOperator_SQRT,
BuiltinOperator_RSQRT,
BuiltinOperator_SHAPE,
- BuiltinOperator_POW
+ BuiltinOperator_POW,
+ BuiltinOperator_ARG_MIN
};
return values;
}
@@ -512,6 +517,7 @@ inline const char **EnumNamesBuiltinOperator() {
"RSQRT",
"SHAPE",
"POW",
+ "ARG_MIN",
nullptr
};
return names;
@@ -580,11 +586,12 @@ enum BuiltinOptions {
BuiltinOptions_NotEqualOptions = 54,
BuiltinOptions_ShapeOptions = 55,
BuiltinOptions_PowOptions = 56,
+ BuiltinOptions_ArgMinOptions = 57,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_PowOptions
+ BuiltinOptions_MAX = BuiltinOptions_ArgMinOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[57] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[58] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -642,7 +649,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[57] {
BuiltinOptions_EqualOptions,
BuiltinOptions_NotEqualOptions,
BuiltinOptions_ShapeOptions,
- BuiltinOptions_PowOptions
+ BuiltinOptions_PowOptions,
+ BuiltinOptions_ArgMinOptions
};
return values;
}
@@ -706,6 +714,7 @@ inline const char **EnumNamesBuiltinOptions() {
"NotEqualOptions",
"ShapeOptions",
"PowOptions",
+ "ArgMinOptions",
nullptr
};
return names;
@@ -944,6 +953,10 @@ template<> struct BuiltinOptionsTraits<PowOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_PowOptions;
};
+template<> struct BuiltinOptionsTraits<ArgMinOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_ArgMinOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1423,6 +1436,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_PowOptions ?
reinterpret_cast<const PowOptionsT *>(value) : nullptr;
}
+ ArgMinOptionsT *AsArgMinOptions() {
+ return type == BuiltinOptions_ArgMinOptions ?
+ reinterpret_cast<ArgMinOptionsT *>(value) : nullptr;
+ }
+ const ArgMinOptionsT *AsArgMinOptions() const {
+ return type == BuiltinOptions_ArgMinOptions ?
+ reinterpret_cast<const ArgMinOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -4486,6 +4507,60 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(
flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct ArgMinOptionsT : public flatbuffers::NativeTable {
+ typedef ArgMinOptions TableType;
+ TensorType output_type;
+ ArgMinOptionsT()
+ : output_type(TensorType_FLOAT32) {
+ }
+};
+
+struct ArgMinOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ArgMinOptionsT NativeTableType;
+ enum {
+ VT_OUTPUT_TYPE = 4
+ };
+ TensorType output_type() const {
+ return static_cast<TensorType>(GetField<int8_t>(VT_OUTPUT_TYPE, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_OUTPUT_TYPE) &&
+ verifier.EndTable();
+ }
+ ArgMinOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ArgMinOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ArgMinOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_output_type(TensorType output_type) {
+ fbb_.AddElement<int8_t>(ArgMinOptions::VT_OUTPUT_TYPE, static_cast<int8_t>(output_type), 0);
+ }
+ explicit ArgMinOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ArgMinOptionsBuilder &operator=(const ArgMinOptionsBuilder &);
+ flatbuffers::Offset<ArgMinOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ArgMinOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ TensorType output_type = TensorType_FLOAT32) {
+ ArgMinOptionsBuilder builder_(_fbb);
+ builder_.add_output_type(output_type);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct GreaterOptionsT : public flatbuffers::NativeTable {
typedef GreaterOptions TableType;
GreaterOptionsT() {
@@ -5413,6 +5488,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const PowOptions *builtin_options_as_PowOptions() const {
return builtin_options_type() == BuiltinOptions_PowOptions ? static_cast<const PowOptions *>(builtin_options()) : nullptr;
}
+ const ArgMinOptions *builtin_options_as_ArgMinOptions() const {
+ return builtin_options_type() == BuiltinOptions_ArgMinOptions ? static_cast<const ArgMinOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -5668,6 +5746,10 @@ template<> inline const PowOptions *Operator::builtin_options_as<PowOptions>() c
return builtin_options_as_PowOptions();
}
+template<> inline const ArgMinOptions *Operator::builtin_options_as<ArgMinOptions>() const {
+ return builtin_options_as_ArgMinOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -7333,6 +7415,32 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatB
_output_type);
}
+inline ArgMinOptionsT *ArgMinOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new ArgMinOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void ArgMinOptions::UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = output_type(); _o->output_type = _e; };
+}
+
+inline flatbuffers::Offset<ArgMinOptions> ArgMinOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateArgMinOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ArgMinOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _output_type = _o->output_type;
+ return tflite::CreateArgMinOptions(
+ _fbb,
+ _output_type);
+}
+
inline GreaterOptionsT *GreaterOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new GreaterOptionsT();
UnPackTo(_o, _resolver);
@@ -8083,6 +8191,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const PowOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<const ArgMinOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -8325,6 +8437,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const PowOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<const ArgMinOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -8555,6 +8671,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const PowOptionsT *>(value);
return CreatePowOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<const ArgMinOptionsT *>(value);
+ return CreateArgMinOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -8785,6 +8905,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new PowOptionsT(*reinterpret_cast<PowOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_ArgMinOptions: {
+ value = new ArgMinOptionsT(*reinterpret_cast<ArgMinOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -9072,6 +9196,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<ArgMinOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 35e0328e9b..1093bd2cbe 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -2224,7 +2224,7 @@ def make_topk_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
-def make_arg_max_tests(zip_path):
+def make_arg_min_max_tests(zip_path):
"""Make a set of tests to do arg_max."""
test_parameters = [{
@@ -2232,6 +2232,7 @@ def make_arg_max_tests(zip_path):
"input_shape": [[1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]],
"output_type": [tf.int32, tf.int64],
"axis_is_last_dim": [True, False],
+ "is_arg_max": [True],
}]
def build_graph(parameters):
@@ -2244,7 +2245,10 @@ def make_arg_max_tests(zip_path):
axis = len(parameters["input_shape"]) - 1
else:
axis = random.randint(0, max(len(parameters["input_shape"]) - 2, 0))
- out = tf.arg_max(input_value, axis, output_type=parameters["output_type"])
+ if parameters["is_arg_max"]:
+ out = tf.arg_max(input_value, axis, output_type=parameters["output_type"])
+ else:
+ out = tf.arg_min(input_value, axis, output_type=parameters["output_type"])
return [input_value], [out]
def build_inputs(parameters, sess, inputs, outputs):
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index c4e20312d8..5bc6b53416 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -97,11 +97,12 @@ std::map<string, string> kBrokenTests = {
{R"(^\/gather.*axis=1)", "76910444"},
// No support for arbitrary dimensions in ArgMax.
- {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])",
+ {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])",
"77546240"},
- {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])",
+ {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])",
+ "77546240"},
+ {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.\])",
"77546240"},
- {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.\])", "77546240"},
};
// Allows test data to be unzipped into a temporary directory and makes
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 6be6b25f93..a08cdbfba6 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1135,6 +1135,22 @@ void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op,
GetTensorFlowDataType(model, src_op.outputs[0]));
}
+void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node();
+ argmin_op->set_op("ArgMin");
+ argmin_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *argmin_op->add_input() = src_op.inputs[0];
+ *argmin_op->add_input() = src_op.inputs[1];
+ (*argmin_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.inputs[0]));
+ (*argmin_op->mutable_attr())["Tidx"].set_type(
+ GetTensorFlowDataType(model, src_op.inputs[1]));
+ (*argmin_op->mutable_attr())["output_type"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
+}
+
void ConvertTransposeOperator(const Model& model,
const TransposeOperator& src_op,
GraphDef* tensorflow_graph) {
@@ -1964,6 +1980,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kArgMax) {
ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kArgMin) {
+ ConvertArgMinOperator(model, static_cast<const ArgMinOperator&>(src_op),
+ tensorflow_graph);
} else if (src_op.type == OperatorType::kTopK_V2) {
ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op),
tensorflow_graph);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 00ab7cbaa9..670bcf64e7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -100,6 +100,13 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
break;
}
+ case OperatorType::kArgMin: {
+ // Data type of the ArgMin op is specified.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto* argmin_op = static_cast<ArgMinOperator*>(op);
+ model->GetArray(op->outputs[0]).data_type = argmin_op->output_data_type;
+ break;
+ }
case OperatorType::kRange: {
auto* range_op = static_cast<RangeOperator*>(op);
// Output type of the Range op can be set via an attribute
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 8eb0423283..4f95c57451 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1404,7 +1404,8 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
}
}
-void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
+template <typename Op>
+void ProcessArgMinMaxOperator(Model* model, Op* op) {
CHECK_EQ(op->inputs.size(), 2);
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
@@ -1696,7 +1697,12 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
static_cast<StridedSliceOperator*>(op));
break;
case OperatorType::kArgMax:
- ProcessArgMaxOperator(model, static_cast<ArgMaxOperator*>(op));
+ ProcessArgMinMaxOperator<ArgMaxOperator>(
+ model, static_cast<ArgMaxOperator*>(op));
+ break;
+ case OperatorType::kArgMin:
+ ProcessArgMinMaxOperator<ArgMinOperator>(
+ model, static_cast<ArgMinOperator*>(op));
break;
case OperatorType::kUnsupported:
break;
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 5c32a39035..bc439a2feb 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1230,10 +1230,11 @@ tensorflow::Status ConvertGatherOperator(
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertArgMaxOperator(
+template <typename Op, const char* op_name>
+tensorflow::Status ConvertArgMinMaxOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
- CHECK_EQ(node.op(), "ArgMax");
+ CHECK_EQ(node.op(), op_name);
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
const auto axis_data_type =
HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
@@ -1242,7 +1243,7 @@ tensorflow::Status ConvertArgMaxOperator(
: DT_INT64;
CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32);
CHECK(output_type == DT_INT64 || output_type == DT_INT32);
- auto* op = new ArgMaxOperator;
+ auto* op = new Op;
op->output_data_type = ConvertDataType(output_type);
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
@@ -1833,12 +1834,16 @@ using ConverterType = tensorflow::Status (*)(
Model* model);
using ConverterMapType = std::unordered_map<std::string, ConverterType>;
+constexpr char kArgMax[] = "ArgMax";
+constexpr char kArgMin[] = "ArgMin";
+
ConverterMapType GetTensorFlowNodeConverterMap() {
return std::unordered_map<std::string, ConverterType>({
{"Add", ConvertSimpleOperator<AddOperator, 2>},
{"AddN", ConvertSimpleOperator<AddNOperator>},
{"All", ConvertSimpleOperator<TensorFlowAllOperator>},
- {"ArgMax", ConvertArgMaxOperator},
+ {"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>},
+ {"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>},
{"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>},
{"AvgPool", ConvertAvgPoolOperator},
{"BatchMatMul", ConvertBatchMatMulOperator},
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 3a1d243f87..8660464fdb 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -140,6 +140,7 @@ enum class OperatorType : uint8 {
kEqual,
kNotEqual,
kPow,
+ kArgMin,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -1528,6 +1529,17 @@ struct ArgMaxOperator : Operator {
ArrayDataType output_data_type = ArrayDataType::kInt64;
};
+// ArgMin operator. It returns the index of the minimum value along axis.
+//
+// Inputs:
+// inputs[0]: required: the input tensor
+//
+// TensorFlow equivalent: ArgMin
+struct ArgMinOperator : Operator {
+ ArgMinOperator() : Operator(OperatorType::kArgMin) {}
+ ArrayDataType output_data_type = ArrayDataType::kInt64;
+};
+
// ResizeBilinear operator. It resizes input images with bilinear interpolation.
// It does not support align_corners at the moment.
//
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 7e55ae92bd..19e2fb0fa1 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -885,6 +885,25 @@ class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
int GetVersion(const Operator& op) const override { return 1; }
};
+class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
+ ::tflite::BuiltinOptions_ArgMinOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateArgMinOptions(
+ *builder, DataType::Serialize(op.output_data_type));
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->output_data_type = DataType::Deserialize(options.output_type());
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class TransposeConv
: public BuiltinOperator<TransposeConvOperator,
::tflite::TransposeConvOptions,
@@ -1175,6 +1194,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(
new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax));
ops.emplace_back(
+ new ArgMin(::tflite::BuiltinOperator_ARG_MIN, OperatorType::kArgMin));
+ ops.emplace_back(
new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS,
OperatorType::kExpandDims));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 8b6808d3c7..ff2d35b1f5 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -416,6 +416,13 @@ TEST_F(OperatorTest, BuiltinArgMax) {
EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
}
+TEST_F(OperatorTest, BuiltinArgMin) {
+ ArgMinOperator op;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("ARG_MIN", OperatorType::kArgMin), op);
+ EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
+}
+
TEST_F(OperatorTest, BuiltinTransposeConv) {
TransposeConvOperator op;
op.stride_width = 123;
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 8abdb014e4..4ec74e351f 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -387,6 +387,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Mean)
HANDLE_OPERATORTYPENAME_CASE(Svdf)
HANDLE_OPERATORTYPENAME_CASE(ArgMax)
+ HANDLE_OPERATORTYPENAME_CASE(ArgMin)
HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
HANDLE_OPERATORTYPENAME_CASE(Unsupported)
HANDLE_OPERATORTYPENAME_CASE(Exp)