aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/binary_ops_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-27 09:00:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-27 09:04:01 -0700
commit4198e27be8115585ad6b5b141383fb7dc7856c24 (patch)
tree244405e6ef96cb098d8abbf2547a8f22dfb4c72d /tensorflow/compiler/tests/binary_ops_test.py
parent4ae245a7db3d0457c4324ee7df8d020ba83b3c60 (diff)
[XLA:CPU] [XLA:GPU] Adds compiler support for C64 primitive type, including relevant elementwise unary and binary op lowering for CPU and GPU.
We use a named LLVM struct "complex64", laid out the same as std::complex<float>. This named struct is accessed via the llvm::Module, which required changes to accessors of PrimitiveTypeToIrType & friends. Ops that require atan2 (in particular, angle and log) are only supported on GPU at this point. LLVM lacks a CPU intrinsic for atan or atan2, whereas libdevice provides this for GPU. PiperOrigin-RevId: 173676849
Diffstat (limited to 'tensorflow/compiler/tests/binary_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py248
1 files changed, 197 insertions, 51 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 9a225b32f8..d412c572ae 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -94,6 +94,15 @@ class BinaryOpsTest(XLATestCase):
dtype(4),
expected=np.array([[16], [81]], dtype=dtype))
+ atan2_supported = self.device == "XLA_GPU"
+ if atan2_supported:
+ self._testBinary(
+ math_ops.atan2,
+ np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype),
+ np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype),
+ expected=np.array(
+ [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype))
+
self._testBinary(
gen_math_ops._reciprocal_grad,
np.array([4, -3, -2, 1], dtype=dtype),
@@ -259,37 +268,38 @@ class BinaryOpsTest(XLATestCase):
dtype(7),
expected=np.array([[-6], [-5]], dtype=dtype))
- self._testBinary(
- math_ops.maximum,
- np.array([1, 2], dtype=dtype),
- np.array([10, 20], dtype=dtype),
- expected=np.array([10, 20], dtype=dtype))
- self._testBinary(
- math_ops.maximum,
- dtype(5),
- np.array([1, 20], dtype=dtype),
- expected=np.array([5, 20], dtype=dtype))
- self._testBinary(
- math_ops.maximum,
- np.array([[10], [2]], dtype=dtype),
- dtype(7),
- expected=np.array([[10], [7]], dtype=dtype))
+ if dtype not in self.complex_types: # min/max not supported for complex
+ self._testBinary(
+ math_ops.maximum,
+ np.array([1, 2], dtype=dtype),
+ np.array([10, 20], dtype=dtype),
+ expected=np.array([10, 20], dtype=dtype))
+ self._testBinary(
+ math_ops.maximum,
+ dtype(5),
+ np.array([1, 20], dtype=dtype),
+ expected=np.array([5, 20], dtype=dtype))
+ self._testBinary(
+ math_ops.maximum,
+ np.array([[10], [2]], dtype=dtype),
+ dtype(7),
+ expected=np.array([[10], [7]], dtype=dtype))
- self._testBinary(
- math_ops.minimum,
- np.array([1, 20], dtype=dtype),
- np.array([10, 2], dtype=dtype),
- expected=np.array([1, 2], dtype=dtype))
- self._testBinary(
- math_ops.minimum,
- dtype(5),
- np.array([1, 20], dtype=dtype),
- expected=np.array([1, 5], dtype=dtype))
- self._testBinary(
- math_ops.minimum,
- np.array([[10], [2]], dtype=dtype),
- dtype(7),
- expected=np.array([[7], [2]], dtype=dtype))
+ self._testBinary(
+ math_ops.minimum,
+ np.array([1, 20], dtype=dtype),
+ np.array([10, 2], dtype=dtype),
+ expected=np.array([1, 2], dtype=dtype))
+ self._testBinary(
+ math_ops.minimum,
+ dtype(5),
+ np.array([1, 20], dtype=dtype),
+ expected=np.array([1, 5], dtype=dtype))
+ self._testBinary(
+ math_ops.minimum,
+ np.array([[10], [2]], dtype=dtype),
+ dtype(7),
+ expected=np.array([[7], [2]], dtype=dtype))
self._testBinary(
math_ops.multiply,
@@ -307,21 +317,23 @@ class BinaryOpsTest(XLATestCase):
dtype(7),
expected=np.array([[70], [14]], dtype=dtype))
- self._testBinary(
- math_ops.squared_difference,
- np.array([1, 2], dtype=dtype),
- np.array([10, 20], dtype=dtype),
- expected=np.array([81, 324], dtype=dtype))
- self._testBinary(
- math_ops.squared_difference,
- dtype(5),
- np.array([1, 2], dtype=dtype),
- expected=np.array([16, 9], dtype=dtype))
- self._testBinary(
- math_ops.squared_difference,
- np.array([[1], [2]], dtype=dtype),
- dtype(7),
- expected=np.array([[36], [25]], dtype=dtype))
+ # Complex support for squared_difference is incidental, see b/68205550
+ if dtype not in self.complex_types:
+ self._testBinary(
+ math_ops.squared_difference,
+ np.array([1, 2], dtype=dtype),
+ np.array([10, 20], dtype=dtype),
+ expected=np.array([81, 324], dtype=dtype))
+ self._testBinary(
+ math_ops.squared_difference,
+ dtype(5),
+ np.array([1, 2], dtype=dtype),
+ expected=np.array([16, 9], dtype=dtype))
+ self._testBinary(
+ math_ops.squared_difference,
+ np.array([[1], [2]], dtype=dtype),
+ dtype(7),
+ expected=np.array([[36], [25]], dtype=dtype))
self._testBinary(
nn_ops.bias_add,
@@ -334,6 +346,139 @@ class BinaryOpsTest(XLATestCase):
np.array([2, -1], dtype=dtype),
expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype))
+ def testComplexOps(self):
+ for dtype in self.complex_types:
+ ctypes = {np.complex64: np.float32}
+ self._testBinary(
+ math_ops.complex,
+ np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]),
+ np.array([[[[2, -3], [0, 4]]]], dtype=ctypes[dtype]),
+ expected=np.array([[[[-1 + 2j, 2 - 3j], [2, 4j]]]], dtype=dtype))
+
+ self._testBinary(
+ lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001),
+ np.array(
+ [[[[-1 + 2j, 2.00009999 - 3j], [2 - 3j, 3 + 4.01j]]]],
+ dtype=dtype),
+ np.array(
+ [[[[-1.001 + 2j, 2 - 3j], [2 - 3.00009j, 3 + 4j]]]], dtype=dtype),
+ expected=np.array([[[[False, True], [True, False]]]], dtype=dtype))
+
+ self._testBinary(
+ gen_math_ops._real_div,
+ np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 44 + 3j], dtype=dtype),
+ np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j, 0], dtype=dtype),
+ expected=np.array(
+ [
+ 1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2,
+ float("inf")
+ ],
+ dtype=dtype))
+
+ # TODO(b/65408531): support+test pow for cplx
+
+ lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype)
+ rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype)
+ self._testBinary(
+ gen_math_ops._reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs)
+
+ self._testBinary(
+ gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs))
+
+ # TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow)
+
+ self._testBinary(
+ gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs))
+
+ self._testBinary(
+ gen_math_ops._tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs))
+
+ def testComplexMath(self):
+ for dtype in self.complex_types:
+ self._testBinary(
+ math_ops.add,
+ np.array([1 + 3j, 2 + 7j], dtype=dtype),
+ np.array([10 - 4j, 20 + 17j], dtype=dtype),
+ expected=np.array([11 - 1j, 22 + 24j], dtype=dtype))
+ self._testBinary(
+ math_ops.add,
+ dtype(5 - 7j),
+ np.array([1 + 2j, 2 + 4j], dtype=dtype),
+ expected=np.array([6 - 5j, 7 - 3j], dtype=dtype))
+ self._testBinary(
+ math_ops.add,
+ np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
+ dtype(7 + 5j),
+ expected=np.array([[8 + 3j], [9 + 6j]], dtype=dtype))
+
+ self._testBinary(
+ math_ops.subtract,
+ np.array([1 + 3j, 2 + 7j], dtype=dtype),
+ np.array([10 - 4j, 20 + 17j], dtype=dtype),
+ expected=np.array([-9 + 7j, -18 - 10j], dtype=dtype))
+ self._testBinary(
+ math_ops.subtract,
+ dtype(5 - 7j),
+ np.array([1 + 2j, 2 + 4j], dtype=dtype),
+ expected=np.array([4 - 9j, 3 - 11j], dtype=dtype))
+ self._testBinary(
+ math_ops.subtract,
+ np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
+ dtype(7 + 5j),
+ expected=np.array([[-6 - 7j], [-5 - 4j]], dtype=dtype))
+
+ self._testBinary(
+ math_ops.multiply,
+ np.array([1 + 3j, 2 + 7j], dtype=dtype),
+ np.array([10 - 4j, 20 + 17j], dtype=dtype),
+ expected=np.array(
+ [(1 + 3j) * (10 - 4j), (2 + 7j) * (20 + 17j)], dtype=dtype))
+ self._testBinary(
+ math_ops.multiply,
+ dtype(5 - 7j),
+ np.array([1 + 2j, 2 + 4j], dtype=dtype),
+ expected=np.array(
+ [(5 - 7j) * (1 + 2j), (5 - 7j) * (2 + 4j)], dtype=dtype))
+ self._testBinary(
+ math_ops.multiply,
+ np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
+ dtype(7 + 5j),
+ expected=np.array(
+ [[(7 + 5j) * (1 - 2j)], [(7 + 5j) * (2 + 1j)]], dtype=dtype))
+
+ self._testBinary(
+ math_ops.div,
+ np.array([8 - 1j, 2 + 16j], dtype=dtype),
+ np.array([2 + 4j, 4 - 8j], dtype=dtype),
+ expected=np.array(
+ [(8 - 1j) / (2 + 4j), (2 + 16j) / (4 - 8j)], dtype=dtype))
+ self._testBinary(
+ math_ops.div,
+ dtype(1 + 2j),
+ np.array([2 + 4j, 4 - 8j], dtype=dtype),
+ expected=np.array(
+ [(1 + 2j) / (2 + 4j), (1 + 2j) / (4 - 8j)], dtype=dtype))
+ self._testBinary(
+ math_ops.div,
+ np.array([2 + 4j, 4 - 8j], dtype=dtype),
+ dtype(1 + 2j),
+ expected=np.array(
+ [(2 + 4j) / (1 + 2j), (4 - 8j) / (1 + 2j)], dtype=dtype))
+
+ # TODO(b/68205550): math_ops.squared_difference shouldn't be supported.
+
+ self._testBinary(
+ nn_ops.bias_add,
+ np.array([[1 + 2j, 2 + 7j], [3 - 5j, 4 + 2j]], dtype=dtype),
+ np.array([2 + 6j, -1 - 3j], dtype=dtype),
+ expected=np.array([[3 + 8j, 1 + 4j], [5 + 1j, 3 - 1j]], dtype=dtype))
+ self._testBinary(
+ nn_ops.bias_add,
+ np.array([[[[1 + 4j, 2 - 1j], [3 + 7j, 4]]]], dtype=dtype),
+ np.array([2 + 1j, -1 + 2j], dtype=dtype),
+ expected=np.array(
+ [[[[3 + 5j, 1 + 1j], [5 + 8j, 3 + 2j]]]], dtype=dtype))
+
def _testDivision(self, dtype):
"""Test cases for division operators."""
self._testBinary(
@@ -352,18 +497,19 @@ class BinaryOpsTest(XLATestCase):
dtype(2),
expected=np.array([[5], [2]], dtype=dtype))
- self._testBinary(
- gen_math_ops._floor_div,
- np.array([3, 3, -1, -9, -8], dtype=dtype),
- np.array([2, -2, 7, 2, -4], dtype=dtype),
- expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
+ if dtype not in self.complex_types: # floordiv unsupported for complex.
+ self._testBinary(
+ gen_math_ops._floor_div,
+ np.array([3, 3, -1, -9, -8], dtype=dtype),
+ np.array([2, -2, 7, 2, -4], dtype=dtype),
+ expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
def testIntDivision(self):
for dtype in self.int_types:
self._testDivision(dtype)
def testFloatDivision(self):
- for dtype in self.float_types:
+ for dtype in self.float_types + self.complex_types:
self._testDivision(dtype)
def _testRemainder(self, dtype):