aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-06-01 17:50:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-01 17:54:26 -0700
commit69075f3546dfc29dbef8b7c5d990f3af094cbd5f (patch)
tree2494878da9ce92431152d74419a1f984ed197d62 /tensorflow/stream_executor/stream.cc
parent7d7a40309693f01359537dce97fd6ff82e19755d (diff)
Add functional support for cudnnConvolutionBiasActivationForward().
PiperOrigin-RevId: 157788425
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc166
1 files changed, 151 insertions, 15 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index a393b07703..bb586c5848 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -350,11 +350,67 @@ Stream &Stream::ThenConvolveWithScratch(
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<Eigen::half> &filter_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const DeviceMemory<Eigen::half> &biases,
+ dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<Eigen::half> *output,
+ DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) {
+ VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
+ PARAM(filter_descriptor), PARAM(filter_data),
+ PARAM(convolution_descriptor), PARAM(biases),
+ PARAM(activation_mode), PARAM(output_descriptor), PARAM(output));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoConvolve(
+ this, input_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, biases, activation_mode, output_descriptor,
+ output, scratch_allocator, dnn::AlgorithmConfig(),
+ /*output_profile_result=*/nullptr));
+ } else {
+ SetErrorAndLogNoDnnSupport();
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenConvolveWithScratch(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
+ const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
ScratchAllocator *scratch_allocator) {
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
PARAM(filter_descriptor), PARAM(filter_data),
+ PARAM(convolution_descriptor), PARAM(biases),
+ PARAM(activation_mode), PARAM(output_descriptor), PARAM(output));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoConvolve(
+ this, input_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, biases, activation_mode, output_descriptor,
+ output, scratch_allocator, dnn::AlgorithmConfig(),
+ /*output_profile_result=*/nullptr));
+ } else {
+ SetErrorAndLogNoDnnSupport();
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenConvolveWithScratch(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<Eigen::half> &input_data,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<Eigen::half> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) {
+ VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
+ PARAM(filter_descriptor), PARAM(filter_data),
PARAM(convolution_descriptor), PARAM(output_descriptor),
PARAM(output));
@@ -362,9 +418,9 @@ Stream &Stream::ThenConvolveWithScratch(
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoConvolve(
this, input_descriptor, input_data, filter_descriptor, filter_data,
- convolution_descriptor, output_descriptor, output,
- /*scratch_allocator=*/scratch_allocator, dnn::AlgorithmConfig(),
- nullptr));
+ convolution_descriptor, output_descriptor, output, scratch_allocator,
+ dnn::AlgorithmConfig(),
+ /*output_profile_result=*/nullptr));
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -389,9 +445,74 @@ Stream &Stream::ThenConvolveWithScratch(
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoConvolve(
this, input_descriptor, input_data, filter_descriptor, filter_data,
- convolution_descriptor, output_descriptor, output,
- /*scratch_allocator=*/scratch_allocator, dnn::AlgorithmConfig(),
- nullptr));
+ convolution_descriptor, output_descriptor, output, scratch_allocator,
+ dnn::AlgorithmConfig(),
+ /*output_profile_result=*/nullptr));
+ } else {
+ SetErrorAndLogNoDnnSupport();
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenConvolveWithAlgorithm(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
+ const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
+ ScratchAllocator *scratch_allocator,
+ const dnn::AlgorithmConfig &algorithm_config,
+ dnn::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
+ PARAM(filter_descriptor), PARAM(filter_data),
+ PARAM(convolution_descriptor), PARAM(biases),
+ PARAM(activation_mode), PARAM(output_descriptor), PARAM(output));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ auto status = dnn->DoConvolve(
+ this, input_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, biases, activation_mode, output_descriptor,
+ output, scratch_allocator, algorithm_config, output_profile_result);
+ if (!status && !output_profile_result) {
+ SetError();
+ }
+ } else {
+ SetErrorAndLogNoDnnSupport();
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenConvolveWithAlgorithm(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<Eigen::half> &input_data,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<Eigen::half> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const DeviceMemory<Eigen::half> &biases,
+ dnn::ActivationMode activation_mode,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
+ const dnn::AlgorithmConfig &algorithm_config,
+ dnn::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
+ PARAM(filter_descriptor), PARAM(filter_data),
+ PARAM(convolution_descriptor), PARAM(biases),
+ PARAM(activation_mode), PARAM(output_descriptor), PARAM(output));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ auto status = dnn->DoConvolve(
+ this, input_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, biases, activation_mode, output_descriptor,
+ output, scratch_allocator, algorithm_config, output_profile_result);
+ if (!status && !output_profile_result) {
+ SetError();
+ }
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -467,6 +588,21 @@ Stream &Stream::ThenConvolve(
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<float> &filter_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> *output) {
+ return ThenConvolveWithScratch(
+ input_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, biases, activation_mode, output_descriptor,
+ output, /*scratch_allocator=*/nullptr);
+}
+
+Stream &Stream::ThenConvolve(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> *output) {
return ThenConvolveWithScratch(input_descriptor, input_data,
@@ -582,7 +718,7 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch(
this, filter_descriptor, filter_data, output_descriptor,
backward_output_data, convolution_descriptor, input_descriptor,
backward_input_data, scratch_allocator, dnn::AlgorithmConfig(),
- nullptr));
+ /*output_profile_result=*/nullptr));
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -676,7 +812,7 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch(
this, filter_descriptor, filter_data, output_descriptor,
backward_output_data, convolution_descriptor, input_descriptor,
backward_input_data, scratch_allocator, dnn::AlgorithmConfig(),
- nullptr));
+ /*output_profile_result=*/nullptr));
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -718,7 +854,7 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch(
this, input_descriptor, input_data, output_descriptor,
backward_output_data, convolution_descriptor, filter_descriptor,
backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(),
- nullptr));
+ /*output_profile_result=*/nullptr));
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -779,7 +915,7 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch(
this, input_descriptor, input_data, output_descriptor,
backward_output_data, convolution_descriptor, filter_descriptor,
backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(),
- nullptr));
+ /*output_profile_result=*/nullptr));
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -3868,7 +4004,7 @@ Stream &Stream::ThenBlasGemmBatched(
int batch_count) {
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
- nullptr);
+ /*scratch_allocator=*/nullptr);
}
Stream &Stream::ThenBlasGemmBatchedWithScratch(
@@ -3900,7 +4036,7 @@ Stream &Stream::ThenBlasGemmBatched(
int batch_count) {
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
- nullptr);
+ /*scratch_allocator=*/nullptr);
}
Stream &Stream::ThenBlasGemmBatchedWithScratch(
@@ -3934,7 +4070,7 @@ Stream &Stream::ThenBlasGemmBatched(
int batch_count) {
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
- nullptr);
+ /*scratch_allocator=*/nullptr);
}
Stream &Stream::ThenBlasGemmBatchedWithScratch(
@@ -3973,7 +4109,7 @@ Stream &Stream::ThenBlasGemmBatched(
int batch_count) {
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
- nullptr);
+ /*scratch_allocator=*/nullptr);
}
Stream &Stream::ThenBlasGemmBatchedWithScratch(