diff options
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r-- | tensorflow/compiler/tests/binary_ops_test.py | 8 | ||||
-rw-r--r-- | tensorflow/compiler/tests/randomized_tests.cc | 17 | ||||
-rw-r--r-- | tensorflow/compiler/tests/unary_ops_test.py | 5 |
3 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 83cfd2ea75..c23ee5f037 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -114,6 +114,14 @@ class BinaryOpsTest(XLATestCase): expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype)) self._testBinary( + gen_nn_ops._selu_grad, + np.array([1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([-.6, -.4, -.2, .2, .4, .6], dtype=dtype), + expected=np.array( + [1.158099340847, 2.7161986816948, 4.67429802254, + 4.202803949422, 5.2535049367774, 6.30420592413], dtype=dtype)) + + self._testBinary( gen_nn_ops._relu_grad, np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype), diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index d3821ad02e..825fd9de2e 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -1434,6 +1434,23 @@ TEST_F(OpTest, EluGrad) { }); } +TEST_F(OpTest, Selu) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Selu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, SeluGrad) { + Repeatedly([this]() { + auto dims = RandomDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SeluGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Equal) { Repeatedly([this]() { DataType type = Choose<DataType>({DT_INT32, DT_FLOAT}); diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index ce35eb9197..81ff18f302 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -230,6 +230,11 @@ class UnaryOpsTest(XLATestCase): expected=np.array([[-0.63212056, 0, 1]], dtype=dtype)) self._assertOpOutputMatchesExpected( + nn_ops.selu, + np.array([[-1, 0, 1]], dtype=dtype), + expected=np.array([[-1.11133074, 0., 1.05070099]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( nn_ops.relu, np.array([[-1, 1]], dtype=dtype), expected=np.array([[0, 1]], dtype=dtype)) |