diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-20 17:44:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-20 17:47:49 -0700 |
commit | b840e5ac84319e6e091a0f9351b7691390275f2f (patch) | |
tree | 539954adfd1203f35928452cc43bafc147922d6d /tensorflow/compiler/xla/python | |
parent | 57e5dfa76a32ff0ee6ec4b72a2461487b7969a3e (diff) |
[XLA] add BitcastConvertType to local Python client
PiperOrigin-RevId: 205479860
Diffstat (limited to 'tensorflow/compiler/xla/python')
5 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index f25348e735..8aefc4cd5e 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -486,6 +486,11 @@ LocalOp LocalComputationBuilder::ConvertElementType( return xla::ConvertElementType(operand.op(), new_element_type); } +LocalOp LocalComputationBuilder::BitcastConvertType( + const LocalOp& operand, PrimitiveType new_element_type) { + return xla::BitcastConvertType(operand.op(), new_element_type); +} + LocalOp LocalComputationBuilder::Call( const LocalComputation& local_computation, tensorflow::gtl::ArraySlice<LocalOp> operands) { diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 0e0d8ac29a..dd9e2fbe72 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -259,6 +259,9 @@ class LocalComputationBuilder { LocalOp ConvertElementType(const LocalOp& operand, PrimitiveType new_element_type); + LocalOp BitcastConvertType(const LocalOp& operand, + PrimitiveType new_element_type); + LocalOp Call(const LocalComputation& local_computation, tensorflow::gtl::ArraySlice<LocalOp> operands); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index eeccbd7cfa..9b8b0aa7f2 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -957,6 +957,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Tuple; %unignore xla::swig::LocalComputationBuilder::GetTupleElement; %unignore xla::swig::LocalComputationBuilder::ConvertElementType; +%unignore xla::swig::LocalComputationBuilder::BitcastConvertType; %unignore xla::swig::LocalComputationBuilder::Call; %unignore xla::swig::LocalComputationBuilder::Transpose; %unignore xla::swig::LocalComputationBuilder::Rev; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index ef043e4ca0..c0105b385b 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -721,6 +721,18 @@ class ComputationBuilder(object): """ return self._client.ConvertElementType(operand, new_element_type) + def BitcastConvertType(self, operand, new_element_type): + """Enqueues a bitcast type conversion operation onto the computation. + + Args: + operand: the operand to convert. + new_element_type: the target primitive type. + + Returns: + A LocalOp representing the added conversion op. + """ + return self._client.BitcastConvertType(operand, new_element_type) + def GetShape(self, operand): return _wrap_shape(self._client.GetShape(operand)) diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 93177aa647..fd98e19457 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -489,6 +489,34 @@ class SingleOpTest(LocalComputationTest): for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): _ConvertAndTest(x, src_dtype, dst_dtype) + def testBitcastConvertType(self): + xla_x32_types = { + np.int32: xla_client.xla_data_pb2.S32, + np.float32: xla_client.xla_data_pb2.F32, + } + + xla_x64_types = { + np.int64: xla_client.xla_data_pb2.S64, + np.float64: xla_client.xla_data_pb2.F64, + } + + def _ConvertAndTest(template, src_dtype, dst_dtype, dst_etype): + c = self._NewComputation() + x = c.Constant(np.array(template, dtype=src_dtype)) + c.BitcastConvertType(x, dst_etype) + + result = c.Build().Compile().Execute() + expected = np.array(template, src_dtype).view(dst_dtype) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(result.dtype, expected.dtype) + np.testing.assert_equal(result, expected) + + x = [0, 1, 0, 0, 1] + for xla_types in [xla_x32_types, xla_x64_types]: + for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): + _ConvertAndTest(x, src_dtype, dst_dtype, xla_types[dst_dtype]) + def testCrossReplicaSumOneReplica(self): samples = [ NumpyArrayF32(42.0), |