diff options
author | Yangzihao Wang <yangzihao@google.com> | 2017-06-01 17:50:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-01 17:54:26 -0700 |
commit | 69075f3546dfc29dbef8b7c5d990f3af094cbd5f (patch) | |
tree | 2494878da9ce92431152d74419a1f984ed197d62 /tensorflow/stream_executor/stream.cc | |
parent | 7d7a40309693f01359537dce97fd6ff82e19755d (diff) |
Add functional support for cudnnConvolutionBiasActivationForward().
PiperOrigin-RevId: 157788425
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 166 |
1 files changed, 151 insertions, 15 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index a393b07703..bb586c5848 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -350,11 +350,67 @@ Stream &Stream::ThenConvolveWithScratch( const dnn::FilterDescriptor &filter_descriptor, const DeviceMemory<Eigen::half> &filter_data, const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory<Eigen::half> &biases, + dnn::ActivationMode activation_mode, const dnn::BatchDescriptor &output_descriptor, - DeviceMemory<Eigen::half> *output, + DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(filter_descriptor), PARAM(filter_data), + PARAM(convolution_descriptor), PARAM(biases), + PARAM(activation_mode), PARAM(output_descriptor), PARAM(output)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoConvolve( + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, scratch_allocator, dnn::AlgorithmConfig(), + /*output_profile_result=*/nullptr)); + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveWithScratch( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory<float> &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory<float> &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output, ScratchAllocator *scratch_allocator) { VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(filter_descriptor), PARAM(filter_data), + PARAM(convolution_descriptor), PARAM(biases), + PARAM(activation_mode), PARAM(output_descriptor), PARAM(output)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoConvolve( + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, scratch_allocator, dnn::AlgorithmConfig(), + /*output_profile_result=*/nullptr)); + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveWithScratch( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory<Eigen::half> &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory<Eigen::half> &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(filter_descriptor), PARAM(filter_data), PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output)); @@ -362,9 +418,9 @@ Stream &Stream::ThenConvolveWithScratch( if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoConvolve( this, input_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, output, - /*scratch_allocator=*/scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + convolution_descriptor, output_descriptor, output, scratch_allocator, + dnn::AlgorithmConfig(), + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -389,9 +445,74 @@ Stream &Stream::ThenConvolveWithScratch( if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoConvolve( this, input_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, output, - /*scratch_allocator=*/scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + convolution_descriptor, output_descriptor, output, scratch_allocator, + dnn::AlgorithmConfig(), + /*output_profile_result=*/nullptr)); + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory<float> &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory<float> &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output, + ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(filter_descriptor), PARAM(filter_data), + PARAM(convolution_descriptor), PARAM(biases), + PARAM(activation_mode), PARAM(output_descriptor), PARAM(output)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoConvolve( + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, scratch_allocator, algorithm_config, output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory<Eigen::half> &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory<Eigen::half> &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory<Eigen::half> &biases, + dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(filter_descriptor), PARAM(filter_data), + PARAM(convolution_descriptor), PARAM(biases), + PARAM(activation_mode), PARAM(output_descriptor), PARAM(output)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoConvolve( + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, scratch_allocator, algorithm_config, output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } } else { SetErrorAndLogNoDnnSupport(); } @@ -467,6 +588,21 @@ Stream &Stream::ThenConvolve( const dnn::FilterDescriptor &filter_descriptor, const DeviceMemory<float> &filter_data, const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory<float> *output) { + return ThenConvolveWithScratch( + input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, biases, activation_mode, output_descriptor, + output, /*scratch_allocator=*/nullptr); +} + +Stream &Stream::ThenConvolve( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory<float> &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory<float> &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output) { return ThenConvolveWithScratch(input_descriptor, input_data, @@ -582,7 +718,7 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch( this, filter_descriptor, filter_data, output_descriptor, backward_output_data, convolution_descriptor, input_descriptor, backward_input_data, scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -676,7 +812,7 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch( this, filter_descriptor, filter_data, output_descriptor, backward_output_data, convolution_descriptor, input_descriptor, backward_input_data, scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -718,7 +854,7 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch( this, input_descriptor, input_data, output_descriptor, backward_output_data, convolution_descriptor, filter_descriptor, backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -779,7 +915,7 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch( this, input_descriptor, input_data, output_descriptor, backward_output_data, convolution_descriptor, filter_descriptor, backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(), - nullptr)); + /*output_profile_result=*/nullptr)); } else { SetErrorAndLogNoDnnSupport(); } @@ -3868,7 +4004,7 @@ Stream &Stream::ThenBlasGemmBatched( int batch_count) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - nullptr); + /*scratch_allocator=*/nullptr); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -3900,7 +4036,7 @@ Stream &Stream::ThenBlasGemmBatched( int batch_count) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - nullptr); + /*scratch_allocator=*/nullptr); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -3934,7 +4070,7 @@ Stream &Stream::ThenBlasGemmBatched( int batch_count) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - nullptr); + /*scratch_allocator=*/nullptr); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -3973,7 +4109,7 @@ Stream &Stream::ThenBlasGemmBatched( int batch_count) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - nullptr); + /*scratch_allocator=*/nullptr); } Stream &Stream::ThenBlasGemmBatchedWithScratch( |