aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
authorGravatar Brian Patton <bjp@google.com>2018-03-06 08:23:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-06 08:29:31 -0800
commita2ea23e91915fabd0e856f284d0af75a496a432a (patch)
treedf490e6e17622d2782dab5a9e58133047af1122c /tensorflow/stream_executor/stream.cc
parentf261257ab26802cf3cab7303a76db2fb729e1d01 (diff)
StreamExecutor support for float64 convolutions and backprop.
PiperOrigin-RevId: 188025477
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc97
1 files changed, 97 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index ba5001e273..4d852e6e5a 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -683,6 +683,37 @@ Stream &Stream::ThenFusedConvolveWithAlgorithm(
Stream &Stream::ThenConvolveWithAlgorithm(
const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<double> &input_data,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<double> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
+ ScratchAllocator *scratch_allocator,
+ const dnn::AlgorithmConfig &algorithm_config,
+ dnn::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
+ PARAM(filter_descriptor), PARAM(filter_data),
+ PARAM(convolution_descriptor), PARAM(output_descriptor),
+ PARAM(output), PARAM(algorithm_config));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ auto status = dnn->DoConvolve(
+ this, input_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, output_descriptor, output, scratch_allocator,
+ algorithm_config, output_profile_result);
+ if (!status && !output_profile_result) {
+ SetError();
+ }
+ } else {
+ SetErrorAndLogNoDnnSupport();
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenConvolveWithAlgorithm(
+ const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<float> &filter_data,
@@ -892,6 +923,39 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch(
Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<double> &filter_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<double> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &input_descriptor,
+ DeviceMemory<double> *backward_input_data,
+ ScratchAllocator *scratch_allocator,
+ const dnn::AlgorithmConfig &algorithm_config,
+ dnn::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
+ PARAM(output_descriptor), PARAM(backward_output_data),
+ PARAM(convolution_descriptor), PARAM(input_descriptor),
+ PARAM(backward_input_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ auto status = dnn->DoConvolveBackwardData(
+ this, filter_descriptor, filter_data, output_descriptor,
+ backward_output_data, convolution_descriptor, input_descriptor,
+ backward_input_data, scratch_allocator, algorithm_config,
+ output_profile_result);
+ if (!status && !output_profile_result) {
+ SetError();
+ }
+ } else {
+ SetErrorAndLogNoDnnSupport();
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
+ const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<float> &filter_data,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> backward_output_data,
@@ -1028,6 +1092,39 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch(
Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<double> &input_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<double> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::FilterDescriptor &filter_descriptor,
+ DeviceMemory<double> *backward_filter_data,
+ ScratchAllocator *scratch_allocator,
+ const dnn::AlgorithmConfig &algorithm_config,
+ dnn::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
+ PARAM(output_descriptor), PARAM(backward_output_data),
+ PARAM(convolution_descriptor), PARAM(filter_descriptor),
+ PARAM(backward_filter_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ auto status = dnn->DoConvolveBackwardFilter(
+ this, input_descriptor, input_data, output_descriptor,
+ backward_output_data, convolution_descriptor, filter_descriptor,
+ backward_filter_data, scratch_allocator, algorithm_config,
+ output_profile_result);
+ if (!status && !output_profile_result) {
+ SetError();
+ }
+ } else {
+ SetErrorAndLogNoDnnSupport();
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
+ const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> backward_output_data,