aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/relu_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/relu_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index 8cd1f52d80..dd11ba700d 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -441,6 +441,24 @@ class CreluTest(test.TestCase):
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
use_gpu=True)
+ def testNumbersWithAxis0(self):
+ with self.test_session():
+ crelu = nn_ops.crelu(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=0)
+ tf_relu = crelu.eval()
+ np_crelu = np.array([[0, 7, 0, 3, 0], [1, 0, 5, 0, 9], [9, 0, 5, 0, 1],
+ [0, 3, 0, 7, 0]])
+ self.assertAllEqual(np_crelu, tf_relu)
+
+ def testNumbersWithAxis1(self):
+ with self.test_session():
+ crelu = nn_ops.crelu(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=1)
+ tf_relu = crelu.eval()
+ np_crelu = np.array([[0, 7, 0, 3, 0, 9, 0, 5, 0, 1],
+ [1, 0, 5, 0, 9, 0, 3, 0, 7, 0]])
+ self.assertAllEqual(np_crelu, tf_relu)
+
if __name__ == "__main__":
test.main()