aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/gather_op.cc
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-03-01 15:49:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-01 16:35:38 -0800
commit3f8a7616452346ef4de5a0da7b79fdb1cc8d6594 (patch)
tree8be48d0b9dacc43b8bdfb8cb22936336f90a8eba /tensorflow/core/kernels/gather_op.cc
parent40ccc0aa632ebbd3b34f2f5da5f1ed576d50f57a (diff)
Rollback of "Adds GPU kernel for gather ops."
Change: 116064181
Diffstat (limited to 'tensorflow/core/kernels/gather_op.cc')
-rw-r--r--tensorflow/core/kernels/gather_op.cc183
1 files changed, 66 insertions, 117 deletions
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
index d4d8d86ef6..d7a4e20fbd 100644
--- a/tensorflow/core/kernels/gather_op.cc
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -15,7 +15,6 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
-#include "tensorflow/core/kernels/gather_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -27,10 +26,49 @@ limitations under the License.
namespace tensorflow {
-typedef Eigen::ThreadPoolDevice CPUDevice;
-typedef Eigen::GpuDevice GPUDevice;
+namespace {
+// Returns -1 on success or a nonnegative i s.t., indices[i] is bad.
+template <typename T, typename Index, int static_slice_elems>
+Index HandleCopies(const Tensor& params,
+ typename TTypes<Index>::ConstVec indices, Index slice_elems,
+ typename TTypes<T>::Matrix out) {
+ const int N = indices.dimension(0);
+ const auto& params_flat = params.flat_outer_dims<T>();
+ const Index limit = params.dim_size(0);
+ T* out_base = &out(0, 0);
+ const T* params_base = &params_flat(0, 0);
+ if (static_slice_elems >= 0) {
+ // Give compiler static knowledge of the number of elements/bytes
+ CHECK_EQ(static_slice_elems, slice_elems);
+ slice_elems = static_slice_elems;
+ }
+ // Compute slice_bytes here so that static knowledge is available
+ const size_t slice_bytes = slice_elems * sizeof(T);
+ for (int i = 0; i < N; i++) {
+ const int j = i + 1;
+ if (j < N) {
+ port::prefetch<port::PREFETCH_HINT_T0>(&params_flat(indices(j), 0));
+ port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0));
+ }
+ // Grab the index and check its validity. An earlier version of the
+ // code checked it and then grabbed it from memory a second time, which
+ // was a security risk since it could have changed in between.
+ const Index index = indices(i);
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Copy using memcpy if possible, otherwise an Eigen loop
+ if (Allocator::is_simple<T>::value) {
+ memcpy(out_base + i * slice_elems, params_base + index * slice_elems,
+ slice_bytes);
+ } else {
+ out.template chip<0>(i) = params_flat.template chip<0>(index);
+ }
+ }
+ return -1;
+}
+
+} // anonymous namespace
-template <typename Device, typename T, typename Index>
+template <typename T, typename Index>
class GatherOp : public OpKernel {
public:
// QUESTION: It'd be nice to support DT_INT16, DT_UINT8,
@@ -78,13 +116,23 @@ class GatherOp : public OpKernel {
Tensor* out = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
if (N > 0) {
- auto params_flat = params.flat_outer_dims<T>();
auto indices_flat = indices.flat<Index>();
auto out_flat = out->shaped<T, 2>({N, out->NumElements() / N});
+ const int64 slice_size = out->NumElements() / N;
+ Index bad_i;
+
+#define CALL(elems) \
+ bad_i = HandleCopies<T, Index, elems>(params, indices_flat, slice_size, \
+ out_flat)
- functor::Gather<Device, T, Index> functor;
- Index bad_i = functor(c->eigen_device<Device>(), params_flat,
- indices_flat, out_flat);
+ if (slice_size == 10)
+ CALL(10);
+ else if (slice_size == 20)
+ CALL(20);
+ else
+ CALL(-1);
+
+#undef CALL
OP_REQUIRES(
c, bad_i < 0,
@@ -95,120 +143,21 @@ class GatherOp : public OpKernel {
}
};
-namespace functor {
-
-// Helper method to copy using memcpy.
-template <typename T, typename Index, int static_slice_elems>
-Index HandleCopies(typename TTypes<T>::ConstMatrix params,
- typename TTypes<Index>::ConstFlat indices, Index slice_elems,
- typename TTypes<T>::Matrix out) {
- const int first_dim_size = indices.dimension(0);
- const Index limit = params.dimension(0);
- T* out_base = &out(0, 0);
- const T* params_base = &params(0, 0);
- if (static_slice_elems >= 0) {
- // Give compiler static knowledge of the number of elements/bytes
- CHECK_EQ(static_slice_elems, slice_elems);
- slice_elems = static_slice_elems;
- }
- // Compute slice_bytes here so that static knowledge is available
- const size_t slice_bytes = slice_elems * sizeof(T);
- for (int i = 0; i < first_dim_size; i++) {
- const int j = i + 1;
- if (j < first_dim_size) {
- port::prefetch<port::PREFETCH_HINT_T0>(&params(indices(j), 0));
- port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0));
- }
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
- const Index index = indices(i);
- if (!FastBoundsCheck(index, limit)) return i;
- // Copy using memcpy if possible, otherwise an Eigen loop
- if (Allocator::is_simple<T>::value) {
- memcpy(out_base + i * slice_elems, params_base + index * slice_elems,
- slice_bytes);
- } else {
- out.template chip<0>(i) = params.template chip<0>(index);
- }
- }
- return -1;
-}
-
-// Specialization gather functor for CPU.
-template <typename T, typename Index>
-struct Gather<CPUDevice, T, Index> {
- Index operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix params,
- typename TTypes<Index>::ConstFlat indices,
- typename TTypes<T>::Matrix out) {
- const int N = indices.size();
- const int64 slice_size = out.size() / N;
- Index bad_i;
-
-#define CALL(elems) \
- bad_i = HandleCopies<T, Index, elems>(params, indices, slice_size, out)
-
- if (slice_size == 10)
- CALL(10);
- else if (slice_size == 20)
- CALL(20);
- else
- CALL(-1);
-#undef CALL
-
- return bad_i;
- }
-};
-} // namespace functor
-
-#define REGISTER_GATHER_FULL(dev, type, index_type) \
+#define REGISTER_GATHER(type, index_type) \
REGISTER_KERNEL_BUILDER(Name("Gather") \
- .Device(DEVICE_##dev) \
+ .Device(DEVICE_CPU) \
.TypeConstraint<type>("Tparams") \
.TypeConstraint<index_type>("Tindices"), \
- GatherOp<dev##Device, type, index_type>)
-
-#define REGISTER_GATHER_ALL_INDICES(dev, type) \
- REGISTER_GATHER_FULL(dev, type, int32); \
- REGISTER_GATHER_FULL(dev, type, int64)
-
-#define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
-
-TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
-
-#undef REGISTER_GATHER_CPU
-
-#if GOOGLE_CUDA
-// Forward declarations of the functor specializations for GPU.
-namespace functor {
-#define DECLARE_GPU_SPECS_INDEX(T, Index) \
- template <> \
- Index Gather<GPUDevice, T, Index>::operator()( \
- const GPUDevice& d, typename TTypes<T>::ConstMatrix Tparams, \
- typename TTypes<Index>::ConstFlat Tindices, \
- typename TTypes<T>::Matrix Tout); \
- extern template struct Gather<GPUDevice, T, Index>
-
-#define DECLARE_GPU_SPECS(T) \
- DECLARE_GPU_SPECS_INDEX(T, int32); \
- DECLARE_GPU_SPECS_INDEX(T, int64)
-
-TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
-
-#undef DECLARE_GPU_SPECS
-#undef DECLARE_GPU_SPECS_INDEX
-} // namespace functor
-
-// Registration of the GPU implementations.
-#define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
-
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU);
+ GatherOp<type, index_type>)
-#undef REGISTER_GATHER_GPU
+#define REGISTER_GATHER_INT32(type) REGISTER_GATHER(type, int32)
+#define REGISTER_GATHER_INT64(type) REGISTER_GATHER(type, int64)
-#endif // GOOGLE_CUDA
+TF_CALL_ALL_TYPES(REGISTER_GATHER_INT32);
+TF_CALL_ALL_TYPES(REGISTER_GATHER_INT64);
-#undef REGISTER_GATHER_ALL_INDICES
-#undef REGISTER_GATHER_FULL
+#undef REGISTER_GATHER_INT32
+#undef REGISTER_GATHER_INT64
+#undef REGISTER_GATHER
} // namespace tensorflow