diff options
Diffstat (limited to 'tensorflow/compiler/xla/python/xla_client.py')
-rw-r--r-- | tensorflow/compiler/xla/python/xla_client.py | 45 |
1 files changed, 39 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index abb97d0c6f..c0105b385b 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -99,12 +99,27 @@ _UNARY_OPS = [ 'Cos', 'Sin', 'Tanh', - 'SqrtF32', - 'SquareF32', 'IsFinite', - 'ReciprocalF32', + 'Sqrt', + 'Rsqrt', + 'Square', + 'Reciprocal', 'Neg', 'Sort', + 'Erf', + 'Erfc', + 'ErfInv', + 'Lgamma', + 'Digamma', + 'Acos', + 'Asin', + 'Atan', + 'Tan', + 'Acosh', + 'Asinh', + 'Atanh', + 'Cosh', + 'Sinh', ] _BINARY_OPS = [ @@ -125,6 +140,10 @@ _BINARY_OPS = [ 'Or', 'Xor', 'Pow', + 'ShiftLeft', + 'ShiftRightArithmetic', + 'ShiftRightLogical', + 'Atan2', ] @@ -461,14 +480,16 @@ class LocalComputation(object): if self.is_compiled: raise ValueError('Attempt to compile a compiled local XLA computation.') + result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape()) + if layout_fn: argument_shapes = [ shape.map_leaves(layout_fn) for shape in argument_shapes ] - result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape()) result_shape = result_shape.map_leaves(layout_fn) - compile_options = compile_options or CompileOptions() - compile_options.result_shape = result_shape + + compile_options = compile_options or CompileOptions() + compile_options.result_shape = result_shape return LocalComputation( self.c_local_computation.Compile(argument_shapes, compile_options), is_compiled=True) @@ -700,6 +721,18 @@ class ComputationBuilder(object): """ return self._client.ConvertElementType(operand, new_element_type) + def BitcastConvertType(self, operand, new_element_type): + """Enqueues a bitcast type conversion operation onto the computation. + + Args: + operand: the operand to convert. + new_element_type: the target primitive type. + + Returns: + A LocalOp representing the added conversion op. + """ + return self._client.BitcastConvertType(operand, new_element_type) + def GetShape(self, operand): return _wrap_shape(self._client.GetShape(operand)) |