aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.h
diff options
context:
space:
mode:
authorGravatar Brian Patton <bjp@google.com>2018-03-06 08:23:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-06 08:29:31 -0800
commita2ea23e91915fabd0e856f284d0af75a496a432a (patch)
treedf490e6e17622d2782dab5a9e58133047af1122c /tensorflow/stream_executor/dnn.h
parentf261257ab26802cf3cab7303a76db2fb729e1d01 (diff)
StreamExecutor support for float64 convolutions and backprop.
PiperOrigin-RevId: 188025477
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r--tensorflow/stream_executor/dnn.h28
1 files changed, 27 insertions, 1 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index aa88fe770f..b41536e638 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -1172,7 +1172,9 @@ class DnnSupport {
const DeviceMemory<double>& filter_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::BatchDescriptor& output_descriptor,
- DeviceMemory<double>* output_data) = 0;
+ DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
+ const dnn::AlgorithmConfig& algorithm_config,
+ dnn::ProfileResult* output_profile_result) = 0;
// Enqueues a half-precision convolution operation onto the stream.
// See DoConvolve above for argument details.
@@ -1275,6 +1277,18 @@ class DnnSupport {
virtual bool DoConvolveBackwardData(
Stream* stream, const FilterDescriptor& filter_descriptor,
+ const DeviceMemory<double>& filter_data,
+ const BatchDescriptor& output_descriptor,
+ DeviceMemory<double> backward_output_data,
+ const ConvolutionDescriptor& convolution_descriptor,
+ const BatchDescriptor& input_descriptor,
+ DeviceMemory<double>* backward_input_data,
+ ScratchAllocator* scratch_allocator,
+ const dnn::AlgorithmConfig& algorithm_config,
+ ProfileResult* output_profile_result) = 0;
+
+ virtual bool DoConvolveBackwardData(
+ Stream* stream, const FilterDescriptor& filter_descriptor,
const DeviceMemory<Eigen::half>& filter_data,
const BatchDescriptor& output_descriptor,
DeviceMemory<Eigen::half> backward_output_data,
@@ -1324,6 +1338,18 @@ class DnnSupport {
virtual bool DoConvolveBackwardFilter(
Stream* stream, const BatchDescriptor& input_descriptor,
+ const DeviceMemory<double>& input_data,
+ const BatchDescriptor& output_descriptor,
+ DeviceMemory<double> backward_output_data,
+ const ConvolutionDescriptor& convolution_descriptor,
+ const FilterDescriptor& filter_descriptor,
+ DeviceMemory<double>* backward_filter_data,
+ ScratchAllocator* scratch_allocator,
+ const dnn::AlgorithmConfig& algorithm_config,
+ ProfileResult* output_profile_result) = 0;
+
+ virtual bool DoConvolveBackwardFilter(
+ Stream* stream, const BatchDescriptor& input_descriptor,
const DeviceMemory<Eigen::half>& input_data,
const BatchDescriptor& output_descriptor,
DeviceMemory<Eigen::half> backward_output_data,