aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py8
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc17
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py5
3 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 83cfd2ea75..c23ee5f037 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -114,6 +114,14 @@ class BinaryOpsTest(XLATestCase):
expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype))
self._testBinary(
+ gen_nn_ops._selu_grad,
+ np.array([1, 2, 3, 4, 5, 6], dtype=dtype),
+ np.array([-.6, -.4, -.2, .2, .4, .6], dtype=dtype),
+ expected=np.array(
+ [1.158099340847, 2.7161986816948, 4.67429802254,
+ 4.202803949422, 5.2535049367774, 6.30420592413], dtype=dtype))
+
+ self._testBinary(
gen_nn_ops._relu_grad,
np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype),
np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype),
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index d3821ad02e..825fd9de2e 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -1434,6 +1434,23 @@ TEST_F(OpTest, EluGrad) {
});
}
+TEST_F(OpTest, Selu) {
+ Repeatedly([this]() {
+ return ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Selu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, SeluGrad) {
+ Repeatedly([this]() {
+ auto dims = RandomDims();
+ return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SeluGrad")
+ .RandomInput(DT_FLOAT, dims)
+ .RandomInput(DT_FLOAT, dims)
+ .Attr("T", DT_FLOAT));
+ });
+}
+
TEST_F(OpTest, Equal) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index ce35eb9197..81ff18f302 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -230,6 +230,11 @@ class UnaryOpsTest(XLATestCase):
expected=np.array([[-0.63212056, 0, 1]], dtype=dtype))
self._assertOpOutputMatchesExpected(
+ nn_ops.selu,
+ np.array([[-1, 0, 1]], dtype=dtype),
+ expected=np.array([[-1.11133074, 0., 1.05070099]], dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
nn_ops.relu,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[0, 1]], dtype=dtype))