diff options
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_dnn.cc')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_dnn.cc | 54 |
1 files changed, 30 insertions, 24 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 3c533c7f99..63ab367086 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -149,6 +149,16 @@ cudnnDataType_t GetCudnnDataType<Eigen::half>() { return CUDNN_DATA_HALF; } +template <> +cudnnDataType_t GetCudnnDataType<int8>() { + return CUDNN_DATA_INT8; +} + +template <> +cudnnDataType_t GetCudnnDataType<int32>() { + return CUDNN_DATA_INT32; +} + // RAII wrapper for all calls to cuDNN with a cuDNN handle argument. // // See CudnnAccess::GetHandle() for details. @@ -2486,19 +2496,19 @@ port::Status CudnnSupport::DoConvolveImpl( return port::Status::OK(); } -template <typename Type, typename BiasType, typename ScaleType, - int cudnn_data_type, int cudnn_compute_type> +template <typename AccumulatorType, typename ElementType, typename BiasType, + typename ScaleType> port::Status CudnnSupport::DoFusedConvolveImpl( Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor, - const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale, - const dnn::FilterDescriptor& filter_descriptor, - const DeviceMemory<Type>& filter_data, + const DeviceMemory<ElementType>& conv_input_data, + ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory<ElementType>& filter_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale, - const dnn::BatchDescriptor& bias_descriptor, + const DeviceMemory<ElementType>& side_input_data, + ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor, const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode, const dnn::BatchDescriptor& output_descriptor, - DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator, + DeviceMemory<ElementType>* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { if (activation_mode != dnn::ActivationMode::kRelu && @@ -2508,15 +2518,15 @@ port::Status CudnnSupport::DoFusedConvolveImpl( "Relu or None activation."); } - CudnnTensorDescriptor conv_input_nd( - conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type)); - CudnnTensorDescriptor output_nd( - output_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type)); + CudnnTensorDescriptor conv_input_nd(conv_input_descriptor, + GetCudnnDataType<ElementType>()); + CudnnTensorDescriptor output_nd(output_descriptor, + GetCudnnDataType<ElementType>()); CudnnFilterDescriptor filter(filter_descriptor, - static_cast<cudnnDataType_t>(cudnn_data_type)); - CudnnTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT); - CudnnConvolutionDescriptor conv( - convolution_descriptor, static_cast<cudnnDataType_t>(cudnn_compute_type)); + GetCudnnDataType<ElementType>()); + CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>()); + CudnnConvolutionDescriptor conv(convolution_descriptor, + GetCudnnDataType<AccumulatorType>()); auto cudnn = cudnn_->GetHandle(parent_, stream); @@ -2933,8 +2943,7 @@ bool CudnnSupport::DoFusedConvolve( const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return IsStatusOk( - DoFusedConvolveImpl<double, double, double, CUDNN_DATA_DOUBLE, - CUDNN_DATA_DOUBLE>( + DoFusedConvolveImpl<double>( stream, conv_input_descriptor, conv_input_data, conv_input_scale, filter_descriptor, filter_data, convolution_descriptor, side_input_data, side_input_scale, bias_descriptor, biases, @@ -2957,8 +2966,7 @@ bool CudnnSupport::DoFusedConvolve( const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return IsStatusOk( - DoFusedConvolveImpl<float, float, float, CUDNN_DATA_FLOAT, - CUDNN_DATA_FLOAT>( + DoFusedConvolveImpl<float>( stream, conv_input_descriptor, conv_input_data, conv_input_scale, filter_descriptor, filter_data, convolution_descriptor, side_input_data, side_input_scale, bias_descriptor, biases, @@ -2982,8 +2990,7 @@ bool CudnnSupport::DoFusedConvolve( const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return IsStatusOk( - DoFusedConvolveImpl<Eigen::half, Eigen::half, float, CUDNN_DATA_HALF, - CUDNN_DATA_FLOAT>( + DoFusedConvolveImpl<float>( stream, conv_input_descriptor, conv_input_data, conv_input_scale, filter_descriptor, filter_data, convolution_descriptor, side_input_data, side_input_scale, bias_descriptor, biases, @@ -3014,8 +3021,7 @@ bool CudnnSupport::DoFusedConvolve( return false; } return IsStatusOk( - DoFusedConvolveImpl<int8, float, float, CUDNN_DATA_INT8x4, - CUDNN_DATA_INT32>( + DoFusedConvolveImpl<int32>( stream, conv_input_descriptor, conv_input_data, conv_input_scale, filter_descriptor, filter_data, convolution_descriptor, side_input_data, side_input_scale, bias_descriptor, biases, |