aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-08-15 16:59:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-15 17:04:15 -0700
commita10219e1de775ca16281f1b597f7bf4d60d0585f (patch)
tree9a4785b611c256b46cc9929955020d3f2430f606
parentd4d93a84497a406bfaebb8176c699ae810bc5ff5 (diff)
Enable f64 convolutions for GPU backend. Currently, all layouts are NCHWs.
PiperOrigin-RevId: 208908539
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc50
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc40
3 files changed, 68 insertions, 23 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 0645fbb3ad..7b0d9e53d6 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -96,15 +96,9 @@ Status RunCudnnConvolution(
// tensorflow/python/ops/nn_ops.py).
const int effective_num_dimensions = std::max(2, num_dimensions);
- if (std::is_same<T, float>::value) {
- CHECK_EQ(F32, output_shape.element_type())
- << ShapeUtil::HumanString(output_shape);
- } else if (std::is_same<T, Eigen::half>::value) {
- CHECK_EQ(F16, output_shape.element_type())
- << ShapeUtil::HumanString(output_shape);
- } else {
- LOG(FATAL) << ShapeUtil::HumanString(output_shape);
- }
+ CHECK_EQ(primitive_util::NativeToPrimitiveType<T>(),
+ output_shape.element_type())
+ << ShapeUtil::HumanString(output_shape);
CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size());
CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size());
@@ -246,21 +240,31 @@ Status RunCudnnConvolution(
se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
PrimitiveType output_primitive_type = output_shape.element_type();
- CHECK(output_primitive_type == F32 || output_primitive_type == F16)
- << ShapeUtil::HumanString(output_shape);
- if (output_primitive_type == F32) {
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<float>(input_buf), se::DeviceMemory<float>(filter_buf),
- se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
- algorithm, stream, profile_result);
+ switch (output_primitive_type) {
+ case F16:
+ return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<Eigen::half>(input_buf),
+ se::DeviceMemory<Eigen::half>(filter_buf),
+ se::DeviceMemory<Eigen::half>(output_buf),
+ scratch_allocator, window, dnums, algorithm,
+ stream, profile_result);
+ case F32:
+ return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<float>(input_buf),
+ se::DeviceMemory<float>(filter_buf),
+ se::DeviceMemory<float>(output_buf),
+ scratch_allocator, window, dnums, algorithm,
+ stream, profile_result);
+ case F64:
+ return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<double>(input_buf),
+ se::DeviceMemory<double>(filter_buf),
+ se::DeviceMemory<double>(output_buf),
+ scratch_allocator, window, dnums, algorithm,
+ stream, profile_result);
+ default:
+ LOG(FATAL) << ShapeUtil::HumanString(output_shape);
}
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<Eigen::half>(input_buf),
- se::DeviceMemory<Eigen::half>(filter_buf),
- se::DeviceMemory<Eigen::half>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 0f8cffd466..76f2e519ae 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -813,6 +813,7 @@ CONVOLUTION_TEST_DEPS = [
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 5ed8122e00..e120adccae 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -765,5 +766,44 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
std::move(*LiteralUtil::CreateFromArray(filter_data))});
}
+class ConvolutionHloTest : public HloTestBase {};
+
+XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64Forward)) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+ENTRY Test {
+ %arg0 = f64[3,56,56,16] parameter(0)
+ %arg1 = f64[3,3,3,64] parameter(1)
+ ROOT %conv = f64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
+})";
+ EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
+}
+
+XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardFilter)) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+ENTRY Test {
+ %arg0 = f64[2,5,8,1] parameter(0)
+ %arg1 = f64[2,5,8,2] parameter(1)
+ ROOT %conv = f64[4,4,1,2] convolution(%arg0, %arg1), window={size=5x8 pad=1_2x1_2}, dim_labels=f01b_i01o->01bf
+})";
+ EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
+}
+
+XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardInput)) {
+ constexpr char kHlo[] = R"(
+HloModule TestModule
+
+ENTRY Test {
+ %output = f64[4,5,16,16] parameter(0)
+ %kernel = f64[5,3,7,7] parameter(1)
+ %reverse = f64[5,3,7,7] reverse(f64[5,3,7,7] %kernel), dimensions={2,3}
+ ROOT %convolution = f64[4,3,16,16] convolution(%output, %reverse), window={size=7x7 pad=3_3x3_3}, dim_labels=bf01_io01->bf01
+})";
+ EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
+}
+
} // namespace
} // namespace xla