# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Module implementing RNN Cells. This module provides a number of basic commonly used RNN cells, such as LSTM (Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of operators that allow adding dropouts, projections, or embeddings for inputs. Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by calling the `rnn` ops several times. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import hashlib import numbers from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.keras import activations from tensorflow.python.keras import initializers from tensorflow.python.keras.utils import tf_utils from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export _BIAS_VARIABLE_NAME = "bias" _WEIGHTS_VARIABLE_NAME = "kernel" # This can be used with self.assertRaisesRegexp for assert_like_rnncell. ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell" def assert_like_rnncell(cell_name, cell): """Raises a TypeError if cell is not like an RNNCell. NOTE: Do not rely on the error message (in particular in tests) which can be subject to change to increase readability. Use ASSERT_LIKE_RNNCELL_ERROR_REGEXP. Args: cell_name: A string to give a meaningful error referencing to the name of the functionargument. cell: The object which should behave like an RNNCell. Raises: TypeError: A human-friendly exception. """ conditions = [ hasattr(cell, "output_size"), hasattr(cell, "state_size"), hasattr(cell, "get_initial_state") or hasattr(cell, "zero_state"), callable(cell), ] errors = [ "'output_size' property is missing", "'state_size' property is missing", "either 'zero_state' or 'get_initial_state' method is required", "is not callable" ] if not all(conditions): errors = [error for error, cond in zip(errors, conditions) if not cond] raise TypeError("The argument {!r} ({}) is not an RNNCell: {}.".format( cell_name, cell, ", ".join(errors))) def _concat(prefix, suffix, static=False): """Concat that enables int, Tensor, or TensorShape values. This function takes a size specification, which can be an integer, a TensorShape, or a Tensor, and converts it into a concatenated Tensor (if static = False) or a list of integers (if static = True). Args: prefix: The prefix; usually the batch size (and/or time step size). (TensorShape, int, or Tensor.) suffix: TensorShape, int, or Tensor. static: If `True`, return a python list with possibly unknown dimensions. Otherwise return a `Tensor`. Returns: shape: the concatenation of prefix and suffix. Raises: ValueError: if `suffix` is not a scalar or vector (or TensorShape). ValueError: if prefix or suffix was `None` and asked for dynamic Tensors out. """ if isinstance(prefix, ops.Tensor): p = prefix p_static = tensor_util.constant_value(prefix) if p.shape.ndims == 0: p = array_ops.expand_dims(p, 0) elif p.shape.ndims != 1: raise ValueError("prefix tensor must be either a scalar or vector, " "but saw tensor: %s" % p) else: p = tensor_shape.as_shape(prefix) p_static = p.as_list() if p.ndims is not None else None p = (constant_op.constant(p.as_list(), dtype=dtypes.int32) if p.is_fully_defined() else None) if isinstance(suffix, ops.Tensor): s = suffix s_static = tensor_util.constant_value(suffix) if s.shape.ndims == 0: s = array_ops.expand_dims(s, 0) elif s.shape.ndims != 1: raise ValueError("suffix tensor must be either a scalar or vector, " "but saw tensor: %s" % s) else: s = tensor_shape.as_shape(suffix) s_static = s.as_list() if s.ndims is not None else None s = (constant_op.constant(s.as_list(), dtype=dtypes.int32) if s.is_fully_defined() else None) if static: shape = tensor_shape.as_shape(p_static).concatenate(s_static) shape = shape.as_list() if shape.ndims is not None else None else: if p is None or s is None: raise ValueError("Provided a prefix or suffix of None: %s and %s" % (prefix, suffix)) shape = array_ops.concat((p, s), 0) return shape def _zero_state_tensors(state_size, batch_size, dtype): """Create tensors of zeros based on state_size, batch_size, and dtype.""" def get_state_shape(s): """Combine s with batch_size to get a proper tensor shape.""" c = _concat(batch_size, s) size = array_ops.zeros(c, dtype=dtype) if not context.executing_eagerly(): c_static = _concat(batch_size, s, static=True) size.set_shape(c_static) return size return nest.map_structure(get_state_shape, state_size) @tf_export("nn.rnn_cell.RNNCell") class RNNCell(base_layer.Layer): """Abstract object representing an RNN cell. Every `RNNCell` must have the properties below and implement `call` with the signature `(output, next_state) = call(input, state)`. The optional third input argument, `scope`, is allowed for backwards compatibility purposes; but should be left off for new subclasses. This definition of cell differs from the definition used in the literature. In the literature, 'cell' refers to an object with a single scalar output. This definition refers to a horizontal array of such units. An RNN cell, in the most abstract setting, is anything that has a state and performs some operation that takes a matrix of inputs. This operation results in an output matrix with `self.output_size` columns. If `self.state_size` is an integer, this operation also results in a new state matrix with `self.state_size` columns. If `self.state_size` is a (possibly nested tuple of) TensorShape object(s), then it should return a matching structure of Tensors having shape `[batch_size].concatenate(s)` for each `s` in `self.batch_size`. """ def __init__(self, trainable=True, name=None, dtype=None, **kwargs): super(RNNCell, self).__init__( trainable=trainable, name=name, dtype=dtype, **kwargs) # Attribute that indicates whether the cell is a TF RNN cell, due the slight # difference between TF and Keras RNN cell. self._is_tf_rnn_cell = True def __call__(self, inputs, state, scope=None): """Run this RNN cell on inputs, starting from the given state. Args: inputs: `2-D` tensor with shape `[batch_size, input_size]`. state: if `self.state_size` is an integer, this should be a `2-D Tensor` with shape `[batch_size, self.state_size]`. Otherwise, if `self.state_size` is a tuple of integers, this should be a tuple with shapes `[batch_size, s] for s in self.state_size`. scope: VariableScope for the created subgraph; defaults to class name. Returns: A pair containing: - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. - New state: Either a single `2-D` tensor, or a tuple of tensors matching the arity and shapes of `state`. """ if scope is not None: with vs.variable_scope(scope, custom_getter=self._rnn_get_variable) as scope: return super(RNNCell, self).__call__(inputs, state, scope=scope) else: scope_attrname = "rnncell_scope" scope = getattr(self, scope_attrname, None) if scope is None: scope = vs.variable_scope(vs.get_variable_scope(), custom_getter=self._rnn_get_variable) setattr(self, scope_attrname, scope) with scope: return super(RNNCell, self).__call__(inputs, state) def _rnn_get_variable(self, getter, *args, **kwargs): variable = getter(*args, **kwargs) if context.executing_eagerly(): trainable = variable._trainable # pylint: disable=protected-access else: trainable = ( variable in tf_variables.trainable_variables() or (isinstance(variable, tf_variables.PartitionedVariable) and list(variable)[0] in tf_variables.trainable_variables())) if trainable and variable not in self._trainable_weights: self._trainable_weights.append(variable) elif not trainable and variable not in self._non_trainable_weights: self._non_trainable_weights.append(variable) return variable @property def state_size(self): """size(s) of state(s) used by this cell. It can be represented by an Integer, a TensorShape or a tuple of Integers or TensorShapes. """ raise NotImplementedError("Abstract method") @property def output_size(self): """Integer or TensorShape: size of outputs produced by this cell.""" raise NotImplementedError("Abstract method") def build(self, _): # This tells the parent Layer object that it's OK to call # self.add_variable() inside the call() method. pass def get_initial_state(self, inputs=None, batch_size=None, dtype=None): if inputs is not None: # Validate the given batch_size and dtype against inputs if provided. inputs = ops.convert_to_tensor(inputs, name="inputs") if batch_size is not None: if tensor_util.is_tensor(batch_size): static_batch_size = tensor_util.constant_value( batch_size, partial=True) else: static_batch_size = batch_size if inputs.shape[0].value != static_batch_size: raise ValueError( "batch size from input tensor is different from the " "input param. Input tensor batch: {}, batch_size: {}".format( inputs.shape[0].value, batch_size)) if dtype is not None and inputs.dtype != dtype: raise ValueError( "dtype from input tensor is different from the " "input param. Input tensor dtype: {}, dtype: {}".format( inputs.dtype, dtype)) batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] dtype = inputs.dtype if None in [batch_size, dtype]: raise ValueError( "batch_size and dtype cannot be None while constructing initial " "state: batch_size={}, dtype={}".format(batch_size, dtype)) return self.zero_state(batch_size, dtype) def zero_state(self, batch_size, dtype): """Return zero-filled state tensor(s). Args: batch_size: int, float, or unit Tensor representing the batch size. dtype: the data type to use for the state. Returns: If `state_size` is an int or TensorShape, then the return value is a `N-D` tensor of shape `[batch_size, state_size]` filled with zeros. If `state_size` is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of `2-D` tensors with the shapes `[batch_size, s]` for each s in `state_size`. """ # Try to use the last cached zero_state. This is done to avoid recreating # zeros, especially when eager execution is enabled. state_size = self.state_size is_eager = context.executing_eagerly() if is_eager and hasattr(self, "_last_zero_state"): (last_state_size, last_batch_size, last_dtype, last_output) = getattr(self, "_last_zero_state") if (last_batch_size == batch_size and last_dtype == dtype and last_state_size == state_size): return last_output with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): output = _zero_state_tensors(state_size, batch_size, dtype) if is_eager: self._last_zero_state = (state_size, batch_size, dtype, output) return output class LayerRNNCell(RNNCell): """Subclass of RNNCells that act like proper `tf.Layer` objects. For backwards compatibility purposes, most `RNNCell` instances allow their `call` methods to instantiate variables via `tf.get_variable`. The underlying variable scope thus keeps track of any variables, and returning cached versions. This is atypical of `tf.layer` objects, which separate this part of layer building into a `build` method that is only called once. Here we provide a subclass for `RNNCell` objects that act exactly as `Layer` objects do. They must provide a `build` method and their `call` methods do not access Variables `tf.get_variable`. """ def __call__(self, inputs, state, scope=None, *args, **kwargs): """Run this RNN cell on inputs, starting from the given state. Args: inputs: `2-D` tensor with shape `[batch_size, input_size]`. state: if `self.state_size` is an integer, this should be a `2-D Tensor` with shape `[batch_size, self.state_size]`. Otherwise, if `self.state_size` is a tuple of integers, this should be a tuple with shapes `[batch_size, s] for s in self.state_size`. scope: optional cell scope. *args: Additional positional arguments. **kwargs: Additional keyword arguments. Returns: A pair containing: - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`. - New state: Either a single `2-D` tensor, or a tuple of tensors matching the arity and shapes of `state`. """ # Bypass RNNCell's variable capturing semantics for LayerRNNCell. # Instead, it is up to subclasses to provide a proper build # method. See the class docstring for more details. return base_layer.Layer.__call__(self, inputs, state, scope=scope, *args, **kwargs) @tf_export(v1=["nn.rnn_cell.BasicRNNCell"]) class BasicRNNCell(LayerRNNCell): """The most basic RNN cell. Note that this cell is not optimized for performance. Please use `tf.contrib.cudnn_rnn.CudnnRNNTanh` for better performance on GPU. Args: num_units: int, The number of units in the RNN cell. activation: Nonlinearity to use. Default: `tanh`. It could also be string that is within Keras activation function names. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases. dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`. **kwargs: Dict, keyword named properties for common layer attributes, like `trainable` etc when constructing the cell from configs of get_config(). """ @deprecated(None, "This class is equivalent as tf.keras.layers.SimpleRNNCell," " and will be replaced by that in Tensorflow 2.0.") def __init__(self, num_units, activation=None, reuse=None, name=None, dtype=None, **kwargs): super(BasicRNNCell, self).__init__( _reuse=reuse, name=name, dtype=dtype, **kwargs) if context.executing_eagerly() and context.num_gpus() > 0: logging.warn("%s: Note that this cell is not optimized for performance. " "Please use tf.contrib.cudnn_rnn.CudnnRNNTanh for better " "performance on GPU.", self) # Inputs must be 2-dimensional. self.input_spec = base_layer.InputSpec(ndim=2) self._num_units = num_units if activation: self._activation = activations.get(activation) else: self._activation = math_ops.tanh @property def state_size(self): return self._num_units @property def output_size(self): return self._num_units @tf_utils.shape_type_conversion def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % str(inputs_shape)) input_depth = inputs_shape[-1] self._kernel = self.add_variable( _WEIGHTS_VARIABLE_NAME, shape=[input_depth + self._num_units, self._num_units]) self._bias = self.add_variable( _BIAS_VARIABLE_NAME, shape=[self._num_units], initializer=init_ops.zeros_initializer(dtype=self.dtype)) self.built = True def call(self, inputs, state): """Most basic RNN: output = new_state = act(W * input + U * state + B).""" gate_inputs = math_ops.matmul( array_ops.concat([inputs, state], 1), self._kernel) gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) output = self._activation(gate_inputs) return output, output def get_config(self): config = { "num_units": self._num_units, "activation": activations.serialize(self._activation), "reuse": self._reuse, } base_config = super(BasicRNNCell, self).get_config() return dict(list(base_config.items()) + list(config.items())) @tf_export("nn.rnn_cell.GRUCell") class GRUCell(LayerRNNCell): """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078). Note that this cell is not optimized for performance. Please use `tf.contrib.cudnn_rnn.CudnnGRU` for better performance on GPU, or `tf.contrib.rnn.GRUBlockCellV2` for better performance on CPU. Args: num_units: int, The number of units in the GRU cell. activation: Nonlinearity to use. Default: `tanh`. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. kernel_initializer: (optional) The initializer to use for the weight and projection matrices. bias_initializer: (optional) The initializer to use for the bias. name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases. dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`. **kwargs: Dict, keyword named properties for common layer attributes, like `trainable` etc when constructing the cell from configs of get_config(). """ def __init__(self, num_units, activation=None, reuse=None, kernel_initializer=None, bias_initializer=None, name=None, dtype=None, **kwargs): super(GRUCell, self).__init__( _reuse=reuse, name=name, dtype=dtype, **kwargs) if context.executing_eagerly() and context.num_gpus() > 0: logging.warn("%s: Note that this cell is not optimized for performance. " "Please use tf.contrib.cudnn_rnn.CudnnGRU for better " "performance on GPU.", self) # Inputs must be 2-dimensional. self.input_spec = base_layer.InputSpec(ndim=2) self._num_units = num_units if activation: self._activation = activations.get(activation) else: self._activation = math_ops.tanh self._kernel_initializer = initializers.get(kernel_initializer) self._bias_initializer = initializers.get(bias_initializer) @property def state_size(self): return self._num_units @property def output_size(self): return self._num_units @tf_utils.shape_type_conversion def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % str(inputs_shape)) input_depth = inputs_shape[-1] self._gate_kernel = self.add_variable( "gates/%s" % _WEIGHTS_VARIABLE_NAME, shape=[input_depth + self._num_units, 2 * self._num_units], initializer=self._kernel_initializer) self._gate_bias = self.add_variable( "gates/%s" % _BIAS_VARIABLE_NAME, shape=[2 * self._num_units], initializer=( self._bias_initializer if self._bias_initializer is not None else init_ops.constant_initializer(1.0, dtype=self.dtype))) self._candidate_kernel = self.add_variable( "candidate/%s" % _WEIGHTS_VARIABLE_NAME, shape=[input_depth + self._num_units, self._num_units], initializer=self._kernel_initializer) self._candidate_bias = self.add_variable( "candidate/%s" % _BIAS_VARIABLE_NAME, shape=[self._num_units], initializer=( self._bias_initializer if self._bias_initializer is not None else init_ops.zeros_initializer(dtype=self.dtype))) self.built = True def call(self, inputs, state): """Gated recurrent unit (GRU) with nunits cells.""" gate_inputs = math_ops.matmul( array_ops.concat([inputs, state], 1), self._gate_kernel) gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) value = math_ops.sigmoid(gate_inputs) r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) r_state = r * state candidate = math_ops.matmul( array_ops.concat([inputs, r_state], 1), self._candidate_kernel) candidate = nn_ops.bias_add(candidate, self._candidate_bias) c = self._activation(candidate) new_h = u * state + (1 - u) * c return new_h, new_h def get_config(self): config = { "num_units": self._num_units, "kernel_initializer": initializers.serialize(self._kernel_initializer), "bias_initializer": initializers.serialize(self._bias_initializer), "activation": activations.serialize(self._activation), "reuse": self._reuse, } base_config = super(GRUCell, self).get_config() return dict(list(base_config.items()) + list(config.items())) _LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) @tf_export("nn.rnn_cell.LSTMStateTuple") class LSTMStateTuple(_LSTMStateTuple): """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state and `h` is the output. Only used when `state_is_tuple=True`. """ __slots__ = () @property def dtype(self): (c, h) = self if c.dtype != h.dtype: raise TypeError("Inconsistent internal state: %s vs %s" % (str(c.dtype), str(h.dtype))) return c.dtype # TODO(scottzhu): Stop exporting this class in TF 2.0. @tf_export("nn.rnn_cell.BasicLSTMCell") class BasicLSTMCell(LayerRNNCell): """DEPRECATED: Please use `tf.nn.rnn_cell.LSTMCell` instead. Basic LSTM recurrent network cell. The implementation is based on: http://arxiv.org/abs/1409.2329. We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training. It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline. For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell` that follows. Note that this cell is not optimized for performance. Please use `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for better performance on CPU. """ @deprecated(None, "This class is deprecated, please use " "tf.nn.rnn_cell.LSTMCell, which supports all the feature " "this cell currently has. Please replace the existing code " "with tf.nn.rnn_cell.LSTMCell(name='basic_lstm_cell').") def __init__(self, num_units, forget_bias=1.0, state_is_tuple=True, activation=None, reuse=None, name=None, dtype=None, **kwargs): """Initialize the basic LSTM cell. Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (see above). Must set to `0.0` manually when restoring from CudnnLSTM-trained checkpoints. state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. The latter behavior will soon be deprecated. activation: Activation function of the inner states. Default: `tanh`. It could also be string that is within Keras activation function names. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases. dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`. **kwargs: Dict, keyword named properties for common layer attributes, like `trainable` etc when constructing the cell from configs of get_config(). When restoring from CudnnLSTM-trained checkpoints, must use `CudnnCompatibleLSTMCell` instead. """ super(BasicLSTMCell, self).__init__( _reuse=reuse, name=name, dtype=dtype, **kwargs) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) if context.executing_eagerly() and context.num_gpus() > 0: logging.warn("%s: Note that this cell is not optimized for performance. " "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " "performance on GPU.", self) # Inputs must be 2-dimensional. self.input_spec = base_layer.InputSpec(ndim=2) self._num_units = num_units self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple if activation: self._activation = activations.get(activation) else: self._activation = math_ops.tanh @property def state_size(self): return (LSTMStateTuple(self._num_units, self._num_units) if self._state_is_tuple else 2 * self._num_units) @property def output_size(self): return self._num_units @tf_utils.shape_type_conversion def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % str(inputs_shape)) input_depth = inputs_shape[-1] h_depth = self._num_units self._kernel = self.add_variable( _WEIGHTS_VARIABLE_NAME, shape=[input_depth + h_depth, 4 * self._num_units]) self._bias = self.add_variable( _BIAS_VARIABLE_NAME, shape=[4 * self._num_units], initializer=init_ops.zeros_initializer(dtype=self.dtype)) self.built = True def call(self, inputs, state): """Long short-term memory cell (LSTM). Args: inputs: `2-D` tensor with shape `[batch_size, input_size]`. state: An `LSTMStateTuple` of state tensors, each shaped `[batch_size, num_units]`, if `state_is_tuple` has been set to `True`. Otherwise, a `Tensor` shaped `[batch_size, 2 * num_units]`. Returns: A pair containing the new hidden state, and the new state (either a `LSTMStateTuple` or a concatenated state, depending on `state_is_tuple`). """ sigmoid = math_ops.sigmoid one = constant_op.constant(1, dtype=dtypes.int32) # Parameters of gates are concatenated into one multiply for efficiency. if self._state_is_tuple: c, h = state else: c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one) gate_inputs = math_ops.matmul( array_ops.concat([inputs, h], 1), self._kernel) gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split( value=gate_inputs, num_or_size_splits=4, axis=one) forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) # Note that using `add` and `multiply` instead of `+` and `*` gives a # performance improvement. So using those at the cost of readability. add = math_ops.add multiply = math_ops.multiply new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))), multiply(sigmoid(i), self._activation(j))) new_h = multiply(self._activation(new_c), sigmoid(o)) if self._state_is_tuple: new_state = LSTMStateTuple(new_c, new_h) else: new_state = array_ops.concat([new_c, new_h], 1) return new_h, new_state def get_config(self): config = { "num_units": self._num_units, "forget_bias": self._forget_bias, "state_is_tuple": self._state_is_tuple, "activation": activations.serialize(self._activation), "reuse": self._reuse, } base_config = super(BasicLSTMCell, self).get_config() return dict(list(base_config.items()) + list(config.items())) @tf_export("nn.rnn_cell.LSTMCell") class LSTMCell(LayerRNNCell): """Long short-term memory unit (LSTM) recurrent network cell. The default non-peephole implementation is based on: https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf Felix Gers, Jurgen Schmidhuber, and Fred Cummins. "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. The peephole implementation is based on: https://research.google.com/pubs/archive/43905.pdf Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. The class uses optional peep-hole connections, optional cell clipping, and an optional projection layer. Note that this cell is not optimized for performance. Please use `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for better performance on CPU. """ def __init__(self, num_units, use_peepholes=False, cell_clip=None, initializer=None, num_proj=None, proj_clip=None, num_unit_shards=None, num_proj_shards=None, forget_bias=1.0, state_is_tuple=True, activation=None, reuse=None, name=None, dtype=None, **kwargs): """Initialize the parameters for an LSTM cell. Args: num_units: int, The number of units in the LSTM cell. use_peepholes: bool, set True to enable diagonal/peephole connections. cell_clip: (optional) A float value, if provided the cell state is clipped by this value prior to the cell output activation. initializer: (optional) The initializer to use for the weight and projection matrices. num_proj: (optional) int, The output dimensionality for the projection matrices. If None, no projection is performed. proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is provided, then the projected values are clipped elementwise to within `[-proj_clip, proj_clip]`. num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a variable_scope partitioner instead. num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a variable_scope partitioner instead. forget_bias: Biases of the forget gate are initialized by default to 1 in order to reduce the scale of forgetting at the beginning of the training. Must set it manually to `0.0` when restoring from CudnnLSTM trained checkpoints. state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. This latter behavior will soon be deprecated. activation: Activation function of the inner states. Default: `tanh`. It could also be string that is within Keras activation function names. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases. dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`. **kwargs: Dict, keyword named properties for common layer attributes, like `trainable` etc when constructing the cell from configs of get_config(). When restoring from CudnnLSTM-trained checkpoints, use `CudnnCompatibleLSTMCell` instead. """ super(LSTMCell, self).__init__( _reuse=reuse, name=name, dtype=dtype, **kwargs) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) if num_unit_shards is not None or num_proj_shards is not None: logging.warn( "%s: The num_unit_shards and proj_unit_shards parameters are " "deprecated and will be removed in Jan 2017. " "Use a variable scope with a partitioner instead.", self) if context.executing_eagerly() and context.num_gpus() > 0: logging.warn("%s: Note that this cell is not optimized for performance. " "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better " "performance on GPU.", self) # Inputs must be 2-dimensional. self.input_spec = base_layer.InputSpec(ndim=2) self._num_units = num_units self._use_peepholes = use_peepholes self._cell_clip = cell_clip self._initializer = initializers.get(initializer) self._num_proj = num_proj self._proj_clip = proj_clip self._num_unit_shards = num_unit_shards self._num_proj_shards = num_proj_shards self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple if activation: self._activation = activations.get(activation) else: self._activation = math_ops.tanh if num_proj: self._state_size = ( LSTMStateTuple(num_units, num_proj) if state_is_tuple else num_units + num_proj) self._output_size = num_proj else: self._state_size = ( LSTMStateTuple(num_units, num_units) if state_is_tuple else 2 * num_units) self._output_size = num_units @property def state_size(self): return self._state_size @property def output_size(self): return self._output_size @tf_utils.shape_type_conversion def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % str(inputs_shape)) input_depth = inputs_shape[-1] h_depth = self._num_units if self._num_proj is None else self._num_proj maybe_partitioner = ( partitioned_variables.fixed_size_partitioner(self._num_unit_shards) if self._num_unit_shards is not None else None) self._kernel = self.add_variable( _WEIGHTS_VARIABLE_NAME, shape=[input_depth + h_depth, 4 * self._num_units], initializer=self._initializer, partitioner=maybe_partitioner) if self.dtype is None: initializer = init_ops.zeros_initializer else: initializer = init_ops.zeros_initializer(dtype=self.dtype) self._bias = self.add_variable( _BIAS_VARIABLE_NAME, shape=[4 * self._num_units], initializer=initializer) if self._use_peepholes: self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units], initializer=self._initializer) self._w_i_diag = self.add_variable("w_i_diag", shape=[self._num_units], initializer=self._initializer) self._w_o_diag = self.add_variable("w_o_diag", shape=[self._num_units], initializer=self._initializer) if self._num_proj is not None: maybe_proj_partitioner = ( partitioned_variables.fixed_size_partitioner(self._num_proj_shards) if self._num_proj_shards is not None else None) self._proj_kernel = self.add_variable( "projection/%s" % _WEIGHTS_VARIABLE_NAME, shape=[self._num_units, self._num_proj], initializer=self._initializer, partitioner=maybe_proj_partitioner) self.built = True def call(self, inputs, state): """Run one step of LSTM. Args: inputs: input Tensor, must be 2-D, `[batch, input_size]`. state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. Returns: A tuple containing: - A `2-D, [batch, output_dim]`, Tensor representing the output of the LSTM after reading `inputs` when previous state was `state`. Here output_dim is: num_proj if num_proj was set, num_units otherwise. - Tensor(s) representing the new state of LSTM after reading `inputs` when the previous state was `state`. Same type and shape(s) as `state`. Raises: ValueError: If input size cannot be inferred from inputs via static shape inference. """ num_proj = self._num_units if self._num_proj is None else self._num_proj sigmoid = math_ops.sigmoid if self._state_is_tuple: (c_prev, m_prev) = state else: c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") # i = input_gate, j = new_input, f = forget_gate, o = output_gate lstm_matrix = math_ops.matmul( array_ops.concat([inputs, m_prev], 1), self._kernel) lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias) i, j, f, o = array_ops.split( value=lstm_matrix, num_or_size_splits=4, axis=1) # Diagonal connections if self._use_peepholes: c = (sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) else: c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * self._activation(j)) if self._cell_clip is not None: # pylint: disable=invalid-unary-operand-type c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) # pylint: enable=invalid-unary-operand-type if self._use_peepholes: m = sigmoid(o + self._w_o_diag * c) * self._activation(c) else: m = sigmoid(o) * self._activation(c) if self._num_proj is not None: m = math_ops.matmul(m, self._proj_kernel) if self._proj_clip is not None: # pylint: disable=invalid-unary-operand-type m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) # pylint: enable=invalid-unary-operand-type new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else array_ops.concat([c, m], 1)) return m, new_state def get_config(self): config = { "num_units": self._num_units, "use_peepholes": self._use_peepholes, "cell_clip": self._cell_clip, "initializer": initializers.serialize(self._initializer), "num_proj": self._num_proj, "proj_clip": self._proj_clip, "num_unit_shards": self._num_unit_shards, "num_proj_shards": self._num_proj_shards, "forget_bias": self._forget_bias, "state_is_tuple": self._state_is_tuple, "activation": activations.serialize(self._activation), "reuse": self._reuse, } base_config = super(LSTMCell, self).get_config() return dict(list(base_config.items()) + list(config.items())) def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs): ix = [0] def enumerated_fn(*inner_args, **inner_kwargs): r = map_fn(ix[0], *inner_args, **inner_kwargs) ix[0] += 1 return r return nest.map_structure_up_to(shallow_structure, enumerated_fn, *args, **kwargs) def _default_dropout_state_filter_visitor(substate): if isinstance(substate, LSTMStateTuple): # Do not perform dropout on the memory state. return LSTMStateTuple(c=False, h=True) elif isinstance(substate, tensor_array_ops.TensorArray): return False return True @tf_export("nn.rnn_cell.DropoutWrapper") class DropoutWrapper(RNNCell): """Operator adding dropout to inputs and outputs of the given cell.""" def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0, state_keep_prob=1.0, variational_recurrent=False, input_size=None, dtype=None, seed=None, dropout_state_filter_visitor=None): """Create a cell with added input, state, and/or output dropout. If `variational_recurrent` is set to `True` (**NOT** the default behavior), then the same dropout mask is applied at every step, as described in: Y. Gal, Z Ghahramani. "A Theoretically Grounded Application of Dropout in Recurrent Neural Networks". https://arxiv.org/abs/1512.05287 Otherwise a different dropout mask is applied at every time step. Note, by default (unless a custom `dropout_state_filter` is provided), the memory state (`c` component of any `LSTMStateTuple`) passing through a `DropoutWrapper` is never modified. This behavior is described in the above article. Args: cell: an RNNCell, a projection to output_size is added to it. input_keep_prob: unit Tensor or float between 0 and 1, input keep probability; if it is constant and 1, no input dropout will be added. output_keep_prob: unit Tensor or float between 0 and 1, output keep probability; if it is constant and 1, no output dropout will be added. state_keep_prob: unit Tensor or float between 0 and 1, output keep probability; if it is constant and 1, no output dropout will be added. State dropout is performed on the outgoing states of the cell. **Note** the state components to which dropout is applied when `state_keep_prob` is in `(0, 1)` are also determined by the argument `dropout_state_filter_visitor` (e.g. by default dropout is never applied to the `c` component of an `LSTMStateTuple`). variational_recurrent: Python bool. If `True`, then the same dropout pattern is applied across all time steps per run call. If this parameter is set, `input_size` **must** be provided. input_size: (optional) (possibly nested tuple of) `TensorShape` objects containing the depth(s) of the input tensors expected to be passed in to the `DropoutWrapper`. Required and used **iff** `variational_recurrent = True` and `input_keep_prob < 1`. dtype: (optional) The `dtype` of the input, state, and output tensors. Required and used **iff** `variational_recurrent = True`. seed: (optional) integer, the randomness seed. dropout_state_filter_visitor: (optional), default: (see below). Function that takes any hierarchical level of the state and returns a scalar or depth=1 structure of Python booleans describing which terms in the state should be dropped out. In addition, if the function returns `True`, dropout is applied across this sublevel. If the function returns `False`, dropout is not applied across this entire sublevel. Default behavior: perform dropout on all terms except the memory (`c`) state of `LSTMCellState` objects, and don't try to apply dropout to `TensorArray` objects: ``` def dropout_state_filter_visitor(s): if isinstance(s, LSTMCellState): # Never perform dropout on the c state. return LSTMCellState(c=False, h=True) elif isinstance(s, TensorArray): return False return True ``` Raises: TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided but not `callable`. ValueError: if any of the keep_probs are not between 0 and 1. """ super(DropoutWrapper, self).__init__() assert_like_rnncell("cell", cell) if (dropout_state_filter_visitor is not None and not callable(dropout_state_filter_visitor)): raise TypeError("dropout_state_filter_visitor must be callable") self._dropout_state_filter = ( dropout_state_filter_visitor or _default_dropout_state_filter_visitor) with ops.name_scope("DropoutWrapperInit"): def tensor_and_const_value(v): tensor_value = ops.convert_to_tensor(v) const_value = tensor_util.constant_value(tensor_value) return (tensor_value, const_value) for prob, attr in [(input_keep_prob, "input_keep_prob"), (state_keep_prob, "state_keep_prob"), (output_keep_prob, "output_keep_prob")]: tensor_prob, const_prob = tensor_and_const_value(prob) if const_prob is not None: if const_prob < 0 or const_prob > 1: raise ValueError("Parameter %s must be between 0 and 1: %d" % (attr, const_prob)) setattr(self, "_%s" % attr, float(const_prob)) else: setattr(self, "_%s" % attr, tensor_prob) # Set cell, variational_recurrent, seed before running the code below self._cell = cell if isinstance(cell, checkpointable.CheckpointableBase): self._track_checkpointable(self._cell, name="cell") self._variational_recurrent = variational_recurrent self._seed = seed self._recurrent_input_noise = None self._recurrent_state_noise = None self._recurrent_output_noise = None if variational_recurrent: if dtype is None: raise ValueError( "When variational_recurrent=True, dtype must be provided") def convert_to_batch_shape(s): # Prepend a 1 for the batch dimension; for recurrent # variational dropout we use the same dropout mask for all # batch elements. return array_ops.concat( ([1], tensor_shape.TensorShape(s).as_list()), 0) def batch_noise(s, inner_seed): shape = convert_to_batch_shape(s) return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype) if (not isinstance(self._input_keep_prob, numbers.Real) or self._input_keep_prob < 1.0): if input_size is None: raise ValueError( "When variational_recurrent=True and input_keep_prob < 1.0 or " "is unknown, input_size must be provided") self._recurrent_input_noise = _enumerated_map_structure_up_to( input_size, lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)), input_size) self._recurrent_state_noise = _enumerated_map_structure_up_to( cell.state_size, lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)), cell.state_size) self._recurrent_output_noise = _enumerated_map_structure_up_to( cell.output_size, lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)), cell.output_size) def _gen_seed(self, salt_prefix, index): if self._seed is None: return None salt = "%s_%d" % (salt_prefix, index) string = (str(self._seed) + salt).encode("utf-8") return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF @property def wrapped_cell(self): return self._cell @property def state_size(self): return self._cell.state_size @property def output_size(self): return self._cell.output_size def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): return self._cell.zero_state(batch_size, dtype) def _variational_recurrent_dropout_value( self, index, value, noise, keep_prob): """Performs dropout given the pre-calculated noise tensor.""" # uniform [keep_prob, 1.0 + keep_prob) random_tensor = keep_prob + noise # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) binary_tensor = math_ops.floor(random_tensor) ret = math_ops.div(value, keep_prob) * binary_tensor ret.set_shape(value.get_shape()) return ret def _dropout(self, values, salt_prefix, recurrent_noise, keep_prob, shallow_filtered_substructure=None): """Decides whether to perform standard dropout or recurrent dropout.""" if shallow_filtered_substructure is None: # Put something so we traverse the entire structure; inside the # dropout function we check to see if leafs of this are bool or not. shallow_filtered_substructure = values if not self._variational_recurrent: def dropout(i, do_dropout, v): if not isinstance(do_dropout, bool) or do_dropout: return nn_ops.dropout( v, keep_prob=keep_prob, seed=self._gen_seed(salt_prefix, i)) else: return v return _enumerated_map_structure_up_to( shallow_filtered_substructure, dropout, *[shallow_filtered_substructure, values]) else: def dropout(i, do_dropout, v, n): if not isinstance(do_dropout, bool) or do_dropout: return self._variational_recurrent_dropout_value(i, v, n, keep_prob) else: return v return _enumerated_map_structure_up_to( shallow_filtered_substructure, dropout, *[shallow_filtered_substructure, values, recurrent_noise]) def __call__(self, inputs, state, scope=None): """Run the cell with the declared dropouts.""" def _should_dropout(p): return (not isinstance(p, float)) or p < 1 if _should_dropout(self._input_keep_prob): inputs = self._dropout(inputs, "input", self._recurrent_input_noise, self._input_keep_prob) output, new_state = self._cell(inputs, state, scope=scope) if _should_dropout(self._state_keep_prob): # Identify which subsets of the state to perform dropout on and # which ones to keep. shallow_filtered_substructure = nest.get_traverse_shallow_structure( self._dropout_state_filter, new_state) new_state = self._dropout(new_state, "state", self._recurrent_state_noise, self._state_keep_prob, shallow_filtered_substructure) if _should_dropout(self._output_keep_prob): output = self._dropout(output, "output", self._recurrent_output_noise, self._output_keep_prob) return output, new_state @tf_export("nn.rnn_cell.ResidualWrapper") class ResidualWrapper(RNNCell): """RNNCell wrapper that ensures cell inputs are added to the outputs.""" def __init__(self, cell, residual_fn=None): """Constructs a `ResidualWrapper` for `cell`. Args: cell: An instance of `RNNCell`. residual_fn: (Optional) The function to map raw cell inputs and raw cell outputs to the actual cell outputs of the residual network. Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs and outputs. """ super(ResidualWrapper, self).__init__() self._cell = cell if isinstance(cell, checkpointable.CheckpointableBase): self._track_checkpointable(self._cell, name="cell") self._residual_fn = residual_fn @property def state_size(self): return self._cell.state_size @property def output_size(self): return self._cell.output_size def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): return self._cell.zero_state(batch_size, dtype) def __call__(self, inputs, state, scope=None): """Run the cell and then apply the residual_fn on its inputs to its outputs. Args: inputs: cell inputs. state: cell state. scope: optional cell scope. Returns: Tuple of cell outputs and new state. Raises: TypeError: If cell inputs and outputs have different structure (type). ValueError: If cell inputs and outputs have different structure (value). """ outputs, new_state = self._cell(inputs, state, scope=scope) # Ensure shapes match def assert_shape_match(inp, out): inp.get_shape().assert_is_compatible_with(out.get_shape()) def default_residual_fn(inputs, outputs): nest.assert_same_structure(inputs, outputs) nest.map_structure(assert_shape_match, inputs, outputs) return nest.map_structure(lambda inp, out: inp + out, inputs, outputs) res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs) return (res_outputs, new_state) @tf_export("nn.rnn_cell.DeviceWrapper") class DeviceWrapper(RNNCell): """Operator that ensures an RNNCell runs on a particular device.""" def __init__(self, cell, device): """Construct a `DeviceWrapper` for `cell` with device `device`. Ensures the wrapped `cell` is called with `tf.device(device)`. Args: cell: An instance of `RNNCell`. device: A device string or function, for passing to `tf.device`. """ super(DeviceWrapper, self).__init__() self._cell = cell if isinstance(cell, checkpointable.CheckpointableBase): self._track_checkpointable(self._cell, name="cell") self._device = device @property def state_size(self): return self._cell.state_size @property def output_size(self): return self._cell.output_size def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): with ops.device(self._device): return self._cell.zero_state(batch_size, dtype) def __call__(self, inputs, state, scope=None): """Run the cell on specified device.""" with ops.device(self._device): return self._cell(inputs, state, scope=scope) @tf_export("nn.rnn_cell.MultiRNNCell") class MultiRNNCell(RNNCell): """RNN cell composed sequentially of multiple simple cells. Example: ```python num_units = [128, 64] cells = [BasicLSTMCell(num_units=n) for n in num_units] stacked_rnn_cell = MultiRNNCell(cells) ``` """ def __init__(self, cells, state_is_tuple=True): """Create a RNN cell composed sequentially of a number of RNNCells. Args: cells: list of RNNCells that will be composed in this order. state_is_tuple: If True, accepted and returned states are n-tuples, where `n = len(cells)`. If False, the states are all concatenated along the column axis. This latter behavior will soon be deprecated. Raises: ValueError: if cells is empty (not allowed), or at least one of the cells returns a state tuple but the flag `state_is_tuple` is `False`. """ super(MultiRNNCell, self).__init__() if not cells: raise ValueError("Must specify at least one cell for MultiRNNCell.") if not nest.is_sequence(cells): raise TypeError( "cells must be a list or tuple, but saw: %s." % cells) if len(set([id(cell) for cell in cells])) < len(cells): logging.log_first_n(logging.WARN, "At least two cells provided to MultiRNNCell " "are the same object and will share weights.", 1) self._cells = cells for cell_number, cell in enumerate(self._cells): # Add Checkpointable dependencies on these cells so their variables get # saved with this object when using object-based saving. if isinstance(cell, checkpointable.CheckpointableBase): # TODO(allenl): Track down non-Checkpointable callers. self._track_checkpointable(cell, name="cell-%d" % (cell_number,)) self._state_is_tuple = state_is_tuple if not state_is_tuple: if any(nest.is_sequence(c.state_size) for c in self._cells): raise ValueError("Some cells return tuples of states, but the flag " "state_is_tuple is not set. State sizes are: %s" % str([c.state_size for c in self._cells])) @property def state_size(self): if self._state_is_tuple: return tuple(cell.state_size for cell in self._cells) else: return sum([cell.state_size for cell in self._cells]) @property def output_size(self): return self._cells[-1].output_size def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): if self._state_is_tuple: return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells) else: # We know here that state_size of each cell is not a tuple and # presumably does not contain TensorArrays or anything else fancy return super(MultiRNNCell, self).zero_state(batch_size, dtype) def call(self, inputs, state): """Run this multi-layer cell on inputs, starting from state.""" cur_state_pos = 0 cur_inp = inputs new_states = [] for i, cell in enumerate(self._cells): with vs.variable_scope("cell_%d" % i): if self._state_is_tuple: if not nest.is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) cur_state = state[i] else: cur_state = array_ops.slice(state, [0, cur_state_pos], [-1, cell.state_size]) cur_state_pos += cell.state_size cur_inp, new_state = cell(cur_inp, cur_state) new_states.append(new_state) new_states = (tuple(new_states) if self._state_is_tuple else array_ops.concat(new_states, 1)) return cur_inp, new_states