diff options
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 76cbf0b1b6..a393b07703 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -966,6 +966,30 @@ Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data, Stream &Stream::ThenPoolForward( const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, + const DeviceMemory<double> &input_data, + const dnn::BatchDescriptor &output_dimensions, + DeviceMemory<double> *output_data) { + VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), + PARAM(input_data), PARAM(output_dimensions), PARAM(output_data)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, + input_data, output_dimensions, + output_data)); + } else { + SetError(); + LOG(WARNING) + << "attempting to perform DNN operation using StreamExecutor " + "without DNN support"; + } + } + return *this; +} + +Stream &Stream::ThenPoolForward( + const dnn::PoolingDescriptor &pooling_dimensions, + const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<float> &input_data, const dnn::BatchDescriptor &output_dimensions, DeviceMemory<float> *output_data) { @@ -1008,6 +1032,33 @@ Stream &Stream::ThenPoolForward( Stream &Stream::ThenPoolBackward( const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, + const DeviceMemory<double> &input_data, + const dnn::BatchDescriptor &output_dimensions, + const DeviceMemory<double> &output_data, + const DeviceMemory<double> &input_diff_data, + DeviceMemory<double> *output_diff_data) { + VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), + PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), + PARAM(input_diff_data), PARAM(output_diff_data)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, + input_data, output_dimensions, output_data, + input_diff_data, output_diff_data)); + } else { + SetError(); + LOG(WARNING) + << "attempting to perform DNN operation using StreamExecutor " + "without DNN support"; + } + } + return *this; +} + +Stream &Stream::ThenPoolBackward( + const dnn::PoolingDescriptor &pooling_dimensions, + const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<float> &input_data, const dnn::BatchDescriptor &output_dimensions, const DeviceMemory<float> &output_data, |