aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/python/xla_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/python/xla_client.py')
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py45
1 files changed, 39 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index abb97d0c6f..c0105b385b 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -99,12 +99,27 @@ _UNARY_OPS = [
'Cos',
'Sin',
'Tanh',
- 'SqrtF32',
- 'SquareF32',
'IsFinite',
- 'ReciprocalF32',
+ 'Sqrt',
+ 'Rsqrt',
+ 'Square',
+ 'Reciprocal',
'Neg',
'Sort',
+ 'Erf',
+ 'Erfc',
+ 'ErfInv',
+ 'Lgamma',
+ 'Digamma',
+ 'Acos',
+ 'Asin',
+ 'Atan',
+ 'Tan',
+ 'Acosh',
+ 'Asinh',
+ 'Atanh',
+ 'Cosh',
+ 'Sinh',
]
_BINARY_OPS = [
@@ -125,6 +140,10 @@ _BINARY_OPS = [
'Or',
'Xor',
'Pow',
+ 'ShiftLeft',
+ 'ShiftRightArithmetic',
+ 'ShiftRightLogical',
+ 'Atan2',
]
@@ -461,14 +480,16 @@ class LocalComputation(object):
if self.is_compiled:
raise ValueError('Attempt to compile a compiled local XLA computation.')
+ result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape())
+
if layout_fn:
argument_shapes = [
shape.map_leaves(layout_fn) for shape in argument_shapes
]
- result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape())
result_shape = result_shape.map_leaves(layout_fn)
- compile_options = compile_options or CompileOptions()
- compile_options.result_shape = result_shape
+
+ compile_options = compile_options or CompileOptions()
+ compile_options.result_shape = result_shape
return LocalComputation(
self.c_local_computation.Compile(argument_shapes, compile_options),
is_compiled=True)
@@ -700,6 +721,18 @@ class ComputationBuilder(object):
"""
return self._client.ConvertElementType(operand, new_element_type)
+ def BitcastConvertType(self, operand, new_element_type):
+ """Enqueues a bitcast type conversion operation onto the computation.
+
+ Args:
+ operand: the operand to convert.
+ new_element_type: the target primitive type.
+
+ Returns:
+ A LocalOp representing the added conversion op.
+ """
+ return self._client.BitcastConvertType(operand, new_element_type)
+
def GetShape(self, operand):
return _wrap_shape(self._client.GetShape(operand))