aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-27 09:00:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-27 09:04:01 -0700
commit4198e27be8115585ad6b5b141383fb7dc7856c24 (patch)
tree244405e6ef96cb098d8abbf2547a8f22dfb4c72d /tensorflow/compiler/xla/service/shape_inference_test.cc
parent4ae245a7db3d0457c4324ee7df8d020ba83b3c60 (diff)
[XLA:CPU] [XLA:GPU] Adds compiler support for C64 primitive type, including relevant elementwise unary and binary op lowering for CPU and GPU.
We use a named LLVM struct "complex64", laid out the same as std::complex<float>. This named struct is accessed via the llvm::Module, which required changes to accessors of PrimitiveTypeToIrType & friends. Ops that require atan2 (in particular, angle and log) are only supported on GPU at this point. LLVM lacks a CPU intrinsic for atan or atan2, whereas libdevice provides this for GPU. PiperOrigin-RevId: 173676849
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc39
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 8df4a73229..d12f7bd145 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -35,6 +35,7 @@ class ShapeInferenceTest : public ::testing::Test {
// Some handy scalar shapes.
const Shape s32_ = ShapeUtil::MakeShape(S32, {});
const Shape f32_ = ShapeUtil::MakeShape(F32, {});
+ const Shape f64_ = ShapeUtil::MakeShape(F64, {});
const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
// Some handy vector and matrix shapes of F32 type.
@@ -251,6 +252,44 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) {
.ok());
}
+TEST_F(ShapeInferenceTest, Complex) {
+ auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
+ const tensorflow::gtl::ArraySlice<int64>& bcast) {
+ return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX,
+ lhs, rhs, bcast);
+ };
+ // Inputs must be FP.
+ ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
+ ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
+ // Component types must match.
+ ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
+ // Only F32->C64 supported.
+ ASSERT_FALSE(complex_shape(f64_, f64_, {}).ok());
+ // Validate correct uses.
+ Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
+ TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
+ ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {})));
+ TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
+ ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
+ TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {}));
+ ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
+ TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
+ ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
+
+ Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64});
+ TF_ASSERT_OK_AND_ASSIGN(result,
+ complex_shape(vector_64_, matrix_32_64_, {1}));
+ ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
+ TF_ASSERT_OK_AND_ASSIGN(result,
+ complex_shape(matrix_32_64_, vector_64_, {1}));
+ ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
+ TF_ASSERT_OK_AND_ASSIGN(result,
+ complex_shape(matrix_32_64_, matrix_32_64_, {}));
+ ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
+ TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
+ ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
+}
+
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});