aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-16 14:20:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-16 14:20:48 -0700
commit703e18752e6616cf6bfec358329bb243f0346935 (patch)
tree7904f98a03f0bf1a328ac7445699d33cafb82a20 /tensorflow/stream_executor/stream.cc
parentc1322043a853601ec9561157b23a5c86cdadc689 (diff)
parent456aaa2fdbf821296a31f5493955f4653ae119dd (diff)
Merge pull request #20706 from ROCmSoftwarePlatform:upstream-staging-stream-executor-pooling-interface
PiperOrigin-RevId: 204805678
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc63
1 files changed, 39 insertions, 24 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 9369183133..ca1b8e28e6 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -1377,15 +1377,16 @@ Stream &Stream::ThenPoolForward(
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<double> &input_data,
const dnn::BatchDescriptor &output_dimensions,
- DeviceMemory<double> *output_data) {
+ DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
- PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
+ PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
+ PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
- input_data, output_dimensions,
- output_data));
+ input_data, output_dimensions, output_data,
+ workspace_allocator));
} else {
SetError();
LOG(WARNING)
@@ -1401,15 +1402,16 @@ Stream &Stream::ThenPoolForward(
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &output_dimensions,
- DeviceMemory<float> *output_data) {
+ DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
- PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
+ PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
+ PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
- input_data, output_dimensions,
- output_data));
+ input_data, output_dimensions, output_data,
+ workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -1422,15 +1424,17 @@ Stream &Stream::ThenPoolForward(
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<Eigen::half> &input_data,
const dnn::BatchDescriptor &output_dimensions,
- DeviceMemory<Eigen::half> *output_data) {
+ DeviceMemory<Eigen::half> *output_data,
+ ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
- PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
+ PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
+ PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
- input_data, output_dimensions,
- output_data));
+ input_data, output_dimensions, output_data,
+ workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -1445,16 +1449,19 @@ Stream &Stream::ThenPoolBackward(
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<double> &output_data,
const DeviceMemory<double> &input_diff_data,
- DeviceMemory<double> *output_diff_data) {
+ DeviceMemory<double> *output_diff_data,
+ ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
- PARAM(input_diff_data), PARAM(output_diff_data));
+ PARAM(input_diff_data), PARAM(output_diff_data),
+ PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
- input_diff_data, output_diff_data));
+ input_diff_data, output_diff_data,
+ workspace_allocator));
} else {
SetError();
LOG(WARNING)
@@ -1472,16 +1479,19 @@ Stream &Stream::ThenPoolBackward(
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<float> &output_data,
const DeviceMemory<float> &input_diff_data,
- DeviceMemory<float> *output_diff_data) {
+ DeviceMemory<float> *output_diff_data,
+ ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
- PARAM(input_diff_data), PARAM(output_diff_data));
+ PARAM(input_diff_data), PARAM(output_diff_data),
+ PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
- input_diff_data, output_diff_data));
+ input_diff_data, output_diff_data,
+ workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -1496,16 +1506,19 @@ Stream &Stream::ThenPoolBackward(
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<Eigen::half> &output_data,
const DeviceMemory<Eigen::half> &input_diff_data,
- DeviceMemory<Eigen::half> *output_diff_data) {
+ DeviceMemory<Eigen::half> *output_diff_data,
+ ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
- PARAM(input_diff_data), PARAM(output_diff_data));
+ PARAM(input_diff_data), PARAM(output_diff_data),
+ PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
- input_diff_data, output_diff_data));
+ input_diff_data, output_diff_data,
+ workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -1552,16 +1565,18 @@ Stream &Stream::ThenNormalizeBackwardWithDimensions(
const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
const DeviceMemory<float> &normalized_data,
const DeviceMemory<float> &normalized_variable_gradient,
- DeviceMemory<float> *raw_variable_gradient) {
+ DeviceMemory<float> *raw_variable_gradient,
+ ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
PARAM(normalized_data), PARAM(normalized_variable_gradient),
- PARAM(raw_variable_gradient));
+ PARAM(raw_variable_gradient), PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoNormalizeBackwardWithDimensions(
this, normalize_descriptor, dimensions, raw_data, normalized_data,
- normalized_variable_gradient, raw_variable_gradient));
+ normalized_variable_gradient, raw_variable_gradient,
+ workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}