aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc51
1 files changed, 51 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 76cbf0b1b6..a393b07703 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -966,6 +966,30 @@ Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
Stream &Stream::ThenPoolForward(
const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
+ const DeviceMemory<double> &input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<double> *output_data) {
+ VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
+ PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
+ input_data, output_dimensions,
+ output_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenPoolForward(
+ const dnn::PoolingDescriptor &pooling_dimensions,
+ const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data) {
@@ -1008,6 +1032,33 @@ Stream &Stream::ThenPoolForward(
Stream &Stream::ThenPoolBackward(
const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
+ const DeviceMemory<double> &input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ const DeviceMemory<double> &output_data,
+ const DeviceMemory<double> &input_diff_data,
+ DeviceMemory<double> *output_diff_data) {
+ VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
+ PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
+ PARAM(input_diff_data), PARAM(output_diff_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
+ input_data, output_dimensions, output_data,
+ input_diff_data, output_diff_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenPoolBackward(
+ const dnn::PoolingDescriptor &pooling_dimensions,
+ const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<float> &output_data,