aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/resource_variable_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/resource_variable_ops.cc')
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc81
1 files changed, 55 insertions, 26 deletions
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index aecad0185f..e134e476f6 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -619,22 +619,35 @@ class ResourceScatterUpdateOp : public OpKernel {
if (N > 0) {
auto indices_flat = indices.flat<Index>();
auto params_flat = params->flat_outer_dims<T>();
- int64 num_updates = updates.NumElements();
- OP_REQUIRES(c, num_updates % N == 0,
- errors::InvalidArgument(
- "shape of indices (", indices.shape().DebugString(),
- ") is not compatible with the shape of updates (",
- updates.shape().DebugString(), ")"));
- auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
-
- functor::ScatterFunctor<Device, T, Index, op> functor;
- const Index bad_i = functor(c, c->template eigen_device<Device>(),
- params_flat, updates_flat, indices_flat);
- OP_REQUIRES(c, bad_i < 0,
- errors::InvalidArgument(
- "indices", SliceDebugString(indices.shape(), bad_i),
- " = ", indices_flat(bad_i), " is not in [0, ",
- params->dim_size(0), ")"));
+ if (TensorShapeUtils::IsScalar(updates.shape())) {
+ const auto update = updates.scalar<T>();
+
+ functor::ScatterScalarFunctor<Device, T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<Device>(),
+ params_flat, update, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params->dim_size(0), ")"));
+ } else {
+ int64 num_updates = updates.NumElements();
+ OP_REQUIRES(c, num_updates % N == 0,
+ errors::InvalidArgument(
+ "shape of indices (", indices.shape().DebugString(),
+ ") is not compatible with the shape of updates (",
+ updates.shape().DebugString(), ")"));
+ auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
+
+ functor::ScatterFunctor<Device, T, Index, op> functor;
+ const Index bad_i = functor(c, c->template eigen_device<Device>(),
+ params_flat, updates_flat, indices_flat);
+ OP_REQUIRES(c, bad_i < 0,
+ errors::InvalidArgument(
+ "indices", SliceDebugString(indices.shape(), bad_i),
+ " = ", indices_flat(bad_i), " is not in [0, ",
+ params->dim_size(0), ")"));
+ }
}
}
};
@@ -652,35 +665,51 @@ class ResourceScatterUpdateOp : public OpKernel {
REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
-// TODO(apassos) add the other types here.
-#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \
+#define REGISTER_SCATTER_ARITHMETIC(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
scatter_op::UpdateOp::ADD); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub", \
+ scatter_op::UpdateOp::SUB); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul", \
+ scatter_op::UpdateOp::MUL); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv", \
+ scatter_op::UpdateOp::DIV); \
REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
scatter_op::UpdateOp::ASSIGN);
+#define REGISTER_SCATTER_MINMAX(type, dev) \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
+ scatter_op::UpdateOp::MIN); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
+ scatter_op::UpdateOp::MAX);
// Registers CPU kernels.
-#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, CPU);
+#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, CPU);
+#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU);
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
scatter_op::UpdateOp::ASSIGN);
// Registers GPU kernels.
#if GOOGLE_CUDA
-#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \
- REGISTER_SCATTER_ARITHEMTIC(type, GPU);
+#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
+ REGISTER_SCATTER_ARITHMETIC(type, GPU);
+#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
#define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
#endif // GOOGLE_CUDA
-#undef REGISTER_SCATTER_ARITHEMTIC
-#undef REGISTER_SCATTER_ARITHEMTIC_CPU
+#undef REGISTER_SCATTER_ARITHMETIC
+#undef REGISTER_SCATTER_ARITHMETIC_CPU
+#undef REGISTER_SCATTER_MINMAX
+#undef REGISTER_SCATTER_MINMAX_CPU
#undef REGISTER_SCATTER_KERNEL
#undef REGISTER_SCATTER_KERNEL_INDEX