diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-27 09:00:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-27 09:04:01 -0700 |
commit | 4198e27be8115585ad6b5b141383fb7dc7856c24 (patch) | |
tree | 244405e6ef96cb098d8abbf2547a8f22dfb4c72d /tensorflow/compiler/tests/binary_ops_test.py | |
parent | 4ae245a7db3d0457c4324ee7df8d020ba83b3c60 (diff) |
[XLA:CPU] [XLA:GPU] Adds compiler support for C64 primitive type, including relevant elementwise unary and binary op lowering for CPU and GPU.
We use a named LLVM struct "complex64", laid out the same as std::complex<float>. This named struct is accessed via the llvm::Module, which required changes to accessors of PrimitiveTypeToIrType & friends.
Ops that require atan2 (in particular, angle and log) are only supported on GPU at this point. LLVM lacks a CPU intrinsic for atan or atan2, whereas libdevice provides this for GPU.
PiperOrigin-RevId: 173676849
Diffstat (limited to 'tensorflow/compiler/tests/binary_ops_test.py')
-rw-r--r-- | tensorflow/compiler/tests/binary_ops_test.py | 248 |
1 files changed, 197 insertions, 51 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 9a225b32f8..d412c572ae 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -94,6 +94,15 @@ class BinaryOpsTest(XLATestCase): dtype(4), expected=np.array([[16], [81]], dtype=dtype)) + atan2_supported = self.device == "XLA_GPU" + if atan2_supported: + self._testBinary( + math_ops.atan2, + np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype), + np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype), + expected=np.array( + [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) + self._testBinary( gen_math_ops._reciprocal_grad, np.array([4, -3, -2, 1], dtype=dtype), @@ -259,37 +268,38 @@ class BinaryOpsTest(XLATestCase): dtype(7), expected=np.array([[-6], [-5]], dtype=dtype)) - self._testBinary( - math_ops.maximum, - np.array([1, 2], dtype=dtype), - np.array([10, 20], dtype=dtype), - expected=np.array([10, 20], dtype=dtype)) - self._testBinary( - math_ops.maximum, - dtype(5), - np.array([1, 20], dtype=dtype), - expected=np.array([5, 20], dtype=dtype)) - self._testBinary( - math_ops.maximum, - np.array([[10], [2]], dtype=dtype), - dtype(7), - expected=np.array([[10], [7]], dtype=dtype)) + if dtype not in self.complex_types: # min/max not supported for complex + self._testBinary( + math_ops.maximum, + np.array([1, 2], dtype=dtype), + np.array([10, 20], dtype=dtype), + expected=np.array([10, 20], dtype=dtype)) + self._testBinary( + math_ops.maximum, + dtype(5), + np.array([1, 20], dtype=dtype), + expected=np.array([5, 20], dtype=dtype)) + self._testBinary( + math_ops.maximum, + np.array([[10], [2]], dtype=dtype), + dtype(7), + expected=np.array([[10], [7]], dtype=dtype)) - self._testBinary( - math_ops.minimum, - np.array([1, 20], dtype=dtype), - np.array([10, 2], dtype=dtype), - expected=np.array([1, 2], dtype=dtype)) - self._testBinary( - math_ops.minimum, - dtype(5), - np.array([1, 20], dtype=dtype), - expected=np.array([1, 5], dtype=dtype)) - self._testBinary( - math_ops.minimum, - np.array([[10], [2]], dtype=dtype), - dtype(7), - expected=np.array([[7], [2]], dtype=dtype)) + self._testBinary( + math_ops.minimum, + np.array([1, 20], dtype=dtype), + np.array([10, 2], dtype=dtype), + expected=np.array([1, 2], dtype=dtype)) + self._testBinary( + math_ops.minimum, + dtype(5), + np.array([1, 20], dtype=dtype), + expected=np.array([1, 5], dtype=dtype)) + self._testBinary( + math_ops.minimum, + np.array([[10], [2]], dtype=dtype), + dtype(7), + expected=np.array([[7], [2]], dtype=dtype)) self._testBinary( math_ops.multiply, @@ -307,21 +317,23 @@ class BinaryOpsTest(XLATestCase): dtype(7), expected=np.array([[70], [14]], dtype=dtype)) - self._testBinary( - math_ops.squared_difference, - np.array([1, 2], dtype=dtype), - np.array([10, 20], dtype=dtype), - expected=np.array([81, 324], dtype=dtype)) - self._testBinary( - math_ops.squared_difference, - dtype(5), - np.array([1, 2], dtype=dtype), - expected=np.array([16, 9], dtype=dtype)) - self._testBinary( - math_ops.squared_difference, - np.array([[1], [2]], dtype=dtype), - dtype(7), - expected=np.array([[36], [25]], dtype=dtype)) + # Complex support for squared_difference is incidental, see b/68205550 + if dtype not in self.complex_types: + self._testBinary( + math_ops.squared_difference, + np.array([1, 2], dtype=dtype), + np.array([10, 20], dtype=dtype), + expected=np.array([81, 324], dtype=dtype)) + self._testBinary( + math_ops.squared_difference, + dtype(5), + np.array([1, 2], dtype=dtype), + expected=np.array([16, 9], dtype=dtype)) + self._testBinary( + math_ops.squared_difference, + np.array([[1], [2]], dtype=dtype), + dtype(7), + expected=np.array([[36], [25]], dtype=dtype)) self._testBinary( nn_ops.bias_add, @@ -334,6 +346,139 @@ class BinaryOpsTest(XLATestCase): np.array([2, -1], dtype=dtype), expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype)) + def testComplexOps(self): + for dtype in self.complex_types: + ctypes = {np.complex64: np.float32} + self._testBinary( + math_ops.complex, + np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]), + np.array([[[[2, -3], [0, 4]]]], dtype=ctypes[dtype]), + expected=np.array([[[[-1 + 2j, 2 - 3j], [2, 4j]]]], dtype=dtype)) + + self._testBinary( + lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001), + np.array( + [[[[-1 + 2j, 2.00009999 - 3j], [2 - 3j, 3 + 4.01j]]]], + dtype=dtype), + np.array( + [[[[-1.001 + 2j, 2 - 3j], [2 - 3.00009j, 3 + 4j]]]], dtype=dtype), + expected=np.array([[[[False, True], [True, False]]]], dtype=dtype)) + + self._testBinary( + gen_math_ops._real_div, + np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 44 + 3j], dtype=dtype), + np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j, 0], dtype=dtype), + expected=np.array( + [ + 1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2, + float("inf") + ], + dtype=dtype)) + + # TODO(b/65408531): support+test pow for cplx + + lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) + rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) + self._testBinary( + gen_math_ops._reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs) + + self._testBinary( + gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) + + # TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow) + + self._testBinary( + gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) + + self._testBinary( + gen_math_ops._tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs)) + + def testComplexMath(self): + for dtype in self.complex_types: + self._testBinary( + math_ops.add, + np.array([1 + 3j, 2 + 7j], dtype=dtype), + np.array([10 - 4j, 20 + 17j], dtype=dtype), + expected=np.array([11 - 1j, 22 + 24j], dtype=dtype)) + self._testBinary( + math_ops.add, + dtype(5 - 7j), + np.array([1 + 2j, 2 + 4j], dtype=dtype), + expected=np.array([6 - 5j, 7 - 3j], dtype=dtype)) + self._testBinary( + math_ops.add, + np.array([[1 - 2j], [2 + 1j]], dtype=dtype), + dtype(7 + 5j), + expected=np.array([[8 + 3j], [9 + 6j]], dtype=dtype)) + + self._testBinary( + math_ops.subtract, + np.array([1 + 3j, 2 + 7j], dtype=dtype), + np.array([10 - 4j, 20 + 17j], dtype=dtype), + expected=np.array([-9 + 7j, -18 - 10j], dtype=dtype)) + self._testBinary( + math_ops.subtract, + dtype(5 - 7j), + np.array([1 + 2j, 2 + 4j], dtype=dtype), + expected=np.array([4 - 9j, 3 - 11j], dtype=dtype)) + self._testBinary( + math_ops.subtract, + np.array([[1 - 2j], [2 + 1j]], dtype=dtype), + dtype(7 + 5j), + expected=np.array([[-6 - 7j], [-5 - 4j]], dtype=dtype)) + + self._testBinary( + math_ops.multiply, + np.array([1 + 3j, 2 + 7j], dtype=dtype), + np.array([10 - 4j, 20 + 17j], dtype=dtype), + expected=np.array( + [(1 + 3j) * (10 - 4j), (2 + 7j) * (20 + 17j)], dtype=dtype)) + self._testBinary( + math_ops.multiply, + dtype(5 - 7j), + np.array([1 + 2j, 2 + 4j], dtype=dtype), + expected=np.array( + [(5 - 7j) * (1 + 2j), (5 - 7j) * (2 + 4j)], dtype=dtype)) + self._testBinary( + math_ops.multiply, + np.array([[1 - 2j], [2 + 1j]], dtype=dtype), + dtype(7 + 5j), + expected=np.array( + [[(7 + 5j) * (1 - 2j)], [(7 + 5j) * (2 + 1j)]], dtype=dtype)) + + self._testBinary( + math_ops.div, + np.array([8 - 1j, 2 + 16j], dtype=dtype), + np.array([2 + 4j, 4 - 8j], dtype=dtype), + expected=np.array( + [(8 - 1j) / (2 + 4j), (2 + 16j) / (4 - 8j)], dtype=dtype)) + self._testBinary( + math_ops.div, + dtype(1 + 2j), + np.array([2 + 4j, 4 - 8j], dtype=dtype), + expected=np.array( + [(1 + 2j) / (2 + 4j), (1 + 2j) / (4 - 8j)], dtype=dtype)) + self._testBinary( + math_ops.div, + np.array([2 + 4j, 4 - 8j], dtype=dtype), + dtype(1 + 2j), + expected=np.array( + [(2 + 4j) / (1 + 2j), (4 - 8j) / (1 + 2j)], dtype=dtype)) + + # TODO(b/68205550): math_ops.squared_difference shouldn't be supported. + + self._testBinary( + nn_ops.bias_add, + np.array([[1 + 2j, 2 + 7j], [3 - 5j, 4 + 2j]], dtype=dtype), + np.array([2 + 6j, -1 - 3j], dtype=dtype), + expected=np.array([[3 + 8j, 1 + 4j], [5 + 1j, 3 - 1j]], dtype=dtype)) + self._testBinary( + nn_ops.bias_add, + np.array([[[[1 + 4j, 2 - 1j], [3 + 7j, 4]]]], dtype=dtype), + np.array([2 + 1j, -1 + 2j], dtype=dtype), + expected=np.array( + [[[[3 + 5j, 1 + 1j], [5 + 8j, 3 + 2j]]]], dtype=dtype)) + def _testDivision(self, dtype): """Test cases for division operators.""" self._testBinary( @@ -352,18 +497,19 @@ class BinaryOpsTest(XLATestCase): dtype(2), expected=np.array([[5], [2]], dtype=dtype)) - self._testBinary( - gen_math_ops._floor_div, - np.array([3, 3, -1, -9, -8], dtype=dtype), - np.array([2, -2, 7, 2, -4], dtype=dtype), - expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) + if dtype not in self.complex_types: # floordiv unsupported for complex. + self._testBinary( + gen_math_ops._floor_div, + np.array([3, 3, -1, -9, -8], dtype=dtype), + np.array([2, -2, 7, 2, -4], dtype=dtype), + expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) def testIntDivision(self): for dtype in self.int_types: self._testDivision(dtype) def testFloatDivision(self): - for dtype in self.float_types: + for dtype in self.float_types + self.complex_types: self._testDivision(dtype) def _testRemainder(self, dtype): |