aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/gather_op.cc
diff options
context:
space:
mode:
authorGravatar Jack Rae <jwrae@google.com>2016-09-30 09:30:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-30 10:33:08 -0700
commit4b8be072eaf0c66298d0f9b1657bc166399d1108 (patch)
tree6c8cc559739e654a286332d3616afe4a3aaa196f /tensorflow/core/kernels/gather_op.cc
parent7d9c0c891d82fb5d35dc4669abe832708940a810 (diff)
Refactor Gather and Scatter functors into separate targets.
Change: 134798900
Diffstat (limited to 'tensorflow/core/kernels/gather_op.cc')
-rw-r--r--tensorflow/core/kernels/gather_op.cc35
1 files changed, 2 insertions, 33 deletions
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
index bde0054ac5..d8182218af 100644
--- a/tensorflow/core/kernels/gather_op.cc
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -15,7 +15,6 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
-#include "tensorflow/core/kernels/gather_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -77,7 +76,7 @@ class GatherOp : public OpKernel {
auto indices_flat = indices.flat<Index>();
auto out_flat = out->shaped<T, 2>({N, out->NumElements() / N});
- functor::Gather<Device, T, Index> functor;
+ functor::GatherFunctor<Device, T, Index> functor;
int64 bad_i = functor(c->eigen_device<Device>(), params_flat,
indices_flat, out_flat);
@@ -90,18 +89,6 @@ class GatherOp : public OpKernel {
}
};
-namespace functor {
-
-template <typename T, typename Index>
-struct Gather<CPUDevice, T, Index> {
- int64 operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix params,
- typename TTypes<Index>::ConstFlat indices,
- typename TTypes<T>::Matrix out) {
- return GatherFunctor<CPUDevice, T, Index>()(d, params, indices, out);
- }
-};
-} // namespace functor
-
#define REGISTER_GATHER_FULL(dev, type, index_type) \
REGISTER_KERNEL_BUILDER(Name("Gather") \
.Device(DEVICE_##dev) \
@@ -115,31 +102,13 @@ struct Gather<CPUDevice, T, Index> {
#define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
+// Registration of the CPU implementations.
TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
#undef REGISTER_GATHER_CPU
#if GOOGLE_CUDA
-// Forward declarations of the functor specializations for GPU.
-namespace functor {
-#define DECLARE_GPU_SPECS_INDEX(T, Index) \
- template <> \
- Index Gather<GPUDevice, T, Index>::operator()( \
- const GPUDevice& d, typename TTypes<T>::ConstMatrix Tparams, \
- typename TTypes<Index>::ConstFlat Tindices, \
- typename TTypes<T>::Matrix Tout); \
- extern template struct Gather<GPUDevice, T, Index>;
-
-#define DECLARE_GPU_SPECS(T) \
- DECLARE_GPU_SPECS_INDEX(T, int32); \
- DECLARE_GPU_SPECS_INDEX(T, int64)
-
-TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
-
-#undef DECLARE_GPU_SPECS
-#undef DECLARE_GPU_SPECS_INDEX
-} // namespace functor
// Registration of the GPU implementations.
#define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)