aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/comparisons_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-07 12:37:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 16:58:51 -0700
commit6f3a890d91e6dbeb811aed23d0eb59abaa8c469f (patch)
treef05bad7a6c979613b43c495a042a675cbc59a13a /tensorflow/contrib/lite/kernels/comparisons_test.cc
parentc3fef21c4ddf34fd68ab2cd44b0be497b5303b4e (diff)
Adding Greater/GreaterEqual/LessEqual ops to complement Less.
PiperOrigin-RevId: 195704492
Diffstat (limited to 'tensorflow/contrib/lite/kernels/comparisons_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons_test.cc207
1 files changed, 203 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc
index da2d7f8589..835d238d36 100644
--- a/tensorflow/contrib/lite/kernels/comparisons_test.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc
@@ -23,6 +23,139 @@ namespace {
using ::testing::ElementsAreArray;
+class GreaterOpModel : public SingleOpModel {
+ public:
+ GreaterOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape,
+ TensorType input_type) {
+ input1_ = AddInput(input_type);
+ input2_ = AddInput(input_type);
+ output_ = AddOutput(TensorType_BOOL);
+ SetBuiltinOp(BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions,
+ CreateGreaterOptions(builder_).Union());
+ BuildInterpreter({input1_shape, input2_shape});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(ComparisonsTest, GreaterFloat) {
+ GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterInt) {
+ GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterBroadcast) {
+ GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterBroadcastTwoD) {
+ GreaterOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false,
+ false, true, false, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+}
+
+class GreaterEqualOpModel : public SingleOpModel {
+ public:
+ GreaterEqualOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape,
+ TensorType input_type) {
+ input1_ = AddInput(input_type);
+ input2_ = AddInput(input_type);
+ output_ = AddOutput(TensorType_BOOL);
+ SetBuiltinOp(BuiltinOperator_GREATER_EQUAL,
+ BuiltinOptions_GreaterEqualOptions,
+ CreateGreaterEqualOptions(builder_).Union());
+ BuildInterpreter({input1_shape, input2_shape});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(ComparisonsTest, GreaterEqualFloat) {
+ GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, true, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterEqualInt) {
+ GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterEqualBroadcast) {
+ GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, GreaterEqualBroadcastTwoD) {
+ GreaterEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false,
+ false, true, true, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+}
+
class LessOpModel : public SingleOpModel {
public:
LessOpModel(std::initializer_list<int> input1_shape,
@@ -47,7 +180,7 @@ class LessOpModel : public SingleOpModel {
int output_;
};
-TEST(ArgMaxOpTest, LessFloat) {
+TEST(ComparisonsTest, LessFloat) {
LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
@@ -57,7 +190,7 @@ TEST(ArgMaxOpTest, LessFloat) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
-TEST(ArgMaxOpTest, LessInt) {
+TEST(ComparisonsTest, LessInt) {
LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 6, 5});
@@ -67,7 +200,7 @@ TEST(ArgMaxOpTest, LessInt) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
-TEST(ArgMaxOpTest, LessBroadcast) {
+TEST(ComparisonsTest, LessBroadcast) {
LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
@@ -77,7 +210,7 @@ TEST(ArgMaxOpTest, LessBroadcast) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
-TEST(ArgMaxOpTest, LessBroadcastTwoD) {
+TEST(ComparisonsTest, LessBroadcastTwoD) {
LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
@@ -88,6 +221,72 @@ TEST(ArgMaxOpTest, LessBroadcastTwoD) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
}
+class LessEqualOpModel : public SingleOpModel {
+ public:
+ LessEqualOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape,
+ TensorType input_type) {
+ input1_ = AddInput(input_type);
+ input2_ = AddInput(input_type);
+ output_ = AddOutput(TensorType_BOOL);
+ SetBuiltinOp(BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions,
+ CreateLessEqualOptions(builder_).Union());
+ BuildInterpreter({input1_shape, input2_shape});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(ComparisonsTest, LessEqualFloat) {
+ LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, LessEqualInt) {
+ LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, LessEqualBroadcast) {
+ LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ComparisonsTest, LessEqualBroadcastTwoD) {
+ LessEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true,
+ true, false, true, false}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+}
+
} // namespace
} // namespace tflite