diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-31 18:30:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-31 18:34:39 -0700 |
commit | a738649f8092dda124b6a4f3ccc31bf4159651ea (patch) | |
tree | c5959939b8a1f85dbc9e32f6be9cfd64bc682565 | |
parent | 855b0b403643cfef7b6b8e4543a070a512a389fa (diff) |
[XLA] add complex number operations to the local Python client
PiperOrigin-RevId: 206863335
5 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 434d78d78d..8246f76d34 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -624,6 +624,7 @@ _FORWARD_BINOP(ShiftRightArithmetic) _FORWARD_BINOP(ShiftRightLogical) _FORWARD_BINOP(Atan2) _FORWARD_BINOP(Pow) +_FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) @@ -658,6 +659,9 @@ _FORWARD_UNOP(Asinh) _FORWARD_UNOP(Atanh) _FORWARD_UNOP(Cosh) _FORWARD_UNOP(Sinh) +_FORWARD_UNOP(Real) +_FORWARD_UNOP(Imag) +_FORWARD_UNOP(Conj) #undef _FORWARD #undef _FORWARD_UNOP diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 545aa63f9d..a568c24c63 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -341,6 +341,7 @@ class LocalComputationBuilder { _FORWARD_BINOP(ShiftRightLogical) _FORWARD_BINOP(Atan2) _FORWARD_BINOP(Pow) + _FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) @@ -375,6 +376,9 @@ class LocalComputationBuilder { _FORWARD_UNOP(Atanh) _FORWARD_UNOP(Cosh) _FORWARD_UNOP(Sinh) + _FORWARD_UNOP(Real) + _FORWARD_UNOP(Imag) + _FORWARD_UNOP(Conj) #undef _FORWARD #undef _FORWARD_UNOP diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 9b8b0aa7f2..5d5a955bfe 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -1029,6 +1029,10 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Atanh; %unignore xla::swig::LocalComputationBuilder::Cosh; %unignore xla::swig::LocalComputationBuilder::Sinh; +%unignore xla::swig::LocalComputationBuilder::Real; +%unignore xla::swig::LocalComputationBuilder::Imag; +%unignore xla::swig::LocalComputationBuilder::Conj; +%unignore xla::swig::LocalComputationBuilder::Complex; %unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DeleteLocalShapedBuffer; %unignore xla::swig::DeleteLocalComputation; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 71351abd59..6f665faf61 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -50,6 +50,8 @@ int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { return NPY_FLOAT32; case F64: return NPY_FLOAT64; + case C64: + return NPY_COMPLEX64; case TUPLE: return NPY_OBJECT; default: @@ -83,6 +85,8 @@ PrimitiveType NumpyTypeToPrimitiveType(int np_type) { return F32; case NPY_FLOAT64: return F64; + case NPY_COMPLEX64: + return C64; case NPY_OBJECT: return TUPLE; default: @@ -104,6 +108,7 @@ bool NumpyTypeIsValid(int np_type) { case NPY_FLOAT16: case NPY_FLOAT32: case NPY_FLOAT64: + case NPY_COMPLEX64: case NPY_OBJECT: return true; default: @@ -425,6 +430,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_FLOAT64: CopyNumpyArrayToLiteral<double>(py_array, literal); break; + case NPY_COMPLEX64: + CopyNumpyArrayToLiteral<complex64>(py_array, literal); + break; default: return InvalidArgument( "No XLA literal container for Numpy type number: %d", np_type); @@ -462,6 +470,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, case NPY_FLOAT64: CopyLiteralToNumpyArray<double>(literal, py_array); break; + case NPY_COMPLEX64: + CopyLiteralToNumpyArray<complex64>(literal, py_array); + break; default: LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; } diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index c0105b385b..a2c6fc344d 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -120,6 +120,9 @@ _UNARY_OPS = [ 'Atanh', 'Cosh', 'Sinh', + 'Real', + 'Imag', + 'Conj', ] _BINARY_OPS = [ @@ -144,6 +147,7 @@ _BINARY_OPS = [ 'ShiftRightArithmetic', 'ShiftRightLogical', 'Atan2', + 'Complex', ] |