aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/resource_variable_ops.cc
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-05-07 11:49:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 16:48:12 -0700
commitaa57960b545ca25223568e366d99b0a4be7a03da (patch)
tree98981f45e9e7014f83aaea71214959c0b4feecc4 /tensorflow/core/kernels/resource_variable_ops.cc
parentf14123dc19be468b6776f057d45ddd4d40fef9b2 (diff)
Register bool scatter_update for resource variables
Fixes #17784 PiperOrigin-RevId: 195696915
Diffstat (limited to 'tensorflow/core/kernels/resource_variable_ops.cc')
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index a8bcc7f7dc..03cc414905 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -703,6 +703,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
scatter_op::UpdateOp::ASSIGN);
+REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate",
+ scatter_op::UpdateOp::ASSIGN);
REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
scatter_op::UpdateOp::ASSIGN);
@@ -728,6 +730,13 @@ REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
.Device(DEVICE_GPU)
.HostMemory("resource")
+ .TypeConstraint<bool>("dtype")
+ .TypeConstraint<int32>("Tindices"),
+ ResourceScatterUpdateOp<GPUDevice, bool, int32,
+ scatter_op::UpdateOp::ASSIGN>)
+REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
.HostMemory("indices")
.TypeConstraint<Variant>("dtype")
.TypeConstraint<int64>("Tindices"),