diff options
author | 2016-09-30 09:30:22 -0800 | |
---|---|---|
committer | 2016-09-30 10:33:08 -0700 | |
commit | 4b8be072eaf0c66298d0f9b1657bc166399d1108 (patch) | |
tree | 6c8cc559739e654a286332d3616afe4a3aaa196f /tensorflow/core/kernels/gather_op.cc | |
parent | 7d9c0c891d82fb5d35dc4669abe832708940a810 (diff) |
Refactor Gather and Scatter functors into separate targets.
Change: 134798900
Diffstat (limited to 'tensorflow/core/kernels/gather_op.cc')
-rw-r--r-- | tensorflow/core/kernels/gather_op.cc | 35 |
1 files changed, 2 insertions, 33 deletions
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index bde0054ac5..d8182218af 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -15,7 +15,6 @@ limitations under the License. // See docs in ../ops/array_ops.cc. -#include "tensorflow/core/kernels/gather_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -77,7 +76,7 @@ class GatherOp : public OpKernel { auto indices_flat = indices.flat<Index>(); auto out_flat = out->shaped<T, 2>({N, out->NumElements() / N}); - functor::Gather<Device, T, Index> functor; + functor::GatherFunctor<Device, T, Index> functor; int64 bad_i = functor(c->eigen_device<Device>(), params_flat, indices_flat, out_flat); @@ -90,18 +89,6 @@ class GatherOp : public OpKernel { } }; -namespace functor { - -template <typename T, typename Index> -struct Gather<CPUDevice, T, Index> { - int64 operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix params, - typename TTypes<Index>::ConstFlat indices, - typename TTypes<T>::Matrix out) { - return GatherFunctor<CPUDevice, T, Index>()(d, params, indices, out); - } -}; -} // namespace functor - #define REGISTER_GATHER_FULL(dev, type, index_type) \ REGISTER_KERNEL_BUILDER(Name("Gather") \ .Device(DEVICE_##dev) \ @@ -115,31 +102,13 @@ struct Gather<CPUDevice, T, Index> { #define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type) +// Registration of the CPU implementations. TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU); TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU); #undef REGISTER_GATHER_CPU #if GOOGLE_CUDA -// Forward declarations of the functor specializations for GPU. -namespace functor { -#define DECLARE_GPU_SPECS_INDEX(T, Index) \ - template <> \ - Index Gather<GPUDevice, T, Index>::operator()( \ - const GPUDevice& d, typename TTypes<T>::ConstMatrix Tparams, \ - typename TTypes<Index>::ConstFlat Tindices, \ - typename TTypes<T>::Matrix Tout); \ - extern template struct Gather<GPUDevice, T, Index>; - -#define DECLARE_GPU_SPECS(T) \ - DECLARE_GPU_SPECS_INDEX(T, int32); \ - DECLARE_GPU_SPECS_INDEX(T, int64) - -TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); - -#undef DECLARE_GPU_SPECS -#undef DECLARE_GPU_SPECS_INDEX -} // namespace functor // Registration of the GPU implementations. #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type) |