aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-08-15 16:59:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-15 17:04:15 -0700
commita10219e1de775ca16281f1b597f7bf4d60d0585f (patch)
tree9a4785b611c256b46cc9929955020d3f2430f606 /tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
parentd4d93a84497a406bfaebb8176c699ae810bc5ff5 (diff)
Enable f64 convolutions for GPU backend. Currently, all layouts are NCHWs.
PiperOrigin-RevId: 208908539
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc50
1 files changed, 27 insertions, 23 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 0645fbb3ad..7b0d9e53d6 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -96,15 +96,9 @@ Status RunCudnnConvolution(
// tensorflow/python/ops/nn_ops.py).
const int effective_num_dimensions = std::max(2, num_dimensions);
- if (std::is_same<T, float>::value) {
- CHECK_EQ(F32, output_shape.element_type())
- << ShapeUtil::HumanString(output_shape);
- } else if (std::is_same<T, Eigen::half>::value) {
- CHECK_EQ(F16, output_shape.element_type())
- << ShapeUtil::HumanString(output_shape);
- } else {
- LOG(FATAL) << ShapeUtil::HumanString(output_shape);
- }
+ CHECK_EQ(primitive_util::NativeToPrimitiveType<T>(),
+ output_shape.element_type())
+ << ShapeUtil::HumanString(output_shape);
CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size());
CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size());
@@ -246,21 +240,31 @@ Status RunCudnnConvolution(
se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
PrimitiveType output_primitive_type = output_shape.element_type();
- CHECK(output_primitive_type == F32 || output_primitive_type == F16)
- << ShapeUtil::HumanString(output_shape);
- if (output_primitive_type == F32) {
- return RunCudnnConvolution(
- kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<float>(input_buf), se::DeviceMemory<float>(filter_buf),
- se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
- algorithm, stream, profile_result);
+ switch (output_primitive_type) {
+ case F16:
+ return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<Eigen::half>(input_buf),
+ se::DeviceMemory<Eigen::half>(filter_buf),
+ se::DeviceMemory<Eigen::half>(output_buf),
+ scratch_allocator, window, dnums, algorithm,
+ stream, profile_result);
+ case F32:
+ return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<float>(input_buf),
+ se::DeviceMemory<float>(filter_buf),
+ se::DeviceMemory<float>(output_buf),
+ scratch_allocator, window, dnums, algorithm,
+ stream, profile_result);
+ case F64:
+ return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<double>(input_buf),
+ se::DeviceMemory<double>(filter_buf),
+ se::DeviceMemory<double>(output_buf),
+ scratch_allocator, window, dnums, algorithm,
+ stream, profile_result);
+ default:
+ LOG(FATAL) << ShapeUtil::HumanString(output_shape);
}
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<Eigen::half>(input_buf),
- se::DeviceMemory<Eigen::half>(filter_buf),
- se::DeviceMemory<Eigen::half>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
}
} // namespace gpu