diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/relu_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/relu_op_test.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index 8cd1f52d80..dd11ba700d 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -441,6 +441,24 @@ class CreluTest(test.TestCase): np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), use_gpu=True) + def testNumbersWithAxis0(self): + with self.test_session(): + crelu = nn_ops.crelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=0) + tf_relu = crelu.eval() + np_crelu = np.array([[0, 7, 0, 3, 0], [1, 0, 5, 0, 9], [9, 0, 5, 0, 1], + [0, 3, 0, 7, 0]]) + self.assertAllEqual(np_crelu, tf_relu) + + def testNumbersWithAxis1(self): + with self.test_session(): + crelu = nn_ops.crelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=1) + tf_relu = crelu.eval() + np_crelu = np.array([[0, 7, 0, 3, 0, 9, 0, 5, 0, 1], + [1, 0, 5, 0, 9, 0, 3, 0, 7, 0]]) + self.assertAllEqual(np_crelu, tf_relu) + if __name__ == "__main__": test.main() |