aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-13 00:12:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-13 00:14:59 -0700
commit1b0c277405171a34c7f41e17cd76459dc36f7f82 (patch)
treeffed1ee7e0314ca98a8aed222b9d0214bf6ae21a /tensorflow
parent73cc1d5b6f95ff56207e4c42b62d383c2427fb75 (diff)
Implementation of Less
PiperOrigin-RevId: 192728635
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h1
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md13
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD19
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc119
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons_test.cc98
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h45
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
-rw-r--r--tensorflow/contrib/lite/model.cc3
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc1
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs5
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h124
-rw-r--r--tensorflow/contrib/lite/testing/BUILD1
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py33
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc2
16 files changed, 463 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 1ceefafc56..859bc7ab70 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -82,6 +82,7 @@ typedef enum {
kTfLiteBuiltinMaximum = 55,
kTfLiteBuiltinArgMax = 56,
kTfLiteBuiltinMinimum = 57,
+ kTfLiteBuiltinLess = 58,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index 61ea5231e3..203924f03d 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -302,6 +302,19 @@ Options {
}
```
+**LESS**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: a tensor of type bool, true whenever an element of the first tensor is less
+ than the corresponding element of the second tensor.
+}
+```
+
**LOCAL_RESPONSE_NORMALIZATION**
```
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 914893cd90..800e2a9558 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -136,6 +136,7 @@ cc_library(
"bidirectional_sequence_lstm.cc",
"bidirectional_sequence_rnn.cc",
"cast.cc",
+ "comparisons.cc",
"concatenation.cc",
"conv.cc",
"depthwise_conv.cc",
@@ -818,6 +819,24 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "comparisons_test",
+ size = "small",
+ srcs = [
+ "comparisons_test.cc",
+ ],
+ tags = [
+ "tflite_not_portable_ios_arm64",
+ "tflite_not_portable_ios_x86_64",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
new file mode 100644
index 0000000000..87c413cb98
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -0,0 +1,119 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace comparisons {
+
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus LessPrepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // Don't support string and bool.
+ TF_LITE_ENSURE(context,
+ input1->type != kTfLiteString || input1->type != kTfLiteBool);
+ // Currently only support tensors have the same type.
+ TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+ output->type = kTfLiteBool;
+
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+
+ TfLiteIntArray* output_size = nullptr;
+ if (requires_broadcast) {
+ TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+ context, input1, input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(input1->dims);
+ }
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+
+#define TF_LITE_LESS(type, opname) \
+ reference_ops::opname(GetTensorData<type>(input1), GetTensorDims(input1), \
+ GetTensorData<type>(input2), GetTensorDims(input2), \
+ GetTensorData<bool>(output), GetTensorDims(output));
+
+ // TODO(renjieliu): Support quantized data.
+ if (requires_broadcast) {
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_LESS(float, BroadcastLess);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_LESS(int32_t, BroadcastLess);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_LESS(int64_t, BroadcastLess);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type other than float|int");
+ return kTfLiteError;
+ }
+ } else {
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_LESS(float, Less);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_LESS(int32_t, Less);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_LESS(int64_t, Less);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type other than float|int");
+ return kTfLiteError;
+ }
+ }
+#undef TF_LITE_LESS
+ return kTfLiteOk;
+}
+
+} // namespace comparisons
+
+TfLiteRegistration* Register_LESS() {
+ static TfLiteRegistration r = {nullptr, nullptr, comparisons::LessPrepare,
+ comparisons::LessEval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc
new file mode 100644
index 0000000000..da2d7f8589
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc
@@ -0,0 +1,98 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class LessOpModel : public SingleOpModel {
+ public:
+ LessOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape, TensorType input_type) {
+ input1_ = AddInput(input_type);
+ input2_ = AddInput(input_type);
+ output_ = AddOutput(TensorType_BOOL);
+ SetBuiltinOp(BuiltinOperator_LESS, BuiltinOptions_LessOptions,
+ CreateLessOptions(builder_).Union());
+ BuildInterpreter({input1_shape, input2_shape});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(ArgMaxOpTest, LessFloat) {
+ LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, false, false, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ArgMaxOpTest, LessInt) {
+ LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 6, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ArgMaxOpTest, LessBroadcast) {
+ LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ArgMaxOpTest, LessBroadcastTwoD) {
+ LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true,
+ true, false, false, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index c6019390f2..6a89dbc803 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -3378,6 +3378,51 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
}
}
+template <typename T>
+inline void Less(int64_t num_elements, const T* input1, const T* input2,
+ bool* output) {
+ for (int64_t i = 0; i < num_elements; ++i) {
+ output[i] = input1[i] < input2[i];
+ }
+}
+
+template <typename T>
+inline void Less(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims) {
+ const int64_t batches =
+ MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
+ const int64_t height =
+ MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
+ const int64_t width =
+ MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
+ const int64_t depth =
+ MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
+ Less(batches * height * width * depth, input1_data, input2_data, output_data);
+}
+
+template <typename T1, typename T2>
+inline void BroadcastLess(T1* input1_data, const Dims<4>& input1_dims,
+ T2* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastLess");
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] <
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ }
+ }
+ }
+ }
+}
+
} // namespace reference_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 67ba8d0f39..b07e7b6ff3 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -79,6 +79,7 @@ TfLiteRegistration* Register_PRELU();
TfLiteRegistration* Register_MAXIMUM();
TfLiteRegistration* Register_MINIMUM();
TfLiteRegistration* Register_ARG_MAX();
+TfLiteRegistration* Register_LESS();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -139,6 +140,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
+ AddBuiltin(BuiltinOperator_LESS, Register_LESS());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 0b65884025..54b1460173 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -665,6 +665,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_LESS: {
+ break;
+ }
case BuiltinOperator_DELEGATE: {
// TODO(ycling): Revisit when supporting saving delegated models.
error_reporter->Report("DELEGATE op shouldn't exist in model.");
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 08fb820767..eab82ea8ef 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -353,6 +353,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_MAXIMUM:
case tflite::BuiltinOperator_MINIMUM:
case tflite::BuiltinOperator_ARG_MAX:
+ case tflite::BuiltinOperator_LESS:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index fa825500fd..93980b15f0 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -135,6 +135,7 @@ enum BuiltinOperator : byte {
MAXIMUM = 55,
ARG_MAX = 56,
MINIMUM = 57,
+ LESS = 58,
}
// Options for the builtin operators.
@@ -179,6 +180,7 @@ union BuiltinOptions {
DequantizeOptions,
MaximumMinimumOptions,
ArgMaxOptions,
+ LessOptions,
}
enum Padding : byte { SAME, VALID }
@@ -399,6 +401,9 @@ table ArgMaxOptions {
output_type : TensorType;
}
+table LessOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 909c4ccb3b..b2a799d0ef 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -151,6 +151,9 @@ struct MaximumMinimumOptionsT;
struct ArgMaxOptions;
struct ArgMaxOptionsT;
+struct LessOptions;
+struct LessOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -267,11 +270,12 @@ enum BuiltinOperator {
BuiltinOperator_MAXIMUM = 55,
BuiltinOperator_ARG_MAX = 56,
BuiltinOperator_MINIMUM = 57,
+ BuiltinOperator_LESS = 58,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_MINIMUM
+ BuiltinOperator_MAX = BuiltinOperator_LESS
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[56] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[57] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -328,7 +332,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[56] {
BuiltinOperator_PRELU,
BuiltinOperator_MAXIMUM,
BuiltinOperator_ARG_MAX,
- BuiltinOperator_MINIMUM
+ BuiltinOperator_MINIMUM,
+ BuiltinOperator_LESS
};
return values;
}
@@ -393,6 +398,7 @@ inline const char **EnumNamesBuiltinOperator() {
"MAXIMUM",
"ARG_MAX",
"MINIMUM",
+ "LESS",
nullptr
};
return names;
@@ -445,11 +451,12 @@ enum BuiltinOptions {
BuiltinOptions_DequantizeOptions = 38,
BuiltinOptions_MaximumMinimumOptions = 39,
BuiltinOptions_ArgMaxOptions = 40,
+ BuiltinOptions_LessOptions = 41,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_ArgMaxOptions
+ BuiltinOptions_MAX = BuiltinOptions_LessOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[41] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[42] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -491,7 +498,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[41] {
BuiltinOptions_CastOptions,
BuiltinOptions_DequantizeOptions,
BuiltinOptions_MaximumMinimumOptions,
- BuiltinOptions_ArgMaxOptions
+ BuiltinOptions_ArgMaxOptions,
+ BuiltinOptions_LessOptions
};
return values;
}
@@ -539,6 +547,7 @@ inline const char **EnumNamesBuiltinOptions() {
"DequantizeOptions",
"MaximumMinimumOptions",
"ArgMaxOptions",
+ "LessOptions",
nullptr
};
return names;
@@ -713,6 +722,10 @@ template<> struct BuiltinOptionsTraits<ArgMaxOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_ArgMaxOptions;
};
+template<> struct BuiltinOptionsTraits<LessOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_LessOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1064,6 +1077,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_ArgMaxOptions ?
reinterpret_cast<const ArgMaxOptionsT *>(value) : nullptr;
}
+ LessOptionsT *AsLessOptions() {
+ return type == BuiltinOptions_LessOptions ?
+ reinterpret_cast<LessOptionsT *>(value) : nullptr;
+ }
+ const LessOptionsT *AsLessOptions() const {
+ return type == BuiltinOptions_LessOptions ?
+ reinterpret_cast<const LessOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -3927,6 +3948,46 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(
flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct LessOptionsT : public flatbuffers::NativeTable {
+ typedef LessOptions TableType;
+ LessOptionsT() {
+ }
+};
+
+struct LessOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef LessOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ LessOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(LessOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<LessOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct LessOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit LessOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ LessOptionsBuilder &operator=(const LessOptionsBuilder &);
+ flatbuffers::Offset<LessOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<LessOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<LessOptions> CreateLessOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ LessOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<LessOptions> CreateLessOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -4164,6 +4225,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const ArgMaxOptions *builtin_options_as_ArgMaxOptions() const {
return builtin_options_type() == BuiltinOptions_ArgMaxOptions ? static_cast<const ArgMaxOptions *>(builtin_options()) : nullptr;
}
+ const LessOptions *builtin_options_as_LessOptions() const {
+ return builtin_options_type() == BuiltinOptions_LessOptions ? static_cast<const LessOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -4350,6 +4414,10 @@ template<> inline const ArgMaxOptions *Operator::builtin_options_as<ArgMaxOption
return builtin_options_as_ArgMaxOptions();
}
+template<> inline const LessOptions *Operator::builtin_options_as<LessOptions>() const {
+ return builtin_options_as_LessOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -5933,6 +6001,29 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatB
_output_type);
}
+inline LessOptionsT *LessOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new LessOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void LessOptions::UnPackTo(LessOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<LessOptions> LessOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateLessOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<LessOptions> CreateLessOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LessOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateLessOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -6273,6 +6364,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const ArgMaxOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_LessOptions: {
+ auto ptr = reinterpret_cast<const LessOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -6451,6 +6546,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const ArgMaxOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_LessOptions: {
+ auto ptr = reinterpret_cast<const LessOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -6617,6 +6716,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const ArgMaxOptionsT *>(value);
return CreateArgMaxOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_LessOptions: {
+ auto ptr = reinterpret_cast<const LessOptionsT *>(value);
+ return CreateLessOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -6783,6 +6886,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new ArgMaxOptionsT(*reinterpret_cast<ArgMaxOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_LessOptions: {
+ value = new LessOptionsT(*reinterpret_cast<LessOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -6990,6 +7097,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_LessOptions: {
+ auto ptr = reinterpret_cast<LessOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 2c226e76d4..bd888a415b 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -34,6 +34,7 @@ gen_zipped_test_files(
"global_batch_norm.zip",
"l2_pool.zip",
"l2norm.zip",
+ "less.zip",
"local_response_norm.zip",
"log_softmax.zip",
"max_pool.zip",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 4b4ccc0c37..53b41d2358 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -1997,6 +1997,39 @@ def make_arg_max_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_less_tests(zip_path):
+ """Make a set of tests to do less."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the less op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input2",
+ shape=parameters["input_shape_pair"][1])
+ out = tf.less(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
# Toco binary path provided by the generate rule.
bin_path = None
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 84ae1d58fe..9da8bd7a28 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -280,6 +280,7 @@ INSTANTIATE_TESTS(squeeze)
INSTANTIATE_TESTS(strided_slice)
INSTANTIATE_TESTS(sub)
INSTANTIATE_TESTS(transpose)
+INSTANTIATE_TESTS(less)
} // namespace testing
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 0e057fd252..f41a312b47 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -895,6 +895,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
"MAXIMUM", OperatorType::kTensorFlowMaximum));
ops.emplace_back(new SimpleOperator<TensorFlowMinimumOperator>(
"MINIMUM", OperatorType::kTensorFlowMinimum));
+ ops.emplace_back(new SimpleOperator<TensorFlowLessOperator>(
+ "LESS", OperatorType::kTensorFlowLess));
return ops;
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index a947630e28..36ed741541 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -113,6 +113,8 @@ TEST_F(OperatorTest, SimpleOperators) {
"MAXIMUM", OperatorType::kTensorFlowMaximum);
CheckSimpleOperator<TensorFlowMinimumOperator>(
"MINIMUM", OperatorType::kTensorFlowMinimum);
+ CheckSimpleOperator<TensorFlowLessOperator>("LESS",
+ OperatorType::kTensorFlowLess);
}
TEST_F(OperatorTest, BuiltinAdd) {