aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-09-26 08:57:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-26 09:01:21 -0700
commitf5ceb90e7f08fbe7605a002a546b22ef893f248c (patch)
tree6b3c73e78d6119468b7e9984ab98028299c1028b /tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
parent36649e842908d89a3dc44a840bd6305fe401123f (diff)
TF: GatherNd and ScatterNd updates.
* Factor out GatherNd and ScatterNd functionality into reusable functors. * Add complex64 and complex128 GatherNd and ScatterNd support. * Add CudaAtomicAdd for complex64 and complex128. PiperOrigin-RevId: 170059406
Diffstat (limited to 'tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc72
1 files changed, 43 insertions, 29 deletions
diff --git a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
index dbd6791bd2..0eb3cf32dd 100644
--- a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
@@ -17,6 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/scatter_nd_op.h"
#include "tensorflow/core/platform/types.h"
@@ -26,18 +27,44 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
+namespace {
+
+template <typename T, scatter_nd_op::UpdateOp Op>
+struct LeftUpdate {
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val);
+};
+
+template <typename T>
+struct LeftUpdate<T, scatter_nd_op::UpdateOp::ASSIGN> {
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
+ *out = val;
+ }
+};
+
+template <typename T>
+struct LeftUpdate<T, scatter_nd_op::UpdateOp::ADD> {
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
+ CudaAtomicAdd(out, val);
+ }
+};
+
+template <typename T>
+struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> {
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
+ CudaAtomicSub(out, val);
+ }
+};
+
+} // namespace
+
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
__global__ void ScatterNdOpKernel(
const Index* indices, const T* updates, T* out,
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
const Eigen::array<int64, IXDIM> batch_strides, const int64 num_indices,
const Index slice_size) {
-#define ASSIGN(dst, src) (*(dst) = src)
+ auto update = LeftUpdate<T, op>();
-#define OP_OVER_SLICE(op) \
- for (int si = 0; si < slice_size; si++) { \
- op(out + i + si, ldg(updates + (index * slice_size + si))); \
- }
CUDA_1D_KERNEL_LOOP(index, num_indices) {
Index i = 0;
bool out_of_bounds = false;
@@ -49,32 +76,12 @@ __global__ void ScatterNdOpKernel(
i += ix_d * batch_strides[dim] * slice_size;
}
if (!out_of_bounds) {
- switch (op) {
- case scatter_nd_op::UpdateOp::ASSIGN:
-#pragma unroll
- OP_OVER_SLICE(ASSIGN);
- break;
- case scatter_nd_op::UpdateOp::ADD:
#pragma unroll
- OP_OVER_SLICE(CudaAtomicAdd);
- break;
- case scatter_nd_op::UpdateOp::SUB:
-#pragma unroll
- OP_OVER_SLICE(CudaAtomicSub);
- break;
- case scatter_nd_op::UpdateOp::MUL:
-#pragma unroll
- OP_OVER_SLICE(CudaAtomicMul);
- break;
- case scatter_nd_op::UpdateOp::DIV:
-#pragma unroll
- OP_OVER_SLICE(CudaAtomicDiv);
- break;
+ for (int si = 0; si < slice_size; si++) {
+ update(out + i + si, ldg(updates + (index * slice_size + si)));
}
}
}
-#undef OP_OVER_SLICE
-#undef ASSIGN
}
namespace functor {
@@ -89,6 +96,11 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {
typename TTypes<Index, 2>::ConstTensor Tindices,
typename TTypes<T, 2>::ConstTensor Tupdates,
typename TTypes<T, 2>::Tensor Toutput) {
+ // TODO(ebrevdo): The performance of this for small indices (large
+ // slices) is poor. Write a kernel whose splitting is
+ // independent of the slice size. Same for CPU. See the
+ // gather_nd kernel for an example.
+
const Eigen::DenseIndex batch_size = Tindices.dimension(0);
// Index batch_strides[IXDIM];
@@ -124,7 +136,7 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {
DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 2); \
DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 3); \
DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 4); \
- DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 5)
+ DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 5);
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
@@ -135,7 +147,9 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {
DECLARE_GPU_SPECS_INDEX(T, int32); \
DECLARE_GPU_SPECS_INDEX(T, int64)
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS);
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
+TF_CALL_complex64(DECLARE_GPU_SPECS);
+TF_CALL_complex128(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_SPECS_INDEX