diff options
author | 2016-08-19 10:01:32 -0800 | |
---|---|---|
committer | 2016-08-19 11:18:09 -0700 | |
commit | 859e47fd8ebcdd7fb1411fd0090d0e95a801e7cb (patch) | |
tree | ae0c4e4e690d53b92720049fb75acc119e4ea5e0 /tensorflow/stream_executor/dnn.h | |
parent | de8838042fb34e53a511f25b4613611fc368beeb (diff) |
Add stream-executor changes to enable Cudnn fused LSTM/RNN support.
Change: 130770287
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r-- | tensorflow/stream_executor/dnn.h | 264 |
1 files changed, 264 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index c2310c8938..1c31178526 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -23,10 +23,12 @@ limitations under the License. #define TENSORFLOW_STREAM_EXECUTOR_DNN_H_ #include <limits> +#include <memory> #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/lib/array_slice.h" #include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" @@ -76,6 +78,94 @@ enum class QuantizedActivationMode { k32Bit = 4, }; +// Specifies the data type used by an operation. +enum class DataType { + kFloat = 0, + kDouble = 1, + kHalf = 2, +}; + +// A helper class to convert C/C++ types to the proper enums. +template <typename T> +struct ToDataType; +template <> +struct ToDataType<float> { + static constexpr DataType value = DataType::kFloat; +}; +template <> +struct ToDataType<double> { + static constexpr DataType value = DataType::kDouble; +}; +template <> +struct ToDataType<Eigen::half> { + static constexpr DataType value = DataType::kHalf; +}; + +// Specifies the types of a RNN model. +enum class RnnMode { + kRnnRelu = 0, + kRnnTanh = 1, + kRnnLstm = 2, + kRnnGru = 3, +}; + +// Specifies the input model and whether there is a linear transformation +// between the input state and the first layer hidden state. +enum class RnnInputMode { + kRnnLinearSkip = 0, + kRnnSkipInput = 1, +}; + +// Specifies the number of directions used in a RNN model. When bidirection +// is used, the input states and output sequence contain data for both +// directions. +enum class RnnDirectionMode { + kRnnUnidirectional = 0, + kRnnBidirectional = 1, +}; + +// Specifies the descriptor for a RNN model. +// +// An example use case: +// * The user first creates a model through createRnnDescriptor. +// * The user queries the size of the underlying opaque parameter buffer. +// * The user creates and initializes a parameter buffer of the proper size. +// * The user runs forward and backward operations using this RNN descriptor. +// * Once a while, user queries maintainable weights and bias regions from +// the underlying parameter buffer. They are more likely to be forward +// compatible and should used in saving and restoring a model. +// * The user releases the RNN descriptor when the model is no longer in use. +class RnnDescriptor { + public: + struct ParamsRegion { + int64 offset; + int64 size; + }; + typedef std::vector<ParamsRegion> ParamsRegions; + virtual ~RnnDescriptor() {} + virtual int64 ParamsSizeInBytes() const { return -1; } + virtual ParamsRegions ParamsWeightRegions() const { return ParamsRegions(); } + virtual ParamsRegions ParamsBiasRegions() const { return ParamsRegions(); } +}; + +// Specifies the sequence in a RNN model. +// +// The user is responsible for releasing this descriptor when it is no longer +// in use. The destructor releases the underlying descriptors. +class RnnSequenceTensorDescriptor { + public: + virtual ~RnnSequenceTensorDescriptor() {} +}; + +// Specifies either the input and hidden state in a RNN model. +// +// The user is responsible for releasing this descriptor when it is no longer +// in use. The destructor releases the underlying descriptors. +class RnnStateTensorDescriptor { + public: + virtual ~RnnStateTensorDescriptor() {} +}; + // Returns a string representation of the given quantization mode. string QuantizedActivationModeString(QuantizedActivationMode mode); @@ -1260,6 +1350,179 @@ class DnnSupport { QuantizedActivationMode mode, DeviceMemory<float>* gpu_unquantized_dst) = 0; + + // Create an RNN descriptor based on model shapes and configurations. + // The caller retains the ownership of the descriptor. + // + // Arguments: + // num_layers: the number of layers for a RNN model. + // hidden_size: the size of the hidden state. + // input_size: the size of the input state. + // input_mode: an enum to specify whether a linear transformation is added + // after the input state. If input_size is different from hidden_size, this + // is required. + // direction_mode: an enum to specify whether this model is unidirectional or + // bidirectional. + // rnn_mode: an enum to specify the type of model to build. + // data_type: an enum to specify the data types used in this model. + // dropout: the dropout threshold between layers. When it is 0., no dropout + // is added. + // seed: a seed for initializing the dropout layers. + // state_allocator: an memory allocator that will be used to store the state + // for dropout layer. The user has to maintain the memory until the model + // is no longer in use. + virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> + createRnnDescriptor(int num_layers, int hidden_size, int input_size, + dnn::RnnInputMode input_mode, + dnn::RnnDirectionMode direction_mode, + dnn::RnnMode rnn_mode, dnn::DataType data_type, + float dropout, uint64 seed, + ScratchAllocator* state_allocator) { + return port::Status{port::error::UNIMPLEMENTED, + "createRnnDescriptor is unimplemented"}; + } + + // Create a RNN sequence descriptor that specifies either the input or output + // sequence. The caller retains the ownership of the returned descriptor. + // + // Arguments: + // seq_length: the length of the sequence. + // batch_size: the size of a minibatch. + // data_size: the size of the state. + // data_type: an enum to specify the type for the underlying data. + virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> + createRnnSequenceTensorDescriptor(int seq_length, int batch_size, + int data_size, dnn::DataType data_type) { + return port::Status{port::error::UNIMPLEMENTED, + "createRnnSequenceTensorDescriptor is unimplemented"}; + } + + // Create an RNN state descriptor that specifies the input or hidden state. + // The caller retains the ownership of the returned descriptor. + virtual port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>> + createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, + dnn::DataType data_type) { + return port::Status{port::error::UNIMPLEMENTED, + "createRnnStateTensorDescriptor is unimplemented"}; + } + + // Enqueue a forward operation of the RNN model onto the stream. + // + // Arguments: + // stream: pointer to the stream where this operation should be enqueued to. + // rnn_desc: a RNN descriptor created by createRnnDescriptor. + // input_desc: descriptor for the input sequence. + // input_data: the device memory region that contains the input data. + // input_h_desc: descriptor for the input "h" state. + // input_h_data: the device memory region that contains the input "h" data. + // input_c_desc: descriptor for the input "c" state. + // input_c_data: the device memory region that contains the input "c" data. + // This must be specified for LSTM models. + // params: the device memory region that contains the parameters used in this + // model. + // output_desc: descriptor for the output sequence. + // output_data: the memory region that stores the output sequence data. + // output_h_desc: descriptor for the output "h" state. + // output_h_data: the memory region that stores the output "h" data. + // output_c_desc: descriptor for the output "c" state. + // output_c_data: the memory region that stores the outptu "c" data. This + // must be specified for LSTM models. + // is_training: whether this is used in training or inference. That decides + // whether respace_space data need to be produced. + // reserve_space_allocator: if "is_training" is true, an memory allocator + // to create memory that holds the produced reserve_space. The caller is + // retains the data and feed it to the backward pass. + // workspace_allocator: an allocator to create temporary workspace used in + // this kernel. The caller is responsible for retaining the memory long + // enough for the lifespan of this operation, and recycles aftewards. + virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, + const dnn::RnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<float>& input_data, + const dnn::RnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<float>& input_h_data, + const dnn::RnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<float>& input_c_data, + const DeviceMemory<float>& params, + const dnn::RnnSequenceTensorDescriptor& output_desc, + DeviceMemory<float>* output_data, + const dnn::RnnStateTensorDescriptor& output_h_desc, + DeviceMemory<float>* output_h_data, + const dnn::RnnStateTensorDescriptor& output_c_desc, + DeviceMemory<float>* output_c_data, + bool is_training, + ScratchAllocator* reserve_space_allocator, + ScratchAllocator* workspace_allocator) { + return false; + } + + // Enqueue a backward operation of the RNN model onto the stream. + // + // Arguments: + // stream: pointer to the stream where this operation should be enqueued to. + // rnn_desc: a RNN descriptor created by createRnnDescriptor. + // input_desc: descriptor for the input sequence. + // input_data: the device memory region that contains the input data. + // input_h_desc: descriptor for the input "h" state. + // input_h_data: the device memory region that contains the input "h" data. + // input_c_desc: descriptor for the input "c" state. + // input_c_data: the device memory region that contains the input "c" data. + // This must be specified for LSTM models. + // params: the device memory region that contains the parameters used in this + // model. + // output_desc: descriptor for the output sequence. + // output_data: the memory region that stores the output sequence data. + // output_h_desc: descriptor for the output "h" state. + // output_h_data: the memory region that stores the output "h" data. + // output_c_desc: descriptor for the output "c" state. + // output_c_data: the memory region that stores the outptu "c" data. This + // must be specified for LSTM models. + // output_backprop_data: the device memory region that contains the backprop + // to the output sequence. + // output_h_backprop_data: the device memory region that contains the + // backprop to the output "h" state. + // output_c_backprop_data: the device memory region that contains the + // backprop to the output "c" state. + // input_backprop_data: the device memory region that stores the backprop + // to the input sequence. + // input_h_backprop_data: the device memory region that stores the backprop + // to the input "h" state. + // input_c_backprop_data: the device memory region that stores the backprop + // to the input "c" state. + // params_backprop_data: the device memory region that stores the backprop + // to the parameters. + // reserve_space_data: the reserve_space data that is produced by the forward + // operation. This memory region could be modified by this operation. + // workspace_allocator: a memory allocator that creates the temporary + // workspace memory used by this operation. The caller is responsible for + // keeping the memory alive long enough for this operation, and recylces + // afterwards. + virtual bool DoRnnBackward( + Stream* stream, const dnn::RnnDescriptor& rnn_desc, + const dnn::RnnSequenceTensorDescriptor& input_desc, + const DeviceMemory<float>& input_data, + const dnn::RnnStateTensorDescriptor& input_h_desc, + const DeviceMemory<float>& input_h_data, + const dnn::RnnStateTensorDescriptor& input_c_desc, + const DeviceMemory<float>& input_c_data, + const DeviceMemory<float>& params, + const dnn::RnnSequenceTensorDescriptor& output_desc, + const DeviceMemory<float>& output_data, + const dnn::RnnStateTensorDescriptor& output_h_desc, + const DeviceMemory<float>& output_h_data, + const dnn::RnnStateTensorDescriptor& output_c_desc, + const DeviceMemory<float>& output_c_data, + const DeviceMemory<float>& output_backprop_data, + const DeviceMemory<float>& output_h_backprop_data, + const DeviceMemory<float>& output_c_backprop_data, + DeviceMemory<float>* input_backprop_data, + DeviceMemory<float>* input_h_backprop_data, + DeviceMemory<float>* input_c_backprop_data, + DeviceMemory<float>* params_backprop_data, + DeviceMemory<uint8>* reserve_space_data, + ScratchAllocator* workspace_allocator) { + return false; + } + private: SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport); }; @@ -1269,3 +1532,4 @@ class DnnSupport { } // namespace perftools #endif // TENSORFLOW_STREAM_EXECUTOR_DNN_H_ + |