diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-02-03 10:11:58 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-03 10:28:32 -0800 |
commit | 1809c7a35794bbbc467a179d1fe15f01687bd506 (patch) | |
tree | 238ebd9768c6195580481766c8fd696d7fd23ef1 /tensorflow/contrib/grid_rnn | |
parent | b4d5181fe674a9ab3a9decd38db08314ddf6b5a0 (diff) |
GridRNN creates multiple internal cell instances - one for each dim.
As we start moving RNNCells to layer-like objects that keep state, it is
necessary for each layer to have its own instance. For the case that
tied=True, the same instance is copied across the cells list.
Change: 146486991
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r-- | tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py | 39 |
1 files changed, 26 insertions, 13 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py index b35d0df98c..269b224581 100644 --- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py +++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from collections import namedtuple +import functools from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -83,6 +84,9 @@ class GridRNNCell(rnn.RNNCell): default parameters will be used. non_recurrent_fn: a tensorflow Op that will be the transfer function of the non-recurrent dimensions + + Raises: + TypeError: if cell_fn does not return an RNNCell instance. """ if num_dims < 1: raise ValueError('dims must be >= 1: {}'.format(num_dims)) @@ -94,12 +98,20 @@ class GridRNNCell(rnn.RNNCell): cell_input_size = (self._config.num_dims - 1) * num_units if cell_fn is None: - self._cell = rnn.LSTMCell( - num_units=num_units, input_size=cell_input_size, state_is_tuple=False) + my_cell_fn = functools.partial( + rnn.LSTMCell, + num_units=num_units, input_size=cell_input_size, + state_is_tuple=False) + else: + my_cell_fn = lambda: cell_fn(num_units, cell_input_size) + if tied: + self._cells = [my_cell_fn()] * num_dims else: - self._cell = cell_fn(num_units, cell_input_size) - if not isinstance(self._cell, rnn.RNNCell): - raise ValueError('cell_fn must return an object of type RNNCell') + self._cells = [my_cell_fn() for _ in range(num_dims)] + if not isinstance(self._cells[0], rnn.RNNCell): + raise TypeError( + 'cell_fn must return an RNNCell instance, saw: %s' + % type(self._cells[0])) @property def input_size(self): @@ -110,11 +122,11 @@ class GridRNNCell(rnn.RNNCell): @property def output_size(self): - return self._cell.output_size * len(self._config.outputs) + return self._cells[0].output_size * len(self._config.outputs) @property def state_size(self): - return self._cell.state_size * len(self._config.recurrents) + return self._cells[0].state_size * len(self._config.recurrents) def __call__(self, inputs, state, scope=None): """Run one step of GridRNN. @@ -146,13 +158,13 @@ class GridRNNCell(rnn.RNNCell): # Keep c and m here for consistency with the codebase c_prev = [None] * self._config.num_dims m_prev = [None] * self._config.num_dims - cell_output_size = self._cell.state_size - conf.num_units + cell_output_size = self._cells[0].state_size - conf.num_units # for LSTM : state = memory cell + output, hence cell_output_size > 0 # for GRU/RNN: state = output (whose size is equal to _num_units), # hence cell_output_size = 0 for recurrent_dim, start_idx in zip(self._config.recurrents, range( - 0, self.state_size, self._cell.state_size)): + 0, self.state_size, self._cells[0].state_size)): if cell_output_size > 0: c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], [-1, conf.num_units]) @@ -185,10 +197,10 @@ class GridRNNCell(rnn.RNNCell): dtype=dtype) c_prev[j] = math_ops.matmul(input_splits[i], input_project_c) - _propagate(conf.non_priority, conf, self._cell, c_prev, m_prev, + _propagate(conf.non_priority, conf, self._cells, c_prev, m_prev, new_output, new_state, True) - _propagate(conf.priority, conf, self._cell, c_prev, m_prev, new_output, - new_state, False) + _propagate(conf.priority, conf, self._cells, + c_prev, m_prev, new_output, new_state, False) output_tensors = [new_output[i] for i in self._config.outputs] output = array_ops.zeros( @@ -414,7 +426,7 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, num_units=num_units) -def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state, +def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state, first_call): """Propagates through all the cells in dim_indices dimensions. """ @@ -464,4 +476,5 @@ def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state, 'recurrent/cell_{}'.format(i)): if conf.tied and not (first_call and i == dim_indices[0]): vs.get_variable_scope().reuse_variables() + cell = cells[i] new_output[d.idx], new_state[d.idx] = cell(cell_inputs, cell_state) |