aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/relu_op_test.py
diff options
context:
space:
mode:
authorGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-09-11 19:59:11 +0800
committerGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-09-11 19:59:11 +0800
commit9b3a93edf5a1f259bfe5230cc3b6c076573d4ec9 (patch)
treecbb0548282ba1584ed91a1be8f89b03ec882f287 /tensorflow/python/kernel_tests/relu_op_test.py
parent90cf7fb7786c8a9c135ef73482856b082e80f61a (diff)
parente18f84a394bcbde62b344a3b32e8d8fd248fea58 (diff)
Merge remote-tracking branch 'origin'
Diffstat (limited to 'tensorflow/python/kernel_tests/relu_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index 3e24b8a2c4..86d9c90e83 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -24,6 +24,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
@@ -72,6 +73,35 @@ class ReluTest(test.TestCase):
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
use_gpu=True)
+ def _testReluInt8x4(self, np_inputs):
+ if not test.is_gpu_available(cuda_only=True):
+ return
+ np_relu = self._npRelu(np_inputs)
+ with self.test_session(use_gpu=True):
+ relu = nn_ops.relu(constant_op.constant(np_inputs, dtypes.qint8))
+ if np_inputs.size % 4 == 0:
+ tf_relu = relu.eval()
+ self.assertAllClose(np_relu, tf_relu)
+ self.assertShapeEqual(np_relu, relu)
+ else:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Tensor size must be a multiple of 4 for Relu<qint8>. Got %d" %
+ np_inputs.size):
+ tf_relu = relu.eval()
+
+ def testReluInt8x4GoodShape(self):
+ self._testReluInt8x4(np.array([[-50, 7, 23, 0], [-1, -5, 6, 11]]))
+
+ def testReluInt8x4BadShape(self):
+ np_inputs = np.array([[-50, 7, 23], [0, 1, -5], [6, -2, 11]])
+ self.assertEqual(np_inputs.size, 9)
+ self._testReluInt8x4(np_inputs)
+ np_inputs = np.array(
+ [1, -2, 3, -4, 5, -6, 7, -8, 9, -8, 7, -6, 5, -4, 3, -2, 1])
+ self.assertEqual(np_inputs.size, 17)
+ self._testReluInt8x4(np_inputs)
+
# The gradient test for ReLU is a bit tricky as the derivative is not well
# defined at around zero and we want to avoid that in terms of input values.
def testGradientFloat32(self):