diff options
Diffstat (limited to 'tensorflow/core/kernels/resource_variable_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/resource_variable_ops.cc | 81 |
1 files changed, 55 insertions, 26 deletions
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index aecad0185f..e134e476f6 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -619,22 +619,35 @@ class ResourceScatterUpdateOp : public OpKernel { if (N > 0) { auto indices_flat = indices.flat<Index>(); auto params_flat = params->flat_outer_dims<T>(); - int64 num_updates = updates.NumElements(); - OP_REQUIRES(c, num_updates % N == 0, - errors::InvalidArgument( - "shape of indices (", indices.shape().DebugString(), - ") is not compatible with the shape of updates (", - updates.shape().DebugString(), ")")); - auto updates_flat = updates.shaped<T, 2>({N, num_updates / N}); - - functor::ScatterFunctor<Device, T, Index, op> functor; - const Index bad_i = functor(c, c->template eigen_device<Device>(), - params_flat, updates_flat, indices_flat); - OP_REQUIRES(c, bad_i < 0, - errors::InvalidArgument( - "indices", SliceDebugString(indices.shape(), bad_i), - " = ", indices_flat(bad_i), " is not in [0, ", - params->dim_size(0), ")")); + if (TensorShapeUtils::IsScalar(updates.shape())) { + const auto update = updates.scalar<T>(); + + functor::ScatterScalarFunctor<Device, T, Index, op> functor; + const Index bad_i = functor(c, c->template eigen_device<Device>(), + params_flat, update, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params->dim_size(0), ")")); + } else { + int64 num_updates = updates.NumElements(); + OP_REQUIRES(c, num_updates % N == 0, + errors::InvalidArgument( + "shape of indices (", indices.shape().DebugString(), + ") is not compatible with the shape of updates (", + updates.shape().DebugString(), ")")); + auto updates_flat = updates.shaped<T, 2>({N, num_updates / N}); + + functor::ScatterFunctor<Device, T, Index, op> functor; + const Index bad_i = functor(c, c->template eigen_device<Device>(), + params_flat, updates_flat, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params->dim_size(0), ")")); + } } } }; @@ -652,35 +665,51 @@ class ResourceScatterUpdateOp : public OpKernel { REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \ REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op); -// TODO(apassos) add the other types here. -#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \ +#define REGISTER_SCATTER_ARITHMETIC(type, dev) \ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \ scatter_op::UpdateOp::ADD); \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub", \ + scatter_op::UpdateOp::SUB); \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul", \ + scatter_op::UpdateOp::MUL); \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv", \ + scatter_op::UpdateOp::DIV); \ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \ scatter_op::UpdateOp::ASSIGN); +#define REGISTER_SCATTER_MINMAX(type, dev) \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \ + scatter_op::UpdateOp::MIN); \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \ + scatter_op::UpdateOp::MAX); // Registers CPU kernels. -#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, CPU); +#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \ + REGISTER_SCATTER_ARITHMETIC(type, CPU); +#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU); -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU); +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU); +TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU); REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate", scatter_op::UpdateOp::ASSIGN); // Registers GPU kernels. #if GOOGLE_CUDA -#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, GPU); +#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \ + REGISTER_SCATTER_ARITHMETIC(type, GPU); +#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU); #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU); #endif // GOOGLE_CUDA -#undef REGISTER_SCATTER_ARITHEMTIC -#undef REGISTER_SCATTER_ARITHEMTIC_CPU +#undef REGISTER_SCATTER_ARITHMETIC +#undef REGISTER_SCATTER_ARITHMETIC_CPU +#undef REGISTER_SCATTER_MINMAX +#undef REGISTER_SCATTER_MINMAX_CPU #undef REGISTER_SCATTER_KERNEL #undef REGISTER_SCATTER_KERNEL_INDEX |