diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/kernels/gather_op_test.cc |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/core/kernels/gather_op_test.cc')
-rw-r--r-- | tensorflow/core/kernels/gather_op_test.cc | 213 |
1 files changed, 213 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc new file mode 100644 index 0000000000..d7410169e1 --- /dev/null +++ b/tensorflow/core/kernels/gather_op_test.cc @@ -0,0 +1,213 @@ +#include <functional> +#include <memory> +#include <vector> + +#include <gtest/gtest.h> +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace { + +class GatherOpTest : public OpsTestBase { + protected: + void MakeOp(DataType index_type) { + RequireDefaultOps(); + ASSERT_OK(NodeDefBuilder("myop", "Gather") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(index_type)) + .Finalize(node_def())); + ASSERT_OK(InitOp()); + } +}; + +TEST_F(GatherOpTest, ScalarIndices) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 3, 4}); + AddInputFromArray<int32>(TensorShape({}), {3}); + ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({})); + test::FillValues<float>(&expected, {3}); + test::ExpectTensorEqual<float>(expected, *GetOutput(0)); +} + +TEST_F(GatherOpTest, Simple_TwoD32) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray<float>(TensorShape({5, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); + AddInputFromArray<int32>(TensorShape({4}), {0, 4, 0, 2}); + ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 3})); + test::FillValues<float>(&expected, {0, 1, 2, 12, 13, 14, 0, 1, 2, 6, 7, 8}); + test::ExpectTensorEqual<float>(expected, *GetOutput(0)); +} + +TEST_F(GatherOpTest, Simple_TwoD64) { + MakeOp(DT_INT64); + + // Feed and run + AddInputFromArray<float>(TensorShape({5, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); + AddInputFromArray<int64>(TensorShape({4}), {0, 4, 0, 2}); + ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 3})); + test::FillValues<float>(&expected, {0, 1, 2, 12, 13, 14, 0, 1, 2, 6, 7, 8}); + test::ExpectTensorEqual<float>(expected, *GetOutput(0)); +} + +TEST_F(GatherOpTest, HighRank) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray<float>(TensorShape({4}), {0, 1, 2, 3}); + AddInputFromArray<int32>(TensorShape({2, 3}), {1, 2, 0, 2, 3, 0}); + ASSERT_OK(RunOpKernel()); + + // Check the output + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + test::FillValues<float>(&expected, {1, 2, 0, 2, 3, 0}); + test::ExpectTensorEqual<float>(expected, *GetOutput(0)); +} + +TEST_F(GatherOpTest, Error_IndexOutOfRange) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray<float>(TensorShape({5, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); + AddInputFromArray<int32>(TensorShape({4}), {0, 4, 99, 2}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("Index 99 at offset 2 in Tindices is out of range")) + << s; +} + +class GatherOpForBenchmark : public GatherOpTest { + public: + void TestBody() override { // not used } + } + void PublicMakeOp(DataType index_type) { MakeOp(index_type); } +}; + +static const int kSorted = 0x8000; // Mask for arg to specify sorting vs. not + +template <typename Index> +void BM_Gather(int iters, int arg) { + testing::StopTiming(); + + bool sorted = ((arg & kSorted) != 0); + int dim = arg & ~kSorted; + + GatherOpForBenchmark t; + t.PublicMakeOp(DataTypeToEnum<Index>::v()); + // Use a 512 MB table, regardless of dim + const int kRows = ((1 << 29) / sizeof(float)) / dim; + std::vector<float> data(kRows * dim, 1.0f); + t.AddInputFromArray<float>(TensorShape({kRows, dim}), data); + const int kLookups = 2000; + const int kBatches = 1000000 / kLookups; + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + std::vector<std::vector<Index>> all_ids(kBatches); + for (int i = 0; i < kBatches; ++i) { + std::vector<Index>* ids = &all_ids[i]; + ids->resize(kLookups); + for (int j = 0; j < kLookups; ++j) { + (*ids)[j] = rnd.Uniform(kRows); + } + if (sorted) { + sort(ids->begin(), ids->end()); + } + } + + t.AddInput<Index>(TensorShape({kLookups}), [](int i) { return 0; }); + if (sorted) { + testing::SetLabel("sorted by id"); + } + testing::BytesProcessed(static_cast<int64>(iters) * kLookups * dim * + sizeof(float)); + testing::StartTiming(); + while (--iters > 0) { + const std::vector<Index>& b = all_ids[iters % kBatches]; + TensorValue input = t.mutable_input(1); + gtl::MutableArraySlice<Index> slice(&input->vec<Index>()(0), + input->NumElements()); + for (int i = 0; i < kLookups; i++) { + slice[i] = b[i]; + } + Status s = t.RunOpKernel(); + } +} + +static void BM_Gather32(int iters, int arg) { BM_Gather<int32>(iters, arg); } + +static void BM_Gather64(int iters, int arg) { BM_Gather<int64>(iters, arg); } + +BENCHMARK(BM_Gather32) + ->Arg(10) + ->Arg(10 | kSorted) + ->Arg(20) + ->Arg(40) + ->Arg(63) + ->Arg(63 | kSorted) + ->Arg(64) + ->Arg(64 | kSorted) + ->Arg(65) + ->Arg(65 | kSorted) + ->Arg(100) + ->Arg(100 | kSorted) + ->Arg(127) + ->Arg(127 | kSorted) + ->Arg(128) + ->Arg(128 | kSorted) + ->Arg(129) + ->Arg(129 | kSorted) + ->Arg(1000) + ->Arg(1000 | kSorted); + +BENCHMARK(BM_Gather64) + ->Arg(10) + ->Arg(10 | kSorted) + ->Arg(20) + ->Arg(40) + ->Arg(63) + ->Arg(63 | kSorted) + ->Arg(64) + ->Arg(64 | kSorted) + ->Arg(65) + ->Arg(65 | kSorted) + ->Arg(100) + ->Arg(100 | kSorted) + ->Arg(127) + ->Arg(127 | kSorted) + ->Arg(128) + ->Arg(128 | kSorted) + ->Arg(129) + ->Arg(129 | kSorted) + ->Arg(1000) + ->Arg(1000 | kSorted); + +} // namespace +} // namespace tensorflow |