diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-16 14:20:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-16 14:20:48 -0700 |
commit | 703e18752e6616cf6bfec358329bb243f0346935 (patch) | |
tree | 7904f98a03f0bf1a328ac7445699d33cafb82a20 /tensorflow/stream_executor/stream.cc | |
parent | c1322043a853601ec9561157b23a5c86cdadc689 (diff) | |
parent | 456aaa2fdbf821296a31f5493955f4653ae119dd (diff) |
Merge pull request #20706 from ROCmSoftwarePlatform:upstream-staging-stream-executor-pooling-interface
PiperOrigin-RevId: 204805678
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 63 |
1 files changed, 39 insertions, 24 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 9369183133..ca1b8e28e6 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -1377,15 +1377,16 @@ Stream &Stream::ThenPoolForward( const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<double> &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory<double> *output_data) { + DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), - PARAM(input_data), PARAM(output_dimensions), PARAM(output_data)); + PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, - input_data, output_dimensions, - output_data)); + input_data, output_dimensions, output_data, + workspace_allocator)); } else { SetError(); LOG(WARNING) @@ -1401,15 +1402,16 @@ Stream &Stream::ThenPoolForward( const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<float> &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory<float> *output_data) { + DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), - PARAM(input_data), PARAM(output_dimensions), PARAM(output_data)); + PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, - input_data, output_dimensions, - output_data)); + input_data, output_dimensions, output_data, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } @@ -1422,15 +1424,17 @@ Stream &Stream::ThenPoolForward( const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<Eigen::half> &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory<Eigen::half> *output_data) { + DeviceMemory<Eigen::half> *output_data, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), - PARAM(input_data), PARAM(output_dimensions), PARAM(output_data)); + PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, - input_data, output_dimensions, - output_data)); + input_data, output_dimensions, output_data, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } @@ -1445,16 +1449,19 @@ Stream &Stream::ThenPoolBackward( const dnn::BatchDescriptor &output_dimensions, const DeviceMemory<double> &output_data, const DeviceMemory<double> &input_diff_data, - DeviceMemory<double> *output_diff_data) { + DeviceMemory<double> *output_diff_data, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), - PARAM(input_diff_data), PARAM(output_diff_data)); + PARAM(input_diff_data), PARAM(output_diff_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, input_data, output_dimensions, output_data, - input_diff_data, output_diff_data)); + input_diff_data, output_diff_data, + workspace_allocator)); } else { SetError(); LOG(WARNING) @@ -1472,16 +1479,19 @@ Stream &Stream::ThenPoolBackward( const dnn::BatchDescriptor &output_dimensions, const DeviceMemory<float> &output_data, const DeviceMemory<float> &input_diff_data, - DeviceMemory<float> *output_diff_data) { + DeviceMemory<float> *output_diff_data, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), - PARAM(input_diff_data), PARAM(output_diff_data)); + PARAM(input_diff_data), PARAM(output_diff_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, input_data, output_dimensions, output_data, - input_diff_data, output_diff_data)); + input_diff_data, output_diff_data, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } @@ -1496,16 +1506,19 @@ Stream &Stream::ThenPoolBackward( const dnn::BatchDescriptor &output_dimensions, const DeviceMemory<Eigen::half> &output_data, const DeviceMemory<Eigen::half> &input_diff_data, - DeviceMemory<Eigen::half> *output_diff_data) { + DeviceMemory<Eigen::half> *output_diff_data, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), - PARAM(input_diff_data), PARAM(output_diff_data)); + PARAM(input_diff_data), PARAM(output_diff_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, input_data, output_dimensions, output_data, - input_diff_data, output_diff_data)); + input_diff_data, output_diff_data, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } @@ -1552,16 +1565,18 @@ Stream &Stream::ThenNormalizeBackwardWithDimensions( const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data, const DeviceMemory<float> &normalized_data, const DeviceMemory<float> &normalized_variable_gradient, - DeviceMemory<float> *raw_variable_gradient) { + DeviceMemory<float> *raw_variable_gradient, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data), PARAM(normalized_data), PARAM(normalized_variable_gradient), - PARAM(raw_variable_gradient)); + PARAM(raw_variable_gradient), PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoNormalizeBackwardWithDimensions( this, normalize_descriptor, dimensions, raw_data, normalized_data, - normalized_variable_gradient, raw_variable_gradient)); + normalized_variable_gradient, raw_variable_gradient, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } |