aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/relu_op_test.py
diff options
context:
space:
mode:
authorGravatar Xiaoqiang Zheng <zhengxq@google.com>2018-03-01 15:41:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-01 15:45:09 -0800
commit45daab910a3c730380594317749d911db5e933e6 (patch)
tree27716a5ae0de0a15752cd1255ecd685c05a85a03 /tensorflow/python/kernel_tests/relu_op_test.py
parentac79486324bda04cc2f3b75e9590935dfe1ef826 (diff)
A fp16 implemention for ReluGrad.
On V100 with Cuda 9, it reduces the average ReluGrad kernel time in Resnet50 from 249.44 us to 175.60 us, a 42% speedup. On Titan-X Pascal with Cuda 9, it reduces the average ReluGrad kernel time in Resnet50 from 747.98 us to 509.37 us, a 46.8% improvement. PiperOrigin-RevId: 187545504
Diffstat (limited to 'tensorflow/python/kernel_tests/relu_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index 6b4091ae5d..25e947f09e 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -19,12 +19,14 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
@@ -87,6 +89,35 @@ class ReluTest(test.TestCase):
print("relu (float32) gradient err = ", err)
self.assertLess(err, 1e-4)
+ # The gradient for fp16 is inaccurate due to the low-precision.
+ # Instead of relying on compute_gradient_error, we compare the fp16 analytical
+ # gradient against their fp32 counterpart.
+ def testGradientFloat16(self):
+ with self.test_session(use_gpu=True) as sess:
+ # Randomly construct a 1D shape from [1, 40)
+ shape = random_ops.random_uniform(
+ [1], minval=1, maxval=40, dtype=dtypes.int32)
+
+ # Construct the fp32 graph and its gradient.
+ x = random_ops.random_uniform(shape, minval=-1, maxval=1, name="x")
+ y1 = nn_ops.relu(x, name="relu_fp32")
+ l1 = nn_ops.l2_loss(y1)
+ dx_f32 = gradients_impl.gradients(l1, x)
+
+ # Construct the fp16 graph and its gradient.
+ # It starts with the same x, in fp32. But before it reaches Relu, it is
+ # cast into fp16. So during backprop, the gradient computation is in fp16.
+ x2 = math_ops.cast(x, dtype=dtypes.float16, name="cast")
+ y2 = nn_ops.relu(x2, name="relu_fp16")
+ l2 = nn_ops.l2_loss(y2)
+ dx_f16 = gradients_impl.gradients(l2, x)
+
+ # Repeat the experiment for 100 times. All tensor shapes and its tensor
+ # values are randomly generated for each run.
+ for _ in xrange(100):
+ dx_f32_v, dx_f16_v = sess.run([dx_f32, dx_f16])
+ self.assertAllClose(dx_f32_v, dx_f16_v, atol=3e-4)
+
def testGradientFloat64(self):
with self.test_session():
x = constant_op.constant(