diff options
author | 2018-08-03 11:34:31 -0700 | |
---|---|---|
committer | 2018-08-03 11:38:37 -0700 | |
commit | 4e4171bb6fcc08c00dbc6f6ae2dcdc502add6931 (patch) | |
tree | d19479bbcb5b833e610fee8f03e2c0b34273ccc0 | |
parent | 2da1d6ccce8bf9725e9e05a802f1228b6d9abb58 (diff) |
[XLA:GPU] cuBlas supports complex floats, use gemm instead of our O(n^3) implementation
Also increase test coverage for C64 a bit.
PiperOrigin-RevId: 207297946
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gemm_thunk.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/dot_operation_test.cc | 20 |
3 files changed, 19 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index a300d5f3fe..74282c568c 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -201,6 +201,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm<float>) { return &DoGemm<float>; case F64: return &DoGemm<double>; + case C64: + return &DoGemm<std::complex<float>>; default: LOG(FATAL) << "Unsupported type."; } @@ -214,6 +216,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type) return &DoGemmWithAlgorithm<float>; case F64: return &DoGemmWithAlgorithm<double>; + case C64: + return &DoGemmWithAlgorithm<std::complex<float>>; default: LOG(FATAL) << "Unsupported type."; } @@ -226,6 +230,8 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune<float>) { return &DoGemmAutotune<float>; case F64: return &DoGemmAutotune<double>; + case C64: + return &DoGemmAutotune<std::complex<float>>; default: LOG(FATAL) << "Unsupported type."; } @@ -244,6 +250,8 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { return se::blas::ComputationType::kF32; case F64: return se::blas::ComputationType::kF64; + case C64: + return se::blas::ComputationType::kComplexF32; default: LOG(FATAL) << "Unsupported type."; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index d74c1a0243..c349063c71 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -54,7 +54,7 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, PrimitiveType output_primitive_type = output_shape.element_type(); bool type_is_allowed = (output_primitive_type == F16 || output_primitive_type == F32 || - output_primitive_type == F64); + output_primitive_type == F64 || output_primitive_type == C64); return type_is_allowed && IsRank2WithNoPadding(lhs_shape, batch_dimensions_size) && IsRank2WithNoPadding(rhs_shape, batch_dimensions_size) && diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index f11d274aab..0e9e92ed99 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -111,7 +111,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, OneElementVectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR1<T>(&builder, {static_cast<T>(2.0f)}); @@ -137,7 +137,7 @@ std::vector<int64> MinorToMajorForIsRowMajor(bool row_major) { return {row_major ? 1 : 0, row_major ? 0 : 1}; } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x0) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2)); @@ -148,7 +148,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x3) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2)); @@ -160,7 +160,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_3x2_2x0) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D<T>( @@ -172,7 +172,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_2x0_0x2) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0)); @@ -183,7 +183,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { &builder, Array2D<T>(2, 2, static_cast<T>(0.0f)), {}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto param0 = @@ -533,7 +533,7 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ConcurrentMatMult) { using T = TypeParam; XlaBuilder builder(this->TestName()); @@ -690,7 +690,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { {x_data.get(), y_data.get()}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) { using T = TypeParam; for (bool transpose_lhs : {false, true}) { for (bool transpose_rhs : {false, true}) { @@ -750,7 +750,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { } } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, DotOfConcatOptimizationWithConstLHS) { using T = TypeParam; auto prim_type = primitive_util::NativeToPrimitiveType<T>(); @@ -796,7 +796,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, DotOfConcatOptimizationWithConstRHS) { using T = TypeParam; std::unique_ptr<Array2D<T>> constant_rhs_array( |