aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/binary_ops_test.py
diff options
context:
space:
mode:
authorGravatar Chris Leary <leary@google.com>2017-07-26 18:39:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-26 18:43:32 -0700
commit32e198f2d5787ca81aba89bf073e4eb380769253 (patch)
tree72051598edeaaedde73ae4658cabaa4cadc1b3a8 /tensorflow/compiler/tests/binary_ops_test.py
parent9b30dc3a824fd277fcd622a458b25f26c0db7b72 (diff)
[TF:XLA] Add tf.cross support.
See #11788 PiperOrigin-RevId: 163287731
Diffstat (limited to 'tensorflow/compiler/tests/binary_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 9eaede7f40..83cfd2ea75 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -765,6 +765,24 @@ class BinaryOpsTest(XLATestCase):
np.array([1, 0], dtype=np.int32),
expected=np.array([[1, 3], [2, 4]], dtype=dtype))
+ def testCross(self):
+ for dtype in self.float_types:
+ self._testBinary(
+ gen_math_ops.cross,
+ np.zeros((4, 3), dtype=dtype),
+ np.zeros((4, 3), dtype=dtype),
+ expected=np.zeros((4, 3), dtype=dtype))
+ self._testBinary(
+ gen_math_ops.cross,
+ np.array([1, 2, 3], dtype=dtype),
+ np.array([4, 5, 6], dtype=dtype),
+ expected=np.array([-3, 6, -3], dtype=dtype))
+ self._testBinary(
+ gen_math_ops.cross,
+ np.array([[1, 2, 3], [10, 11, 12]], dtype=dtype),
+ np.array([[4, 5, 6], [40, 50, 60]], dtype=dtype),
+ expected=np.array([[-3, 6, -3], [60, -120, 60]], dtype=dtype))
+
if __name__ == "__main__":
googletest.main()