aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.h
diff options
context:
space:
mode:
authorGravatar Xiaoqiang Zheng <zhengxq@google.com>2016-08-19 10:01:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-19 11:18:09 -0700
commit859e47fd8ebcdd7fb1411fd0090d0e95a801e7cb (patch)
treeae0c4e4e690d53b92720049fb75acc119e4ea5e0 /tensorflow/stream_executor/dnn.h
parentde8838042fb34e53a511f25b4613611fc368beeb (diff)
Add stream-executor changes to enable Cudnn fused LSTM/RNN support.
Change: 130770287
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r--tensorflow/stream_executor/dnn.h264
1 files changed, 264 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index c2310c8938..1c31178526 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -23,10 +23,12 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_DNN_H_
#include <limits>
+#include <memory>
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/port.h"
@@ -76,6 +78,94 @@ enum class QuantizedActivationMode {
k32Bit = 4,
};
+// Specifies the data type used by an operation.
+enum class DataType {
+ kFloat = 0,
+ kDouble = 1,
+ kHalf = 2,
+};
+
+// A helper class to convert C/C++ types to the proper enums.
+template <typename T>
+struct ToDataType;
+template <>
+struct ToDataType<float> {
+ static constexpr DataType value = DataType::kFloat;
+};
+template <>
+struct ToDataType<double> {
+ static constexpr DataType value = DataType::kDouble;
+};
+template <>
+struct ToDataType<Eigen::half> {
+ static constexpr DataType value = DataType::kHalf;
+};
+
+// Specifies the types of a RNN model.
+enum class RnnMode {
+ kRnnRelu = 0,
+ kRnnTanh = 1,
+ kRnnLstm = 2,
+ kRnnGru = 3,
+};
+
+// Specifies the input model and whether there is a linear transformation
+// between the input state and the first layer hidden state.
+enum class RnnInputMode {
+ kRnnLinearSkip = 0,
+ kRnnSkipInput = 1,
+};
+
+// Specifies the number of directions used in a RNN model. When bidirection
+// is used, the input states and output sequence contain data for both
+// directions.
+enum class RnnDirectionMode {
+ kRnnUnidirectional = 0,
+ kRnnBidirectional = 1,
+};
+
+// Specifies the descriptor for a RNN model.
+//
+// An example use case:
+// * The user first creates a model through createRnnDescriptor.
+// * The user queries the size of the underlying opaque parameter buffer.
+// * The user creates and initializes a parameter buffer of the proper size.
+// * The user runs forward and backward operations using this RNN descriptor.
+// * Once a while, user queries maintainable weights and bias regions from
+// the underlying parameter buffer. They are more likely to be forward
+// compatible and should used in saving and restoring a model.
+// * The user releases the RNN descriptor when the model is no longer in use.
+class RnnDescriptor {
+ public:
+ struct ParamsRegion {
+ int64 offset;
+ int64 size;
+ };
+ typedef std::vector<ParamsRegion> ParamsRegions;
+ virtual ~RnnDescriptor() {}
+ virtual int64 ParamsSizeInBytes() const { return -1; }
+ virtual ParamsRegions ParamsWeightRegions() const { return ParamsRegions(); }
+ virtual ParamsRegions ParamsBiasRegions() const { return ParamsRegions(); }
+};
+
+// Specifies the sequence in a RNN model.
+//
+// The user is responsible for releasing this descriptor when it is no longer
+// in use. The destructor releases the underlying descriptors.
+class RnnSequenceTensorDescriptor {
+ public:
+ virtual ~RnnSequenceTensorDescriptor() {}
+};
+
+// Specifies either the input and hidden state in a RNN model.
+//
+// The user is responsible for releasing this descriptor when it is no longer
+// in use. The destructor releases the underlying descriptors.
+class RnnStateTensorDescriptor {
+ public:
+ virtual ~RnnStateTensorDescriptor() {}
+};
+
// Returns a string representation of the given quantization mode.
string QuantizedActivationModeString(QuantizedActivationMode mode);
@@ -1260,6 +1350,179 @@ class DnnSupport {
QuantizedActivationMode mode,
DeviceMemory<float>* gpu_unquantized_dst) = 0;
+
+ // Create an RNN descriptor based on model shapes and configurations.
+ // The caller retains the ownership of the descriptor.
+ //
+ // Arguments:
+ // num_layers: the number of layers for a RNN model.
+ // hidden_size: the size of the hidden state.
+ // input_size: the size of the input state.
+ // input_mode: an enum to specify whether a linear transformation is added
+ // after the input state. If input_size is different from hidden_size, this
+ // is required.
+ // direction_mode: an enum to specify whether this model is unidirectional or
+ // bidirectional.
+ // rnn_mode: an enum to specify the type of model to build.
+ // data_type: an enum to specify the data types used in this model.
+ // dropout: the dropout threshold between layers. When it is 0., no dropout
+ // is added.
+ // seed: a seed for initializing the dropout layers.
+ // state_allocator: an memory allocator that will be used to store the state
+ // for dropout layer. The user has to maintain the memory until the model
+ // is no longer in use.
+ virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
+ createRnnDescriptor(int num_layers, int hidden_size, int input_size,
+ dnn::RnnInputMode input_mode,
+ dnn::RnnDirectionMode direction_mode,
+ dnn::RnnMode rnn_mode, dnn::DataType data_type,
+ float dropout, uint64 seed,
+ ScratchAllocator* state_allocator) {
+ return port::Status{port::error::UNIMPLEMENTED,
+ "createRnnDescriptor is unimplemented"};
+ }
+
+ // Create a RNN sequence descriptor that specifies either the input or output
+ // sequence. The caller retains the ownership of the returned descriptor.
+ //
+ // Arguments:
+ // seq_length: the length of the sequence.
+ // batch_size: the size of a minibatch.
+ // data_size: the size of the state.
+ // data_type: an enum to specify the type for the underlying data.
+ virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
+ createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
+ int data_size, dnn::DataType data_type) {
+ return port::Status{port::error::UNIMPLEMENTED,
+ "createRnnSequenceTensorDescriptor is unimplemented"};
+ }
+
+ // Create an RNN state descriptor that specifies the input or hidden state.
+ // The caller retains the ownership of the returned descriptor.
+ virtual port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
+ createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
+ dnn::DataType data_type) {
+ return port::Status{port::error::UNIMPLEMENTED,
+ "createRnnStateTensorDescriptor is unimplemented"};
+ }
+
+ // Enqueue a forward operation of the RNN model onto the stream.
+ //
+ // Arguments:
+ // stream: pointer to the stream where this operation should be enqueued to.
+ // rnn_desc: a RNN descriptor created by createRnnDescriptor.
+ // input_desc: descriptor for the input sequence.
+ // input_data: the device memory region that contains the input data.
+ // input_h_desc: descriptor for the input "h" state.
+ // input_h_data: the device memory region that contains the input "h" data.
+ // input_c_desc: descriptor for the input "c" state.
+ // input_c_data: the device memory region that contains the input "c" data.
+ // This must be specified for LSTM models.
+ // params: the device memory region that contains the parameters used in this
+ // model.
+ // output_desc: descriptor for the output sequence.
+ // output_data: the memory region that stores the output sequence data.
+ // output_h_desc: descriptor for the output "h" state.
+ // output_h_data: the memory region that stores the output "h" data.
+ // output_c_desc: descriptor for the output "c" state.
+ // output_c_data: the memory region that stores the outptu "c" data. This
+ // must be specified for LSTM models.
+ // is_training: whether this is used in training or inference. That decides
+ // whether respace_space data need to be produced.
+ // reserve_space_allocator: if "is_training" is true, an memory allocator
+ // to create memory that holds the produced reserve_space. The caller is
+ // retains the data and feed it to the backward pass.
+ // workspace_allocator: an allocator to create temporary workspace used in
+ // this kernel. The caller is responsible for retaining the memory long
+ // enough for the lifespan of this operation, and recycles aftewards.
+ virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<float>& input_data,
+ const dnn::RnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<float>& input_h_data,
+ const dnn::RnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<float>& input_c_data,
+ const DeviceMemory<float>& params,
+ const dnn::RnnSequenceTensorDescriptor& output_desc,
+ DeviceMemory<float>* output_data,
+ const dnn::RnnStateTensorDescriptor& output_h_desc,
+ DeviceMemory<float>* output_h_data,
+ const dnn::RnnStateTensorDescriptor& output_c_desc,
+ DeviceMemory<float>* output_c_data,
+ bool is_training,
+ ScratchAllocator* reserve_space_allocator,
+ ScratchAllocator* workspace_allocator) {
+ return false;
+ }
+
+ // Enqueue a backward operation of the RNN model onto the stream.
+ //
+ // Arguments:
+ // stream: pointer to the stream where this operation should be enqueued to.
+ // rnn_desc: a RNN descriptor created by createRnnDescriptor.
+ // input_desc: descriptor for the input sequence.
+ // input_data: the device memory region that contains the input data.
+ // input_h_desc: descriptor for the input "h" state.
+ // input_h_data: the device memory region that contains the input "h" data.
+ // input_c_desc: descriptor for the input "c" state.
+ // input_c_data: the device memory region that contains the input "c" data.
+ // This must be specified for LSTM models.
+ // params: the device memory region that contains the parameters used in this
+ // model.
+ // output_desc: descriptor for the output sequence.
+ // output_data: the memory region that stores the output sequence data.
+ // output_h_desc: descriptor for the output "h" state.
+ // output_h_data: the memory region that stores the output "h" data.
+ // output_c_desc: descriptor for the output "c" state.
+ // output_c_data: the memory region that stores the outptu "c" data. This
+ // must be specified for LSTM models.
+ // output_backprop_data: the device memory region that contains the backprop
+ // to the output sequence.
+ // output_h_backprop_data: the device memory region that contains the
+ // backprop to the output "h" state.
+ // output_c_backprop_data: the device memory region that contains the
+ // backprop to the output "c" state.
+ // input_backprop_data: the device memory region that stores the backprop
+ // to the input sequence.
+ // input_h_backprop_data: the device memory region that stores the backprop
+ // to the input "h" state.
+ // input_c_backprop_data: the device memory region that stores the backprop
+ // to the input "c" state.
+ // params_backprop_data: the device memory region that stores the backprop
+ // to the parameters.
+ // reserve_space_data: the reserve_space data that is produced by the forward
+ // operation. This memory region could be modified by this operation.
+ // workspace_allocator: a memory allocator that creates the temporary
+ // workspace memory used by this operation. The caller is responsible for
+ // keeping the memory alive long enough for this operation, and recylces
+ // afterwards.
+ virtual bool DoRnnBackward(
+ Stream* stream, const dnn::RnnDescriptor& rnn_desc,
+ const dnn::RnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<float>& input_data,
+ const dnn::RnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<float>& input_h_data,
+ const dnn::RnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<float>& input_c_data,
+ const DeviceMemory<float>& params,
+ const dnn::RnnSequenceTensorDescriptor& output_desc,
+ const DeviceMemory<float>& output_data,
+ const dnn::RnnStateTensorDescriptor& output_h_desc,
+ const DeviceMemory<float>& output_h_data,
+ const dnn::RnnStateTensorDescriptor& output_c_desc,
+ const DeviceMemory<float>& output_c_data,
+ const DeviceMemory<float>& output_backprop_data,
+ const DeviceMemory<float>& output_h_backprop_data,
+ const DeviceMemory<float>& output_c_backprop_data,
+ DeviceMemory<float>* input_backprop_data,
+ DeviceMemory<float>* input_h_backprop_data,
+ DeviceMemory<float>* input_c_backprop_data,
+ DeviceMemory<float>* params_backprop_data,
+ DeviceMemory<uint8>* reserve_space_data,
+ ScratchAllocator* workspace_allocator) {
+ return false;
+ }
+
private:
SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
};
@@ -1269,3 +1532,4 @@ class DnnSupport {
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_DNN_H_
+