aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Brian Patton <bjp@google.com>2018-04-09 11:08:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-09 11:10:29 -0700
commit2138a691abfa726b0b6ef28d7f3482e94ada38aa (patch)
tree04eefcca6561f0a3d2f0b799f3636af2ca66084e
parent9d1bf2bd4723fd3d0a012891bc54cc9db54bd9cd (diff)
Adds complex64/128 Fill kernel registrations for GPU.
PiperOrigin-RevId: 192153935
-rw-r--r--tensorflow/core/kernels/constant_op.cc4
-rw-r--r--tensorflow/core/kernels/fill_functor.cu.cc2
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py8
3 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index 312c1a41d3..fe1a1ba5a3 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -258,13 +258,15 @@ REGISTER_KERNEL(GPU, Eigen::half);
REGISTER_KERNEL(GPU, bfloat16);
REGISTER_KERNEL(GPU, float);
REGISTER_KERNEL(GPU, double);
+REGISTER_KERNEL(GPU, complex64);
+REGISTER_KERNEL(GPU, complex128);
REGISTER_KERNEL(GPU, uint8);
REGISTER_KERNEL(GPU, int8);
REGISTER_KERNEL(GPU, uint16);
REGISTER_KERNEL(GPU, int16);
REGISTER_KERNEL(GPU, int64);
REGISTER_KERNEL(GPU, bool);
-// Currently we do not support filling strings and complex64 on GPU
+// Currently we do not support filling strings on GPU
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/fill_functor.cu.cc b/tensorflow/core/kernels/fill_functor.cu.cc
index 3487606778..050c95cf40 100644
--- a/tensorflow/core/kernels/fill_functor.cu.cc
+++ b/tensorflow/core/kernels/fill_functor.cu.cc
@@ -76,7 +76,7 @@ struct FillFunctor<GPUDevice, T> {
};
#define DEFINE_FILL_GPU(T) template struct FillFunctor<GPUDevice, T>;
-TF_CALL_REAL_NUMBER_TYPES(DEFINE_FILL_GPU);
+TF_CALL_NUMBER_TYPES(DEFINE_FILL_GPU);
TF_CALL_bool(DEFINE_FILL_GPU);
#undef DEFINE_FILL_GPU
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 18796f7095..749313b00d 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -653,12 +653,12 @@ class FillTest(test.TestCase):
self._compareAll([2, 3], np_ans[0][0], np_ans)
def testFillComplex64(self):
- np_ans = np.array([[0.15] * 3] * 2).astype(np.complex64)
- self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False)
+ np_ans = np.array([[0.15 + 0.3j] * 3] * 2).astype(np.complex64)
+ self._compareAll([2, 3], np_ans[0][0], np_ans)
def testFillComplex128(self):
- np_ans = np.array([[0.15] * 3] * 2).astype(np.complex128)
- self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False)
+ np_ans = np.array([[0.15 + 0.3j] * 3] * 2).astype(np.complex128)
+ self._compareAll([2, 3], np_ans[0][0], np_ans)
def testFillString(self):
np_ans = np.array([[b"yolo"] * 3] * 2)