diff options
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_dnn.h')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_dnn.h | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index c924d41cb5..9d88f971bb 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -515,21 +515,24 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& input_dimensions, const DeviceMemory<double>& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory<double>* output_data) override; + DeviceMemory<double>* output_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolForward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, const DeviceMemory<float>& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory<float>* output_data) override; + DeviceMemory<float>* output_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolForward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, const DeviceMemory<Eigen::half>& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory<Eigen::half>* output_data) override; + DeviceMemory<Eigen::half>* output_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolBackward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, @@ -538,7 +541,8 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory<double>& output_data, const DeviceMemory<double>& input_diff_data, - DeviceMemory<double>* output_diff_data) override; + DeviceMemory<double>* output_diff_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolBackward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, @@ -547,7 +551,8 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory<float>& output_data, const DeviceMemory<float>& input_diff_data, - DeviceMemory<float>* output_diff_data) override; + DeviceMemory<float>* output_diff_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolBackward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, @@ -556,7 +561,8 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory<Eigen::half>& output_data, const DeviceMemory<Eigen::half>& input_diff_data, - DeviceMemory<Eigen::half>* output_diff_data) override; + DeviceMemory<Eigen::half>* output_diff_data, + ScratchAllocator* workspace_allocator) override; bool DoNormalize(Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, @@ -575,7 +581,8 @@ class CudnnSupport : public dnn::DnnSupport { const DeviceMemory<float>& raw_data, const DeviceMemory<float>& normalized_data, const DeviceMemory<float>& normalized_variable_gradient, - DeviceMemory<float>* raw_variable_gradient) override; + DeviceMemory<float>* raw_variable_gradient, + ScratchAllocator* workspace_allocator) override; bool DoDepthConcatenate( Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions, |