diff options
Diffstat (limited to 'tensorflow/core/kernels/reverse_op_test.cc')
-rw-r--r-- | tensorflow/core/kernels/reverse_op_test.cc | 101 |
1 files changed, 101 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/reverse_op_test.cc b/tensorflow/core/kernels/reverse_op_test.cc new file mode 100644 index 0000000000..d41c36e693 --- /dev/null +++ b/tensorflow/core/kernels/reverse_op_test.cc @@ -0,0 +1,101 @@ +#include <functional> +#include <memory> +#include <vector> + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/tensor.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +class ReverseOpTest : public OpsTestBase { + protected: + void MakeOp(DataType data_type) { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("myop", "Reverse") + .Input(FakeInput(data_type)) + .Input(FakeInput()) + .Attr("T", data_type) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(ReverseOpTest, Reverse_0) { + MakeOp(DT_FLOAT); + AddInputFromArray<float>(TensorShape({}), {3}); + AddInputFromArray<bool>(TensorShape({}), {true}); + ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({})); + expected.scalar<float>() = expected.scalar<float>().constant(3.f); + test::ExpectTensorEqual<float>(expected, *output); +} + +TEST_F(ReverseOpTest, Reverse_234) { + MakeOp(DT_FLOAT); + + // Feed and run + // [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + AddInputFromArray<float>(TensorShape({2, 3, 4}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23}); + AddInputFromArray<bool>(TensorShape({3}), {true, false, true}); + + ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor* params_tensor = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3, 4})); + // Should become + // [[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]] + // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]] + test::FillValues<float>( + &expected, {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 3, 2, 1, 0, 7, + 6, 5, 4, 11, 10, 9, 8}); + test::ExpectTensorEqual<float>(expected, *params_tensor); +} + +TEST_F(ReverseOpTest, Reverse_1234) { + MakeOp(DT_FLOAT); + + // Feed and run + // [[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]] + AddInputFromArray<float>(TensorShape({1, 2, 3, 4}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23}); + AddInputFromArray<bool>(TensorShape({4}), {true, true, false, true}); + + ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor* params_tensor = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 3, 4})); + // Should become + // [[[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]] + // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]]] + test::FillValues<float>( + &expected, {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 3, 2, 1, 0, 7, + 6, 5, 4, 11, 10, 9, 8}); + test::ExpectTensorEqual<float>(expected, *params_tensor); +} + +} // namespace +} // namespace tensorflow |