diff options
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 19d3b2389a..69558fd14b 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -587,6 +587,44 @@ Stream &Stream::ThenConvolveWithScratch( Stream &Stream::ThenFusedConvolveWithAlgorithm( const dnn::BatchDescriptor &conv_input_descriptor, + const DeviceMemory<double> &conv_input_data, double conv_input_scale, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory<double> &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const DeviceMemory<double> &side_input_data, double side_input_scale, + const dnn::BatchDescriptor &bias_descriptor, + const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode, + const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output, + ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data), + PARAM(conv_input_scale), PARAM(filter_descriptor), + PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases), + PARAM(side_input_data), PARAM(side_input_scale), + PARAM(activation_mode), PARAM(output_descriptor), PARAM(output), + PARAM(algorithm_config)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoFusedConvolve( + this, conv_input_descriptor, conv_input_data, conv_input_scale, + filter_descriptor, filter_data, convolution_descriptor, + side_input_data, side_input_scale, bias_descriptor, biases, + activation_mode, output_descriptor, output, scratch_allocator, + algorithm_config, output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenFusedConvolveWithAlgorithm( + const dnn::BatchDescriptor &conv_input_descriptor, const DeviceMemory<float> &conv_input_data, float conv_input_scale, const dnn::FilterDescriptor &filter_descriptor, const DeviceMemory<float> &filter_data, |