aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h3
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md39
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc160
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons_test.cc207
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD1
-rw-r--r--tensorflow/contrib/lite/kernels/internal/common.h14
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h11
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h191
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc6
-rw-r--r--tensorflow/contrib/lite/model.cc5
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc3
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs15
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h348
-rw-r--r--tensorflow/contrib/lite/testing/BUILD3
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py102
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc3
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc21
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc9
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc7
20 files changed, 1051 insertions, 103 deletions
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index d66b72843a..778933f569 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -86,6 +86,9 @@ typedef enum {
kTfLiteBuiltinLess = 58,
kTfLiteBuiltinNeg = 59,
kTfLiteBuiltinPadv2 = 60,
+ kTfLiteBuiltinGreater = 61,
+ kTfLiteBuiltinGreaterEqual = 62,
+ kTfLiteBuiltinLessEqual = 63,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index 0051ee84ec..fc57b8f28b 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -281,6 +281,32 @@ Options {
}
```
+**GREATER**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: a tensor of type bool, true whenever an element of the first tensor is
+ greater than the corresponding element of the second tensor.
+}
+```
+
+**GREATER_EQUAL**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: a tensor of type bool, true whenever an element of the first tensor is
+ greater than or equal to the corresponding element of the second tensor.
+}
+```
+
**L2_NORMALIZATION**
```
@@ -325,6 +351,19 @@ Outputs {
}
```
+**LESS_EQUAL**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: a tensor of type bool, true whenever an element of the first tensor is less
+ than or equal to the corresponding element of the second tensor.
+}
+```
+
**LOCAL_RESPONSE_NORMALIZATION**
```
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 87c413cb98..2885ce032b 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -28,7 +28,7 @@ constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
constexpr int kOutputTensor = 0;
-TfLiteStatus LessPrepare(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -56,61 +56,139 @@ TfLiteStatus LessPrepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, output_size);
}
-TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
+#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \
+ requires_broadcast \
+ ? reference_ops::Broadcast##opname( \
+ GetTensorData<type>(input1), GetTensorDims(input1), \
+ GetTensorData<type>(input2), GetTensorDims(input2), \
+ GetTensorData<bool>(output), GetTensorDims(output)) \
+ : reference_ops::opname( \
+ GetTensorData<type>(input1), GetTensorDims(input1), \
+ GetTensorData<type>(input2), GetTensorDims(input2), \
+ GetTensorData<bool>(output), GetTensorDims(output));
+
+TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, Greater, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, Greater, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type other than float|int");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, GreaterEqual, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type other than float|int");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
-#define TF_LITE_LESS(type, opname) \
- reference_ops::opname(GetTensorData<type>(input1), GetTensorDims(input1), \
- GetTensorData<type>(input2), GetTensorDims(input2), \
- GetTensorData<bool>(output), GetTensorDims(output));
+TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, Less, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, Less, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, Less, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type other than float|int");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
// TODO(renjieliu): Support quantized data.
- if (requires_broadcast) {
- switch (input1->type) {
- case kTfLiteFloat32:
- TF_LITE_LESS(float, BroadcastLess);
- break;
- case kTfLiteInt32:
- TF_LITE_LESS(int32_t, BroadcastLess);
- break;
- case kTfLiteInt64:
- TF_LITE_LESS(int64_t, BroadcastLess);
- break;
- default:
- context->ReportError(context,
- "Does not support type other than float|int");
- return kTfLiteError;
- }
- } else {
- switch (input1->type) {
- case kTfLiteFloat32:
- TF_LITE_LESS(float, Less);
- break;
- case kTfLiteInt32:
- TF_LITE_LESS(int32_t, Less);
- break;
- case kTfLiteInt64:
- TF_LITE_LESS(int64_t, Less);
- break;
- default:
- context->ReportError(context,
- "Does not support type other than float|int");
- return kTfLiteError;
- }
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, LessEqual, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, LessEqual, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type other than float|int");
+ return kTfLiteError;
}
-#undef TF_LITE_LESS
return kTfLiteOk;
}
} // namespace comparisons
+TfLiteRegistration* Register_GREATER() {
+ static TfLiteRegistration r = {nullptr, nullptr,
+ comparisons::ComparisonPrepare,
+ comparisons::GreaterEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_GREATER_EQUAL() {
+ static TfLiteRegistration r = {nullptr, nullptr,
+ comparisons::ComparisonPrepare,
+ comparisons::GreaterEqualEval};
+ return &r;
+}
+
TfLiteRegistration* Register_LESS() {
- static TfLiteRegistration r = {nullptr, nullptr, comparisons::LessPrepare,
- comparisons::LessEval};
+ static TfLiteRegistration r = {
+ nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::LessEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_LESS_EQUAL() {
+ static TfLiteRegistration r = {nullptr, nullptr,
+ comparisons::ComparisonPrepare,
+ comparisons::LessEqualEval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc
index da2d7f8589..835d238d36 100644
--- a/tensorflow/contrib/lite/kernels/comparisons_test.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc
@@ -23,6 +23,139 @@ namespace {
using ::testing::ElementsAreArray;
+class GreaterOpModel : public SingleOpModel {
+ public:
+ GreaterOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape,
+ TensorType input_type) {
+ input1_ = AddInput(input_type);
+ input2_ = AddInput(input_type);
+ output_ = AddOutput(TensorType_BOOL);
+ SetBuiltinOp(BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions,
+ CreateGreaterOptions(builder_).Union());
+ BuildInterpreter({input1_shape, input2_shape});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(ComparisonsTest, GreaterFloat) {
+ GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterInt) {
+ GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterBroadcast) {
+ GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterBroadcastTwoD) {
+ GreaterOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false,
+ false, true, false, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+}
+
+class GreaterEqualOpModel : public SingleOpModel {
+ public:
+ GreaterEqualOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape,
+ TensorType input_type) {
+ input1_ = AddInput(input_type);
+ input2_ = AddInput(input_type);
+ output_ = AddOutput(TensorType_BOOL);
+ SetBuiltinOp(BuiltinOperator_GREATER_EQUAL,
+ BuiltinOptions_GreaterEqualOptions,
+ CreateGreaterEqualOptions(builder_).Union());
+ BuildInterpreter({input1_shape, input2_shape});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(ComparisonsTest, GreaterEqualFloat) {
+ GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, true, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterEqualInt) {
+ GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterEqualBroadcast) {
+ GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterEqualBroadcastTwoD) {
+ GreaterEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false,
+ false, true, true, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+}
+
class LessOpModel : public SingleOpModel {
public:
LessOpModel(std::initializer_list<int> input1_shape,
@@ -47,7 +180,7 @@ class LessOpModel : public SingleOpModel {
int output_;
};
-TEST(ArgMaxOpTest, LessFloat) {
+TEST(ComparisonsTest, LessFloat) {
LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
@@ -57,7 +190,7 @@ TEST(ArgMaxOpTest, LessFloat) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
-TEST(ArgMaxOpTest, LessInt) {
+TEST(ComparisonsTest, LessInt) {
LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 6, 5});
@@ -67,7 +200,7 @@ TEST(ArgMaxOpTest, LessInt) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
-TEST(ArgMaxOpTest, LessBroadcast) {
+TEST(ComparisonsTest, LessBroadcast) {
LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
@@ -77,7 +210,7 @@ TEST(ArgMaxOpTest, LessBroadcast) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
-TEST(ArgMaxOpTest, LessBroadcastTwoD) {
+TEST(ComparisonsTest, LessBroadcastTwoD) {
LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
@@ -88,6 +221,72 @@ TEST(ArgMaxOpTest, LessBroadcastTwoD) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
}
+class LessEqualOpModel : public SingleOpModel {
+ public:
+ LessEqualOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape,
+ TensorType input_type) {
+ input1_ = AddInput(input_type);
+ input2_ = AddInput(input_type);
+ output_ = AddOutput(TensorType_BOOL);
+ SetBuiltinOp(BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions,
+ CreateLessEqualOptions(builder_).Union());
+ BuildInterpreter({input1_shape, input2_shape});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(ComparisonsTest, LessEqualFloat) {
+ LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, LessEqualInt) {
+ LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, LessEqualBroadcast) {
+ LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, LessEqualBroadcastTwoD) {
+ LessEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true,
+ true, false, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index df29172f83..7ec4782f96 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -157,6 +157,7 @@ cc_library(
":quantization_util",
":strided_slice_logic",
":types",
+ ":reference_base",
":round",
"//third_party/eigen3",
"@gemmlowp",
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
index 18601df22c..ede95dfee0 100644
--- a/tensorflow/contrib/lite/kernels/internal/common.h
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -113,6 +113,20 @@ inline int32 MultiplyByQuantizedMultiplier(int32 x, int32 quantized_multiplier,
right_shift);
}
+template <typename T>
+int CountLeadingZeros(T integer_input) {
+ static_assert(std::is_unsigned<T>::value,
+ "Only unsigned integer types handled.");
+ const T one_in_leading_positive = static_cast<T>(1)
+ << (std::numeric_limits<T>::digits - 1);
+ int leading_zeros = 0;
+ while (integer_input < one_in_leading_positive) {
+ integer_input <<= 1;
+ ++leading_zeros;
+ }
+ return leading_zeros;
+}
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index e2a1a6996d..c506c5636c 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -31,6 +31,7 @@ limitations under the License.
#include "public/gemmlowp.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -38,6 +39,16 @@ limitations under the License.
namespace tflite {
namespace optimized_ops {
+// Unoptimized reference ops:
+using reference_ops::BroadcastGreater;
+using reference_ops::BroadcastGreaterEqual;
+using reference_ops::BroadcastLess;
+using reference_ops::BroadcastLessEqual;
+using reference_ops::Greater;
+using reference_ops::GreaterEqual;
+using reference_ops::Less;
+using reference_ops::LessEqual;
+
// Make a local VectorMap typedef allowing to map a float array
// as a Eigen vector expression. The std::conditional here is to
// construct the suitable Eigen type for the constness of the
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 05e6ca8e7e..93dba1cc8e 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -35,35 +35,6 @@ limitations under the License.
namespace tflite {
namespace reference_ops {
-inline int32 MultiplyByQuantizedMultiplierSmallerThanOne(
- int32 x, int32 quantized_multiplier, int right_shift) {
- using gemmlowp::RoundingDivideByPOT;
- using gemmlowp::SaturatingRoundingDoublingHighMul;
- return RoundingDivideByPOT(
- SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift);
-}
-
-inline int32 MultiplyByQuantizedMultiplierGreaterThanOne(
- int32 x, int32 quantized_multiplier, int left_shift) {
- using gemmlowp::SaturatingRoundingDoublingHighMul;
- return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
- quantized_multiplier);
-}
-
-template <typename T>
-int CountLeadingZeros(T integer_input) {
- static_assert(std::is_unsigned<T>::value,
- "Only unsigned integer types handled.");
- const T one_in_leading_positive = static_cast<T>(1)
- << (std::numeric_limits<T>::digits - 1);
- int leading_zeros = 0;
- while (integer_input < one_in_leading_positive) {
- integer_input <<= 1;
- ++leading_zeros;
- }
- return leading_zeros;
-}
-
// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
// BROADCASTING.
//
@@ -3614,17 +3585,29 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void Less(int64_t num_elements, const T* input1, const T* input2,
- bool* output) {
- for (int64_t i = 0; i < num_elements; ++i) {
- output[i] = input1[i] < input2[i];
- }
+inline bool GreaterFn(T lhs, T rhs) {
+ return lhs > rhs;
+}
+template <typename T>
+inline bool GreaterEqualFn(T lhs, T rhs) {
+ return lhs >= rhs;
+}
+template <typename T>
+inline bool LessFn(T lhs, T rhs) {
+ return lhs < rhs;
+}
+template <typename T>
+inline bool LessEqualFn(T lhs, T rhs) {
+ return lhs <= rhs;
}
template <typename T>
-inline void Less(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- bool* output_data, const Dims<4>& output_dims) {
+using ComparisonFn = bool (*)(T, T);
+
+template <typename T, ComparisonFn<T> F>
+inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims) {
const int64_t batches =
MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
const int64_t height =
@@ -3633,31 +3616,149 @@ inline void Less(const T* input1_data, const Dims<4>& input1_dims,
MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
const int64_t depth =
MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
- Less(batches * height * width * depth, input1_data, input2_data, output_data);
+ for (int64_t i = 0; i < batches * height * width * depth; ++i) {
+ output_data[i] = F(input1_data[i], input2_data[i]);
+ }
}
-template <typename T1, typename T2>
-inline void BroadcastLess(T1* input1_data, const Dims<4>& input1_dims,
- T2* input2_data, const Dims<4>& input2_dims,
- bool* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastLess");
+template <typename T, ComparisonFn<T> F>
+inline void Comparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, bool* output_data,
+ const Dims<4>& output_dims) {
+ const int64_t batches =
+ MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
+ const int64_t height =
+ MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
+ const int64_t width =
+ MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
+ const int64_t depth =
+ MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
+ for (int64_t i = 0; i < batches * height * width * depth; ++i) {
+ const int32 input1_val = input1_offset + input1_data[i];
+ const int32 input2_val = input2_offset + input2_data[i];
+ const int32 shifted_input1_val = input1_val * (1 << left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input1_val, input1_multiplier, input1_shift);
+ const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input2_val, input2_multiplier, input2_shift);
+ output_data[i] = F(scaled_input1_val, scaled_input2_val);
+ }
+}
+
+template <typename T, ComparisonFn<T> F>
+inline void BroadcastComparison(const T* input1_data,
+ const Dims<4>& input1_dims,
+ const T* input2_data,
+ const Dims<4>& input2_dims, bool* output_data,
+ const Dims<4>& output_dims) {
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ F(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+template <typename T, ComparisonFn<T> F>
+inline void BroadcastComparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 input2_multiplier, int input2_shift,
+ bool* output_data, const Dims<4>& output_dims) {
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ const int32 input1_val =
+ input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ const int32 input2_val =
+ input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ const int32 shifted_input1_val = input1_val * (1 << left_shift);
+ const int32 shifted_input2_val = input2_val * (1 << left_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input1_val, input1_multiplier, input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_input2_val, input2_multiplier, input2_shift);
output_data[Offset(output_dims, c, x, y, b)] =
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] <
- input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ F(scaled_input1_val, scaled_input2_val);
}
}
}
}
}
+#define TFLITE_COMPARISON_OP(name) \
+ template <typename T> \
+ inline void name(const T* input1_data, const Dims<4>& input1_dims, \
+ const T* input2_data, const Dims<4>& input2_dims, \
+ bool* output_data, const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name); \
+ Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
+ BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, \
+ input1_shift, input2_data, input2_dims, \
+ input2_offset, input2_multiplier, \
+ input2_shift, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
+ const Dims<4>& input2_dims, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
+ BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
+ BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, \
+ input1_shift, input2_data, input2_dims, \
+ input2_offset, input2_multiplier, \
+ input2_shift, output_data, output_dims); \
+ }
+TFLITE_COMPARISON_OP(Greater);
+TFLITE_COMPARISON_OP(GreaterEqual);
+TFLITE_COMPARISON_OP(Less);
+TFLITE_COMPARISON_OP(LessEqual);
+#undef TFLITE_COMPARISON_OP
+
} // namespace reference_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index a6ea874546..40855891a6 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -80,7 +80,10 @@ TfLiteRegistration* Register_PRELU();
TfLiteRegistration* Register_MAXIMUM();
TfLiteRegistration* Register_MINIMUM();
TfLiteRegistration* Register_ARG_MAX();
+TfLiteRegistration* Register_GREATER();
+TfLiteRegistration* Register_GREATER_EQUAL();
TfLiteRegistration* Register_LESS();
+TfLiteRegistration* Register_LESS_EQUAL();
TfLiteRegistration* Register_FLOOR();
TfLiteRegistration* Register_NEG();
@@ -144,7 +147,10 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
+ AddBuiltin(BuiltinOperator_GREATER, Register_GREATER());
+ AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL());
AddBuiltin(BuiltinOperator_LESS, Register_LESS());
+ AddBuiltin(BuiltinOperator_LESS_EQUAL, Register_LESS_EQUAL());
AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR());
AddBuiltin(BuiltinOperator_NEG, Register_NEG());
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 6253570fa2..21c2181377 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -672,7 +672,10 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_LESS: {
+ case BuiltinOperator_GREATER:
+ case BuiltinOperator_GREATER_EQUAL:
+ case BuiltinOperator_LESS:
+ case BuiltinOperator_LESS_EQUAL: {
break;
}
case BuiltinOperator_DELEGATE: {
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index b4c46917bf..e903af87b7 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -372,7 +372,10 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_MAXIMUM:
case tflite::BuiltinOperator_MINIMUM:
case tflite::BuiltinOperator_ARG_MAX:
+ case tflite::BuiltinOperator_GREATER:
+ case tflite::BuiltinOperator_GREATER_EQUAL:
case tflite::BuiltinOperator_LESS:
+ case tflite::BuiltinOperator_LESS_EQUAL:
case tflite::BuiltinOperator_NEG:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 84ff3b16bd..9409e76233 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -138,6 +138,9 @@ enum BuiltinOperator : byte {
LESS = 58,
NEG = 59,
PADV2 = 60,
+ GREATER = 61,
+ GREATER_EQUAL = 62,
+ LESS_EQUAL = 63,
}
// Options for the builtin operators.
@@ -183,7 +186,10 @@ union BuiltinOptions {
DequantizeOptions,
MaximumMinimumOptions,
ArgMaxOptions,
+ GreaterOptions,
+ GreaterEqualOptions,
LessOptions,
+ LessEqualOptions,
NegOptions,
}
@@ -410,9 +416,18 @@ table ArgMaxOptions {
output_type : TensorType;
}
+table GreaterOptions {
+}
+
+table GreaterEqualOptions {
+}
+
table LessOptions {
}
+table LessEqualOptions {
+}
+
table NegOptions {
}
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 8855e4ad58..ae3b33063e 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -154,9 +154,18 @@ struct MaximumMinimumOptionsT;
struct ArgMaxOptions;
struct ArgMaxOptionsT;
+struct GreaterOptions;
+struct GreaterOptionsT;
+
+struct GreaterEqualOptions;
+struct GreaterEqualOptionsT;
+
struct LessOptions;
struct LessOptionsT;
+struct LessEqualOptions;
+struct LessEqualOptionsT;
+
struct NegOptions;
struct NegOptionsT;
@@ -280,11 +289,14 @@ enum BuiltinOperator {
BuiltinOperator_LESS = 58,
BuiltinOperator_NEG = 59,
BuiltinOperator_PADV2 = 60,
+ BuiltinOperator_GREATER = 61,
+ BuiltinOperator_GREATER_EQUAL = 62,
+ BuiltinOperator_LESS_EQUAL = 63,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_PADV2
+ BuiltinOperator_MAX = BuiltinOperator_LESS_EQUAL
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[60] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[63] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -345,7 +357,10 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[60] {
BuiltinOperator_MINIMUM,
BuiltinOperator_LESS,
BuiltinOperator_NEG,
- BuiltinOperator_PADV2
+ BuiltinOperator_PADV2,
+ BuiltinOperator_GREATER,
+ BuiltinOperator_GREATER_EQUAL,
+ BuiltinOperator_LESS_EQUAL
};
return values;
}
@@ -413,6 +428,9 @@ inline const char **EnumNamesBuiltinOperator() {
"LESS",
"NEG",
"PADV2",
+ "GREATER",
+ "GREATER_EQUAL",
+ "LESS_EQUAL",
nullptr
};
return names;
@@ -466,13 +484,16 @@ enum BuiltinOptions {
BuiltinOptions_DequantizeOptions = 39,
BuiltinOptions_MaximumMinimumOptions = 40,
BuiltinOptions_ArgMaxOptions = 41,
- BuiltinOptions_LessOptions = 42,
- BuiltinOptions_NegOptions = 43,
+ BuiltinOptions_GreaterOptions = 42,
+ BuiltinOptions_GreaterEqualOptions = 43,
+ BuiltinOptions_LessOptions = 44,
+ BuiltinOptions_LessEqualOptions = 45,
+ BuiltinOptions_NegOptions = 46,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_NegOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[44] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[47] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -516,7 +537,10 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[44] {
BuiltinOptions_DequantizeOptions,
BuiltinOptions_MaximumMinimumOptions,
BuiltinOptions_ArgMaxOptions,
+ BuiltinOptions_GreaterOptions,
+ BuiltinOptions_GreaterEqualOptions,
BuiltinOptions_LessOptions,
+ BuiltinOptions_LessEqualOptions,
BuiltinOptions_NegOptions
};
return values;
@@ -566,7 +590,10 @@ inline const char **EnumNamesBuiltinOptions() {
"DequantizeOptions",
"MaximumMinimumOptions",
"ArgMaxOptions",
+ "GreaterOptions",
+ "GreaterEqualOptions",
"LessOptions",
+ "LessEqualOptions",
"NegOptions",
nullptr
};
@@ -746,10 +773,22 @@ template<> struct BuiltinOptionsTraits<ArgMaxOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_ArgMaxOptions;
};
+template<> struct BuiltinOptionsTraits<GreaterOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_GreaterOptions;
+};
+
+template<> struct BuiltinOptionsTraits<GreaterEqualOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_GreaterEqualOptions;
+};
+
template<> struct BuiltinOptionsTraits<LessOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_LessOptions;
};
+template<> struct BuiltinOptionsTraits<LessEqualOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_LessEqualOptions;
+};
+
template<> struct BuiltinOptionsTraits<NegOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_NegOptions;
};
@@ -1113,6 +1152,22 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_ArgMaxOptions ?
reinterpret_cast<const ArgMaxOptionsT *>(value) : nullptr;
}
+ GreaterOptionsT *AsGreaterOptions() {
+ return type == BuiltinOptions_GreaterOptions ?
+ reinterpret_cast<GreaterOptionsT *>(value) : nullptr;
+ }
+ const GreaterOptionsT *AsGreaterOptions() const {
+ return type == BuiltinOptions_GreaterOptions ?
+ reinterpret_cast<const GreaterOptionsT *>(value) : nullptr;
+ }
+ GreaterEqualOptionsT *AsGreaterEqualOptions() {
+ return type == BuiltinOptions_GreaterEqualOptions ?
+ reinterpret_cast<GreaterEqualOptionsT *>(value) : nullptr;
+ }
+ const GreaterEqualOptionsT *AsGreaterEqualOptions() const {
+ return type == BuiltinOptions_GreaterEqualOptions ?
+ reinterpret_cast<const GreaterEqualOptionsT *>(value) : nullptr;
+ }
LessOptionsT *AsLessOptions() {
return type == BuiltinOptions_LessOptions ?
reinterpret_cast<LessOptionsT *>(value) : nullptr;
@@ -1121,6 +1176,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_LessOptions ?
reinterpret_cast<const LessOptionsT *>(value) : nullptr;
}
+ LessEqualOptionsT *AsLessEqualOptions() {
+ return type == BuiltinOptions_LessEqualOptions ?
+ reinterpret_cast<LessEqualOptionsT *>(value) : nullptr;
+ }
+ const LessEqualOptionsT *AsLessEqualOptions() const {
+ return type == BuiltinOptions_LessEqualOptions ?
+ reinterpret_cast<const LessEqualOptionsT *>(value) : nullptr;
+ }
NegOptionsT *AsNegOptions() {
return type == BuiltinOptions_NegOptions ?
reinterpret_cast<NegOptionsT *>(value) : nullptr;
@@ -4056,6 +4119,86 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(
flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct GreaterOptionsT : public flatbuffers::NativeTable {
+ typedef GreaterOptions TableType;
+ GreaterOptionsT() {
+ }
+};
+
+struct GreaterOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef GreaterOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ GreaterOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(GreaterOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<GreaterOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct GreaterOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit GreaterOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ GreaterOptionsBuilder &operator=(const GreaterOptionsBuilder &);
+ flatbuffers::Offset<GreaterOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<GreaterOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<GreaterOptions> CreateGreaterOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ GreaterOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<GreaterOptions> CreateGreaterOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct GreaterEqualOptionsT : public flatbuffers::NativeTable {
+ typedef GreaterEqualOptions TableType;
+ GreaterEqualOptionsT() {
+ }
+};
+
+struct GreaterEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef GreaterEqualOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ GreaterEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(GreaterEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<GreaterEqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct GreaterEqualOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit GreaterEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ GreaterEqualOptionsBuilder &operator=(const GreaterEqualOptionsBuilder &);
+ flatbuffers::Offset<GreaterEqualOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<GreaterEqualOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<GreaterEqualOptions> CreateGreaterEqualOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ GreaterEqualOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<GreaterEqualOptions> CreateGreaterEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct LessOptionsT : public flatbuffers::NativeTable {
typedef LessOptions TableType;
LessOptionsT() {
@@ -4096,6 +4239,46 @@ inline flatbuffers::Offset<LessOptions> CreateLessOptions(
flatbuffers::Offset<LessOptions> CreateLessOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct LessEqualOptionsT : public flatbuffers::NativeTable {
+ typedef LessEqualOptions TableType;
+ LessEqualOptionsT() {
+ }
+};
+
+struct LessEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef LessEqualOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ LessEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(LessEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<LessEqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct LessEqualOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit LessEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ LessEqualOptionsBuilder &operator=(const LessEqualOptionsBuilder &);
+ flatbuffers::Offset<LessEqualOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<LessEqualOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<LessEqualOptions> CreateLessEqualOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ LessEqualOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<LessEqualOptions> CreateLessEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct NegOptionsT : public flatbuffers::NativeTable {
typedef NegOptions TableType;
NegOptionsT() {
@@ -4376,9 +4559,18 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const ArgMaxOptions *builtin_options_as_ArgMaxOptions() const {
return builtin_options_type() == BuiltinOptions_ArgMaxOptions ? static_cast<const ArgMaxOptions *>(builtin_options()) : nullptr;
}
+ const GreaterOptions *builtin_options_as_GreaterOptions() const {
+ return builtin_options_type() == BuiltinOptions_GreaterOptions ? static_cast<const GreaterOptions *>(builtin_options()) : nullptr;
+ }
+ const GreaterEqualOptions *builtin_options_as_GreaterEqualOptions() const {
+ return builtin_options_type() == BuiltinOptions_GreaterEqualOptions ? static_cast<const GreaterEqualOptions *>(builtin_options()) : nullptr;
+ }
const LessOptions *builtin_options_as_LessOptions() const {
return builtin_options_type() == BuiltinOptions_LessOptions ? static_cast<const LessOptions *>(builtin_options()) : nullptr;
}
+ const LessEqualOptions *builtin_options_as_LessEqualOptions() const {
+ return builtin_options_type() == BuiltinOptions_LessEqualOptions ? static_cast<const LessEqualOptions *>(builtin_options()) : nullptr;
+ }
const NegOptions *builtin_options_as_NegOptions() const {
return builtin_options_type() == BuiltinOptions_NegOptions ? static_cast<const NegOptions *>(builtin_options()) : nullptr;
}
@@ -4572,10 +4764,22 @@ template<> inline const ArgMaxOptions *Operator::builtin_options_as<ArgMaxOption
return builtin_options_as_ArgMaxOptions();
}
+template<> inline const GreaterOptions *Operator::builtin_options_as<GreaterOptions>() const {
+ return builtin_options_as_GreaterOptions();
+}
+
+template<> inline const GreaterEqualOptions *Operator::builtin_options_as<GreaterEqualOptions>() const {
+ return builtin_options_as_GreaterEqualOptions();
+}
+
template<> inline const LessOptions *Operator::builtin_options_as<LessOptions>() const {
return builtin_options_as_LessOptions();
}
+template<> inline const LessEqualOptions *Operator::builtin_options_as<LessEqualOptions>() const {
+ return builtin_options_as_LessEqualOptions();
+}
+
template<> inline const NegOptions *Operator::builtin_options_as<NegOptions>() const {
return builtin_options_as_NegOptions();
}
@@ -6206,6 +6410,52 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatB
_output_type);
}
+inline GreaterOptionsT *GreaterOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new GreaterOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void GreaterOptions::UnPackTo(GreaterOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<GreaterOptions> GreaterOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateGreaterOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<GreaterOptions> CreateGreaterOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GreaterOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateGreaterOptions(
+ _fbb);
+}
+
+inline GreaterEqualOptionsT *GreaterEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new GreaterEqualOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void GreaterEqualOptions::UnPackTo(GreaterEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<GreaterEqualOptions> GreaterEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateGreaterEqualOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<GreaterEqualOptions> CreateGreaterEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GreaterEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateGreaterEqualOptions(
+ _fbb);
+}
+
inline LessOptionsT *LessOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new LessOptionsT();
UnPackTo(_o, _resolver);
@@ -6229,6 +6479,29 @@ inline flatbuffers::Offset<LessOptions> CreateLessOptions(flatbuffers::FlatBuffe
_fbb);
}
+inline LessEqualOptionsT *LessEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new LessEqualOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void LessEqualOptions::UnPackTo(LessEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<LessEqualOptions> LessEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateLessEqualOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<LessEqualOptions> CreateLessEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LessEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateLessEqualOptions(
+ _fbb);
+}
+
inline NegOptionsT *NegOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new NegOptionsT();
UnPackTo(_o, _resolver);
@@ -6599,10 +6872,22 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const ArgMaxOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_GreaterOptions: {
+ auto ptr = reinterpret_cast<const GreaterOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_GreaterEqualOptions: {
+ auto ptr = reinterpret_cast<const GreaterEqualOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
case BuiltinOptions_LessOptions: {
auto ptr = reinterpret_cast<const LessOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_LessEqualOptions: {
+ auto ptr = reinterpret_cast<const LessEqualOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
case BuiltinOptions_NegOptions: {
auto ptr = reinterpret_cast<const NegOptions *>(obj);
return verifier.VerifyTable(ptr);
@@ -6789,10 +7074,22 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const ArgMaxOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_GreaterOptions: {
+ auto ptr = reinterpret_cast<const GreaterOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_GreaterEqualOptions: {
+ auto ptr = reinterpret_cast<const GreaterEqualOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
case BuiltinOptions_LessOptions: {
auto ptr = reinterpret_cast<const LessOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_LessEqualOptions: {
+ auto ptr = reinterpret_cast<const LessEqualOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
case BuiltinOptions_NegOptions: {
auto ptr = reinterpret_cast<const NegOptions *>(obj);
return ptr->UnPack(resolver);
@@ -6967,10 +7264,22 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const ArgMaxOptionsT *>(value);
return CreateArgMaxOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_GreaterOptions: {
+ auto ptr = reinterpret_cast<const GreaterOptionsT *>(value);
+ return CreateGreaterOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_GreaterEqualOptions: {
+ auto ptr = reinterpret_cast<const GreaterEqualOptionsT *>(value);
+ return CreateGreaterEqualOptions(_fbb, ptr, _rehasher).Union();
+ }
case BuiltinOptions_LessOptions: {
auto ptr = reinterpret_cast<const LessOptionsT *>(value);
return CreateLessOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_LessEqualOptions: {
+ auto ptr = reinterpret_cast<const LessEqualOptionsT *>(value);
+ return CreateLessEqualOptions(_fbb, ptr, _rehasher).Union();
+ }
case BuiltinOptions_NegOptions: {
auto ptr = reinterpret_cast<const NegOptionsT *>(value);
return CreateNegOptions(_fbb, ptr, _rehasher).Union();
@@ -7145,10 +7454,22 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new ArgMaxOptionsT(*reinterpret_cast<ArgMaxOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_GreaterOptions: {
+ value = new GreaterOptionsT(*reinterpret_cast<GreaterOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_GreaterEqualOptions: {
+ value = new GreaterEqualOptionsT(*reinterpret_cast<GreaterEqualOptionsT *>(u.value));
+ break;
+ }
case BuiltinOptions_LessOptions: {
value = new LessOptionsT(*reinterpret_cast<LessOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_LessEqualOptions: {
+ value = new LessEqualOptionsT(*reinterpret_cast<LessEqualOptionsT *>(u.value));
+ break;
+ }
case BuiltinOptions_NegOptions: {
value = new NegOptionsT(*reinterpret_cast<NegOptionsT *>(u.value));
break;
@@ -7365,11 +7686,26 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_GreaterOptions: {
+ auto ptr = reinterpret_cast<GreaterOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_GreaterEqualOptions: {
+ auto ptr = reinterpret_cast<GreaterEqualOptionsT *>(value);
+ delete ptr;
+ break;
+ }
case BuiltinOptions_LessOptions: {
auto ptr = reinterpret_cast<LessOptionsT *>(value);
delete ptr;
break;
}
+ case BuiltinOptions_LessEqualOptions: {
+ auto ptr = reinterpret_cast<LessEqualOptionsT *>(value);
+ delete ptr;
+ break;
+ }
case BuiltinOptions_NegOptions: {
auto ptr = reinterpret_cast<NegOptionsT *>(value);
delete ptr;
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index ca1390fdeb..6749e63552 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -33,9 +33,12 @@ gen_zipped_test_files(
"fused_batch_norm.zip",
"gather.zip",
"global_batch_norm.zip",
+ "greater.zip",
+ "greater_equal.zip",
"l2_pool.zip",
"l2norm.zip",
"less.zip",
+ "less_equal.zip",
"local_response_norm.zip",
"log_softmax.zip",
"max_pool.zip",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 6fe0f491d0..7a658d43d3 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -2055,6 +2055,74 @@ def make_arg_max_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_greater_tests(zip_path):
+ """Make a set of tests to do greater."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the greater op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input2",
+ shape=parameters["input_shape_pair"][1])
+ out = tf.greater(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_greater_equal_tests(zip_path):
+ """Make a set of tests to do greater_equal."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the greater_equal op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input2",
+ shape=parameters["input_shape_pair"][1])
+ out = tf.greater_equal(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_less_tests(zip_path):
"""Make a set of tests to do less."""
@@ -2089,6 +2157,40 @@ def make_less_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_less_equal_tests(zip_path):
+ """Make a set of tests to do less_equal."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the less_equal op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input2",
+ shape=parameters["input_shape_pair"][1])
+ out = tf.less_equal(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_floor_tests(zip_path):
"""Make a set of tests to do floor."""
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 96681952c9..2ce14f3b38 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -258,9 +258,12 @@ INSTANTIATE_TESTS(fully_connected)
INSTANTIATE_TESTS(fused_batch_norm)
INSTANTIATE_TESTS(gather)
INSTANTIATE_TESTS(global_batch_norm)
+INSTANTIATE_TESTS(greater)
+INSTANTIATE_TESTS(greater_equal)
INSTANTIATE_TESTS(l2_pool)
INSTANTIATE_TESTS(l2norm)
INSTANTIATE_TESTS(less)
+INSTANTIATE_TESTS(less_equal)
INSTANTIATE_TESTS(local_response_norm)
INSTANTIATE_TESTS(log_softmax)
INSTANTIATE_TESTS(max_pool)
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 9e899cf977..53df1987b3 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1702,6 +1702,19 @@ void ConvertRandomUniformOperator(const Model& model,
(*new_op->mutable_attr())["seed2"].set_i(src_op.seed2);
}
+void ConvertComparisonOperator(const Model& model, const Operator& src_op,
+ const char* op_name,
+ GraphDef* tensorflow_graph) {
+ auto* comparison_op = tensorflow_graph->add_node();
+ comparison_op->set_op(op_name);
+ comparison_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *comparison_op->add_input() = src_op.inputs[0];
+ *comparison_op->add_input() = src_op.inputs[1];
+ const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*comparison_op->mutable_attr())["T"].set_type(data_type);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -1893,6 +1906,14 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertRandomUniformOperator(
model, static_cast<const RandomUniformOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowGreater) {
+ ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) {
+ ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowLess) {
+ ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowLessEqual) {
+ ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 9b0e232132..a081abea55 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -499,8 +499,8 @@ void ProcessTensorFlowReshapeOperator(Model* model,
<< op->outputs[0] << "\". Are your input shapes correct?";
}
-void ProcessSimpleOperator(Model* model, Operator* op) {
- const auto& input_array = model->GetArray(op->inputs[0]);
+void ProcessSimpleOperator(Model* model, Operator* op, int input_index) {
+ const auto& input_array = model->GetArray(op->inputs[input_index]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -1499,7 +1499,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kCast:
case OperatorType::kFloor:
case OperatorType::kExp:
- ProcessSimpleOperator(model, op);
+ ProcessSimpleOperator(model, op, 0);
break;
case OperatorType::kGather:
ProcessGatherOperator(model, static_cast<GatherOperator*>(op));
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 58e214b76b..a1ca7371c8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -55,7 +55,11 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kStridedSlice ||
type == OperatorType::kDepthToSpace ||
type == OperatorType::kLstmCell || type == OperatorType::kGather ||
- type == OperatorType::kTranspose || type == OperatorType::kMean;
+ type == OperatorType::kTranspose || type == OperatorType::kMean ||
+ type == OperatorType::kTensorFlowGreater ||
+ type == OperatorType::kTensorFlowGreaterEqual ||
+ type == OperatorType::kTensorFlowLess ||
+ type == OperatorType::kTensorFlowLessEqual;
}
const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
@@ -257,8 +261,7 @@ bool ChooseHardcodedQuantizationForOperatorOutput(
IsExactlyRepresentable(0., *quantized_data_type, *quantization_params));
return true;
}
- if ((op.type == OperatorType::kLogistic) ||
- (op.type == OperatorType::kSoftmax)) {
+ if (op.type == OperatorType::kLogistic || op.type == OperatorType::kSoftmax) {
// Logistic and Softmax have range: [0, 1].
//
// For Logistic, 0.5 should be exactly representable, as implementations
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index df784a2a76..a008e63351 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -915,9 +915,16 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
"MAXIMUM", OperatorType::kTensorFlowMaximum));
ops.emplace_back(new SimpleOperator<TensorFlowMinimumOperator>(
"MINIMUM", OperatorType::kTensorFlowMinimum));
+ ops.emplace_back(new SimpleOperator<TensorFlowGreaterOperator>(
+ "GREATER", OperatorType::kTensorFlowGreater));
+ ops.emplace_back(new SimpleOperator<TensorFlowGreaterEqualOperator>(
+ "GREATER_EQUAL", OperatorType::kTensorFlowGreaterEqual));
ops.emplace_back(new SimpleOperator<TensorFlowLessOperator>(
"LESS", OperatorType::kTensorFlowLess));
+ ops.emplace_back(new SimpleOperator<TensorFlowLessEqualOperator>(
+ "LESS_EQUAL", OperatorType::kTensorFlowLessEqual));
ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg));
+
return ops;
}
} // namespace