aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-10-04 06:00:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 06:04:16 -0700
commit2c9369c8d878c913b5dfcd3c27849bcd3d6af6c9 (patch)
tree98890e6a4b6bbcdc56ec0eab953777bd9cbbd4ad /tensorflow/compiler/tf2xla
parent28f239fdfa0c94f715fccf0197ab6c3c8df27d28 (diff)
[TF:XLA] Don't expand complex64 tensors during TF/XLA lowering, if possible.
PiperOrigin-RevId: 215724324
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/const_op.cc12
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc
index da8cf3fc6f..2628ef8e24 100644
--- a/tensorflow/compiler/tf2xla/kernels/const_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
namespace {
@@ -76,6 +77,17 @@ class ConstOp : public XlaOpKernel {
return;
}
break;
+ case DT_COMPLEX64:
+ if (proto_.scomplex_val_size() == 2) {
+ ctx->SetOutput(
+ 0,
+ xla::Broadcast(xla::ConstantR0<xla::complex64>(
+ b, xla::complex64(proto_.scomplex_val(0),
+ proto_.scomplex_val(1))),
+ shape.dim_sizes()));
+ return;
+ }
+ break;
case DT_INT32:
if (proto_.int_val_size() == 1) {
ctx->SetOutput(