diff options
author | 2018-08-06 18:12:42 -0700 | |
---|---|---|
committer | 2018-08-06 18:16:44 -0700 | |
commit | 0fc1de719ba0ad6d953f618475c73d35dd27767d (patch) | |
tree | 303aecb6c46c5888d9c155c7676aa4a384755133 /tensorflow/contrib/lite/kernels/elementwise_test.cc | |
parent | 56a82b00f461dc8bc2b3e8e63fa768144795a7b2 (diff) |
Implementation of logical_and logical_not
PiperOrigin-RevId: 207642985
Diffstat (limited to 'tensorflow/contrib/lite/kernels/elementwise_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/elementwise_test.cc | 49 |
1 files changed, 36 insertions, 13 deletions
diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc index ce4c602ee5..b9d7d73c52 100644 --- a/tensorflow/contrib/lite/kernels/elementwise_test.cc +++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc @@ -24,26 +24,40 @@ namespace { using ::testing::ElementsAreArray; -class ElementWiseOpModel : public SingleOpModel { +class ElementWiseOpBaseModel : public SingleOpModel { public: - ElementWiseOpModel(BuiltinOperator op, - std::initializer_list<int> input_shape) { + int input() const { return input_; } + int output() const { return output_; } + + protected: + int input_; + int output_; +}; + +class ElementWiseOpFloatModel : public ElementWiseOpBaseModel { + public: + ElementWiseOpFloatModel(BuiltinOperator op, + std::initializer_list<int> input_shape) { input_ = AddInput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(op, BuiltinOptions_NONE, 0); BuildInterpreter({input_shape}); } +}; - int input() const { return input_; } - int output() const { return output_; } - - private: - int input_; - int output_; +class ElementWiseOpBoolModel : public ElementWiseOpBaseModel { + public: + ElementWiseOpBoolModel(BuiltinOperator op, + std::initializer_list<int> input_shape) { + input_ = AddInput(TensorType_BOOL); + output_ = AddOutput(TensorType_BOOL); + SetBuiltinOp(op, BuiltinOptions_NONE, 0); + BuildInterpreter({input_shape}); + } }; TEST(ElementWise, Sin) { - ElementWiseOpModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); m.PopulateTensor<float>(m.input(), {0, 3.1415926, -3.1415926, 1}); m.Invoke(); EXPECT_THAT(m.ExtractVector<float>(m.output()), @@ -52,7 +66,7 @@ TEST(ElementWise, Sin) { } TEST(ElementWise, Log) { - ElementWiseOpModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); m.PopulateTensor<float>(m.input(), {1, 3.1415926, 1, 1}); m.Invoke(); EXPECT_THAT(m.ExtractVector<float>(m.output()), @@ -61,7 +75,7 @@ TEST(ElementWise, Log) { } TEST(ElementWise, Sqrt) { - ElementWiseOpModel m(BuiltinOperator_SQRT, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_SQRT, {1, 1, 4, 1}); m.PopulateTensor<float>(m.input(), {0, 1, 2, 4}); m.Invoke(); EXPECT_THAT(m.ExtractVector<float>(m.output()), @@ -70,7 +84,7 @@ TEST(ElementWise, Sqrt) { } TEST(ElementWise, Rsqrt) { - ElementWiseOpModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1}); m.PopulateTensor<float>(m.input(), {1, 2, 4, 9}); m.Invoke(); EXPECT_THAT(m.ExtractVector<float>(m.output()), @@ -78,6 +92,15 @@ TEST(ElementWise, Rsqrt) { EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); } +TEST(ElementWise, LogicalNot) { + ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1}); + m.PopulateTensor<bool>(m.input(), {true, false, true, false}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector<bool>(m.output()), + ElementsAreArray({false, true, false, true})); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + } // namespace } // namespace tflite |