aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 11:36:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 11:36:10 -0700
commit1f6f5e3b9fe092c5218b872648a5fb65334a2af8 (patch)
tree23885dbf24ace141cf75ebfa22255c95d81216ea
parent792a367350318a8d695dd626fe1e32a07b1023e1 (diff)
parent05a35adbe4456e6cf7854dcb4b2a113c6c05fd24 (diff)
Merge pull request #21784 from facaiy:ENH/add_gpu_kernel_for_div_no_nan
PiperOrigin-RevId: 209799240
-rw-r--r--tensorflow/core/kernels/cwise_op_div.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_div.cu.cc1
-rw-r--r--tensorflow/python/ops/math_ops_test.py2
3 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc
index 35662e278f..313d976e2c 100644
--- a/tensorflow/core/kernels/cwise_op_div.cc
+++ b/tensorflow/core/kernels/cwise_op_div.cc
@@ -33,6 +33,7 @@ REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
int64);
REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double,
complex64, complex128);
+REGISTER2(BinaryOp, GPU, "DivNoNan", functor::div_no_nan, float, double);
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
index 0b05416274..25ccdcfb00 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
@@ -21,6 +21,7 @@ namespace tensorflow {
namespace functor {
DEFINE_BINARY10(div, Eigen::half, float, double, uint8, uint16, int16, int32,
int64, complex64, complex128);
+DEFINE_BINARY2(div_no_nan, float, double);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 6bd41020c5..1b01d1d37f 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -483,7 +483,7 @@ class DivNoNanTest(test_util.TensorFlowTestCase):
np_result = np.true_divide(nums, divs)
np_result[:, divs[0] == 0] = 0
- with self.cached_session():
+ with self.cached_session(use_gpu=True):
tf_result = math_ops.div_no_nan(nums, divs).eval()
self.assertAllEqual(tf_result, np_result)