aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/unary_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/unary_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py60
1 files changed, 58 insertions, 2 deletions
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index a24abd7547..5f25ff9002 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -23,7 +23,7 @@ import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
@@ -44,7 +44,7 @@ def nhwc_to_format(x, data_format):
raise ValueError("Unknown format {}".format(data_format))
-class UnaryOpsTest(XLATestCase):
+class UnaryOpsTest(xla_test.XLATestCase):
"""Test cases for unary operators."""
def _assertOpOutputMatchesExpected(self,
@@ -382,6 +382,62 @@ class UnaryOpsTest(XLATestCase):
expected=np.array(
[[True, False, True], [False, True, True]], dtype=np.bool))
+ self._assertOpOutputMatchesExpected(
+ math_ops.lgamma,
+ np.array(
+ [[1, 2, 3], [4, 5, 6], [1 / 2, 3 / 2, 5 / 2],
+ [-3 / 2, -7 / 2, -11 / 2]],
+ dtype=dtype),
+ expected=np.array(
+ [
+ [0, 0, np.log(2.0)],
+ [np.log(6.0), np.log(24.0),
+ np.log(120)],
+ [
+ np.log(np.pi) / 2,
+ np.log(np.pi) / 2 - np.log(2),
+ np.log(np.pi) / 2 - np.log(4) + np.log(3)
+ ],
+ [
+ np.log(np.pi) / 2 - np.log(3) + np.log(4),
+ np.log(np.pi) / 2 - np.log(105) + np.log(16),
+ np.log(np.pi) / 2 - np.log(10395) + np.log(64),
+ ],
+ ],
+ dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
+ math_ops.digamma,
+ np.array(
+ [[1.0, 0.5, 1 / 3.0], [0.25, 1 / 6.0, 0.125], [2.0, 3.0, 4.0],
+ [6.0, 8.0, 9.0]],
+ dtype=dtype),
+ expected=np.array(
+ [
+ [
+ -np.euler_gamma, -2 * np.log(2) - np.euler_gamma,
+ -np.pi / 2 / np.sqrt(3) - 3 * np.log(3) / 2 -
+ np.euler_gamma
+ ],
+ [
+ -np.pi / 2 - 3 * np.log(2) - np.euler_gamma,
+ -np.pi * np.sqrt(3) / 2 - 2 * np.log(2) -
+ 3 * np.log(3) / 2 - np.euler_gamma,
+ -np.pi / 2 - 4 * np.log(2) -
+ (np.pi + np.log(2 + np.sqrt(2)) - np.log(2 - np.sqrt(2)))
+ / np.sqrt(2) - np.euler_gamma
+ ],
+ [
+ 1 - np.euler_gamma, 1.5 - np.euler_gamma,
+ 11 / 6.0 - np.euler_gamma
+ ],
+ [
+ 137 / 60.0 - np.euler_gamma, 363 / 140.0 - np.euler_gamma,
+ 761 / 280.0 - np.euler_gamma
+ ],
+ ],
+ dtype=dtype))
+
def quantize_and_dequantize_v2(x):
return array_ops.quantize_and_dequantize_v2(
x, -127, 127, signed_input=True, num_bits=8)