diff options
author | 2016-03-01 15:49:41 -0800 | |
---|---|---|
committer | 2016-03-01 16:35:38 -0800 | |
commit | 3f8a7616452346ef4de5a0da7b79fdb1cc8d6594 (patch) | |
tree | 8be48d0b9dacc43b8bdfb8cb22936336f90a8eba /tensorflow | |
parent | 40ccc0aa632ebbd3b34f2f5da5f1ed576d50f57a (diff) |
Rollback of "Adds GPU kernel for gather ops."
Change: 116064181
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/graph/testlib.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/graph/testlib.h | 3 | ||||
-rw-r--r-- | tensorflow/core/kernels/gather_op.cc | 183 | ||||
-rw-r--r-- | tensorflow/core/kernels/gather_op.h | 41 | ||||
-rw-r--r-- | tensorflow/core/kernels/gather_op_gpu.cu.cc | 86 | ||||
-rw-r--r-- | tensorflow/core/kernels/gather_op_test.cc | 138 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/gather_op_test.py | 15 |
7 files changed, 168 insertions, 307 deletions
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc index f3164009fc..63ac4b46a7 100644 --- a/tensorflow/core/graph/testlib.cc +++ b/tensorflow/core/graph/testlib.cc @@ -351,15 +351,6 @@ Node* BroadcastGradientArgs(Graph* g, Node* s0, Node* s1) { return ret; } -Node* Gather(Graph* g, Node* in0, Node* in1) { - Node* ret; - TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Gather") - .Input(in0) - .Input(in1) - .Finalize(g, &ret)); - return ret; -} - void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } } // end namespace graph diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h index cb6f2468f2..57f9e224a8 100644 --- a/tensorflow/core/graph/testlib.h +++ b/tensorflow/core/graph/testlib.h @@ -155,9 +155,6 @@ Node* Select(Graph* g, Node* c, Node* inx, Node* iny); // Casts "in" into data type "dst". Node* Cast(Graph* g, Node* in, DataType dst); -// Perform gather op on params "in0" with indicies "in1". -Node* Gather(Graph* g, Node* in0, Node* in1); - // Computes the args needed broadcast gradient function. Node* BroadcastGradientArgs(Graph* g, Node* s0, Node* s1); diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index d4d8d86ef6..d7a4e20fbd 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -15,7 +15,6 @@ limitations under the License. // See docs in ../ops/array_ops.cc. -#include "tensorflow/core/kernels/gather_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -27,10 +26,49 @@ limitations under the License. namespace tensorflow { -typedef Eigen::ThreadPoolDevice CPUDevice; -typedef Eigen::GpuDevice GPUDevice; +namespace { +// Returns -1 on success or a nonnegative i s.t., indices[i] is bad. +template <typename T, typename Index, int static_slice_elems> +Index HandleCopies(const Tensor& params, + typename TTypes<Index>::ConstVec indices, Index slice_elems, + typename TTypes<T>::Matrix out) { + const int N = indices.dimension(0); + const auto& params_flat = params.flat_outer_dims<T>(); + const Index limit = params.dim_size(0); + T* out_base = &out(0, 0); + const T* params_base = ¶ms_flat(0, 0); + if (static_slice_elems >= 0) { + // Give compiler static knowledge of the number of elements/bytes + CHECK_EQ(static_slice_elems, slice_elems); + slice_elems = static_slice_elems; + } + // Compute slice_bytes here so that static knowledge is available + const size_t slice_bytes = slice_elems * sizeof(T); + for (int i = 0; i < N; i++) { + const int j = i + 1; + if (j < N) { + port::prefetch<port::PREFETCH_HINT_T0>(¶ms_flat(indices(j), 0)); + port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0)); + } + // Grab the index and check its validity. An earlier version of the + // code checked it and then grabbed it from memory a second time, which + // was a security risk since it could have changed in between. + const Index index = indices(i); + if (!FastBoundsCheck(index, limit)) return i; + // Copy using memcpy if possible, otherwise an Eigen loop + if (Allocator::is_simple<T>::value) { + memcpy(out_base + i * slice_elems, params_base + index * slice_elems, + slice_bytes); + } else { + out.template chip<0>(i) = params_flat.template chip<0>(index); + } + } + return -1; +} + +} // anonymous namespace -template <typename Device, typename T, typename Index> +template <typename T, typename Index> class GatherOp : public OpKernel { public: // QUESTION: It'd be nice to support DT_INT16, DT_UINT8, @@ -78,13 +116,23 @@ class GatherOp : public OpKernel { Tensor* out = nullptr; OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out)); if (N > 0) { - auto params_flat = params.flat_outer_dims<T>(); auto indices_flat = indices.flat<Index>(); auto out_flat = out->shaped<T, 2>({N, out->NumElements() / N}); + const int64 slice_size = out->NumElements() / N; + Index bad_i; + +#define CALL(elems) \ + bad_i = HandleCopies<T, Index, elems>(params, indices_flat, slice_size, \ + out_flat) - functor::Gather<Device, T, Index> functor; - Index bad_i = functor(c->eigen_device<Device>(), params_flat, - indices_flat, out_flat); + if (slice_size == 10) + CALL(10); + else if (slice_size == 20) + CALL(20); + else + CALL(-1); + +#undef CALL OP_REQUIRES( c, bad_i < 0, @@ -95,120 +143,21 @@ class GatherOp : public OpKernel { } }; -namespace functor { - -// Helper method to copy using memcpy. -template <typename T, typename Index, int static_slice_elems> -Index HandleCopies(typename TTypes<T>::ConstMatrix params, - typename TTypes<Index>::ConstFlat indices, Index slice_elems, - typename TTypes<T>::Matrix out) { - const int first_dim_size = indices.dimension(0); - const Index limit = params.dimension(0); - T* out_base = &out(0, 0); - const T* params_base = ¶ms(0, 0); - if (static_slice_elems >= 0) { - // Give compiler static knowledge of the number of elements/bytes - CHECK_EQ(static_slice_elems, slice_elems); - slice_elems = static_slice_elems; - } - // Compute slice_bytes here so that static knowledge is available - const size_t slice_bytes = slice_elems * sizeof(T); - for (int i = 0; i < first_dim_size; i++) { - const int j = i + 1; - if (j < first_dim_size) { - port::prefetch<port::PREFETCH_HINT_T0>(¶ms(indices(j), 0)); - port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0)); - } - // Grab the index and check its validity. An earlier version of the - // code checked it and then grabbed it from memory a second time, which - // was a security risk since it could have changed in between. - const Index index = indices(i); - if (!FastBoundsCheck(index, limit)) return i; - // Copy using memcpy if possible, otherwise an Eigen loop - if (Allocator::is_simple<T>::value) { - memcpy(out_base + i * slice_elems, params_base + index * slice_elems, - slice_bytes); - } else { - out.template chip<0>(i) = params.template chip<0>(index); - } - } - return -1; -} - -// Specialization gather functor for CPU. -template <typename T, typename Index> -struct Gather<CPUDevice, T, Index> { - Index operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix params, - typename TTypes<Index>::ConstFlat indices, - typename TTypes<T>::Matrix out) { - const int N = indices.size(); - const int64 slice_size = out.size() / N; - Index bad_i; - -#define CALL(elems) \ - bad_i = HandleCopies<T, Index, elems>(params, indices, slice_size, out) - - if (slice_size == 10) - CALL(10); - else if (slice_size == 20) - CALL(20); - else - CALL(-1); -#undef CALL - - return bad_i; - } -}; -} // namespace functor - -#define REGISTER_GATHER_FULL(dev, type, index_type) \ +#define REGISTER_GATHER(type, index_type) \ REGISTER_KERNEL_BUILDER(Name("Gather") \ - .Device(DEVICE_##dev) \ + .Device(DEVICE_CPU) \ .TypeConstraint<type>("Tparams") \ .TypeConstraint<index_type>("Tindices"), \ - GatherOp<dev##Device, type, index_type>) - -#define REGISTER_GATHER_ALL_INDICES(dev, type) \ - REGISTER_GATHER_FULL(dev, type, int32); \ - REGISTER_GATHER_FULL(dev, type, int64) - -#define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type) - -TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU); - -#undef REGISTER_GATHER_CPU - -#if GOOGLE_CUDA -// Forward declarations of the functor specializations for GPU. -namespace functor { -#define DECLARE_GPU_SPECS_INDEX(T, Index) \ - template <> \ - Index Gather<GPUDevice, T, Index>::operator()( \ - const GPUDevice& d, typename TTypes<T>::ConstMatrix Tparams, \ - typename TTypes<Index>::ConstFlat Tindices, \ - typename TTypes<T>::Matrix Tout); \ - extern template struct Gather<GPUDevice, T, Index> - -#define DECLARE_GPU_SPECS(T) \ - DECLARE_GPU_SPECS_INDEX(T, int32); \ - DECLARE_GPU_SPECS_INDEX(T, int64) - -TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); - -#undef DECLARE_GPU_SPECS -#undef DECLARE_GPU_SPECS_INDEX -} // namespace functor - -// Registration of the GPU implementations. -#define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type) - -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU); + GatherOp<type, index_type>) -#undef REGISTER_GATHER_GPU +#define REGISTER_GATHER_INT32(type) REGISTER_GATHER(type, int32) +#define REGISTER_GATHER_INT64(type) REGISTER_GATHER(type, int64) -#endif // GOOGLE_CUDA +TF_CALL_ALL_TYPES(REGISTER_GATHER_INT32); +TF_CALL_ALL_TYPES(REGISTER_GATHER_INT64); -#undef REGISTER_GATHER_ALL_INDICES -#undef REGISTER_GATHER_FULL +#undef REGISTER_GATHER_INT32 +#undef REGISTER_GATHER_INT64 +#undef REGISTER_GATHER } // namespace tensorflow diff --git a/tensorflow/core/kernels/gather_op.h b/tensorflow/core/kernels/gather_op.h deleted file mode 100644 index 4b82815f54..0000000000 --- a/tensorflow/core/kernels/gather_op.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2015 Google Inc. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_KERNELS_GATHER_OP_H_ -#define TENSORFLOW_KERNELS_GATHER_OP_H_ -// Functor definition for GatherOp, must be compilable by nvcc. - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/tensor_types.h" - -namespace tensorflow { - -class OpKernelContext; - -namespace functor { -template <typename Device, typename T, typename Index> -struct Gather { - // Performs gather op on (Tparams, Tindices), writing to Tout. - // Returns an index to Tindices if the value at that index is out of range. - // Returns -1 if all values of Tindices are in range. - Index operator()(const Device& d, typename TTypes<T>::ConstMatrix Tparams, - typename TTypes<Index>::ConstFlat Tindices, - typename TTypes<T>::Matrix Tout); -}; - -} // namespace functor -} // namespace tensorflow - -#endif // TENSORFLOW_KERNELS_GATHER_OP_H_ diff --git a/tensorflow/core/kernels/gather_op_gpu.cu.cc b/tensorflow/core/kernels/gather_op_gpu.cu.cc deleted file mode 100644 index 5bc54a90f9..0000000000 --- a/tensorflow/core/kernels/gather_op_gpu.cu.cc +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2015 Google Inc. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#if GOOGLE_CUDA - -#define EIGEN_USE_GPU - -#include "tensorflow/core/kernels/gather_op.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" - -namespace tensorflow { - -typedef Eigen::GpuDevice GPUDevice; - -template <typename T, typename Index> -__global__ void GatherOpKernel(const T* params, const Index* indices, T* out, - int64 first_dim_size, int64 indices_size, - int64 out_size) { - const int32 slice_size = out_size / indices_size; - CUDA_1D_KERNEL_LOOP(i, out_size) { - Index indices_i = i / slice_size; - Index indices_slice_i = i - indices_i * slice_size; - Index params_first_index = ldg(indices + indices_i); - if (!(params_first_index >= 0 && params_first_index < first_dim_size)) { - // Ignore indices that are out of range. - continue; - } - Index params_i = params_first_index * slice_size + indices_slice_i; - out[i] = ldg(params + params_i); - } -} - -namespace functor { -template <typename T, typename Index> -struct Gather<GPUDevice, T, Index> { - Index operator()(const GPUDevice& d, typename TTypes<T>::ConstMatrix Tparams, - typename TTypes<Index>::ConstFlat Tindices, - typename TTypes<T>::Matrix Tout) { - const int64 first_dim_size = Tparams.dimension(0); - const int64 indices_size = Tindices.size(); - const int64 out_size = Tout.size(); - CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d); - // clang-format off - GatherOpKernel<T, Index> - <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( - Tparams.data(), Tindices.data(), Tout.data(), first_dim_size, - indices_size, out_size); - // clang-format on - // TODO(fpmc): enable indices validation on GPU. - // Right now checking for indicies out of bound in the kernel would - // require copying code between GPU/CPU, and thus slow. - return -1; - } -}; - -} // namespace functor - -#define DEFINE_GPU_SPECS_INDEX(T, Index) \ - template struct functor::Gather<GPUDevice, T, Index> - -#define DEFINE_GPU_SPECS(T) \ - DEFINE_GPU_SPECS_INDEX(T, int32); \ - DEFINE_GPU_SPECS_INDEX(T, int64); - -TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); - -#undef DEFINE_GPU_SPECS -#undef DEFINE_GPU_SPECS_INDEX - -} // namespace tensorflow - -#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc index b183510445..af8bfc432f 100644 --- a/tensorflow/core/kernels/gather_op_test.cc +++ b/tensorflow/core/kernels/gather_op_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include <memory> #include <vector> -#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/graph.pb.h" @@ -26,7 +25,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -120,52 +118,110 @@ TEST_F(GatherOpTest, Error_IndexOutOfRange) { << s; } -template <typename Index> -static Graph* Gather(int lookups, int dim) { - Graph* g = new Graph(OpRegistry::Global()); - // Always use a 512MB buffer. - const int kRows = ((512 << 20) / sizeof(float)) / dim; - Tensor params(DT_FLOAT, TensorShape({kRows, dim})); - params.flat<float>().setRandom(); +class GatherOpForBenchmark : public GatherOpTest { + public: + void TestBody() override { // not used } + } + void PublicMakeOp(DataType index_type) { MakeOp(index_type); } +}; + +static const int kSorted = 0x8000; // Mask for arg to specify sorting vs. not +template <typename Index> +void BM_Gather(int iters, int arg) { + testing::StopTiming(); + + bool sorted = ((arg & kSorted) != 0); + int dim = arg & ~kSorted; + + GatherOpForBenchmark t; + t.PublicMakeOp(DataTypeToEnum<Index>::v()); + // Use a 512 MB table, regardless of dim + const int kRows = ((1 << 29) / sizeof(float)) / dim; + std::vector<float> data(kRows * dim, 1.0f); + t.AddInputFromArray<float>(TensorShape({kRows, dim}), data); + const int kLookups = 2000; + const int kBatches = 1000000 / kLookups; random::PhiloxRandom philox(301, 17); random::SimplePhilox rnd(&philox); - std::vector<Index> indices_vec; - for (int i = 0; i < lookups; i++) { - indices_vec.push_back(rnd.Uniform(kRows)); - } - Tensor indices(DataTypeToEnum<Index>::value, TensorShape({lookups})); - for (int i = 0; i < indices_vec.size(); i++) { - indices.flat<Index>()(i) = indices_vec[i]; + std::vector<std::vector<Index>> all_ids(kBatches); + for (int i = 0; i < kBatches; ++i) { + std::vector<Index>* ids = &all_ids[i]; + ids->resize(kLookups); + for (int j = 0; j < kLookups; ++j) { + (*ids)[j] = rnd.Uniform(kRows); + } + if (sorted) { + sort(ids->begin(), ids->end()); + } } - test::graph::Gather(g, test::graph::Constant(g, params), - test::graph::Constant(g, indices)); - return g; + t.AddInput<Index>(TensorShape({kLookups}), [](int i) { return 0; }); + if (sorted) { + testing::SetLabel("sorted by id"); + } + testing::BytesProcessed(static_cast<int64>(iters) * kLookups * dim * + sizeof(float)); + testing::StartTiming(); + while (--iters > 0) { + const std::vector<Index>& b = all_ids[iters % kBatches]; + TensorValue input = t.mutable_input(1); + gtl::MutableArraySlice<Index> slice(&input->vec<Index>()(0), + input->NumElements()); + for (int i = 0; i < kLookups; i++) { + slice[i] = b[i]; + } + Status s = t.RunOpKernel(); + } } -#define BM_GATHER(DEVICE, INDEX) \ - static void BM_##DEVICE##_gather_##INDEX(int iters, int lookups, int dim) { \ - const int64 tot = static_cast<int64>(iters) * lookups * dim; \ - testing::ItemsProcessed(tot); \ - testing::BytesProcessed(tot * sizeof(float)); \ - testing::UseRealTime(); \ - test::Benchmark(#DEVICE, Gather<INDEX>(lookups, dim)).Run(iters); \ - } \ - BENCHMARK(BM_##DEVICE##_gather_##INDEX) \ - ->ArgPair(2000, 1) \ - ->ArgPair(2000, 10) \ - ->ArgPair(2000, 20) \ - ->ArgPair(2000, 100) \ - ->ArgPair(200, 1000) \ - ->ArgPair(20, 10000) \ - ->ArgPair(20000, 10) \ - ->ArgPair(2000, 1) - -BM_GATHER(cpu, int32); -BM_GATHER(gpu, int32); -BM_GATHER(cpu, int64); -BM_GATHER(gpu, int64); +static void BM_Gather32(int iters, int arg) { BM_Gather<int32>(iters, arg); } + +static void BM_Gather64(int iters, int arg) { BM_Gather<int64>(iters, arg); } + +BENCHMARK(BM_Gather32) + ->Arg(10) + ->Arg(10 | kSorted) + ->Arg(20) + ->Arg(40) + ->Arg(63) + ->Arg(63 | kSorted) + ->Arg(64) + ->Arg(64 | kSorted) + ->Arg(65) + ->Arg(65 | kSorted) + ->Arg(100) + ->Arg(100 | kSorted) + ->Arg(127) + ->Arg(127 | kSorted) + ->Arg(128) + ->Arg(128 | kSorted) + ->Arg(129) + ->Arg(129 | kSorted) + ->Arg(1000) + ->Arg(1000 | kSorted); + +BENCHMARK(BM_Gather64) + ->Arg(10) + ->Arg(10 | kSorted) + ->Arg(20) + ->Arg(40) + ->Arg(63) + ->Arg(63 | kSorted) + ->Arg(64) + ->Arg(64 | kSorted) + ->Arg(65) + ->Arg(65 | kSorted) + ->Arg(100) + ->Arg(100 | kSorted) + ->Arg(127) + ->Arg(127 | kSorted) + ->Arg(128) + ->Arg(128 | kSorted) + ->Arg(129) + ->Arg(129 | kSorted) + ->Arg(1000) + ->Arg(1000 | kSorted); } // namespace } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py index 33a9554711..5292a5bbad 100644 --- a/tensorflow/python/kernel_tests/gather_op_test.py +++ b/tensorflow/python/kernel_tests/gather_op_test.py @@ -23,10 +23,9 @@ import tensorflow as tf class GatherTest(tf.test.TestCase): - use_gpu = False def testScalar1D(self): - with self.test_session(use_gpu=self.use_gpu): + with self.test_session(): params = tf.constant([0, 1, 2, 3, 7, 5]) indices = tf.constant(4) gather_t = tf.gather(params, indices) @@ -35,7 +34,7 @@ class GatherTest(tf.test.TestCase): self.assertEqual([], gather_t.get_shape()) def testScalar2D(self): - with self.test_session(use_gpu=self.use_gpu): + with self.test_session(): params = tf.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) indices = tf.constant(2) @@ -45,7 +44,7 @@ class GatherTest(tf.test.TestCase): self.assertEqual([3], gather_t.get_shape()) def testSimpleTwoD32(self): - with self.test_session(use_gpu=self.use_gpu): + with self.test_session(): params = tf.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) indices = tf.constant([0, 4, 0, 2]) @@ -60,7 +59,7 @@ class GatherTest(tf.test.TestCase): shape = (4, 3, 2) params = np.random.randn(*shape) indices = np.random.randint(shape[0], size=15).reshape(3, 5) - with self.test_session(use_gpu=self.use_gpu): + with self.test_session(): tf_params = tf.constant(params) tf_indices = tf.constant(indices) gather = tf.gather(tf_params, tf_indices) @@ -85,7 +84,7 @@ class GatherTest(tf.test.TestCase): self.assertEqual(None, gather_t.get_shape()) def testBadIndices(self): - with self.test_session(use_gpu=False): + with self.test_session(): params = [0, 1, 2] indices = [[7]] gather = tf.gather(params, indices) @@ -93,9 +92,5 @@ class GatherTest(tf.test.TestCase): gather.eval() -class GatherGpuTest(GatherTest): - use_gpu = True - - if __name__ == "__main__": tf.test.main() |