diff options
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r-- | tensorflow/stream_executor/dnn.h | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index f53ec21530..c2310c8938 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -1081,6 +1081,43 @@ class DnnSupport { const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) = 0; + // Applies local response normalization to the values from input_data and + // writes the result to output_data. + // + // Similar to DoNormalize, but normalizes across feature maps and allows for + // specifying the dimensions of the tensor. + // + // See comments on NormalizeDescriptor for a description of local response + // normalization. + virtual bool DoNormalizeWithDimensions( + Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, + const dnn::BatchDescriptor& dimensions, + const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) { + return false; + } + + // Performs backpropagation for the normalization operation + // + // Given raw data, its corresponding normalized output, and a gradient of some + // unspecified function with respect to the normalized variables, computes the + // gradient of that unspecified function with respect to the raw variables. + // + // The normalized data input array is expected to match the output that would + // be obtained by running the raw data input array through the DoNormalize + // method above. + // + // See comments on NormalizeDescriptor for a description of local response + // normalization. + virtual bool DoNormalizeBackwardWithDimensions( + Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, + const dnn::BatchDescriptor& dimensions, + const DeviceMemory<float>& raw_data, + const DeviceMemory<float>& normalized_data, + const DeviceMemory<float>& normalized_variable_gradient, + DeviceMemory<float>* raw_variable_gradient) { + return false; + } + // Applies an activation function (see ActivationMode) to all of the values // held on the device in 'input_data', whose dimensions are described by // 'dimensions'. |