aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-03-01 15:49:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-01 16:35:38 -0800
commit3f8a7616452346ef4de5a0da7b79fdb1cc8d6594 (patch)
tree8be48d0b9dacc43b8bdfb8cb22936336f90a8eba /tensorflow
parent40ccc0aa632ebbd3b34f2f5da5f1ed576d50f57a (diff)
Rollback of "Adds GPU kernel for gather ops."
Change: 116064181
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/graph/testlib.cc9
-rw-r--r--tensorflow/core/graph/testlib.h3
-rw-r--r--tensorflow/core/kernels/gather_op.cc183
-rw-r--r--tensorflow/core/kernels/gather_op.h41
-rw-r--r--tensorflow/core/kernels/gather_op_gpu.cu.cc86
-rw-r--r--tensorflow/core/kernels/gather_op_test.cc138
-rw-r--r--tensorflow/python/kernel_tests/gather_op_test.py15
7 files changed, 168 insertions, 307 deletions
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index f3164009fc..63ac4b46a7 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -351,15 +351,6 @@ Node* BroadcastGradientArgs(Graph* g, Node* s0, Node* s1) {
return ret;
}
-Node* Gather(Graph* g, Node* in0, Node* in1) {
- Node* ret;
- TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Gather")
- .Input(in0)
- .Input(in1)
- .Finalize(g, &ret));
- return ret;
-}
-
void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
} // end namespace graph
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index cb6f2468f2..57f9e224a8 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -155,9 +155,6 @@ Node* Select(Graph* g, Node* c, Node* inx, Node* iny);
// Casts "in" into data type "dst".
Node* Cast(Graph* g, Node* in, DataType dst);
-// Perform gather op on params "in0" with indicies "in1".
-Node* Gather(Graph* g, Node* in0, Node* in1);
-
// Computes the args needed broadcast gradient function.
Node* BroadcastGradientArgs(Graph* g, Node* s0, Node* s1);
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
index d4d8d86ef6..d7a4e20fbd 100644
--- a/tensorflow/core/kernels/gather_op.cc
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -15,7 +15,6 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
-#include "tensorflow/core/kernels/gather_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -27,10 +26,49 @@ limitations under the License.
namespace tensorflow {
-typedef Eigen::ThreadPoolDevice CPUDevice;
-typedef Eigen::GpuDevice GPUDevice;
+namespace {
+// Returns -1 on success or a nonnegative i s.t., indices[i] is bad.
+template <typename T, typename Index, int static_slice_elems>
+Index HandleCopies(const Tensor& params,
+ typename TTypes<Index>::ConstVec indices, Index slice_elems,
+ typename TTypes<T>::Matrix out) {
+ const int N = indices.dimension(0);
+ const auto& params_flat = params.flat_outer_dims<T>();
+ const Index limit = params.dim_size(0);
+ T* out_base = &out(0, 0);
+ const T* params_base = &params_flat(0, 0);
+ if (static_slice_elems >= 0) {
+ // Give compiler static knowledge of the number of elements/bytes
+ CHECK_EQ(static_slice_elems, slice_elems);
+ slice_elems = static_slice_elems;
+ }
+ // Compute slice_bytes here so that static knowledge is available
+ const size_t slice_bytes = slice_elems * sizeof(T);
+ for (int i = 0; i < N; i++) {
+ const int j = i + 1;
+ if (j < N) {
+ port::prefetch<port::PREFETCH_HINT_T0>(&params_flat(indices(j), 0));
+ port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0));
+ }
+ // Grab the index and check its validity. An earlier version of the
+ // code checked it and then grabbed it from memory a second time, which
+ // was a security risk since it could have changed in between.
+ const Index index = indices(i);
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Copy using memcpy if possible, otherwise an Eigen loop
+ if (Allocator::is_simple<T>::value) {
+ memcpy(out_base + i * slice_elems, params_base + index * slice_elems,
+ slice_bytes);
+ } else {
+ out.template chip<0>(i) = params_flat.template chip<0>(index);
+ }
+ }
+ return -1;
+}
+
+} // anonymous namespace
-template <typename Device, typename T, typename Index>
+template <typename T, typename Index>
class GatherOp : public OpKernel {
public:
// QUESTION: It'd be nice to support DT_INT16, DT_UINT8,
@@ -78,13 +116,23 @@ class GatherOp : public OpKernel {
Tensor* out = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
if (N > 0) {
- auto params_flat = params.flat_outer_dims<T>();
auto indices_flat = indices.flat<Index>();
auto out_flat = out->shaped<T, 2>({N, out->NumElements() / N});
+ const int64 slice_size = out->NumElements() / N;
+ Index bad_i;
+
+#define CALL(elems) \
+ bad_i = HandleCopies<T, Index, elems>(params, indices_flat, slice_size, \
+ out_flat)
- functor::Gather<Device, T, Index> functor;
- Index bad_i = functor(c->eigen_device<Device>(), params_flat,
- indices_flat, out_flat);
+ if (slice_size == 10)
+ CALL(10);
+ else if (slice_size == 20)
+ CALL(20);
+ else
+ CALL(-1);
+
+#undef CALL
OP_REQUIRES(
c, bad_i < 0,
@@ -95,120 +143,21 @@ class GatherOp : public OpKernel {
}
};
-namespace functor {
-
-// Helper method to copy using memcpy.
-template <typename T, typename Index, int static_slice_elems>
-Index HandleCopies(typename TTypes<T>::ConstMatrix params,
- typename TTypes<Index>::ConstFlat indices, Index slice_elems,
- typename TTypes<T>::Matrix out) {
- const int first_dim_size = indices.dimension(0);
- const Index limit = params.dimension(0);
- T* out_base = &out(0, 0);
- const T* params_base = &params(0, 0);
- if (static_slice_elems >= 0) {
- // Give compiler static knowledge of the number of elements/bytes
- CHECK_EQ(static_slice_elems, slice_elems);
- slice_elems = static_slice_elems;
- }
- // Compute slice_bytes here so that static knowledge is available
- const size_t slice_bytes = slice_elems * sizeof(T);
- for (int i = 0; i < first_dim_size; i++) {
- const int j = i + 1;
- if (j < first_dim_size) {
- port::prefetch<port::PREFETCH_HINT_T0>(&params(indices(j), 0));
- port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0));
- }
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
- const Index index = indices(i);
- if (!FastBoundsCheck(index, limit)) return i;
- // Copy using memcpy if possible, otherwise an Eigen loop
- if (Allocator::is_simple<T>::value) {
- memcpy(out_base + i * slice_elems, params_base + index * slice_elems,
- slice_bytes);
- } else {
- out.template chip<0>(i) = params.template chip<0>(index);
- }
- }
- return -1;
-}
-
-// Specialization gather functor for CPU.
-template <typename T, typename Index>
-struct Gather<CPUDevice, T, Index> {
- Index operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix params,
- typename TTypes<Index>::ConstFlat indices,
- typename TTypes<T>::Matrix out) {
- const int N = indices.size();
- const int64 slice_size = out.size() / N;
- Index bad_i;
-
-#define CALL(elems) \
- bad_i = HandleCopies<T, Index, elems>(params, indices, slice_size, out)
-
- if (slice_size == 10)
- CALL(10);
- else if (slice_size == 20)
- CALL(20);
- else
- CALL(-1);
-#undef CALL
-
- return bad_i;
- }
-};
-} // namespace functor
-
-#define REGISTER_GATHER_FULL(dev, type, index_type) \
+#define REGISTER_GATHER(type, index_type) \
REGISTER_KERNEL_BUILDER(Name("Gather") \
- .Device(DEVICE_##dev) \
+ .Device(DEVICE_CPU) \
.TypeConstraint<type>("Tparams") \
.TypeConstraint<index_type>("Tindices"), \
- GatherOp<dev##Device, type, index_type>)
-
-#define REGISTER_GATHER_ALL_INDICES(dev, type) \
- REGISTER_GATHER_FULL(dev, type, int32); \
- REGISTER_GATHER_FULL(dev, type, int64)
-
-#define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
-
-TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
-
-#undef REGISTER_GATHER_CPU
-
-#if GOOGLE_CUDA
-// Forward declarations of the functor specializations for GPU.
-namespace functor {
-#define DECLARE_GPU_SPECS_INDEX(T, Index) \
- template <> \
- Index Gather<GPUDevice, T, Index>::operator()( \
- const GPUDevice& d, typename TTypes<T>::ConstMatrix Tparams, \
- typename TTypes<Index>::ConstFlat Tindices, \
- typename TTypes<T>::Matrix Tout); \
- extern template struct Gather<GPUDevice, T, Index>
-
-#define DECLARE_GPU_SPECS(T) \
- DECLARE_GPU_SPECS_INDEX(T, int32); \
- DECLARE_GPU_SPECS_INDEX(T, int64)
-
-TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
-
-#undef DECLARE_GPU_SPECS
-#undef DECLARE_GPU_SPECS_INDEX
-} // namespace functor
-
-// Registration of the GPU implementations.
-#define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
-
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU);
+ GatherOp<type, index_type>)
-#undef REGISTER_GATHER_GPU
+#define REGISTER_GATHER_INT32(type) REGISTER_GATHER(type, int32)
+#define REGISTER_GATHER_INT64(type) REGISTER_GATHER(type, int64)
-#endif // GOOGLE_CUDA
+TF_CALL_ALL_TYPES(REGISTER_GATHER_INT32);
+TF_CALL_ALL_TYPES(REGISTER_GATHER_INT64);
-#undef REGISTER_GATHER_ALL_INDICES
-#undef REGISTER_GATHER_FULL
+#undef REGISTER_GATHER_INT32
+#undef REGISTER_GATHER_INT64
+#undef REGISTER_GATHER
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/gather_op.h b/tensorflow/core/kernels/gather_op.h
deleted file mode 100644
index 4b82815f54..0000000000
--- a/tensorflow/core/kernels/gather_op.h
+++ /dev/null
@@ -1,41 +0,0 @@
-/* Copyright 2015 Google Inc. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_KERNELS_GATHER_OP_H_
-#define TENSORFLOW_KERNELS_GATHER_OP_H_
-// Functor definition for GatherOp, must be compilable by nvcc.
-
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/framework/tensor_types.h"
-
-namespace tensorflow {
-
-class OpKernelContext;
-
-namespace functor {
-template <typename Device, typename T, typename Index>
-struct Gather {
- // Performs gather op on (Tparams, Tindices), writing to Tout.
- // Returns an index to Tindices if the value at that index is out of range.
- // Returns -1 if all values of Tindices are in range.
- Index operator()(const Device& d, typename TTypes<T>::ConstMatrix Tparams,
- typename TTypes<Index>::ConstFlat Tindices,
- typename TTypes<T>::Matrix Tout);
-};
-
-} // namespace functor
-} // namespace tensorflow
-
-#endif // TENSORFLOW_KERNELS_GATHER_OP_H_
diff --git a/tensorflow/core/kernels/gather_op_gpu.cu.cc b/tensorflow/core/kernels/gather_op_gpu.cu.cc
deleted file mode 100644
index 5bc54a90f9..0000000000
--- a/tensorflow/core/kernels/gather_op_gpu.cu.cc
+++ /dev/null
@@ -1,86 +0,0 @@
-/* Copyright 2015 Google Inc. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#if GOOGLE_CUDA
-
-#define EIGEN_USE_GPU
-
-#include "tensorflow/core/kernels/gather_op.h"
-#include "tensorflow/core/framework/register_types.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/cuda_kernel_helper.h"
-
-namespace tensorflow {
-
-typedef Eigen::GpuDevice GPUDevice;
-
-template <typename T, typename Index>
-__global__ void GatherOpKernel(const T* params, const Index* indices, T* out,
- int64 first_dim_size, int64 indices_size,
- int64 out_size) {
- const int32 slice_size = out_size / indices_size;
- CUDA_1D_KERNEL_LOOP(i, out_size) {
- Index indices_i = i / slice_size;
- Index indices_slice_i = i - indices_i * slice_size;
- Index params_first_index = ldg(indices + indices_i);
- if (!(params_first_index >= 0 && params_first_index < first_dim_size)) {
- // Ignore indices that are out of range.
- continue;
- }
- Index params_i = params_first_index * slice_size + indices_slice_i;
- out[i] = ldg(params + params_i);
- }
-}
-
-namespace functor {
-template <typename T, typename Index>
-struct Gather<GPUDevice, T, Index> {
- Index operator()(const GPUDevice& d, typename TTypes<T>::ConstMatrix Tparams,
- typename TTypes<Index>::ConstFlat Tindices,
- typename TTypes<T>::Matrix Tout) {
- const int64 first_dim_size = Tparams.dimension(0);
- const int64 indices_size = Tindices.size();
- const int64 out_size = Tout.size();
- CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d);
- // clang-format off
- GatherOpKernel<T, Index>
- <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
- Tparams.data(), Tindices.data(), Tout.data(), first_dim_size,
- indices_size, out_size);
- // clang-format on
- // TODO(fpmc): enable indices validation on GPU.
- // Right now checking for indicies out of bound in the kernel would
- // require copying code between GPU/CPU, and thus slow.
- return -1;
- }
-};
-
-} // namespace functor
-
-#define DEFINE_GPU_SPECS_INDEX(T, Index) \
- template struct functor::Gather<GPUDevice, T, Index>
-
-#define DEFINE_GPU_SPECS(T) \
- DEFINE_GPU_SPECS_INDEX(T, int32); \
- DEFINE_GPU_SPECS_INDEX(T, int64);
-
-TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
-
-#undef DEFINE_GPU_SPECS
-#undef DEFINE_GPU_SPECS_INDEX
-
-} // namespace tensorflow
-
-#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc
index b183510445..af8bfc432f 100644
--- a/tensorflow/core/kernels/gather_op_test.cc
+++ b/tensorflow/core/kernels/gather_op_test.cc
@@ -17,7 +17,6 @@ limitations under the License.
#include <memory>
#include <vector>
-#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -26,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -120,52 +118,110 @@ TEST_F(GatherOpTest, Error_IndexOutOfRange) {
<< s;
}
-template <typename Index>
-static Graph* Gather(int lookups, int dim) {
- Graph* g = new Graph(OpRegistry::Global());
- // Always use a 512MB buffer.
- const int kRows = ((512 << 20) / sizeof(float)) / dim;
- Tensor params(DT_FLOAT, TensorShape({kRows, dim}));
- params.flat<float>().setRandom();
+class GatherOpForBenchmark : public GatherOpTest {
+ public:
+ void TestBody() override { // not used }
+ }
+ void PublicMakeOp(DataType index_type) { MakeOp(index_type); }
+};
+
+static const int kSorted = 0x8000; // Mask for arg to specify sorting vs. not
+template <typename Index>
+void BM_Gather(int iters, int arg) {
+ testing::StopTiming();
+
+ bool sorted = ((arg & kSorted) != 0);
+ int dim = arg & ~kSorted;
+
+ GatherOpForBenchmark t;
+ t.PublicMakeOp(DataTypeToEnum<Index>::v());
+ // Use a 512 MB table, regardless of dim
+ const int kRows = ((1 << 29) / sizeof(float)) / dim;
+ std::vector<float> data(kRows * dim, 1.0f);
+ t.AddInputFromArray<float>(TensorShape({kRows, dim}), data);
+ const int kLookups = 2000;
+ const int kBatches = 1000000 / kLookups;
random::PhiloxRandom philox(301, 17);
random::SimplePhilox rnd(&philox);
- std::vector<Index> indices_vec;
- for (int i = 0; i < lookups; i++) {
- indices_vec.push_back(rnd.Uniform(kRows));
- }
- Tensor indices(DataTypeToEnum<Index>::value, TensorShape({lookups}));
- for (int i = 0; i < indices_vec.size(); i++) {
- indices.flat<Index>()(i) = indices_vec[i];
+ std::vector<std::vector<Index>> all_ids(kBatches);
+ for (int i = 0; i < kBatches; ++i) {
+ std::vector<Index>* ids = &all_ids[i];
+ ids->resize(kLookups);
+ for (int j = 0; j < kLookups; ++j) {
+ (*ids)[j] = rnd.Uniform(kRows);
+ }
+ if (sorted) {
+ sort(ids->begin(), ids->end());
+ }
}
- test::graph::Gather(g, test::graph::Constant(g, params),
- test::graph::Constant(g, indices));
- return g;
+ t.AddInput<Index>(TensorShape({kLookups}), [](int i) { return 0; });
+ if (sorted) {
+ testing::SetLabel("sorted by id");
+ }
+ testing::BytesProcessed(static_cast<int64>(iters) * kLookups * dim *
+ sizeof(float));
+ testing::StartTiming();
+ while (--iters > 0) {
+ const std::vector<Index>& b = all_ids[iters % kBatches];
+ TensorValue input = t.mutable_input(1);
+ gtl::MutableArraySlice<Index> slice(&input->vec<Index>()(0),
+ input->NumElements());
+ for (int i = 0; i < kLookups; i++) {
+ slice[i] = b[i];
+ }
+ Status s = t.RunOpKernel();
+ }
}
-#define BM_GATHER(DEVICE, INDEX) \
- static void BM_##DEVICE##_gather_##INDEX(int iters, int lookups, int dim) { \
- const int64 tot = static_cast<int64>(iters) * lookups * dim; \
- testing::ItemsProcessed(tot); \
- testing::BytesProcessed(tot * sizeof(float)); \
- testing::UseRealTime(); \
- test::Benchmark(#DEVICE, Gather<INDEX>(lookups, dim)).Run(iters); \
- } \
- BENCHMARK(BM_##DEVICE##_gather_##INDEX) \
- ->ArgPair(2000, 1) \
- ->ArgPair(2000, 10) \
- ->ArgPair(2000, 20) \
- ->ArgPair(2000, 100) \
- ->ArgPair(200, 1000) \
- ->ArgPair(20, 10000) \
- ->ArgPair(20000, 10) \
- ->ArgPair(2000, 1)
-
-BM_GATHER(cpu, int32);
-BM_GATHER(gpu, int32);
-BM_GATHER(cpu, int64);
-BM_GATHER(gpu, int64);
+static void BM_Gather32(int iters, int arg) { BM_Gather<int32>(iters, arg); }
+
+static void BM_Gather64(int iters, int arg) { BM_Gather<int64>(iters, arg); }
+
+BENCHMARK(BM_Gather32)
+ ->Arg(10)
+ ->Arg(10 | kSorted)
+ ->Arg(20)
+ ->Arg(40)
+ ->Arg(63)
+ ->Arg(63 | kSorted)
+ ->Arg(64)
+ ->Arg(64 | kSorted)
+ ->Arg(65)
+ ->Arg(65 | kSorted)
+ ->Arg(100)
+ ->Arg(100 | kSorted)
+ ->Arg(127)
+ ->Arg(127 | kSorted)
+ ->Arg(128)
+ ->Arg(128 | kSorted)
+ ->Arg(129)
+ ->Arg(129 | kSorted)
+ ->Arg(1000)
+ ->Arg(1000 | kSorted);
+
+BENCHMARK(BM_Gather64)
+ ->Arg(10)
+ ->Arg(10 | kSorted)
+ ->Arg(20)
+ ->Arg(40)
+ ->Arg(63)
+ ->Arg(63 | kSorted)
+ ->Arg(64)
+ ->Arg(64 | kSorted)
+ ->Arg(65)
+ ->Arg(65 | kSorted)
+ ->Arg(100)
+ ->Arg(100 | kSorted)
+ ->Arg(127)
+ ->Arg(127 | kSorted)
+ ->Arg(128)
+ ->Arg(128 | kSorted)
+ ->Arg(129)
+ ->Arg(129 | kSorted)
+ ->Arg(1000)
+ ->Arg(1000 | kSorted);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py
index 33a9554711..5292a5bbad 100644
--- a/tensorflow/python/kernel_tests/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_op_test.py
@@ -23,10 +23,9 @@ import tensorflow as tf
class GatherTest(tf.test.TestCase):
- use_gpu = False
def testScalar1D(self):
- with self.test_session(use_gpu=self.use_gpu):
+ with self.test_session():
params = tf.constant([0, 1, 2, 3, 7, 5])
indices = tf.constant(4)
gather_t = tf.gather(params, indices)
@@ -35,7 +34,7 @@ class GatherTest(tf.test.TestCase):
self.assertEqual([], gather_t.get_shape())
def testScalar2D(self):
- with self.test_session(use_gpu=self.use_gpu):
+ with self.test_session():
params = tf.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8],
[9, 10, 11], [12, 13, 14]])
indices = tf.constant(2)
@@ -45,7 +44,7 @@ class GatherTest(tf.test.TestCase):
self.assertEqual([3], gather_t.get_shape())
def testSimpleTwoD32(self):
- with self.test_session(use_gpu=self.use_gpu):
+ with self.test_session():
params = tf.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8],
[9, 10, 11], [12, 13, 14]])
indices = tf.constant([0, 4, 0, 2])
@@ -60,7 +59,7 @@ class GatherTest(tf.test.TestCase):
shape = (4, 3, 2)
params = np.random.randn(*shape)
indices = np.random.randint(shape[0], size=15).reshape(3, 5)
- with self.test_session(use_gpu=self.use_gpu):
+ with self.test_session():
tf_params = tf.constant(params)
tf_indices = tf.constant(indices)
gather = tf.gather(tf_params, tf_indices)
@@ -85,7 +84,7 @@ class GatherTest(tf.test.TestCase):
self.assertEqual(None, gather_t.get_shape())
def testBadIndices(self):
- with self.test_session(use_gpu=False):
+ with self.test_session():
params = [0, 1, 2]
indices = [[7]]
gather = tf.gather(params, indices)
@@ -93,9 +92,5 @@ class GatherTest(tf.test.TestCase):
gather.eval()
-class GatherGpuTest(GatherTest):
- use_gpu = True
-
-
if __name__ == "__main__":
tf.test.main()