aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/grid_rnn
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-02-03 10:11:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-03 10:28:32 -0800
commit1809c7a35794bbbc467a179d1fe15f01687bd506 (patch)
tree238ebd9768c6195580481766c8fd696d7fd23ef1 /tensorflow/contrib/grid_rnn
parentb4d5181fe674a9ab3a9decd38db08314ddf6b5a0 (diff)
GridRNN creates multiple internal cell instances - one for each dim.
As we start moving RNNCells to layer-like objects that keep state, it is necessary for each layer to have its own instance. For the case that tied=True, the same instance is copied across the cells list. Change: 146486991
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r--tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py39
1 files changed, 26 insertions, 13 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
index b35d0df98c..269b224581 100644
--- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
+++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from collections import namedtuple
+import functools
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -83,6 +84,9 @@ class GridRNNCell(rnn.RNNCell):
default parameters will be used.
non_recurrent_fn: a tensorflow Op that will be the transfer function of
the non-recurrent dimensions
+
+ Raises:
+ TypeError: if cell_fn does not return an RNNCell instance.
"""
if num_dims < 1:
raise ValueError('dims must be >= 1: {}'.format(num_dims))
@@ -94,12 +98,20 @@ class GridRNNCell(rnn.RNNCell):
cell_input_size = (self._config.num_dims - 1) * num_units
if cell_fn is None:
- self._cell = rnn.LSTMCell(
- num_units=num_units, input_size=cell_input_size, state_is_tuple=False)
+ my_cell_fn = functools.partial(
+ rnn.LSTMCell,
+ num_units=num_units, input_size=cell_input_size,
+ state_is_tuple=False)
+ else:
+ my_cell_fn = lambda: cell_fn(num_units, cell_input_size)
+ if tied:
+ self._cells = [my_cell_fn()] * num_dims
else:
- self._cell = cell_fn(num_units, cell_input_size)
- if not isinstance(self._cell, rnn.RNNCell):
- raise ValueError('cell_fn must return an object of type RNNCell')
+ self._cells = [my_cell_fn() for _ in range(num_dims)]
+ if not isinstance(self._cells[0], rnn.RNNCell):
+ raise TypeError(
+ 'cell_fn must return an RNNCell instance, saw: %s'
+ % type(self._cells[0]))
@property
def input_size(self):
@@ -110,11 +122,11 @@ class GridRNNCell(rnn.RNNCell):
@property
def output_size(self):
- return self._cell.output_size * len(self._config.outputs)
+ return self._cells[0].output_size * len(self._config.outputs)
@property
def state_size(self):
- return self._cell.state_size * len(self._config.recurrents)
+ return self._cells[0].state_size * len(self._config.recurrents)
def __call__(self, inputs, state, scope=None):
"""Run one step of GridRNN.
@@ -146,13 +158,13 @@ class GridRNNCell(rnn.RNNCell):
# Keep c and m here for consistency with the codebase
c_prev = [None] * self._config.num_dims
m_prev = [None] * self._config.num_dims
- cell_output_size = self._cell.state_size - conf.num_units
+ cell_output_size = self._cells[0].state_size - conf.num_units
# for LSTM : state = memory cell + output, hence cell_output_size > 0
# for GRU/RNN: state = output (whose size is equal to _num_units),
# hence cell_output_size = 0
for recurrent_dim, start_idx in zip(self._config.recurrents, range(
- 0, self.state_size, self._cell.state_size)):
+ 0, self.state_size, self._cells[0].state_size)):
if cell_output_size > 0:
c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
[-1, conf.num_units])
@@ -185,10 +197,10 @@ class GridRNNCell(rnn.RNNCell):
dtype=dtype)
c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
- _propagate(conf.non_priority, conf, self._cell, c_prev, m_prev,
+ _propagate(conf.non_priority, conf, self._cells, c_prev, m_prev,
new_output, new_state, True)
- _propagate(conf.priority, conf, self._cell, c_prev, m_prev, new_output,
- new_state, False)
+ _propagate(conf.priority, conf, self._cells,
+ c_prev, m_prev, new_output, new_state, False)
output_tensors = [new_output[i] for i in self._config.outputs]
output = array_ops.zeros(
@@ -414,7 +426,7 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
num_units=num_units)
-def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state,
+def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
first_call):
"""Propagates through all the cells in dim_indices dimensions.
"""
@@ -464,4 +476,5 @@ def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state,
'recurrent/cell_{}'.format(i)):
if conf.tied and not (first_call and i == dim_indices[0]):
vs.get_variable_scope().reuse_variables()
+ cell = cells[i]
new_output[d.idx], new_state[d.idx] = cell(cell_inputs, cell_state)