aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-10-04 06:00:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 06:04:16 -0700
commit2c9369c8d878c913b5dfcd3c27849bcd3d6af6c9 (patch)
tree98890e6a4b6bbcdc56ec0eab953777bd9cbbd4ad /tensorflow/compiler
parent28f239fdfa0c94f715fccf0197ab6c3c8df27d28 (diff)
[TF:XLA] Don't expand complex64 tensors during TF/XLA lowering, if possible.
PiperOrigin-RevId: 215724324
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tests/nullary_ops_test.py43
-rw-r--r--tensorflow/compiler/tf2xla/kernels/const_op.cc12
2 files changed, 43 insertions, 12 deletions
diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py
index f985c5d2d9..38cb2f83ef 100644
--- a/tensorflow/compiler/tests/nullary_ops_test.py
+++ b/tensorflow/compiler/tests/nullary_ops_test.py
@@ -43,18 +43,37 @@ class NullaryOpsTest(xla_test.XLATestCase):
output.run()
def testConstants(self):
- constants = [
- np.float32(42),
- np.array([], dtype=np.float32),
- np.array([1, 2], dtype=np.float32),
- np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32),
- np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]],
- dtype=np.float32),
- np.array([[[]], [[]]], dtype=np.float32),
- np.array([[[[1]]]], dtype=np.float32),
- ]
- for c in constants:
- self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
+ for dtype in self.numeric_types:
+ constants = [
+ dtype(42),
+ np.array([], dtype=dtype),
+ np.array([1, 2], dtype=dtype),
+ np.array([7, 7, 7, 7, 7], dtype=dtype),
+ np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype),
+ np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]],
+ dtype=dtype),
+ np.array([[[]], [[]]], dtype=dtype),
+ np.array([[[[1]]]], dtype=dtype),
+ ]
+ for c in constants:
+ self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
+
+ def testComplexConstants(self):
+ for dtype in self.complex_types:
+ constants = [
+ dtype(42 + 3j),
+ np.array([], dtype=dtype),
+ np.ones([50], dtype=dtype) * (3 + 4j),
+ np.array([1j, 2 + 1j], dtype=dtype),
+ np.array([[1, 2j, 7j], [4, 5, 6]], dtype=dtype),
+ np.array([[[1, 2], [3, 4 + 6j], [5, 6]],
+ [[10 + 7j, 20], [30, 40], [50, 60]]],
+ dtype=dtype),
+ np.array([[[]], [[]]], dtype=dtype),
+ np.array([[[[1 + 3j]]]], dtype=dtype),
+ ]
+ for c in constants:
+ self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc
index da8cf3fc6f..2628ef8e24 100644
--- a/tensorflow/compiler/tf2xla/kernels/const_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
namespace {
@@ -76,6 +77,17 @@ class ConstOp : public XlaOpKernel {
return;
}
break;
+ case DT_COMPLEX64:
+ if (proto_.scomplex_val_size() == 2) {
+ ctx->SetOutput(
+ 0,
+ xla::Broadcast(xla::ConstantR0<xla::complex64>(
+ b, xla::complex64(proto_.scomplex_val(0),
+ proto_.scomplex_val(1))),
+ shape.dim_sizes()));
+ return;
+ }
+ break;
case DT_INT32:
if (proto_.int_val_size() == 1) {
ctx->SetOutput(