diff options
author | Brian Patton <bjp@google.com> | 2018-04-09 11:08:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-09 11:10:29 -0700 |
commit | 2138a691abfa726b0b6ef28d7f3482e94ada38aa (patch) | |
tree | 04eefcca6561f0a3d2f0b799f3636af2ca66084e | |
parent | 9d1bf2bd4723fd3d0a012891bc54cc9db54bd9cd (diff) |
Adds complex64/128 Fill kernel registrations for GPU.
PiperOrigin-RevId: 192153935
-rw-r--r-- | tensorflow/core/kernels/constant_op.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/fill_functor.cu.cc | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/constant_op_test.py | 8 |
3 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 312c1a41d3..fe1a1ba5a3 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -258,13 +258,15 @@ REGISTER_KERNEL(GPU, Eigen::half); REGISTER_KERNEL(GPU, bfloat16); REGISTER_KERNEL(GPU, float); REGISTER_KERNEL(GPU, double); +REGISTER_KERNEL(GPU, complex64); +REGISTER_KERNEL(GPU, complex128); REGISTER_KERNEL(GPU, uint8); REGISTER_KERNEL(GPU, int8); REGISTER_KERNEL(GPU, uint16); REGISTER_KERNEL(GPU, int16); REGISTER_KERNEL(GPU, int64); REGISTER_KERNEL(GPU, bool); -// Currently we do not support filling strings and complex64 on GPU +// Currently we do not support filling strings on GPU // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel diff --git a/tensorflow/core/kernels/fill_functor.cu.cc b/tensorflow/core/kernels/fill_functor.cu.cc index 3487606778..050c95cf40 100644 --- a/tensorflow/core/kernels/fill_functor.cu.cc +++ b/tensorflow/core/kernels/fill_functor.cu.cc @@ -76,7 +76,7 @@ struct FillFunctor<GPUDevice, T> { }; #define DEFINE_FILL_GPU(T) template struct FillFunctor<GPUDevice, T>; -TF_CALL_REAL_NUMBER_TYPES(DEFINE_FILL_GPU); +TF_CALL_NUMBER_TYPES(DEFINE_FILL_GPU); TF_CALL_bool(DEFINE_FILL_GPU); #undef DEFINE_FILL_GPU diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index 18796f7095..749313b00d 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -653,12 +653,12 @@ class FillTest(test.TestCase): self._compareAll([2, 3], np_ans[0][0], np_ans) def testFillComplex64(self): - np_ans = np.array([[0.15] * 3] * 2).astype(np.complex64) - self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False) + np_ans = np.array([[0.15 + 0.3j] * 3] * 2).astype(np.complex64) + self._compareAll([2, 3], np_ans[0][0], np_ans) def testFillComplex128(self): - np_ans = np.array([[0.15] * 3] * 2).astype(np.complex128) - self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False) + np_ans = np.array([[0.15 + 0.3j] * 3] * 2).astype(np.complex128) + self._compareAll([2, 3], np_ans[0][0], np_ans) def testFillString(self): np_ans = np.array([[b"yolo"] * 3] * 2) |