diff options
author | 2018-05-07 11:49:08 -0700 | |
---|---|---|
committer | 2018-05-07 16:48:12 -0700 | |
commit | aa57960b545ca25223568e366d99b0a4be7a03da (patch) | |
tree | 98981f45e9e7014f83aaea71214959c0b4feecc4 /tensorflow/core/kernels/resource_variable_ops.cc | |
parent | f14123dc19be468b6776f057d45ddd4d40fef9b2 (diff) |
Register bool scatter_update for resource variables
Fixes #17784
PiperOrigin-RevId: 195696915
Diffstat (limited to 'tensorflow/core/kernels/resource_variable_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/resource_variable_ops.cc | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index a8bcc7f7dc..03cc414905 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -703,6 +703,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU); REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate", scatter_op::UpdateOp::ASSIGN); +REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate", + scatter_op::UpdateOp::ASSIGN); REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate", scatter_op::UpdateOp::ASSIGN); @@ -728,6 +730,13 @@ REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate") REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate") .Device(DEVICE_GPU) .HostMemory("resource") + .TypeConstraint<bool>("dtype") + .TypeConstraint<int32>("Tindices"), + ResourceScatterUpdateOp<GPUDevice, bool, int32, + scatter_op::UpdateOp::ASSIGN>) +REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate") + .Device(DEVICE_GPU) + .HostMemory("resource") .HostMemory("indices") .TypeConstraint<Variant>("dtype") .TypeConstraint<int64>("Tindices"), |