aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/python/xla_client_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/python/xla_client_test.py')
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index 0564ddcb85..fd98e19457 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -171,6 +171,24 @@ class ComputationsWithConstantsTest(LocalComputationTest):
c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]])))
self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
+ def testShiftLeft(self):
+ c = self._NewComputation()
+ c.ShiftLeft(c.Constant(NumpyArrayS32([3])),
+ c.Constant(NumpyArrayS32([2])))
+ self._ExecuteAndCompareClose(c, expected=[12])
+
+ def testShiftRightArithmetic(self):
+ c = self._NewComputation()
+ c.ShiftRightArithmetic(c.Constant(NumpyArrayS32([-2])),
+ c.Constant(NumpyArrayS32([1])))
+ self._ExecuteAndCompareClose(c, expected=[-1])
+
+ def testShiftRightLogical(self):
+ c = self._NewComputation()
+ c.ShiftRightLogical(c.Constant(NumpyArrayS32([-1])),
+ c.Constant(NumpyArrayS32([1])))
+ self._ExecuteAndCompareClose(c, expected=[2**31 - 1])
+
def testGetProto(self):
c = self._NewComputation()
c.Add(
@@ -471,6 +489,34 @@ class SingleOpTest(LocalComputationTest):
for src_dtype, dst_dtype in itertools.product(xla_types, xla_types):
_ConvertAndTest(x, src_dtype, dst_dtype)
+ def testBitcastConvertType(self):
+ xla_x32_types = {
+ np.int32: xla_client.xla_data_pb2.S32,
+ np.float32: xla_client.xla_data_pb2.F32,
+ }
+
+ xla_x64_types = {
+ np.int64: xla_client.xla_data_pb2.S64,
+ np.float64: xla_client.xla_data_pb2.F64,
+ }
+
+ def _ConvertAndTest(template, src_dtype, dst_dtype, dst_etype):
+ c = self._NewComputation()
+ x = c.Constant(np.array(template, dtype=src_dtype))
+ c.BitcastConvertType(x, dst_etype)
+
+ result = c.Build().Compile().Execute()
+ expected = np.array(template, src_dtype).view(dst_dtype)
+
+ self.assertEqual(result.shape, expected.shape)
+ self.assertEqual(result.dtype, expected.dtype)
+ np.testing.assert_equal(result, expected)
+
+ x = [0, 1, 0, 0, 1]
+ for xla_types in [xla_x32_types, xla_x64_types]:
+ for src_dtype, dst_dtype in itertools.product(xla_types, xla_types):
+ _ConvertAndTest(x, src_dtype, dst_dtype, xla_types[dst_dtype])
+
def testCrossReplicaSumOneReplica(self):
samples = [
NumpyArrayF32(42.0),