diff options
author | 2017-09-26 08:57:33 -0700 | |
---|---|---|
committer | 2017-09-26 09:01:21 -0700 | |
commit | f5ceb90e7f08fbe7605a002a546b22ef893f248c (patch) | |
tree | 6b3c73e78d6119468b7e9984ab98028299c1028b /tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc | |
parent | 36649e842908d89a3dc44a840bd6305fe401123f (diff) |
TF: GatherNd and ScatterNd updates.
* Factor out GatherNd and ScatterNd functionality into reusable functors.
* Add complex64 and complex128 GatherNd and ScatterNd support.
* Add CudaAtomicAdd for complex64 and complex128.
PiperOrigin-RevId: 170059406
Diffstat (limited to 'tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc | 72 |
1 files changed, 43 insertions, 29 deletions
diff --git a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc index dbd6791bd2..0eb3cf32dd 100644 --- a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc @@ -17,6 +17,7 @@ limitations under the License. #define EIGEN_USE_GPU +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/scatter_nd_op.h" #include "tensorflow/core/platform/types.h" @@ -26,18 +27,44 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; +namespace { + +template <typename T, scatter_nd_op::UpdateOp Op> +struct LeftUpdate { + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val); +}; + +template <typename T> +struct LeftUpdate<T, scatter_nd_op::UpdateOp::ASSIGN> { + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) { + *out = val; + } +}; + +template <typename T> +struct LeftUpdate<T, scatter_nd_op::UpdateOp::ADD> { + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) { + CudaAtomicAdd(out, val); + } +}; + +template <typename T> +struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> { + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) { + CudaAtomicSub(out, val); + } +}; + +} // namespace + template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM> __global__ void ScatterNdOpKernel( const Index* indices, const T* updates, T* out, const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, const Eigen::array<int64, IXDIM> batch_strides, const int64 num_indices, const Index slice_size) { -#define ASSIGN(dst, src) (*(dst) = src) + auto update = LeftUpdate<T, op>(); -#define OP_OVER_SLICE(op) \ - for (int si = 0; si < slice_size; si++) { \ - op(out + i + si, ldg(updates + (index * slice_size + si))); \ - } CUDA_1D_KERNEL_LOOP(index, num_indices) { Index i = 0; bool out_of_bounds = false; @@ -49,32 +76,12 @@ __global__ void ScatterNdOpKernel( i += ix_d * batch_strides[dim] * slice_size; } if (!out_of_bounds) { - switch (op) { - case scatter_nd_op::UpdateOp::ASSIGN: -#pragma unroll - OP_OVER_SLICE(ASSIGN); - break; - case scatter_nd_op::UpdateOp::ADD: #pragma unroll - OP_OVER_SLICE(CudaAtomicAdd); - break; - case scatter_nd_op::UpdateOp::SUB: -#pragma unroll - OP_OVER_SLICE(CudaAtomicSub); - break; - case scatter_nd_op::UpdateOp::MUL: -#pragma unroll - OP_OVER_SLICE(CudaAtomicMul); - break; - case scatter_nd_op::UpdateOp::DIV: -#pragma unroll - OP_OVER_SLICE(CudaAtomicDiv); - break; + for (int si = 0; si < slice_size; si++) { + update(out + i + si, ldg(updates + (index * slice_size + si))); } } } -#undef OP_OVER_SLICE -#undef ASSIGN } namespace functor { @@ -89,6 +96,11 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> { typename TTypes<Index, 2>::ConstTensor Tindices, typename TTypes<T, 2>::ConstTensor Tupdates, typename TTypes<T, 2>::Tensor Toutput) { + // TODO(ebrevdo): The performance of this for small indices (large + // slices) is poor. Write a kernel whose splitting is + // independent of the slice size. Same for CPU. See the + // gather_nd kernel for an example. + const Eigen::DenseIndex batch_size = Tindices.dimension(0); // Index batch_strides[IXDIM]; @@ -124,7 +136,7 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> { DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 2); \ DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 3); \ DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 4); \ - DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 5) + DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 5); #define DECLARE_GPU_SPECS_INDEX(T, Index) \ DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \ @@ -135,7 +147,9 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> { DECLARE_GPU_SPECS_INDEX(T, int32); \ DECLARE_GPU_SPECS_INDEX(T, int64) -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS); +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); +TF_CALL_complex64(DECLARE_GPU_SPECS); +TF_CALL_complex128(DECLARE_GPU_SPECS); #undef DECLARE_GPU_SPECS #undef DECLARE_GPU_SPECS_INDEX |