diff options
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_dnn.cc | 54 | ||||
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_dnn.h | 16 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 38 |
3 files changed, 77 insertions, 31 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 3c533c7f99..63ab367086 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -149,6 +149,16 @@ cudnnDataType_t GetCudnnDataType<Eigen::half>() { return CUDNN_DATA_HALF; } +template <> +cudnnDataType_t GetCudnnDataType<int8>() { + return CUDNN_DATA_INT8; +} + +template <> +cudnnDataType_t GetCudnnDataType<int32>() { + return CUDNN_DATA_INT32; +} + // RAII wrapper for all calls to cuDNN with a cuDNN handle argument. // // See CudnnAccess::GetHandle() for details. @@ -2486,19 +2496,19 @@ port::Status CudnnSupport::DoConvolveImpl( return port::Status::OK(); } -template <typename Type, typename BiasType, typename ScaleType, - int cudnn_data_type, int cudnn_compute_type> +template <typename AccumulatorType, typename ElementType, typename BiasType, + typename ScaleType> port::Status CudnnSupport::DoFusedConvolveImpl( Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor, - const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale, - const dnn::FilterDescriptor& filter_descriptor, - const DeviceMemory<Type>& filter_data, + const DeviceMemory<ElementType>& conv_input_data, + ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory<ElementType>& filter_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale, - const dnn::BatchDescriptor& bias_descriptor, + const DeviceMemory<ElementType>& side_input_data, + ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor, const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode, const dnn::BatchDescriptor& output_descriptor, - DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator, + DeviceMemory<ElementType>* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { if (activation_mode != dnn::ActivationMode::kRelu && @@ -2508,15 +2518,15 @@ port::Status CudnnSupport::DoFusedConvolveImpl( "Relu or None activation."); } - CudnnTensorDescriptor conv_input_nd( - conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type)); - CudnnTensorDescriptor output_nd( - output_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type)); + CudnnTensorDescriptor conv_input_nd(conv_input_descriptor, + GetCudnnDataType<ElementType>()); + CudnnTensorDescriptor output_nd(output_descriptor, + GetCudnnDataType<ElementType>()); CudnnFilterDescriptor filter(filter_descriptor, - static_cast<cudnnDataType_t>(cudnn_data_type)); - CudnnTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT); - CudnnConvolutionDescriptor conv( - convolution_descriptor, static_cast<cudnnDataType_t>(cudnn_compute_type)); + GetCudnnDataType<ElementType>()); + CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>()); + CudnnConvolutionDescriptor conv(convolution_descriptor, + GetCudnnDataType<AccumulatorType>()); auto cudnn = cudnn_->GetHandle(parent_, stream); @@ -2933,8 +2943,7 @@ bool CudnnSupport::DoFusedConvolve( const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return IsStatusOk( - DoFusedConvolveImpl<double, double, double, CUDNN_DATA_DOUBLE, - CUDNN_DATA_DOUBLE>( + DoFusedConvolveImpl<double>( stream, conv_input_descriptor, conv_input_data, conv_input_scale, filter_descriptor, filter_data, convolution_descriptor, side_input_data, side_input_scale, bias_descriptor, biases, @@ -2957,8 +2966,7 @@ bool CudnnSupport::DoFusedConvolve( const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return IsStatusOk( - DoFusedConvolveImpl<float, float, float, CUDNN_DATA_FLOAT, - CUDNN_DATA_FLOAT>( + DoFusedConvolveImpl<float>( stream, conv_input_descriptor, conv_input_data, conv_input_scale, filter_descriptor, filter_data, convolution_descriptor, side_input_data, side_input_scale, bias_descriptor, biases, @@ -2982,8 +2990,7 @@ bool CudnnSupport::DoFusedConvolve( const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return IsStatusOk( - DoFusedConvolveImpl<Eigen::half, Eigen::half, float, CUDNN_DATA_HALF, - CUDNN_DATA_FLOAT>( + DoFusedConvolveImpl<float>( stream, conv_input_descriptor, conv_input_data, conv_input_scale, filter_descriptor, filter_data, convolution_descriptor, side_input_data, side_input_scale, bias_descriptor, biases, @@ -3014,8 +3021,7 @@ bool CudnnSupport::DoFusedConvolve( return false; } return IsStatusOk( - DoFusedConvolveImpl<int8, float, float, CUDNN_DATA_INT8x4, - CUDNN_DATA_INT32>( + DoFusedConvolveImpl<int32>( stream, conv_input_descriptor, conv_input_data, conv_input_scale, filter_descriptor, filter_data, convolution_descriptor, side_input_data, side_input_scale, bias_descriptor, biases, diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 9d88f971bb..74f6f935b8 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -674,19 +674,21 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result); - template <typename Type, typename BiasType, typename ScaleType, - int cudnn_data_type, int cudnn_compute_type> + template <typename AccumulatorType, typename ElementType, typename BiasType, + typename ScaleType> port::Status DoFusedConvolveImpl( Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor, - const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale, + const DeviceMemory<ElementType>& conv_input_data, + ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor, - const DeviceMemory<Type>& filter_data, + const DeviceMemory<ElementType>& filter_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale, - const dnn::BatchDescriptor& bias_descriptor, + const DeviceMemory<ElementType>& side_input_data, + ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor, const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode, const dnn::BatchDescriptor& output_descriptor, - DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator, + DeviceMemory<ElementType>* output_data, + ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result); diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 19d3b2389a..69558fd14b 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -587,6 +587,44 @@ Stream &Stream::ThenConvolveWithScratch( Stream &Stream::ThenFusedConvolveWithAlgorithm( const dnn::BatchDescriptor &conv_input_descriptor, + const DeviceMemory<double> &conv_input_data, double conv_input_scale, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory<double> &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory<double> &side_input_data, double side_input_scale, + const dnn::BatchDescriptor &bias_descriptor, + const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output, + ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data), + PARAM(conv_input_scale), PARAM(filter_descriptor), + PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases), + PARAM(side_input_data), PARAM(side_input_scale), + PARAM(activation_mode), PARAM(output_descriptor), PARAM(output), + PARAM(algorithm_config)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoFusedConvolve( + this, conv_input_descriptor, conv_input_data, conv_input_scale, + filter_descriptor, filter_data, convolution_descriptor, + side_input_data, side_input_scale, bias_descriptor, biases, + activation_mode, output_descriptor, output, scratch_allocator, + algorithm_config, output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenFusedConvolveWithAlgorithm( + const dnn::BatchDescriptor &conv_input_descriptor, const DeviceMemory<float> &conv_input_data, float conv_input_scale, const dnn::FilterDescriptor &filter_descriptor, const DeviceMemory<float> &filter_data, |