aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scatter_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-23 16:00:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-25 04:21:07 -0700
commit084c10784887d7c4d467416430626cf7eb333cb8 (patch)
tree69727c1d14ddbb97fc74b24f69ce2381125ee6c9 /tensorflow/core/kernels/scatter_op.cc
parent95a87277174f9fc49b4b5d9c1edbbd149bd0274c (diff)
Extended scatter operations to work with a scalar update parameter and added scatter-min and scatter-max operations.
PiperOrigin-RevId: 190289664
Diffstat (limited to 'tensorflow/core/kernels/scatter_op.cc')
-rw-r--r--tensorflow/core/kernels/scatter_op.cc126
1 files changed, 87 insertions, 39 deletions
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc
index 282165349f..0fbde764d5 100644
--- a/tensorflow/core/kernels/scatter_op.cc
+++ b/tensorflow/core/kernels/scatter_op.cc
@@ -38,6 +38,7 @@ typedef Eigen::SyclDevice SYCLDevice;
// Check whether updates.shape = indices.shape + params.shape[1:]
static bool ValidShapes(const Tensor& params, const Tensor& updates,
const Tensor& indices) {
+ if (updates.dims() == 0) return true;
if (updates.dims() != indices.dims() + params.dims() - 1) return false;
for (int d = 0; d < indices.dims(); d++) {
if (updates.dim_size(d) != indices.dim_size(d)) {
@@ -61,11 +62,11 @@ static void DoValidationChecking(OpKernelContext* c, const Tensor& params,
params.shape().DebugString()));
OP_REQUIRES(
c, ValidShapes(params, updates, indices),
- errors::InvalidArgument(
- "Must have updates.shape = indices.shape + params.shape[1:], got ",
- "updates.shape ", updates.shape().DebugString(), ", indices.shape ",
- indices.shape().DebugString(), ", params.shape ",
- params.shape().DebugString()));
+ errors::InvalidArgument("Must have updates.shape = indices.shape + "
+ "params.shape[1:] or updates.shape = [], got ",
+ "updates.shape ", updates.shape().DebugString(),
+ ", indices.shape ", indices.shape().DebugString(),
+ ", params.shape ", params.shape().DebugString()));
}
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
@@ -122,16 +123,31 @@ class ScatterUpdateOp : public OpKernel {
if (N > 0) {
auto indices_flat = indices.flat<Index>();
auto params_flat = params.flat_outer_dims<T>();
- auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
-
- functor::ScatterFunctor<Device, T, Index, op> functor;
- const Index bad_i = functor(c, c->template eigen_device<Device>(),
- params_flat, updates_flat, indices_flat);
- OP_REQUIRES(
- c, bad_i < 0,
- errors::InvalidArgument(
- "indices", SliceDebugString(indices.shape(), bad_i), " = ",
- indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
+
+ if (TensorShapeUtils::IsScalar(updates.shape()) ||
+ IsLegacyScalar(updates.shape())) {
+ const auto update = updates.scalar<T>();
+ functor::ScatterScalarFunctor<Device, T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<Device>(),
+ params_flat, update, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params.dim_size(0), ")"));
+ } else {
+ auto updates_flat =
+ updates.shaped<T, 2>({N, updates.NumElements() / N});
+
+ functor::ScatterFunctor<Device, T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<Device>(),
+ params_flat, updates_flat, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params.dim_size(0), ")"));
+ }
}
}
};
@@ -195,16 +211,31 @@ class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel {
auto indices_flat = indices_host.flat<Index>();
auto params_flat = params.flat_outer_dims<T>();
- auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
-
- functor::ScatterFunctorSYCL<T, Index, op> functor;
- const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
- params_flat, updates_flat, indices_flat);
- OP_REQUIRES(
- c, bad_i < 0,
- errors::InvalidArgument(
- "indices", SliceDebugString(indices.shape(), bad_i), " = ",
- indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
+
+ if (TensorShapeUtils::IsScalar(updates.shape())) {
+ const auto update = updates.scalar<T>();
+
+ functor::ScatterScalarFunctorSYCL<T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
+ params_flat, update, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params.dim_size(0), ")"));
+ } else {
+ auto updates_flat =
+ updates.shaped<T, 2>({N, updates.NumElements() / N});
+
+ functor::ScatterFunctorSYCL<T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
+ params_flat, updates_flat, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params.dim_size(0), ")"));
+ }
}
}
};
@@ -221,54 +252,71 @@ class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel {
REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
-#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \
+#define REGISTER_SCATTER_ARITHMETIC(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB);
+#define REGISTER_SCATTER_MINMAX(type, dev) \
+ REGISTER_SCATTER_KERNEL(type, dev, "ScatterMin", scatter_op::UpdateOp::MIN); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ScatterMax", scatter_op::UpdateOp::MAX);
+
#define REGISTER_SCATTER_UPDATE(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterUpdate", \
scatter_op::UpdateOp::ASSIGN);
// Registers CPU kernels.
-#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, CPU);
+#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, CPU);
+
+#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
#define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU);
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
// Registers GPU kernels.
#if GOOGLE_CUDA
-#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, GPU);
+#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, GPU);
+
+#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
#define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
#endif // GOOGLE_CUDA
// Registers GPU kernels.
#if TENSORFLOW_USE_SYCL
-#define REGISTER_SCATTER_ARITHEMTIC_SYCL(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, SYCL);
+#define REGISTER_SCATTER_ARITHMETIC_SYCL(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, SYCL);
+
+#define REGISTER_SCATTER_MINMAX_SYCL(type) REGISTER_SCATTER_MINMAX(type, SYCL);
#define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_SYCL);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_SYCL);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_SYCL);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_SYCL);
-#undef REGISTER_SCATTER_ARITHEMTIC_SYCL
+#undef REGISTER_SCATTER_ARITHMETIC_SYCL
+#undef REGISTER_SCATTER_MINMAX_SYCL
#undef REGISTER_SCATTER_UPDATE_SYCL
#endif // TENSORFLOW_USE_SYCL
-#undef REGISTER_SCATTER_ARITHEMTIC
-#undef REGISTER_SCATTER_ARITHEMTIC_CPU
-#undef REGISTER_SCATTER_ARITHEMTIC_GPU
+#undef REGISTER_SCATTER_ARITHMETIC
+#undef REGISTER_SCATTER_ARITHMETIC_CPU
+#undef REGISTER_SCATTER_ARITHMETIC_GPU
+#undef REGISTER_SCATTER_MINMAX
+#undef REGISTER_SCATTER_MINMAX_CPU
+#undef REGISTER_SCATTER_MINMAX_GPU
#undef REGISTER_SCATTER_UPDATE
#undef REGISTER_SCATTER_UPDATE_CPU
#undef REGISTER_SCATTER_UPDATE_GPU