diff options
author | 2018-03-23 16:00:14 -0700 | |
---|---|---|
committer | 2018-03-25 04:21:07 -0700 | |
commit | 084c10784887d7c4d467416430626cf7eb333cb8 (patch) | |
tree | 69727c1d14ddbb97fc74b24f69ce2381125ee6c9 /tensorflow/core/kernels/scatter_op.cc | |
parent | 95a87277174f9fc49b4b5d9c1edbbd149bd0274c (diff) |
Extended scatter operations to work with a scalar update parameter and added scatter-min and scatter-max operations.
PiperOrigin-RevId: 190289664
Diffstat (limited to 'tensorflow/core/kernels/scatter_op.cc')
-rw-r--r-- | tensorflow/core/kernels/scatter_op.cc | 126 |
1 files changed, 87 insertions, 39 deletions
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 282165349f..0fbde764d5 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -38,6 +38,7 @@ typedef Eigen::SyclDevice SYCLDevice; // Check whether updates.shape = indices.shape + params.shape[1:] static bool ValidShapes(const Tensor& params, const Tensor& updates, const Tensor& indices) { + if (updates.dims() == 0) return true; if (updates.dims() != indices.dims() + params.dims() - 1) return false; for (int d = 0; d < indices.dims(); d++) { if (updates.dim_size(d) != indices.dim_size(d)) { @@ -61,11 +62,11 @@ static void DoValidationChecking(OpKernelContext* c, const Tensor& params, params.shape().DebugString())); OP_REQUIRES( c, ValidShapes(params, updates, indices), - errors::InvalidArgument( - "Must have updates.shape = indices.shape + params.shape[1:], got ", - "updates.shape ", updates.shape().DebugString(), ", indices.shape ", - indices.shape().DebugString(), ", params.shape ", - params.shape().DebugString())); + errors::InvalidArgument("Must have updates.shape = indices.shape + " + "params.shape[1:] or updates.shape = [], got ", + "updates.shape ", updates.shape().DebugString(), + ", indices.shape ", indices.shape().DebugString(), + ", params.shape ", params.shape().DebugString())); } template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> @@ -122,16 +123,31 @@ class ScatterUpdateOp : public OpKernel { if (N > 0) { auto indices_flat = indices.flat<Index>(); auto params_flat = params.flat_outer_dims<T>(); - auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N}); - - functor::ScatterFunctor<Device, T, Index, op> functor; - const Index bad_i = functor(c, c->template eigen_device<Device>(), - params_flat, updates_flat, indices_flat); - OP_REQUIRES( - c, bad_i < 0, - errors::InvalidArgument( - "indices", SliceDebugString(indices.shape(), bad_i), " = ", - indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")")); + + if (TensorShapeUtils::IsScalar(updates.shape()) || + IsLegacyScalar(updates.shape())) { + const auto update = updates.scalar<T>(); + functor::ScatterScalarFunctor<Device, T, Index, op> functor; + const Index bad_i = functor(c, c->template eigen_device<Device>(), + params_flat, update, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params.dim_size(0), ")")); + } else { + auto updates_flat = + updates.shaped<T, 2>({N, updates.NumElements() / N}); + + functor::ScatterFunctor<Device, T, Index, op> functor; + const Index bad_i = functor(c, c->template eigen_device<Device>(), + params_flat, updates_flat, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params.dim_size(0), ")")); + } } } }; @@ -195,16 +211,31 @@ class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel { auto indices_flat = indices_host.flat<Index>(); auto params_flat = params.flat_outer_dims<T>(); - auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N}); - - functor::ScatterFunctorSYCL<T, Index, op> functor; - const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(), - params_flat, updates_flat, indices_flat); - OP_REQUIRES( - c, bad_i < 0, - errors::InvalidArgument( - "indices", SliceDebugString(indices.shape(), bad_i), " = ", - indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")")); + + if (TensorShapeUtils::IsScalar(updates.shape())) { + const auto update = updates.scalar<T>(); + + functor::ScatterScalarFunctorSYCL<T, Index, op> functor; + const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(), + params_flat, update, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params.dim_size(0), ")")); + } else { + auto updates_flat = + updates.shaped<T, 2>({N, updates.NumElements() / N}); + + functor::ScatterFunctorSYCL<T, Index, op> functor; + const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(), + params_flat, updates_flat, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params.dim_size(0), ")")); + } } } }; @@ -221,54 +252,71 @@ class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel { REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \ REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op); -#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \ +#define REGISTER_SCATTER_ARITHMETIC(type, dev) \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB); +#define REGISTER_SCATTER_MINMAX(type, dev) \ + REGISTER_SCATTER_KERNEL(type, dev, "ScatterMin", scatter_op::UpdateOp::MIN); \ + REGISTER_SCATTER_KERNEL(type, dev, "ScatterMax", scatter_op::UpdateOp::MAX); + #define REGISTER_SCATTER_UPDATE(type, dev) \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterUpdate", \ scatter_op::UpdateOp::ASSIGN); // Registers CPU kernels. -#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, CPU); +#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \ + REGISTER_SCATTER_ARITHMETIC(type, CPU); + +#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU); #define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU); -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU); +TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU); +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU); TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU); // Registers GPU kernels. #if GOOGLE_CUDA -#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, GPU); +#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \ + REGISTER_SCATTER_ARITHMETIC(type, GPU); + +#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU); #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU); #endif // GOOGLE_CUDA // Registers GPU kernels. #if TENSORFLOW_USE_SYCL -#define REGISTER_SCATTER_ARITHEMTIC_SYCL(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, SYCL); +#define REGISTER_SCATTER_ARITHMETIC_SYCL(type) \ + REGISTER_SCATTER_ARITHMETIC(type, SYCL); + +#define REGISTER_SCATTER_MINMAX_SYCL(type) REGISTER_SCATTER_MINMAX(type, SYCL); #define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_SYCL); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_SYCL); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_SYCL); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_SYCL); -#undef REGISTER_SCATTER_ARITHEMTIC_SYCL +#undef REGISTER_SCATTER_ARITHMETIC_SYCL +#undef REGISTER_SCATTER_MINMAX_SYCL #undef REGISTER_SCATTER_UPDATE_SYCL #endif // TENSORFLOW_USE_SYCL -#undef REGISTER_SCATTER_ARITHEMTIC -#undef REGISTER_SCATTER_ARITHEMTIC_CPU -#undef REGISTER_SCATTER_ARITHEMTIC_GPU +#undef REGISTER_SCATTER_ARITHMETIC +#undef REGISTER_SCATTER_ARITHMETIC_CPU +#undef REGISTER_SCATTER_ARITHMETIC_GPU +#undef REGISTER_SCATTER_MINMAX +#undef REGISTER_SCATTER_MINMAX_CPU +#undef REGISTER_SCATTER_MINMAX_GPU #undef REGISTER_SCATTER_UPDATE #undef REGISTER_SCATTER_UPDATE_CPU #undef REGISTER_SCATTER_UPDATE_GPU |