aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-20 17:44:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 17:47:49 -0700
commitb840e5ac84319e6e091a0f9351b7691390275f2f (patch)
tree539954adfd1203f35928452cc43bafc147922d6d /tensorflow/compiler/xla/python
parent57e5dfa76a32ff0ee6ec4b72a2461487b7969a3e (diff)
[XLA] add BitcastConvertType to local Python client
PiperOrigin-RevId: 205479860
Diffstat (limited to 'tensorflow/compiler/xla/python')
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc5
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h3
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i1
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py12
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py28
5 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index f25348e735..8aefc4cd5e 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -486,6 +486,11 @@ LocalOp LocalComputationBuilder::ConvertElementType(
return xla::ConvertElementType(operand.op(), new_element_type);
}
+LocalOp LocalComputationBuilder::BitcastConvertType(
+ const LocalOp& operand, PrimitiveType new_element_type) {
+ return xla::BitcastConvertType(operand.op(), new_element_type);
+}
+
LocalOp LocalComputationBuilder::Call(
const LocalComputation& local_computation,
tensorflow::gtl::ArraySlice<LocalOp> operands) {
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 0e0d8ac29a..dd9e2fbe72 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -259,6 +259,9 @@ class LocalComputationBuilder {
LocalOp ConvertElementType(const LocalOp& operand,
PrimitiveType new_element_type);
+ LocalOp BitcastConvertType(const LocalOp& operand,
+ PrimitiveType new_element_type);
+
LocalOp Call(const LocalComputation& local_computation,
tensorflow::gtl::ArraySlice<LocalOp> operands);
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index eeccbd7cfa..9b8b0aa7f2 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -957,6 +957,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Tuple;
%unignore xla::swig::LocalComputationBuilder::GetTupleElement;
%unignore xla::swig::LocalComputationBuilder::ConvertElementType;
+%unignore xla::swig::LocalComputationBuilder::BitcastConvertType;
%unignore xla::swig::LocalComputationBuilder::Call;
%unignore xla::swig::LocalComputationBuilder::Transpose;
%unignore xla::swig::LocalComputationBuilder::Rev;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index ef043e4ca0..c0105b385b 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -721,6 +721,18 @@ class ComputationBuilder(object):
"""
return self._client.ConvertElementType(operand, new_element_type)
+ def BitcastConvertType(self, operand, new_element_type):
+ """Enqueues a bitcast type conversion operation onto the computation.
+
+ Args:
+ operand: the operand to convert.
+ new_element_type: the target primitive type.
+
+ Returns:
+ A LocalOp representing the added conversion op.
+ """
+ return self._client.BitcastConvertType(operand, new_element_type)
+
def GetShape(self, operand):
return _wrap_shape(self._client.GetShape(operand))
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index 93177aa647..fd98e19457 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -489,6 +489,34 @@ class SingleOpTest(LocalComputationTest):
for src_dtype, dst_dtype in itertools.product(xla_types, xla_types):
_ConvertAndTest(x, src_dtype, dst_dtype)
+ def testBitcastConvertType(self):
+ xla_x32_types = {
+ np.int32: xla_client.xla_data_pb2.S32,
+ np.float32: xla_client.xla_data_pb2.F32,
+ }
+
+ xla_x64_types = {
+ np.int64: xla_client.xla_data_pb2.S64,
+ np.float64: xla_client.xla_data_pb2.F64,
+ }
+
+ def _ConvertAndTest(template, src_dtype, dst_dtype, dst_etype):
+ c = self._NewComputation()
+ x = c.Constant(np.array(template, dtype=src_dtype))
+ c.BitcastConvertType(x, dst_etype)
+
+ result = c.Build().Compile().Execute()
+ expected = np.array(template, src_dtype).view(dst_dtype)
+
+ self.assertEqual(result.shape, expected.shape)
+ self.assertEqual(result.dtype, expected.dtype)
+ np.testing.assert_equal(result, expected)
+
+ x = [0, 1, 0, 0, 1]
+ for xla_types in [xla_x32_types, xla_x64_types]:
+ for src_dtype, dst_dtype in itertools.product(xla_types, xla_types):
+ _ConvertAndTest(x, src_dtype, dst_dtype, xla_types[dst_dtype])
+
def testCrossReplicaSumOneReplica(self):
samples = [
NumpyArrayF32(42.0),