diff options
author | Vijay Vasudevan <vrv@google.com> | 2016-03-01 15:49:41 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-03-01 16:35:38 -0800 |
commit | 3f8a7616452346ef4de5a0da7b79fdb1cc8d6594 (patch) | |
tree | 8be48d0b9dacc43b8bdfb8cb22936336f90a8eba /tensorflow/core/kernels/gather_op.cc | |
parent | 40ccc0aa632ebbd3b34f2f5da5f1ed576d50f57a (diff) |
Rollback of "Adds GPU kernel for gather ops."
Change: 116064181
Diffstat (limited to 'tensorflow/core/kernels/gather_op.cc')
-rw-r--r-- | tensorflow/core/kernels/gather_op.cc | 183 |
1 files changed, 66 insertions, 117 deletions
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index d4d8d86ef6..d7a4e20fbd 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -15,7 +15,6 @@ limitations under the License. // See docs in ../ops/array_ops.cc. -#include "tensorflow/core/kernels/gather_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -27,10 +26,49 @@ limitations under the License. namespace tensorflow { -typedef Eigen::ThreadPoolDevice CPUDevice; -typedef Eigen::GpuDevice GPUDevice; +namespace { +// Returns -1 on success or a nonnegative i s.t., indices[i] is bad. +template <typename T, typename Index, int static_slice_elems> +Index HandleCopies(const Tensor& params, + typename TTypes<Index>::ConstVec indices, Index slice_elems, + typename TTypes<T>::Matrix out) { + const int N = indices.dimension(0); + const auto& params_flat = params.flat_outer_dims<T>(); + const Index limit = params.dim_size(0); + T* out_base = &out(0, 0); + const T* params_base = ¶ms_flat(0, 0); + if (static_slice_elems >= 0) { + // Give compiler static knowledge of the number of elements/bytes + CHECK_EQ(static_slice_elems, slice_elems); + slice_elems = static_slice_elems; + } + // Compute slice_bytes here so that static knowledge is available + const size_t slice_bytes = slice_elems * sizeof(T); + for (int i = 0; i < N; i++) { + const int j = i + 1; + if (j < N) { + port::prefetch<port::PREFETCH_HINT_T0>(¶ms_flat(indices(j), 0)); + port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0)); + } + // Grab the index and check its validity. An earlier version of the + // code checked it and then grabbed it from memory a second time, which + // was a security risk since it could have changed in between. + const Index index = indices(i); + if (!FastBoundsCheck(index, limit)) return i; + // Copy using memcpy if possible, otherwise an Eigen loop + if (Allocator::is_simple<T>::value) { + memcpy(out_base + i * slice_elems, params_base + index * slice_elems, + slice_bytes); + } else { + out.template chip<0>(i) = params_flat.template chip<0>(index); + } + } + return -1; +} + +} // anonymous namespace -template <typename Device, typename T, typename Index> +template <typename T, typename Index> class GatherOp : public OpKernel { public: // QUESTION: It'd be nice to support DT_INT16, DT_UINT8, @@ -78,13 +116,23 @@ class GatherOp : public OpKernel { Tensor* out = nullptr; OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out)); if (N > 0) { - auto params_flat = params.flat_outer_dims<T>(); auto indices_flat = indices.flat<Index>(); auto out_flat = out->shaped<T, 2>({N, out->NumElements() / N}); + const int64 slice_size = out->NumElements() / N; + Index bad_i; + +#define CALL(elems) \ + bad_i = HandleCopies<T, Index, elems>(params, indices_flat, slice_size, \ + out_flat) - functor::Gather<Device, T, Index> functor; - Index bad_i = functor(c->eigen_device<Device>(), params_flat, - indices_flat, out_flat); + if (slice_size == 10) + CALL(10); + else if (slice_size == 20) + CALL(20); + else + CALL(-1); + +#undef CALL OP_REQUIRES( c, bad_i < 0, @@ -95,120 +143,21 @@ class GatherOp : public OpKernel { } }; -namespace functor { - -// Helper method to copy using memcpy. -template <typename T, typename Index, int static_slice_elems> -Index HandleCopies(typename TTypes<T>::ConstMatrix params, - typename TTypes<Index>::ConstFlat indices, Index slice_elems, - typename TTypes<T>::Matrix out) { - const int first_dim_size = indices.dimension(0); - const Index limit = params.dimension(0); - T* out_base = &out(0, 0); - const T* params_base = ¶ms(0, 0); - if (static_slice_elems >= 0) { - // Give compiler static knowledge of the number of elements/bytes - CHECK_EQ(static_slice_elems, slice_elems); - slice_elems = static_slice_elems; - } - // Compute slice_bytes here so that static knowledge is available - const size_t slice_bytes = slice_elems * sizeof(T); - for (int i = 0; i < first_dim_size; i++) { - const int j = i + 1; - if (j < first_dim_size) { - port::prefetch<port::PREFETCH_HINT_T0>(¶ms(indices(j), 0)); - port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0)); - } - // Grab the index and check its validity. An earlier version of the - // code checked it and then grabbed it from memory a second time, which - // was a security risk since it could have changed in between. - const Index index = indices(i); - if (!FastBoundsCheck(index, limit)) return i; - // Copy using memcpy if possible, otherwise an Eigen loop - if (Allocator::is_simple<T>::value) { - memcpy(out_base + i * slice_elems, params_base + index * slice_elems, - slice_bytes); - } else { - out.template chip<0>(i) = params.template chip<0>(index); - } - } - return -1; -} - -// Specialization gather functor for CPU. -template <typename T, typename Index> -struct Gather<CPUDevice, T, Index> { - Index operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix params, - typename TTypes<Index>::ConstFlat indices, - typename TTypes<T>::Matrix out) { - const int N = indices.size(); - const int64 slice_size = out.size() / N; - Index bad_i; - -#define CALL(elems) \ - bad_i = HandleCopies<T, Index, elems>(params, indices, slice_size, out) - - if (slice_size == 10) - CALL(10); - else if (slice_size == 20) - CALL(20); - else - CALL(-1); -#undef CALL - - return bad_i; - } -}; -} // namespace functor - -#define REGISTER_GATHER_FULL(dev, type, index_type) \ +#define REGISTER_GATHER(type, index_type) \ REGISTER_KERNEL_BUILDER(Name("Gather") \ - .Device(DEVICE_##dev) \ + .Device(DEVICE_CPU) \ .TypeConstraint<type>("Tparams") \ .TypeConstraint<index_type>("Tindices"), \ - GatherOp<dev##Device, type, index_type>) - -#define REGISTER_GATHER_ALL_INDICES(dev, type) \ - REGISTER_GATHER_FULL(dev, type, int32); \ - REGISTER_GATHER_FULL(dev, type, int64) - -#define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type) - -TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU); - -#undef REGISTER_GATHER_CPU - -#if GOOGLE_CUDA -// Forward declarations of the functor specializations for GPU. -namespace functor { -#define DECLARE_GPU_SPECS_INDEX(T, Index) \ - template <> \ - Index Gather<GPUDevice, T, Index>::operator()( \ - const GPUDevice& d, typename TTypes<T>::ConstMatrix Tparams, \ - typename TTypes<Index>::ConstFlat Tindices, \ - typename TTypes<T>::Matrix Tout); \ - extern template struct Gather<GPUDevice, T, Index> - -#define DECLARE_GPU_SPECS(T) \ - DECLARE_GPU_SPECS_INDEX(T, int32); \ - DECLARE_GPU_SPECS_INDEX(T, int64) - -TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); - -#undef DECLARE_GPU_SPECS -#undef DECLARE_GPU_SPECS_INDEX -} // namespace functor - -// Registration of the GPU implementations. -#define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type) - -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU); + GatherOp<type, index_type>) -#undef REGISTER_GATHER_GPU +#define REGISTER_GATHER_INT32(type) REGISTER_GATHER(type, int32) +#define REGISTER_GATHER_INT64(type) REGISTER_GATHER(type, int64) -#endif // GOOGLE_CUDA +TF_CALL_ALL_TYPES(REGISTER_GATHER_INT32); +TF_CALL_ALL_TYPES(REGISTER_GATHER_INT64); -#undef REGISTER_GATHER_ALL_INDICES -#undef REGISTER_GATHER_FULL +#undef REGISTER_GATHER_INT32 +#undef REGISTER_GATHER_INT64 +#undef REGISTER_GATHER } // namespace tensorflow |