aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2017-08-30 17:36:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-30 17:39:43 -0700
commit4807158b3df7fdfd1a597c581e7ba9f894b1b256 (patch)
tree009b56fcb7c8751a6e00ab31b66c598fd812747c /tensorflow/contrib/cudnn_rnn
parentd22ed610a0c7395a0e3ea8349b0da0e8afa62fe1 (diff)
Add functional cudnn_rnn_ops
Will deprecate class-style cudnn_rnn_ops soon. As a transition plan, have the former depend on the latter for now. PiperOrigin-RevId: 167075637
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py555
1 files changed, 509 insertions, 46 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index 694bd507d9..bc4fd10cac 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -716,6 +716,482 @@ _cudnn_rnn_common_doc_string = """
"""
+def _check_direction(direction):
+ if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION):
+ raise ValueError("Invalid direction: %s, expect %s or %s" %
+ (direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION))
+
+
+def _check_rnn_mode(rnn_mode):
+ if rnn_mode not in (CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU):
+ raise ValueError("Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" %
+ (rnn_mode, CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH,
+ CUDNN_RNN_RELU))
+
+
+def _get_seed(seed):
+ seed, seed2 = random_seed.get_seed(seed)
+ if seed is None and seed2 is None:
+ seed, seed2 = 0, 0
+ return seed, seed2
+
+
+def _get_num_params(rnn_mode, num_layers, direction):
+ """Return num params for given Cudnn config."""
+ if rnn_mode == CUDNN_LSTM:
+ num_params_per_layer = 8
+ elif rnn_mode == CUDNN_GRU:
+ num_params_per_layer = 6
+ elif rnn_mode in (CUDNN_RNN_RELU, CUDNN_RNN_TANH):
+ num_params_per_layer = 2
+ else:
+ raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode)
+ num_params = num_layers * num_params_per_layer
+ if direction != CUDNN_RNN_UNIDIRECTION:
+ num_params *= 2
+ return num_params
+
+
+def _cudnn_rnn(inputs,
+ input_h,
+ input_c,
+ params,
+ is_training,
+ rnn_mode,
+ input_mode=CUDNN_INPUT_LINEAR_MODE,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dropout=0.,
+ seed=0,
+ name=None):
+ """Cudnn RNN.
+
+ Args:
+ inputs: the input sequence to the RNN model. A Tensor of shape [?,
+ batch_size, input_size].
+ input_h: the initial hidden state for h. A Tensor of shape [num_layers,
+ batch_size, num_units].
+ input_c: the initial hidden state for c. This is only relevant for LSTM.
+ A Tensor of the same shape as input_h.
+ params: the parameter buffer created for this model.
+ is_training: whether this operation will be used in training or inference
+ rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh').
+ input_mode: indicate whether there is a linear projection between the
+ input and the actual computation before the first layer. It could be
+ 'linear_input', 'skip_input' or 'auto_select'.
+ 'linear_input' (default) always applies a linear projection of input
+ onto RNN hidden state. (standard RNN behavior).
+ 'skip_input' is only allowed when input_size == num_units;
+ 'auto_select' implies 'skip_input' when input_size == num_units;
+ otherwise, it implies 'linear_input'.
+ direction: the direction model that the model operates. Could be either
+ 'unidirectional' or 'bidirectional'
+ dropout: whether to enable dropout. With it is 0, dropout is disabled.
+ seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
+ for behavior.
+ name: name of the operation.
+ Returns:
+ outputs, output_h, output_c
+ """
+ _check_rnn_mode(rnn_mode)
+ _check_direction(direction)
+ seed, seed2 = random_seed.get_seed(seed)
+ outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
+ input=inputs,
+ input_h=input_h,
+ input_c=input_c,
+ params=params,
+ is_training=is_training,
+ rnn_mode=rnn_mode,
+ input_mode=input_mode,
+ direction=direction,
+ dropout=dropout,
+ seed=seed,
+ seed2=seed2,
+ name=name)
+ return (outputs, output_h, output_c)
+
+
+def cudnn_lstm(inputs,
+ input_h,
+ input_c,
+ params,
+ is_training,
+ input_mode=CUDNN_INPUT_LINEAR_MODE,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dropout=0.,
+ seed=0,
+ name=None):
+ """Cudnn LSTM.
+
+ Args:
+ inputs: the input sequence to the RNN model. A Tensor of shape [?,
+ batch_size, input_size].
+ input_h: the initial hidden state for h. A Tensor of shape [num_layers,
+ batch_size, num_units].
+ input_c: the initial hidden state for c. This is only relevant for LSTM.
+ A Tensor of the same shape as input_h.
+ params: the parameter buffer created for this model.
+ is_training: whether this operation will be used in training or inference
+ input_mode: indicate whether there is a linear projection between the
+ input and the actual computation before the first layer. It could be
+ 'linear_input', 'skip_input' or 'auto_select'.
+ 'linear_input' (default) always applies a linear projection of input
+ onto RNN hidden state. (standard RNN behavior).
+ 'skip_input' is only allowed when input_size == num_units;
+ 'auto_select' implies 'skip_input' when input_size == num_units;
+ otherwise, it implies 'linear_input'.
+ direction: the direction model that the model operates. Could be either
+ 'unidirectional' or 'bidirectional'
+ dropout: whether to enable dropout. With it is 0, dropout is disabled.
+ seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
+ for behavior.
+ name: name of the operation.
+ Returns:
+ outputs, output_h, output_c
+ """
+ return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM,
+ input_mode, direction, dropout, seed, name)
+
+
+def _cudnn_rnn_no_input_c(inputs,
+ input_h,
+ params,
+ is_training,
+ rnn_mode,
+ input_mode=CUDNN_INPUT_LINEAR_MODE,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dropout=0.,
+ seed=0,
+ name=None):
+ """Cudnn RNN w/o input_c.
+
+ Args:
+ inputs: the input sequence to the RNN model. A Tensor of shape [?,
+ batch_size, input_size].
+ input_h: the initial hidden state for h. A Tensor of shape [num_layers,
+ batch_size, num_units].
+ params: the parameter buffer created for this model.
+ is_training: whether this operation will be used in training or inference
+ rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh').
+ input_mode: indicate whether there is a linear projection between the
+ input and the actual computation before the first layer. It could be
+ 'linear_input', 'skip_input' or 'auto_select'.
+ 'linear_input' (default) always applies a linear projection of input
+ onto RNN hidden state. (standard RNN behavior).
+ 'skip_input' is only allowed when input_size == num_units;
+ 'auto_select' implies 'skip_input' when input_size == num_units;
+ otherwise, it implies 'linear_input'.
+ direction: the direction model that the model operates. Could be either
+ 'unidirectional' or 'bidirectional'
+ dropout: whether to enable dropout. With it is 0, dropout is disabled.
+ seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
+ for behavior.
+ name: name of the operation.
+ Returns:
+ outputs, output_h
+ """
+ input_c = array_ops.constant([], dtype=input_h.dtype)
+ outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params,
+ is_training, rnn_mode, input_mode,
+ direction, dropout, seed, name)
+ return outputs, output_h
+
+
+def cudnn_gru(inputs,
+ input_h,
+ params,
+ is_training,
+ input_mode=CUDNN_INPUT_LINEAR_MODE,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dropout=0.,
+ seed=0,
+ name=None):
+ """Cudnn GRU.
+
+ Args:
+ inputs: the input sequence to the RNN model. A Tensor of shape [?,
+ batch_size, input_size].
+ input_h: the initial hidden state for h. A Tensor of shape [num_layers,
+ batch_size, num_units].
+ params: the parameter buffer created for this model.
+ is_training: whether this operation will be used in training or inference
+ input_mode: indicate whether there is a linear projection between the
+ input and the actual computation before the first layer. It could be
+ 'linear_input', 'skip_input' or 'auto_select'.
+ 'linear_input' (default) always applies a linear projection of input
+ onto RNN hidden state. (standard RNN behavior).
+ 'skip_input' is only allowed when input_size == num_units;
+ 'auto_select' implies 'skip_input' when input_size == num_units;
+ otherwise, it implies 'linear_input'.
+ direction: the direction model that the model operates. Could be either
+ 'unidirectional' or 'bidirectional'
+ dropout: whether to enable dropout. With it is 0, dropout is disabled.
+ seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
+ for behavior.
+ name: name of the operation.
+ Returns:
+ outputs, output_h
+ """
+ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU,
+ input_mode, direction, dropout, seed, name)
+
+
+def cudnn_rnn_relu(inputs,
+ input_h,
+ params,
+ is_training,
+ input_mode=CUDNN_INPUT_LINEAR_MODE,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dropout=0.,
+ seed=0,
+ name=None):
+ """Cudnn RNN Relu.
+
+ Args:
+ inputs: the input sequence to the RNN model. A Tensor of shape [?,
+ batch_size, input_size].
+ input_h: the initial hidden state for h. A Tensor of shape [num_layers,
+ batch_size, num_units].
+ params: the parameter buffer created for this model.
+ is_training: whether this operation will be used in training or inference
+ input_mode: indicate whether there is a linear projection between the
+ input and the actual computation before the first layer. It could be
+ 'linear_input', 'skip_input' or 'auto_select'.
+ 'linear_input' (default) always applies a linear projection of input
+ onto RNN hidden state. (standard RNN behavior).
+ 'skip_input' is only allowed when input_size == num_units;
+ 'auto_select' implies 'skip_input' when input_size == num_units;
+ otherwise, it implies 'linear_input'.
+ direction: the direction model that the model operates. Could be either
+ 'unidirectional' or 'bidirectional'
+ dropout: whether to enable dropout. With it is 0, dropout is disabled.
+ seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
+ for behavior.
+ name: name of the operation.
+ Returns:
+ outputs, output_h
+ """
+ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
+ CUDNN_RNN_RELU, input_mode, direction, dropout,
+ seed, name)
+
+
+def cudnn_rnn_tanh(inputs,
+ input_h,
+ params,
+ is_training,
+ input_mode=CUDNN_INPUT_LINEAR_MODE,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dropout=0.,
+ seed=0,
+ name=None):
+ """Cudnn RNN Tanh.
+
+ Args:
+ inputs: the input sequence to the RNN model. A Tensor of shape [?,
+ batch_size, input_size].
+ input_h: the initial hidden state for h. A Tensor of shape [num_layers,
+ batch_size, num_units].
+ params: the parameter buffer created for this model.
+ is_training: whether this operation will be used in training or inference
+ input_mode: indicate whether there is a linear projection between the
+ input and the actual computation before the first layer. It could be
+ 'linear_input', 'skip_input' or 'auto_select'.
+ 'linear_input' (default) always applies a linear projection of input
+ onto RNN hidden state. (standard RNN behavior).
+ 'skip_input' is only allowed when input_size == num_units;
+ 'auto_select' implies 'skip_input' when input_size == num_units;
+ otherwise, it implies 'linear_input'.
+ direction: the direction model that the model operates. Could be either
+ 'unidirectional' or 'bidirectional'
+ dropout: whether to enable dropout. With it is 0, dropout is disabled.
+ seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
+ for behavior.
+ name: name of the operation.
+ Returns:
+ outputs, output_h
+ """
+ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
+ CUDNN_RNN_TANH, input_mode, direction, dropout,
+ seed, name)
+
+
+def cudnn_rnn_params_to_canonical(rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ params,
+ input_mode=CUDNN_INPUT_LINEAR_MODE,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dropout=0,
+ seed=0,
+ name=None):
+ """Convert cudnn opaque params to canonical.
+
+ Args:
+ rnn_mode: a string specifies the mode, under which this RNN model runs.
+ Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
+ num_layers: the number of layers for the RNN model.
+ num_units: the number of units within the RNN model.
+ input_size: the size of the input, it could be different from the
+ num_units.
+ params: opaque cudnn params var.
+ input_mode: indicate whether there is a linear projection between the
+ input and the actual computation before the first layer. It could be
+ 'linear_input', 'skip_input' or 'auto_select'.
+ 'linear_input' (default) always applies a linear projection of input
+ onto RNN hidden state. (standard RNN behavior).
+ 'skip_input' is only allowed when input_size == num_units;
+ 'auto_select' implies 'skip_input' when input_size == num_units;
+ otherwise, it implies 'linear_input'.
+ direction: the direction model that the model operates. Could be either
+ 'unidirectional' or 'bidirectional'
+ dropout: whether to enable dropout. With it is 0, dropout is disabled.
+ seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
+ for behavior.
+ name: name of the operation.
+ Returns:
+ weights list and bias list
+ Raises:
+ ValueError: if rnn_mode or direction is invalid.
+ """
+
+ _check_rnn_mode(rnn_mode)
+ _check_direction(direction)
+ num_params = _get_num_params(rnn_mode, num_layers, direction)
+ seed, seed2 = random_seed.get_seed(seed)
+ weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
+ rnn_mode=rnn_mode,
+ num_layers=num_layers,
+ num_units=num_units,
+ input_size=input_size,
+ params=params,
+ input_mode=input_mode,
+ direction=direction,
+ dropout=dropout,
+ seed=seed,
+ seed2=seed2,
+ num_params=num_params,
+ name=name)
+ return weights, biases
+
+
+def cudnn_rnn_canonical_to_params(rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ weights,
+ biases,
+ input_mode=CUDNN_INPUT_LINEAR_MODE,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dropout=0,
+ seed=0,
+ name=None):
+ """Converts params from the canonical format to a specific format of cuDNN.
+
+ Args:
+ rnn_mode: a string specifies the mode, under which this RNN model runs.
+ Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
+ num_layers: the number of layers for the RNN model.
+ num_units: the number of units within the RNN model.
+ input_size: the size of the input, it could be different from the
+ num_units.
+ weights: a Tensor for weight parameters.
+ biases: a Tensor for bias parameters.
+ input_mode: indicate whether there is a linear projection between the
+ input and the actual computation before the first layer. It could be
+ 'linear_input', 'skip_input' or 'auto_select'.
+ 'linear_input' (default) always applies a linear projection of input
+ onto RNN hidden state. (standard RNN behavior).
+ 'skip_input' is only allowed when input_size == num_units;
+ 'auto_select' implies 'skip_input' when input_size == num_units;
+ otherwise, it implies 'linear_input'.
+ direction: the direction model that the model operates. Could be either
+ 'unidirectional' or 'bidirectional'
+ dropout: whether to enable dropout. With it is 0, dropout is disabled.
+ seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
+ for behavior.
+ name: name of the operation.
+ Returns:
+ an opaque Cudnn param.
+ Raises:
+ ValueError: if rnn_mode or direction is invalid.
+ """
+ _check_rnn_mode(rnn_mode)
+ _check_direction(direction)
+ seed, seed2 = random_seed.get_seed(seed)
+ return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
+ rnn_mode=rnn_mode,
+ num_layers=num_layers,
+ num_units=num_units,
+ input_size=input_size,
+ weights=weights,
+ biases=biases,
+ input_mode=input_mode,
+ direction=direction,
+ dropout=dropout,
+ seed=seed,
+ seed2=seed2,
+ name=name)
+
+
+def cudnn_opaque_params_size(rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ input_mode=CUDNN_INPUT_LINEAR_MODE,
+ direction=CUDNN_RNN_UNIDIRECTION,
+ dtype=dtypes.float32,
+ dropout=0,
+ seed=0,
+ name=None):
+ """Returns opaque params size for specific Cudnn config.
+
+ Args:
+ rnn_mode: a string specifies the mode, under which this RNN model runs.
+ Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'.
+ num_layers: the number of layers for the RNN model.
+ num_units: the number of units within the RNN model.
+ input_size: the size of the input, it could be different from the
+ num_units.
+ input_mode: indicate whether there is a linear projection between the
+ input and the actual computation before the first layer. It could be
+ 'linear_input', 'skip_input' or 'auto_select'.
+ 'linear_input' (default) always applies a linear projection of input
+ onto RNN hidden state. (standard RNN behavior).
+ 'skip_input' is only allowed when input_size == num_units;
+ 'auto_select' implies 'skip_input' when input_size == num_units;
+ otherwise, it implies 'linear_input'.
+ direction: the direction model that the model operates. Could be either
+ 'unidirectional' or 'bidirectional'
+ dtype: one of tf.float32 or tf.float64.
+ dropout: whether to enable dropout. With it is 0, dropout is disabled.
+ seed: the op seed used for initializing dropout. See @{tf.set_random_seed}
+ for behavior.
+ name: name of the operation.
+ Returns:
+ a int, size of Cudnn opaque params.
+ Raises:
+ ValueError: if rnn_mode or direction is invalid.
+ """
+ _check_rnn_mode(rnn_mode)
+ _check_direction(direction)
+ seed, seed2 = random_seed.get_seed(seed)
+ return gen_cudnn_rnn_ops.cudnn_rnn_params_size(
+ rnn_mode=rnn_mode,
+ num_layers=num_layers,
+ num_units=num_units,
+ input_size=input_size,
+ T=dtype,
+ S=dtypes.int32,
+ dropout=dropout,
+ seed=seed,
+ seed2=seed2,
+ input_mode=input_mode,
+ direction=direction,
+ name=name)[0]
+
+
class _CudnnRNN(object):
"""Creates an RNN model using the underlying Cudnn implementation.
@@ -761,9 +1237,6 @@ class _CudnnRNN(object):
Raises:
ValueError: if direction is invalid.
"""
- if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION):
- raise ValueError("Invalid direction: %s, expect %s or %s",
- direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION)
self._num_layers = num_layers
self._num_units = num_units
self._input_size = input_size
@@ -772,10 +1245,7 @@ class _CudnnRNN(object):
self._direction = direction
self._dtype = dtype
self._dropout = dropout
- # get graph and op seed.
- self._seed, self._seed2 = random_seed.get_seed(seed)
- if self._seed is None and self._seed2 is None:
- self._seed, self._seed2 = 0, 0
+ self._seed = seed
@property
def input_mode(self):
@@ -807,18 +1277,16 @@ class _CudnnRNN(object):
Returns:
The calculated parameter buffer size.
"""
- return gen_cudnn_rnn_ops.cudnn_rnn_params_size(
+ return cudnn_opaque_params_size(
+ rnn_mode=self._rnn_mode,
num_layers=self._num_layers,
num_units=self._num_units,
input_size=self._input_size,
- T=self._dtype,
- S=dtypes.int32,
+ dtype=self._dtype,
dropout=self._dropout,
seed=self._seed,
- seed2=self._seed2,
- rnn_mode=self._rnn_mode,
input_mode=self._input_mode,
- direction=self._direction)[0]
+ direction=self._direction)
def __call__(self, input_data, input_h, input_c, params, is_training=True):
"""Runs the forward step for the RNN model.
@@ -837,22 +1305,17 @@ class _CudnnRNN(object):
output_h: the final state for h.
output_c: the final state for c. This is only relevant for LSTM.
"""
- if self._rnn_mode != CUDNN_LSTM:
- # For model that doesn't take input_c, replace with a dummy tensor.
- input_c = array_ops.constant([], dtype=self._dtype)
- output, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
- input=input_data,
- input_h=input_h,
- input_c=input_c,
- params=params,
- rnn_mode=self._rnn_mode,
+ return _cudnn_rnn(
+ input_data,
+ input_h,
+ input_c,
+ params,
+ is_training,
+ self._rnn_mode,
input_mode=self._input_mode,
direction=self._direction,
dropout=self._dropout,
- seed=self._seed,
- seed2=self._seed2,
- is_training=is_training)
- return (output, output_h, output_c)
+ seed=self._seed)
def params_to_canonical(self, params):
"""Converts params from a specific format of cuDNN to the canonical format.
@@ -863,22 +1326,16 @@ class _CudnnRNN(object):
Returns:
A function for the specific-to-canonical conversion.
"""
- num_params = self._num_layers * self._NUM_PARAMS_PER_LAYER
- if self._direction != CUDNN_RNN_UNIDIRECTION:
- num_params *= 2
- weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
+ return cudnn_rnn_params_to_canonical(
+ rnn_mode=self._rnn_mode,
num_layers=self._num_layers,
num_units=self._num_units,
input_size=self._input_size,
params=params,
- dropout=self._dropout,
- seed=self._seed,
- seed2=self._seed2,
- num_params=num_params,
- rnn_mode=self._rnn_mode,
input_mode=self._input_mode,
- direction=self._direction)
- return weights, biases
+ direction=self._direction,
+ dropout=self._dropout,
+ seed=self._seed)
def canonical_to_params(self, weights, biases):
"""Converts params from the canonical format to a specific format of cuDNN.
@@ -890,18 +1347,17 @@ class _CudnnRNN(object):
Returns:
A function for the canonical-to-params-to-specific conversion..
"""
- return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
+ return cudnn_rnn_canonical_to_params(
+ rnn_mode=self._rnn_mode,
num_layers=self._num_layers,
num_units=self._num_units,
input_size=self._input_size,
weights=weights,
biases=biases,
- dropout=self._dropout,
- seed=self._seed,
- seed2=self._seed2,
- rnn_mode=self._rnn_mode,
input_mode=self._input_mode,
- direction=self._direction)
+ direction=self._direction,
+ dropout=self._dropout,
+ seed=self._seed)
class CudnnLSTM(_CudnnRNN):
@@ -1036,9 +1492,16 @@ class _CudnnRNNNoInputC(_CudnnRNN):
output: the output sequuence.
output_h: the final state for h.
"""
- output, output_h, _ = super(_CudnnRNNNoInputC, self).__call__(
- input_data, input_h, None, params, is_training=is_training)
- return (output, output_h)
+ return _cudnn_rnn_no_input_c(
+ input_data,
+ input_h,
+ params,
+ is_training,
+ self._rnn_mode,
+ input_mode=self._input_mode,
+ direction=self._direction,
+ dropout=self._dropout,
+ seed=self._seed)
class CudnnGRU(_CudnnRNNNoInputC):