aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/cuda/cuda_dnn.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_dnn.cc')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc54
1 files changed, 30 insertions, 24 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 3c533c7f99..63ab367086 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -149,6 +149,16 @@ cudnnDataType_t GetCudnnDataType<Eigen::half>() {
return CUDNN_DATA_HALF;
}
+template <>
+cudnnDataType_t GetCudnnDataType<int8>() {
+ return CUDNN_DATA_INT8;
+}
+
+template <>
+cudnnDataType_t GetCudnnDataType<int32>() {
+ return CUDNN_DATA_INT32;
+}
+
// RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
//
// See CudnnAccess::GetHandle() for details.
@@ -2486,19 +2496,19 @@ port::Status CudnnSupport::DoConvolveImpl(
return port::Status::OK();
}
-template <typename Type, typename BiasType, typename ScaleType,
- int cudnn_data_type, int cudnn_compute_type>
+template <typename AccumulatorType, typename ElementType, typename BiasType,
+ typename ScaleType>
port::Status CudnnSupport::DoFusedConvolveImpl(
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
- const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
- const dnn::FilterDescriptor& filter_descriptor,
- const DeviceMemory<Type>& filter_data,
+ const DeviceMemory<ElementType>& conv_input_data,
+ ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor,
+ const DeviceMemory<ElementType>& filter_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
- const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale,
- const dnn::BatchDescriptor& bias_descriptor,
+ const DeviceMemory<ElementType>& side_input_data,
+ ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor,
const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor& output_descriptor,
- DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
+ DeviceMemory<ElementType>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
if (activation_mode != dnn::ActivationMode::kRelu &&
@@ -2508,15 +2518,15 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
"Relu or None activation.");
}
- CudnnTensorDescriptor conv_input_nd(
- conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
- CudnnTensorDescriptor output_nd(
- output_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
+ CudnnTensorDescriptor conv_input_nd(conv_input_descriptor,
+ GetCudnnDataType<ElementType>());
+ CudnnTensorDescriptor output_nd(output_descriptor,
+ GetCudnnDataType<ElementType>());
CudnnFilterDescriptor filter(filter_descriptor,
- static_cast<cudnnDataType_t>(cudnn_data_type));
- CudnnTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT);
- CudnnConvolutionDescriptor conv(
- convolution_descriptor, static_cast<cudnnDataType_t>(cudnn_compute_type));
+ GetCudnnDataType<ElementType>());
+ CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>());
+ CudnnConvolutionDescriptor conv(convolution_descriptor,
+ GetCudnnDataType<AccumulatorType>());
auto cudnn = cudnn_->GetHandle(parent_, stream);
@@ -2933,8 +2943,7 @@ bool CudnnSupport::DoFusedConvolve(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
return IsStatusOk(
- DoFusedConvolveImpl<double, double, double, CUDNN_DATA_DOUBLE,
- CUDNN_DATA_DOUBLE>(
+ DoFusedConvolveImpl<double>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -2957,8 +2966,7 @@ bool CudnnSupport::DoFusedConvolve(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
return IsStatusOk(
- DoFusedConvolveImpl<float, float, float, CUDNN_DATA_FLOAT,
- CUDNN_DATA_FLOAT>(
+ DoFusedConvolveImpl<float>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -2982,8 +2990,7 @@ bool CudnnSupport::DoFusedConvolve(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
return IsStatusOk(
- DoFusedConvolveImpl<Eigen::half, Eigen::half, float, CUDNN_DATA_HALF,
- CUDNN_DATA_FLOAT>(
+ DoFusedConvolveImpl<float>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -3014,8 +3021,7 @@ bool CudnnSupport::DoFusedConvolve(
return false;
}
return IsStatusOk(
- DoFusedConvolveImpl<int8, float, float, CUDNN_DATA_INT8x4,
- CUDNN_DATA_INT32>(
+ DoFusedConvolveImpl<int32>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,