aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc38
1 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 19d3b2389a..69558fd14b 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -587,6 +587,44 @@ Stream &Stream::ThenConvolveWithScratch(
Stream &Stream::ThenFusedConvolveWithAlgorithm(
const dnn::BatchDescriptor &conv_input_descriptor,
+ const DeviceMemory<double> &conv_input_data, double conv_input_scale,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<double> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const DeviceMemory<double> &side_input_data, double side_input_scale,
+ const dnn::BatchDescriptor &bias_descriptor,
+ const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
+ const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
+ ScratchAllocator *scratch_allocator,
+ const dnn::AlgorithmConfig &algorithm_config,
+ dnn::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
+ PARAM(conv_input_scale), PARAM(filter_descriptor),
+ PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
+ PARAM(side_input_data), PARAM(side_input_scale),
+ PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
+ PARAM(algorithm_config));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ auto status = dnn->DoFusedConvolve(
+ this, conv_input_descriptor, conv_input_data, conv_input_scale,
+ filter_descriptor, filter_data, convolution_descriptor,
+ side_input_data, side_input_scale, bias_descriptor, biases,
+ activation_mode, output_descriptor, output, scratch_allocator,
+ algorithm_config, output_profile_result);
+ if (!status && !output_profile_result) {
+ SetError();
+ }
+ } else {
+ SetErrorAndLogNoDnnSupport();
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenFusedConvolveWithAlgorithm(
+ const dnn::BatchDescriptor &conv_input_descriptor,
const DeviceMemory<float> &conv_input_data, float conv_input_scale,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<float> &filter_data,