diff options
author | Brian Patton <bjp@google.com> | 2018-03-06 08:23:04 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-06 08:29:31 -0800 |
commit | a2ea23e91915fabd0e856f284d0af75a496a432a (patch) | |
tree | df490e6e17622d2782dab5a9e58133047af1122c /tensorflow/stream_executor/stream.cc | |
parent | f261257ab26802cf3cab7303a76db2fb729e1d01 (diff) |
StreamExecutor support for float64 convolutions and backprop.
PiperOrigin-RevId: 188025477
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index ba5001e273..4d852e6e5a 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -683,6 +683,37 @@ Stream &Stream::ThenFusedConvolveWithAlgorithm( Stream &Stream::ThenConvolveWithAlgorithm( const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory<double> &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory<double> &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output, + ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(filter_descriptor), PARAM(filter_data), + PARAM(convolution_descriptor), PARAM(output_descriptor), + PARAM(output), PARAM(algorithm_config)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoConvolve( + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, output_descriptor, output, scratch_allocator, + algorithm_config, output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, const DeviceMemory<float> &input_data, const dnn::FilterDescriptor &filter_descriptor, const DeviceMemory<float> &filter_data, @@ -892,6 +923,39 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch( Stream &Stream::ThenConvolveBackwardDataWithAlgorithm( const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory<double> &filter_data, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory<double> backward_output_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::BatchDescriptor &input_descriptor, + DeviceMemory<double> *backward_input_data, + ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data), + PARAM(output_descriptor), PARAM(backward_output_data), + PARAM(convolution_descriptor), PARAM(input_descriptor), + PARAM(backward_input_data)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoConvolveBackwardData( + this, filter_descriptor, filter_data, output_descriptor, + backward_output_data, convolution_descriptor, input_descriptor, + backward_input_data, scratch_allocator, algorithm_config, + output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveBackwardDataWithAlgorithm( + const dnn::FilterDescriptor &filter_descriptor, const DeviceMemory<float> &filter_data, const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> backward_output_data, @@ -1028,6 +1092,39 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch( Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm( const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory<double> &input_data, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory<double> backward_output_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::FilterDescriptor &filter_descriptor, + DeviceMemory<double> *backward_filter_data, + ScratchAllocator *scratch_allocator, + const dnn::AlgorithmConfig &algorithm_config, + dnn::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), + PARAM(output_descriptor), PARAM(backward_output_data), + PARAM(convolution_descriptor), PARAM(filter_descriptor), + PARAM(backward_filter_data)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoConvolveBackwardFilter( + this, input_descriptor, input_data, output_descriptor, + backward_output_data, convolution_descriptor, filter_descriptor, + backward_filter_data, scratch_allocator, algorithm_config, + output_profile_result); + if (!status && !output_profile_result) { + SetError(); + } + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + +Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm( + const dnn::BatchDescriptor &input_descriptor, const DeviceMemory<float> &input_data, const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> backward_output_data, |