diff options
author | Alan Chiao <alanchiao@google.com> | 2018-05-04 10:31:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-04 10:56:20 -0700 |
commit | a5f44b3519627859fb476a9cad1acc354bfa649f (patch) | |
tree | 616e06d92971a8452cbb179ece80d569bf244ef8 /tensorflow/contrib/lite/kernels/neg_test.cc | |
parent | 3db0e545d2460be0392dfcaa304231cd2105648e (diff) |
Implement neg op
PiperOrigin-RevId: 195435079
Diffstat (limited to 'tensorflow/contrib/lite/kernels/neg_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/neg_test.cc | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/neg_test.cc b/tensorflow/contrib/lite/kernels/neg_test.cc new file mode 100644 index 0000000000..3c95ac8cc2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/neg_test.cc @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class NegOpModel : public SingleOpModel { + public: + NegOpModel(const TensorData& input, const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_NEG, BuiltinOptions_NegOptions, + CreateNegOptions(builder_).Union()); + BuildInterpreter({GetShape(input_)}); + } + + template <class T> + void SetInput(std::initializer_list<T> data) { + PopulateTensor<T>(input_, data); + } + + template <class T> + std::vector<T> GetOutput() { + return ExtractVector<T>(output_); + } + + protected: + int input_; + int output_; +}; + +TEST(NegOpModel, NegFloat) { + NegOpModel m({TensorType_FLOAT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); + m.SetInput<float>({-2.0f, -1.0f, 0.f, 1.0f, 2.0f, 3.0f}); + m.Invoke(); + EXPECT_THAT(m.GetOutput<float>(), + ElementsAreArray({2.0f, 1.0f, 0.f, -1.0f, -2.0f, -3.0f})); +} + +TEST(NegOpModel, NegInt32) { + NegOpModel m({TensorType_INT32, {2, 3}}, {TensorType_INT32, {2, 3}}); + m.SetInput<int32>({-2, -1, 0, 1, 2, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput<int32>(), ElementsAreArray({2, 1, 0, -1, -2, -3})); +} + +TEST(NegOpModel, NegInt64) { + NegOpModel m({TensorType_INT64, {2, 3}}, {TensorType_INT64, {2, 3}}); + m.SetInput<int64_t>({-2, -1, 0, 1, 2, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput<int64_t>(), ElementsAreArray({2, 1, 0, -1, -2, -3})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} |