diff options
author | 2017-09-28 20:26:42 -0700 | |
---|---|---|
committer | 2017-09-28 20:29:59 -0700 | |
commit | 3ab081d65caa3801db82f417ea52345b87b07844 (patch) | |
tree | 3322d63b05a8816211742a76e9f2eff055d4d9cc /tensorflow/core/kernels/scatter_nd_op.cc | |
parent | 872917e78f7628c00f93162c70d74e8b659e0123 (diff) |
Add complex kernel registrations for GatherNd and ScatterNd.
PiperOrigin-RevId: 170436916
Diffstat (limited to 'tensorflow/core/kernels/scatter_nd_op.cc')
-rw-r--r-- | tensorflow/core/kernels/scatter_nd_op.cc | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 2d8db7298d..484932ab01 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -205,9 +205,17 @@ TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU); #define REGISTER_SCATTER_ND_UPDATE_GPU(type) \ REGISTER_SCATTER_ND_UPDATE(type, GPU); -TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_GPU); -TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_GPU); -TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_GPU); +#define REGISTER_SCATTER_ND_ALL_GPU(type) \ + REGISTER_SCATTER_ND_ADD_SUB_GPU(type); \ + REGISTER_SCATTER_ND_UPDATE_GPU(type); \ + REGISTER_SCATTER_ND_GPU(type); + +// TODO(b/66916790): Support half types in ScatterNd. +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ALL_GPU); +TF_CALL_complex64(REGISTER_SCATTER_ND_ALL_GPU); +TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU); + +#undef REGISTER_SCATTER_ND_ALL_GPU #ifdef TENSORFLOW_USE_SYCL #define REGISTER_SCATTER_ND_ADD_SUB_SYCL(type) \ |