aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scatter_nd_op.cc
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-09-28 20:26:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-28 20:29:59 -0700
commit3ab081d65caa3801db82f417ea52345b87b07844 (patch)
tree3322d63b05a8816211742a76e9f2eff055d4d9cc /tensorflow/core/kernels/scatter_nd_op.cc
parent872917e78f7628c00f93162c70d74e8b659e0123 (diff)
Add complex kernel registrations for GatherNd and ScatterNd.
PiperOrigin-RevId: 170436916
Diffstat (limited to 'tensorflow/core/kernels/scatter_nd_op.cc')
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc14
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index 2d8db7298d..484932ab01 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -205,9 +205,17 @@ TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);
#define REGISTER_SCATTER_ND_UPDATE_GPU(type) \
REGISTER_SCATTER_ND_UPDATE(type, GPU);
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_GPU);
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_GPU);
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ND_GPU);
+#define REGISTER_SCATTER_ND_ALL_GPU(type) \
+ REGISTER_SCATTER_ND_ADD_SUB_GPU(type); \
+ REGISTER_SCATTER_ND_UPDATE_GPU(type); \
+ REGISTER_SCATTER_ND_GPU(type);
+
+// TODO(b/66916790): Support half types in ScatterNd.
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ALL_GPU);
+TF_CALL_complex64(REGISTER_SCATTER_ND_ALL_GPU);
+TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU);
+
+#undef REGISTER_SCATTER_ND_ALL_GPU
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SCATTER_ND_ADD_SUB_SYCL(type) \