aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc76
1 files changed, 74 insertions, 2 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index b02df02c90..5b07f13037 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -403,7 +403,43 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch(
CheckError(dnn->DoConvolveBackwardData(
this, filter_descriptor, filter_data, output_descriptor,
backward_output_data, convolution_descriptor, input_descriptor,
- backward_input_data, scratch_allocator));
+ backward_input_data, scratch_allocator, dnn::kDefaultAlgorithm,
+ nullptr));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &input_descriptor,
+ DeviceMemory<float> *backward_input_data,
+ ScratchAllocator *scratch_allocator, dnn::AlgorithmType algorithm,
+ dnn::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
+ PARAM(output_descriptor), PARAM(backward_output_data),
+ PARAM(convolution_descriptor), PARAM(input_descriptor),
+ PARAM(backward_input_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ auto status = dnn->DoConvolveBackwardData(
+ this, filter_descriptor, filter_data, output_descriptor,
+ backward_output_data, convolution_descriptor, input_descriptor,
+ backward_input_data, scratch_allocator, algorithm,
+ output_profile_result);
+ if (!status && !output_profile_result) {
+ SetError();
+ }
} else {
SetError();
LOG(WARNING)
@@ -447,7 +483,43 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch(
CheckError(dnn->DoConvolveBackwardFilter(
this, input_descriptor, input_data, output_descriptor,
backward_output_data, convolution_descriptor, filter_descriptor,
- backward_filter_data, scratch_allocator));
+ backward_filter_data, scratch_allocator, dnn::kDefaultAlgorithm,
+ nullptr));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::FilterDescriptor &filter_descriptor,
+ DeviceMemory<float> *backward_filter_data,
+ ScratchAllocator *scratch_allocator, dnn::AlgorithmType algorithm,
+ dnn::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
+ PARAM(output_descriptor), PARAM(backward_output_data),
+ PARAM(convolution_descriptor), PARAM(filter_descriptor),
+ PARAM(backward_filter_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ auto status = dnn->DoConvolveBackwardFilter(
+ this, input_descriptor, input_data, output_descriptor,
+ backward_output_data, convolution_descriptor, filter_descriptor,
+ backward_filter_data, scratch_allocator, algorithm,
+ output_profile_result);
+ if (!status && !output_profile_result) {
+ SetError();
+ }
} else {
SetError();
LOG(WARNING)