aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-08-03 11:34:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-03 11:38:37 -0700
commit4e4171bb6fcc08c00dbc6f6ae2dcdc502add6931 (patch)
treed19479bbcb5b833e610fee8f03e2c0b34273ccc0
parent2da1d6ccce8bf9725e9e05a802f1228b6d9abb58 (diff)
[XLA:GPU] cuBlas supports complex floats, use gemm instead of our O(n^3) implementation
Also increase test coverage for C64 a bit. PiperOrigin-RevId: 207297946
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc2
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc20
3 files changed, 19 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index a300d5f3fe..74282c568c 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -201,6 +201,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm<float>) {
return &DoGemm<float>;
case F64:
return &DoGemm<double>;
+ case C64:
+ return &DoGemm<std::complex<float>>;
default:
LOG(FATAL) << "Unsupported type.";
}
@@ -214,6 +216,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type)
return &DoGemmWithAlgorithm<float>;
case F64:
return &DoGemmWithAlgorithm<double>;
+ case C64:
+ return &DoGemmWithAlgorithm<std::complex<float>>;
default:
LOG(FATAL) << "Unsupported type.";
}
@@ -226,6 +230,8 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune<float>) {
return &DoGemmAutotune<float>;
case F64:
return &DoGemmAutotune<double>;
+ case C64:
+ return &DoGemmAutotune<std::complex<float>>;
default:
LOG(FATAL) << "Unsupported type.";
}
@@ -244,6 +250,8 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) {
return se::blas::ComputationType::kF32;
case F64:
return se::blas::ComputationType::kF64;
+ case C64:
+ return se::blas::ComputationType::kComplexF32;
default:
LOG(FATAL) << "Unsupported type.";
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index d74c1a0243..c349063c71 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -54,7 +54,7 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
PrimitiveType output_primitive_type = output_shape.element_type();
bool type_is_allowed =
(output_primitive_type == F16 || output_primitive_type == F32 ||
- output_primitive_type == F64);
+ output_primitive_type == F64 || output_primitive_type == C64);
return type_is_allowed &&
IsRank2WithNoPadding(lhs_shape, batch_dimensions_size) &&
IsRank2WithNoPadding(rhs_shape, batch_dimensions_size) &&
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index f11d274aab..0e9e92ed99 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -111,7 +111,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) {
this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, OneElementVectorDot) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
auto lhs = ConstantR1<T>(&builder, {static_cast<T>(2.0f)});
@@ -137,7 +137,7 @@ std::vector<int64> MinorToMajorForIsRowMajor(bool row_major) {
return {row_major ? 1 : 0, row_major ? 0 : 1};
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x0) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
@@ -148,7 +148,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) {
this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x3) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
@@ -160,7 +160,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) {
this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_3x2_2x0) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
auto lhs = ConstantR2FromArray2D<T>(
@@ -172,7 +172,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) {
this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_2x0_0x2) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
@@ -183,7 +183,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) {
&builder, Array2D<T>(2, 2, static_cast<T>(0.0f)), {}, this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
auto param0 =
@@ -533,7 +533,7 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
&builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ConcurrentMatMult) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
@@ -690,7 +690,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
{x_data.get(), y_data.get()}, this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) {
using T = TypeParam;
for (bool transpose_lhs : {false, true}) {
for (bool transpose_rhs : {false, true}) {
@@ -750,7 +750,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) {
}
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64,
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
DotOfConcatOptimizationWithConstLHS) {
using T = TypeParam;
auto prim_type = primitive_util::NativeToPrimitiveType<T>();
@@ -796,7 +796,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64,
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
DotOfConcatOptimizationWithConstRHS) {
using T = TypeParam;
std::unique_ptr<Array2D<T>> constant_rhs_array(