diff options
author | Tim Shen <timshen@google.com> | 2018-08-15 16:59:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-15 17:04:15 -0700 |
commit | a10219e1de775ca16281f1b597f7bf4d60d0585f (patch) | |
tree | 9a4785b611c256b46cc9929955020d3f2430f606 /tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc | |
parent | d4d93a84497a406bfaebb8176c699ae810bc5ff5 (diff) |
Enable f64 convolutions for GPU backend. Currently, all layouts are NCHWs.
PiperOrigin-RevId: 208908539
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc | 50 |
1 files changed, 27 insertions, 23 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 0645fbb3ad..7b0d9e53d6 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -96,15 +96,9 @@ Status RunCudnnConvolution( // tensorflow/python/ops/nn_ops.py). const int effective_num_dimensions = std::max(2, num_dimensions); - if (std::is_same<T, float>::value) { - CHECK_EQ(F32, output_shape.element_type()) - << ShapeUtil::HumanString(output_shape); - } else if (std::is_same<T, Eigen::half>::value) { - CHECK_EQ(F16, output_shape.element_type()) - << ShapeUtil::HumanString(output_shape); - } else { - LOG(FATAL) << ShapeUtil::HumanString(output_shape); - } + CHECK_EQ(primitive_util::NativeToPrimitiveType<T>(), + output_shape.element_type()) + << ShapeUtil::HumanString(output_shape); CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); @@ -246,21 +240,31 @@ Status RunCudnnConvolution( se::dnn::AlgorithmConfig algorithm, se::Stream* stream, se::dnn::ProfileResult* profile_result) { PrimitiveType output_primitive_type = output_shape.element_type(); - CHECK(output_primitive_type == F32 || output_primitive_type == F16) - << ShapeUtil::HumanString(output_shape); - if (output_primitive_type == F32) { - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<float>(input_buf), se::DeviceMemory<float>(filter_buf), - se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums, - algorithm, stream, profile_result); + switch (output_primitive_type) { + case F16: + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + se::DeviceMemory<Eigen::half>(input_buf), + se::DeviceMemory<Eigen::half>(filter_buf), + se::DeviceMemory<Eigen::half>(output_buf), + scratch_allocator, window, dnums, algorithm, + stream, profile_result); + case F32: + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + se::DeviceMemory<float>(input_buf), + se::DeviceMemory<float>(filter_buf), + se::DeviceMemory<float>(output_buf), + scratch_allocator, window, dnums, algorithm, + stream, profile_result); + case F64: + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + se::DeviceMemory<double>(input_buf), + se::DeviceMemory<double>(filter_buf), + se::DeviceMemory<double>(output_buf), + scratch_allocator, window, dnums, algorithm, + stream, profile_result); + default: + LOG(FATAL) << ShapeUtil::HumanString(output_shape); } - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<Eigen::half>(input_buf), - se::DeviceMemory<Eigen::half>(filter_buf), - se::DeviceMemory<Eigen::half>(output_buf), - scratch_allocator, window, dnums, algorithm, - stream, profile_result); } } // namespace gpu |