diff options
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 76 |
1 files changed, 74 insertions, 2 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index b02df02c90..5b07f13037 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -403,7 +403,43 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch( CheckError(dnn->DoConvolveBackwardData( this, filter_descriptor, filter_data, output_descriptor, backward_output_data, convolution_descriptor, input_descriptor, - backward_input_data, scratch_allocator)); + backward_input_data, scratch_allocator, dnn::kDefaultAlgorithm, + nullptr)); + } else { + SetError(); + LOG(WARNING) + << "attempting to perform DNN operation using StreamExecutor " + "without DNN support"; + } + } + return *this; +} + +Stream &Stream::ThenConvolveBackwardDataWithAlgorithm( + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory<float> &filter_data, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory<float> backward_output_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::BatchDescriptor &input_descriptor, + DeviceMemory<float> *backward_input_data, + ScratchAllocator *scratch_allocator, dnn::AlgorithmType algorithm, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data), + PARAM(output_descriptor), PARAM(backward_output_data), + PARAM(convolution_descriptor), PARAM(input_descriptor), + PARAM(backward_input_data)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoConvolveBackwardData( + this, filter_descriptor, filter_data, output_descriptor, + backward_output_data, convolution_descriptor, input_descriptor, + backward_input_data, scratch_allocator, algorithm, + output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } } else { SetError(); LOG(WARNING) @@ -447,7 +483,43 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch( CheckError(dnn->DoConvolveBackwardFilter( this, input_descriptor, input_data, output_descriptor, backward_output_data, convolution_descriptor, filter_descriptor, - backward_filter_data, scratch_allocator)); + backward_filter_data, scratch_allocator, dnn::kDefaultAlgorithm, + nullptr)); + } else { + SetError(); + LOG(WARNING) + << "attempting to perform DNN operation using StreamExecutor " + "without DNN support"; + } + } + return *this; +} + +Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory<float> &input_data, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory<float> backward_output_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::FilterDescriptor &filter_descriptor, + DeviceMemory<float> *backward_filter_data, + ScratchAllocator *scratch_allocator, dnn::AlgorithmType algorithm, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(output_descriptor), PARAM(backward_output_data), + PARAM(convolution_descriptor), PARAM(filter_descriptor), + PARAM(backward_filter_data)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoConvolveBackwardFilter( + this, input_descriptor, input_data, output_descriptor, + backward_output_data, convolution_descriptor, filter_descriptor, + backward_filter_data, scratch_allocator, algorithm, + output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } } else { SetError(); LOG(WARNING) |