aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/gather_op_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-03-01 14:04:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-01 16:35:05 -0800
commitf16aabc6d9c1f11d1b2c393544999d1ac960a80b (patch)
treea8f1c43f6035035cf1b178d23ede46f733087a49 /tensorflow/core/kernels/gather_op_test.cc
parentdd0a68cb77167890ac5a939020cdf72b234c395e (diff)
Adds GPU kernel for gather ops.
Change: 116048950
Diffstat (limited to 'tensorflow/core/kernels/gather_op_test.cc')
-rw-r--r--tensorflow/core/kernels/gather_op_test.cc138
1 files changed, 41 insertions, 97 deletions
diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc
index af8bfc432f..b183510445 100644
--- a/tensorflow/core/kernels/gather_op_test.cc
+++ b/tensorflow/core/kernels/gather_op_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -25,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -118,110 +120,52 @@ TEST_F(GatherOpTest, Error_IndexOutOfRange) {
<< s;
}
-class GatherOpForBenchmark : public GatherOpTest {
- public:
- void TestBody() override { // not used }
- }
- void PublicMakeOp(DataType index_type) { MakeOp(index_type); }
-};
-
-static const int kSorted = 0x8000; // Mask for arg to specify sorting vs. not
-
template <typename Index>
-void BM_Gather(int iters, int arg) {
- testing::StopTiming();
-
- bool sorted = ((arg & kSorted) != 0);
- int dim = arg & ~kSorted;
-
- GatherOpForBenchmark t;
- t.PublicMakeOp(DataTypeToEnum<Index>::v());
- // Use a 512 MB table, regardless of dim
- const int kRows = ((1 << 29) / sizeof(float)) / dim;
- std::vector<float> data(kRows * dim, 1.0f);
- t.AddInputFromArray<float>(TensorShape({kRows, dim}), data);
- const int kLookups = 2000;
- const int kBatches = 1000000 / kLookups;
+static Graph* Gather(int lookups, int dim) {
+ Graph* g = new Graph(OpRegistry::Global());
+ // Always use a 512MB buffer.
+ const int kRows = ((512 << 20) / sizeof(float)) / dim;
+ Tensor params(DT_FLOAT, TensorShape({kRows, dim}));
+ params.flat<float>().setRandom();
+
random::PhiloxRandom philox(301, 17);
random::SimplePhilox rnd(&philox);
- std::vector<std::vector<Index>> all_ids(kBatches);
- for (int i = 0; i < kBatches; ++i) {
- std::vector<Index>* ids = &all_ids[i];
- ids->resize(kLookups);
- for (int j = 0; j < kLookups; ++j) {
- (*ids)[j] = rnd.Uniform(kRows);
- }
- if (sorted) {
- sort(ids->begin(), ids->end());
- }
- }
-
- t.AddInput<Index>(TensorShape({kLookups}), [](int i) { return 0; });
- if (sorted) {
- testing::SetLabel("sorted by id");
+ std::vector<Index> indices_vec;
+ for (int i = 0; i < lookups; i++) {
+ indices_vec.push_back(rnd.Uniform(kRows));
}
- testing::BytesProcessed(static_cast<int64>(iters) * kLookups * dim *
- sizeof(float));
- testing::StartTiming();
- while (--iters > 0) {
- const std::vector<Index>& b = all_ids[iters % kBatches];
- TensorValue input = t.mutable_input(1);
- gtl::MutableArraySlice<Index> slice(&input->vec<Index>()(0),
- input->NumElements());
- for (int i = 0; i < kLookups; i++) {
- slice[i] = b[i];
- }
- Status s = t.RunOpKernel();
+ Tensor indices(DataTypeToEnum<Index>::value, TensorShape({lookups}));
+ for (int i = 0; i < indices_vec.size(); i++) {
+ indices.flat<Index>()(i) = indices_vec[i];
}
+
+ test::graph::Gather(g, test::graph::Constant(g, params),
+ test::graph::Constant(g, indices));
+ return g;
}
-static void BM_Gather32(int iters, int arg) { BM_Gather<int32>(iters, arg); }
-
-static void BM_Gather64(int iters, int arg) { BM_Gather<int64>(iters, arg); }
-
-BENCHMARK(BM_Gather32)
- ->Arg(10)
- ->Arg(10 | kSorted)
- ->Arg(20)
- ->Arg(40)
- ->Arg(63)
- ->Arg(63 | kSorted)
- ->Arg(64)
- ->Arg(64 | kSorted)
- ->Arg(65)
- ->Arg(65 | kSorted)
- ->Arg(100)
- ->Arg(100 | kSorted)
- ->Arg(127)
- ->Arg(127 | kSorted)
- ->Arg(128)
- ->Arg(128 | kSorted)
- ->Arg(129)
- ->Arg(129 | kSorted)
- ->Arg(1000)
- ->Arg(1000 | kSorted);
-
-BENCHMARK(BM_Gather64)
- ->Arg(10)
- ->Arg(10 | kSorted)
- ->Arg(20)
- ->Arg(40)
- ->Arg(63)
- ->Arg(63 | kSorted)
- ->Arg(64)
- ->Arg(64 | kSorted)
- ->Arg(65)
- ->Arg(65 | kSorted)
- ->Arg(100)
- ->Arg(100 | kSorted)
- ->Arg(127)
- ->Arg(127 | kSorted)
- ->Arg(128)
- ->Arg(128 | kSorted)
- ->Arg(129)
- ->Arg(129 | kSorted)
- ->Arg(1000)
- ->Arg(1000 | kSorted);
+#define BM_GATHER(DEVICE, INDEX) \
+ static void BM_##DEVICE##_gather_##INDEX(int iters, int lookups, int dim) { \
+ const int64 tot = static_cast<int64>(iters) * lookups * dim; \
+ testing::ItemsProcessed(tot); \
+ testing::BytesProcessed(tot * sizeof(float)); \
+ testing::UseRealTime(); \
+ test::Benchmark(#DEVICE, Gather<INDEX>(lookups, dim)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_gather_##INDEX) \
+ ->ArgPair(2000, 1) \
+ ->ArgPair(2000, 10) \
+ ->ArgPair(2000, 20) \
+ ->ArgPair(2000, 100) \
+ ->ArgPair(200, 1000) \
+ ->ArgPair(20, 10000) \
+ ->ArgPair(20000, 10) \
+ ->ArgPair(2000, 1)
+
+BM_GATHER(cpu, int32);
+BM_GATHER(gpu, int32);
+BM_GATHER(cpu, int64);
+BM_GATHER(gpu, int64);
} // namespace
} // namespace tensorflow