diff options
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc | 50 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/convolution_test.cc | 40 |
3 files changed, 68 insertions, 23 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 0645fbb3ad..7b0d9e53d6 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -96,15 +96,9 @@ Status RunCudnnConvolution( // tensorflow/python/ops/nn_ops.py). const int effective_num_dimensions = std::max(2, num_dimensions); - if (std::is_same<T, float>::value) { - CHECK_EQ(F32, output_shape.element_type()) - << ShapeUtil::HumanString(output_shape); - } else if (std::is_same<T, Eigen::half>::value) { - CHECK_EQ(F16, output_shape.element_type()) - << ShapeUtil::HumanString(output_shape); - } else { - LOG(FATAL) << ShapeUtil::HumanString(output_shape); - } + CHECK_EQ(primitive_util::NativeToPrimitiveType<T>(), + output_shape.element_type()) + << ShapeUtil::HumanString(output_shape); CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); @@ -246,21 +240,31 @@ Status RunCudnnConvolution( se::dnn::AlgorithmConfig algorithm, se::Stream* stream, se::dnn::ProfileResult* profile_result) { PrimitiveType output_primitive_type = output_shape.element_type(); - CHECK(output_primitive_type == F32 || output_primitive_type == F16) - << ShapeUtil::HumanString(output_shape); - if (output_primitive_type == F32) { - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<float>(input_buf), se::DeviceMemory<float>(filter_buf), - se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums, - algorithm, stream, profile_result); + switch (output_primitive_type) { + case F16: + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + se::DeviceMemory<Eigen::half>(input_buf), + se::DeviceMemory<Eigen::half>(filter_buf), + se::DeviceMemory<Eigen::half>(output_buf), + scratch_allocator, window, dnums, algorithm, + stream, profile_result); + case F32: + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + se::DeviceMemory<float>(input_buf), + se::DeviceMemory<float>(filter_buf), + se::DeviceMemory<float>(output_buf), + scratch_allocator, window, dnums, algorithm, + stream, profile_result); + case F64: + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + se::DeviceMemory<double>(input_buf), + se::DeviceMemory<double>(filter_buf), + se::DeviceMemory<double>(output_buf), + scratch_allocator, window, dnums, algorithm, + stream, profile_result); + default: + LOG(FATAL) << ShapeUtil::HumanString(output_shape); } - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<Eigen::half>(input_buf), - se::DeviceMemory<Eigen::half>(filter_buf), - se::DeviceMemory<Eigen::half>(output_buf), - scratch_allocator, window, dnums, algorithm, - stream, profile_result); } } // namespace gpu diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 0f8cffd466..76f2e519ae 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -813,6 +813,7 @@ CONVOLUTION_TEST_DEPS = [ "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 5ed8122e00..e120adccae 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -765,5 +766,44 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { std::move(*LiteralUtil::CreateFromArray(filter_data))}); } +class ConvolutionHloTest : public HloTestBase {}; + +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64Forward)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %arg0 = f64[3,56,56,16] parameter(0) + %arg1 = f64[3,3,3,64] parameter(1) + ROOT %conv = f64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardFilter)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %arg0 = f64[2,5,8,1] parameter(0) + %arg1 = f64[2,5,8,2] parameter(1) + ROOT %conv = f64[4,4,1,2] convolution(%arg0, %arg1), window={size=5x8 pad=1_2x1_2}, dim_labels=f01b_i01o->01bf +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardInput)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %output = f64[4,5,16,16] parameter(0) + %kernel = f64[5,3,7,7] parameter(1) + %reverse = f64[5,3,7,7] reverse(f64[5,3,7,7] %kernel), dimensions={2,3} + ROOT %convolution = f64[4,3,16,16] convolution(%output, %reverse), window={size=7x7 pad=3_3x3_3}, dim_labels=bf01_io01->bf01 +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + } // namespace } // namespace xla |