diff options
Diffstat (limited to 'tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py')
-rw-r--r-- | tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py | 497 |
1 files changed, 341 insertions, 156 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py index 269b224581..252788140f 100644 --- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py +++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py @@ -25,6 +25,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import variable_scope as vs + +from tensorflow.python.platform import tf_logging as logging from tensorflow.contrib import layers from tensorflow.contrib import rnn @@ -53,7 +55,9 @@ class GridRNNCell(rnn.RNNCell): non_recurrent_dims=None, tied=False, cell_fn=None, - non_recurrent_fn=None): + non_recurrent_fn=None, + state_is_tuple=True, + output_is_tuple=True): """Initialize the parameters of a Grid RNN cell Args: @@ -68,26 +72,47 @@ class GridRNNCell(rnn.RNNCell): non_recurrent_dims: int or list, List of dimensions that are not recurrent. The transfer function for non-recurrent dimensions is specified - via `non_recurrent_fn`, - which is default to be `tensorflow.nn.relu`. + via `non_recurrent_fn`, which is + default to be `tensorflow.nn.relu`. tied: bool, Whether to share the weights among the dimensions of this GridRNN cell. If there are non-recurrent dimensions in the grid, weights are - shared between each - group of recurrent and non-recurrent dimensions. - cell_fn: function, a function which returns the recurrent cell object. Has - to be in the following signature: - def cell_func(num_units, input_size): + shared between each group of recurrent and non-recurrent + dimensions. + cell_fn: function, a function which returns the recurrent cell object. + Has to be in the following signature: + ``` + def cell_func(num_units): # ... - + ``` and returns an object of type `RNNCell`. If None, LSTMCell with default parameters will be used. + Note that if you use a custom RNNCell (with `cell_fn`), it is your + responsibility to make sure the inner cell use `state_is_tuple=True`. + non_recurrent_fn: a tensorflow Op that will be the transfer function of the non-recurrent dimensions + state_is_tuple: If True, accepted and returned states are tuples of the + states of the recurrent dimensions. If False, they are concatenated + along the column axis. The latter behavior will soon be deprecated. + + Note that if you use a custom RNNCell (with `cell_fn`), it is your + responsibility to make sure the inner cell use `state_is_tuple=True`. + + output_is_tuple: If True, the output is a tuple of the outputs of the + recurrent dimensions. If False, they are concatenated along the + column axis. The later behavior will soon be deprecated. Raises: TypeError: if cell_fn does not return an RNNCell instance. """ + if not state_is_tuple: + logging.warning('%s: Using a concatenated state is slower and will ' + 'soon be deprecated. Use state_is_tuple=True.', self) + if not output_is_tuple: + logging.warning('%s: Using a concatenated output is slower and will' + 'soon be deprecated. Use output_is_tuple=True.', self) + if num_dims < 1: raise ValueError('dims must be >= 1: {}'.format(num_dims)) @@ -96,37 +121,41 @@ class GridRNNCell(rnn.RNNCell): non_recurrent_fn or nn.relu, tied, num_units) - cell_input_size = (self._config.num_dims - 1) * num_units + self._state_is_tuple = state_is_tuple + self._output_is_tuple = output_is_tuple + if cell_fn is None: my_cell_fn = functools.partial( - rnn.LSTMCell, - num_units=num_units, input_size=cell_input_size, - state_is_tuple=False) + rnn.LSTMCell, num_units=num_units, state_is_tuple=state_is_tuple) else: - my_cell_fn = lambda: cell_fn(num_units, cell_input_size) + my_cell_fn = lambda: cell_fn(num_units) if tied: self._cells = [my_cell_fn()] * num_dims else: self._cells = [my_cell_fn() for _ in range(num_dims)] if not isinstance(self._cells[0], rnn.RNNCell): - raise TypeError( - 'cell_fn must return an RNNCell instance, saw: %s' - % type(self._cells[0])) + raise TypeError('cell_fn must return an RNNCell instance, saw: %s' % + type(self._cells[0])) - @property - def input_size(self): - # temporarily using num_units as the input_size of each dimension. - # The actual input size only determined when this cell get invoked, - # so this information can be considered unreliable. - return self._config.num_units * len(self._config.inputs) + if self._output_is_tuple: + self._output_size = tuple(self._cells[0].output_size + for _ in self._config.outputs) + else: + self._output_size = self._cells[0].output_size * len(self._config.outputs) + + if self._state_is_tuple: + self._state_size = tuple(self._cells[0].state_size + for _ in self._config.recurrents) + else: + self._state_size = self._cell_state_size() * len(self._config.recurrents) @property def output_size(self): - return self._cells[0].output_size * len(self._config.outputs) + return self._output_size @property def state_size(self): - return self._cells[0].state_size * len(self._config.recurrents) + return self._state_size def __call__(self, inputs, state, scope=None): """Run one step of GridRNN. @@ -145,76 +174,148 @@ class GridRNNCell(rnn.RNNCell): - A 2D, batch x state_size, Tensor representing the new state of the cell after reading "inputs" when previous state was "state". """ - state_sz = state.get_shape().as_list()[1] - if self.state_size != state_sz: - raise ValueError( - 'Actual state size not same as specified: {} vs {}.'.format( - state_sz, self.state_size)) - conf = self._config - dtype = inputs.dtype if inputs is not None else state.dtype + dtype = inputs.dtype - # c_prev is `m`, and m_prev is `h` in the paper. - # Keep c and m here for consistency with the codebase - c_prev = [None] * self._config.num_dims - m_prev = [None] * self._config.num_dims - cell_output_size = self._cells[0].state_size - conf.num_units - - # for LSTM : state = memory cell + output, hence cell_output_size > 0 - # for GRU/RNN: state = output (whose size is equal to _num_units), - # hence cell_output_size = 0 - for recurrent_dim, start_idx in zip(self._config.recurrents, range( - 0, self.state_size, self._cells[0].state_size)): - if cell_output_size > 0: - c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], - [-1, conf.num_units]) - m_prev[recurrent_dim] = array_ops.slice( - state, [0, start_idx + conf.num_units], [-1, cell_output_size]) - else: - m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], - [-1, conf.num_units]) + c_prev, m_prev, cell_output_size = self._extract_states(state) new_output = [None] * conf.num_dims new_state = [None] * conf.num_dims with vs.variable_scope(scope or type(self).__name__): # GridRNNCell + # project input, populate c_prev and m_prev + self._project_input(inputs, c_prev, m_prev, cell_output_size > 0) - # project input - if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len( - conf.inputs) > 0: - input_splits = array_ops.split( - value=inputs, num_or_size_splits=len(conf.inputs), axis=1) - input_sz = input_splits[0].get_shape().as_list()[1] - - for i, j in enumerate(conf.inputs): - input_project_m = vs.get_variable( - 'project_m_{}'.format(j), [input_sz, conf.num_units], dtype=dtype) - m_prev[j] = math_ops.matmul(input_splits[i], input_project_m) - - if cell_output_size > 0: - input_project_c = vs.get_variable( - 'project_c_{}'.format(j), [input_sz, conf.num_units], - dtype=dtype) - c_prev[j] = math_ops.matmul(input_splits[i], input_project_c) - + # propagate along dimensions, first for non-priority dimensions + # then priority dimensions _propagate(conf.non_priority, conf, self._cells, c_prev, m_prev, new_output, new_state, True) _propagate(conf.priority, conf, self._cells, c_prev, m_prev, new_output, new_state, False) + # collect outputs and states output_tensors = [new_output[i] for i in self._config.outputs] - output = array_ops.zeros( - [0, 0], dtype) if len(output_tensors) == 0 else array_ops.concat( - output_tensors, 1) + if self._output_is_tuple: + output = tuple(output_tensors) + else: + if output_tensors: + output = array_ops.concat(output_tensors, 1) + else: + output = array_ops.zeros([0, 0], dtype) - state_tensors = [new_state[i] for i in self._config.recurrents] - states = array_ops.zeros( - [0, 0], - dtype) if len(state_tensors) == 0 else array_ops.concat(state_tensors, - 1) + if self._state_is_tuple: + states = tuple(new_state[i] for i in self._config.recurrents) + else: + # concat each state first, then flatten the whole thing + state_tensors = [ + x for i in self._config.recurrents for x in new_state[i] + ] + if state_tensors: + states = array_ops.concat(state_tensors, 1) + else: + states = array_ops.zeros([0, 0], dtype) return output, states + def _extract_states(self, state): + """Extract the cell and previous output tensors from the given state. + + Args: + state: The RNN state. + + Returns: + Tuple of the cell value, previous output, and cell_output_size. + + Raises: + ValueError: If len(self._config.recurrents) != len(state). + """ + conf = self._config + + # c_prev is `m` (cell value), and + # m_prev is `h` (previous output) in the paper. + # Keeping c and m here for consistency with the codebase + c_prev = [None] * conf.num_dims + m_prev = [None] * conf.num_dims + + # for LSTM : state = memory cell + output, hence cell_output_size > 0 + # for GRU/RNN: state = output (whose size is equal to _num_units), + # hence cell_output_size = 0 + total_cell_state_size = self._cell_state_size() + cell_output_size = total_cell_state_size - conf.num_units + + if self._state_is_tuple: + if len(conf.recurrents) != len(state): + raise ValueError('Expected state as a tuple of {} ' + 'element'.format(len(conf.recurrents))) + + for recurrent_dim, recurrent_state in zip(conf.recurrents, state): + if cell_output_size > 0: + c_prev[recurrent_dim], m_prev[recurrent_dim] = recurrent_state + else: + m_prev[recurrent_dim] = recurrent_state + else: + for recurrent_dim, start_idx in zip(conf.recurrents, + range(0, self.state_size, + total_cell_state_size)): + if cell_output_size > 0: + c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], + [-1, conf.num_units]) + m_prev[recurrent_dim] = array_ops.slice( + state, [0, start_idx + conf.num_units], [-1, cell_output_size]) + else: + m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], + [-1, conf.num_units]) + return c_prev, m_prev, cell_output_size + + def _project_input(self, inputs, c_prev, m_prev, with_c): + """Fills in c_prev and m_prev with projected input, for input dimensions. + + Args: + inputs: inputs tensor + c_prev: cell value + m_prev: previous output + with_c: boolean; whether to include project_c. + + Raises: + ValueError: if len(self._config.input) != len(inputs) + """ + conf = self._config + + if (inputs is not None and inputs.get_shape().with_rank(2)[1].value > 0 and + conf.inputs): + if isinstance(inputs, tuple): + if len(conf.inputs) != len(inputs): + raise ValueError('Expect inputs as a tuple of {} ' + 'tensors'.format(len(conf.inputs))) + input_splits = inputs + else: + input_splits = array_ops.split( + value=inputs, num_or_size_splits=len(conf.inputs), axis=1) + input_sz = input_splits[0].get_shape().with_rank(2)[1].value + + for i, j in enumerate(conf.inputs): + input_project_m = vs.get_variable( + 'project_m_{}'.format(j), [input_sz, conf.num_units], + dtype=inputs.dtype) + m_prev[j] = math_ops.matmul(input_splits[i], input_project_m) + + if with_c: + input_project_c = vs.get_variable( + 'project_c_{}'.format(j), [input_sz, conf.num_units], + dtype=inputs.dtype) + c_prev[j] = math_ops.matmul(input_splits[i], input_project_c) + + def _cell_state_size(self): + """Total size of the state of the inner cell used in this grid. + + Returns: + Total size of the state of the inner cell. + """ + state_sizes = self._cells[0].state_size + if isinstance(state_sizes, tuple): + return sum(state_sizes) + return state_sizes + """Specialized cells, for convenience """ @@ -223,11 +324,17 @@ class GridRNNCell(rnn.RNNCell): class Grid1BasicRNNCell(GridRNNCell): """1D BasicRNN cell""" - def __init__(self, num_units): + def __init__(self, num_units, state_is_tuple=True, output_is_tuple=True): super(Grid1BasicRNNCell, self).__init__( - num_units=num_units, num_dims=1, - input_dims=0, output_dims=0, priority_dims=0, tied=False, - cell_fn=lambda n, i: rnn.BasicRNNCell(num_units=n, input_size=i)) + num_units=num_units, + num_dims=1, + input_dims=0, + output_dims=0, + priority_dims=0, + tied=False, + cell_fn=lambda n: rnn.BasicRNNCell(num_units=n), + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) class Grid2BasicRNNCell(GridRNNCell): @@ -240,71 +347,112 @@ class Grid2BasicRNNCell(GridRNNCell): specified. """ - def __init__(self, num_units, tied=False, non_recurrent_fn=None): + def __init__(self, + num_units, + tied=False, + non_recurrent_fn=None, + state_is_tuple=True, + output_is_tuple=True): super(Grid2BasicRNNCell, self).__init__( - num_units=num_units, num_dims=2, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, + num_units=num_units, + num_dims=2, + input_dims=0, + output_dims=0, + priority_dims=0, + tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn.BasicRNNCell(num_units=n, input_size=i), - non_recurrent_fn=non_recurrent_fn) + cell_fn=lambda n: rnn.BasicRNNCell(num_units=n), + non_recurrent_fn=non_recurrent_fn, + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) class Grid1BasicLSTMCell(GridRNNCell): - """1D BasicLSTM cell""" + """1D BasicLSTM cell.""" - def __init__(self, num_units, forget_bias=1): + def __init__(self, + num_units, + forget_bias=1, + state_is_tuple=True, + output_is_tuple=True): + def cell_fn(n): + return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias) super(Grid1BasicLSTMCell, self).__init__( - num_units=num_units, num_dims=1, - input_dims=0, output_dims=0, priority_dims=0, tied=False, - cell_fn=lambda n, i: rnn.BasicLSTMCell( - num_units=n, - forget_bias=forget_bias, input_size=i, - state_is_tuple=False)) + num_units=num_units, + num_dims=1, + input_dims=0, + output_dims=0, + priority_dims=0, + tied=False, + cell_fn=cell_fn, + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) class Grid2BasicLSTMCell(GridRNNCell): - """2D BasicLSTM cell + """2D BasicLSTM cell. - This creates a 2D cell which receives input and gives output in the first - dimension. + This creates a 2D cell which receives input and gives output in the first + dimension. - The first dimension can optionally be non-recurrent if `non_recurrent_fn` is - specified. + The first dimension can optionally be non-recurrent if `non_recurrent_fn` is + specified. """ def __init__(self, num_units, tied=False, non_recurrent_fn=None, - forget_bias=1): + forget_bias=1, + state_is_tuple=True, + output_is_tuple=True): + def cell_fn(n): + return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias) super(Grid2BasicLSTMCell, self).__init__( - num_units=num_units, num_dims=2, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, + num_units=num_units, + num_dims=2, + input_dims=0, + output_dims=0, + priority_dims=0, + tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn.BasicLSTMCell( - num_units=n, forget_bias=forget_bias, input_size=i, - state_is_tuple=False), - non_recurrent_fn=non_recurrent_fn) + cell_fn=cell_fn, + non_recurrent_fn=non_recurrent_fn, + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) class Grid1LSTMCell(GridRNNCell): - """1D LSTM cell + """1D LSTM cell. - This is different from Grid1BasicLSTMCell because it gives options to - specify the forget bias and enabling peepholes + This is different from Grid1BasicLSTMCell because it gives options to + specify the forget bias and enabling peepholes. """ - def __init__(self, num_units, use_peepholes=False, forget_bias=1.0): + def __init__(self, + num_units, + use_peepholes=False, + forget_bias=1.0, + state_is_tuple=True, + output_is_tuple=True): + + def cell_fn(n): + return rnn.LSTMCell( + num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes) + super(Grid1LSTMCell, self).__init__( - num_units=num_units, num_dims=1, - input_dims=0, output_dims=0, priority_dims=0, - cell_fn=lambda n, i: rnn.LSTMCell( - num_units=n, input_size=i, use_peepholes=use_peepholes, - forget_bias=forget_bias, state_is_tuple=False)) + num_units=num_units, + num_dims=1, + input_dims=0, + output_dims=0, + priority_dims=0, + cell_fn=cell_fn, + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) class Grid2LSTMCell(GridRNNCell): - """2D LSTM cell + """2D LSTM cell. This creates a 2D cell which receives input and gives output in the first dimension. @@ -317,19 +465,30 @@ class Grid2LSTMCell(GridRNNCell): tied=False, non_recurrent_fn=None, use_peepholes=False, - forget_bias=1.0): + forget_bias=1.0, + state_is_tuple=True, + output_is_tuple=True): + + def cell_fn(n): + return rnn.LSTMCell( + num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes) + super(Grid2LSTMCell, self).__init__( - num_units=num_units, num_dims=2, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, + num_units=num_units, + num_dims=2, + input_dims=0, + output_dims=0, + priority_dims=0, + tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn.LSTMCell( - num_units=n, input_size=i, forget_bias=forget_bias, - use_peepholes=use_peepholes, state_is_tuple=False), - non_recurrent_fn=non_recurrent_fn) + cell_fn=cell_fn, + non_recurrent_fn=non_recurrent_fn, + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) class Grid3LSTMCell(GridRNNCell): - """3D BasicLSTM cell + """3D BasicLSTM cell. This creates a 2D cell which receives input and gives output in the first dimension. @@ -343,19 +502,30 @@ class Grid3LSTMCell(GridRNNCell): tied=False, non_recurrent_fn=None, use_peepholes=False, - forget_bias=1.0): + forget_bias=1.0, + state_is_tuple=True, + output_is_tuple=True): + + def cell_fn(n): + return rnn.LSTMCell( + num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes) + super(Grid3LSTMCell, self).__init__( - num_units=num_units, num_dims=3, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, + num_units=num_units, + num_dims=3, + input_dims=0, + output_dims=0, + priority_dims=0, + tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn.LSTMCell( - num_units=n, input_size=i, forget_bias=forget_bias, - use_peepholes=use_peepholes, state_is_tuple=False), - non_recurrent_fn=non_recurrent_fn) + cell_fn=cell_fn, + non_recurrent_fn=non_recurrent_fn, + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) class Grid2GRUCell(GridRNNCell): - """2D LSTM cell + """2D LSTM cell. This creates a 2D cell which receives input and gives output in the first dimension. @@ -363,21 +533,31 @@ class Grid2GRUCell(GridRNNCell): specified. """ - def __init__(self, num_units, tied=False, non_recurrent_fn=None): + def __init__(self, + num_units, + tied=False, + non_recurrent_fn=None, + state_is_tuple=True, + output_is_tuple=True): super(Grid2GRUCell, self).__init__( - num_units=num_units, num_dims=2, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, + num_units=num_units, + num_dims=2, + input_dims=0, + output_dims=0, + priority_dims=0, + tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn.GRUCell(num_units=n, input_size=i), - non_recurrent_fn=non_recurrent_fn) + cell_fn=lambda n: rnn.GRUCell(num_units=n), + non_recurrent_fn=non_recurrent_fn, + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) -"""Helpers -""" +# Helpers -_GridRNNDimension = namedtuple( - '_GridRNNDimension', - ['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn']) +_GridRNNDimension = namedtuple('_GridRNNDimension', [ + 'idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn' +]) _GridRNNConfig = namedtuple('_GridRNNConfig', ['num_dims', 'dims', 'inputs', 'outputs', @@ -387,7 +567,6 @@ _GridRNNConfig = namedtuple('_GridRNNConfig', def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, ls_non_recurrent_dims, non_recurrent_fn, tied, num_units): - def check_dim_list(ls): if ls is None: ls = [] @@ -412,8 +591,8 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, is_input=(i in input_dims), is_output=(i in output_dims), is_priority=(i in priority_dims), - non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else - None)) + non_recurrent_fn=non_recurrent_fn + if i in non_recurrent_dims else None)) return _GridRNNConfig( num_dims=num_dims, dims=rnn_dims, @@ -440,34 +619,40 @@ def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state, if conf.num_dims > 1: ls_cell_inputs = [None] * (conf.num_dims - 1) for d in conf.dims[:-1]: - ls_cell_inputs[d.idx] = new_output[d.idx] if new_output[ - d.idx] is not None else m_prev[d.idx] + if new_output[d.idx] is None: + ls_cell_inputs[d.idx] = m_prev[d.idx] + else: + ls_cell_inputs[d.idx] = new_output[d.idx] cell_inputs = array_ops.concat(ls_cell_inputs, 1) else: cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0], m_prev[0].dtype) - last_dim_output = new_output[-1] if new_output[-1] is not None else m_prev[-1] + last_dim_output = (new_output[-1] + if new_output[-1] is not None else m_prev[-1]) for i in dim_indices: d = conf.dims[i] if d.non_recurrent_fn: - linear_args = array_ops.concat( - [cell_inputs, last_dim_output], - 1) if conf.num_dims > 1 else last_dim_output + if conf.num_dims > 1: + linear_args = array_ops.concat([cell_inputs, last_dim_output], 1) + else: + linear_args = last_dim_output with vs.variable_scope('non_recurrent' if conf.tied else 'non_recurrent/cell_{}'.format(i)): if conf.tied and not (first_call and i == dim_indices[0]): vs.get_variable_scope().reuse_variables() - new_output[d.idx] = layers.legacy_fully_connected( + + new_output[d.idx] = layers.fully_connected( linear_args, - num_output_units=conf.num_units, + num_outputs=conf.num_units, activation_fn=d.non_recurrent_fn, - weight_init=vs.get_variable_scope().initializer or - layers.initializers.xavier_initializer) + weights_initializer=(vs.get_variable_scope().initializer or + layers.initializers.xavier_initializer), + weights_regularizer=vs.get_variable_scope().regularizer) else: if c_prev[i] is not None: - cell_state = array_ops.concat([c_prev[i], last_dim_output], 1) + cell_state = (c_prev[i], last_dim_output) else: # for GRU/RNN, the state is just the previous output cell_state = last_dim_output |