aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py')
-rw-r--r--tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py497
1 files changed, 341 insertions, 156 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
index 269b224581..252788140f 100644
--- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
+++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
@@ -25,6 +25,8 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope as vs
+
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.contrib import layers
from tensorflow.contrib import rnn
@@ -53,7 +55,9 @@ class GridRNNCell(rnn.RNNCell):
non_recurrent_dims=None,
tied=False,
cell_fn=None,
- non_recurrent_fn=None):
+ non_recurrent_fn=None,
+ state_is_tuple=True,
+ output_is_tuple=True):
"""Initialize the parameters of a Grid RNN cell
Args:
@@ -68,26 +72,47 @@ class GridRNNCell(rnn.RNNCell):
non_recurrent_dims: int or list, List of dimensions that are not
recurrent.
The transfer function for non-recurrent dimensions is specified
- via `non_recurrent_fn`,
- which is default to be `tensorflow.nn.relu`.
+ via `non_recurrent_fn`, which is
+ default to be `tensorflow.nn.relu`.
tied: bool, Whether to share the weights among the dimensions of this
GridRNN cell.
If there are non-recurrent dimensions in the grid, weights are
- shared between each
- group of recurrent and non-recurrent dimensions.
- cell_fn: function, a function which returns the recurrent cell object. Has
- to be in the following signature:
- def cell_func(num_units, input_size):
+ shared between each group of recurrent and non-recurrent
+ dimensions.
+ cell_fn: function, a function which returns the recurrent cell object.
+ Has to be in the following signature:
+ ```
+ def cell_func(num_units):
# ...
-
+ ```
and returns an object of type `RNNCell`. If None, LSTMCell with
default parameters will be used.
+ Note that if you use a custom RNNCell (with `cell_fn`), it is your
+ responsibility to make sure the inner cell use `state_is_tuple=True`.
+
non_recurrent_fn: a tensorflow Op that will be the transfer function of
the non-recurrent dimensions
+ state_is_tuple: If True, accepted and returned states are tuples of the
+ states of the recurrent dimensions. If False, they are concatenated
+ along the column axis. The latter behavior will soon be deprecated.
+
+ Note that if you use a custom RNNCell (with `cell_fn`), it is your
+ responsibility to make sure the inner cell use `state_is_tuple=True`.
+
+ output_is_tuple: If True, the output is a tuple of the outputs of the
+ recurrent dimensions. If False, they are concatenated along the
+ column axis. The later behavior will soon be deprecated.
Raises:
TypeError: if cell_fn does not return an RNNCell instance.
"""
+ if not state_is_tuple:
+ logging.warning('%s: Using a concatenated state is slower and will '
+ 'soon be deprecated. Use state_is_tuple=True.', self)
+ if not output_is_tuple:
+ logging.warning('%s: Using a concatenated output is slower and will'
+ 'soon be deprecated. Use output_is_tuple=True.', self)
+
if num_dims < 1:
raise ValueError('dims must be >= 1: {}'.format(num_dims))
@@ -96,37 +121,41 @@ class GridRNNCell(rnn.RNNCell):
non_recurrent_fn or nn.relu, tied,
num_units)
- cell_input_size = (self._config.num_dims - 1) * num_units
+ self._state_is_tuple = state_is_tuple
+ self._output_is_tuple = output_is_tuple
+
if cell_fn is None:
my_cell_fn = functools.partial(
- rnn.LSTMCell,
- num_units=num_units, input_size=cell_input_size,
- state_is_tuple=False)
+ rnn.LSTMCell, num_units=num_units, state_is_tuple=state_is_tuple)
else:
- my_cell_fn = lambda: cell_fn(num_units, cell_input_size)
+ my_cell_fn = lambda: cell_fn(num_units)
if tied:
self._cells = [my_cell_fn()] * num_dims
else:
self._cells = [my_cell_fn() for _ in range(num_dims)]
if not isinstance(self._cells[0], rnn.RNNCell):
- raise TypeError(
- 'cell_fn must return an RNNCell instance, saw: %s'
- % type(self._cells[0]))
+ raise TypeError('cell_fn must return an RNNCell instance, saw: %s' %
+ type(self._cells[0]))
- @property
- def input_size(self):
- # temporarily using num_units as the input_size of each dimension.
- # The actual input size only determined when this cell get invoked,
- # so this information can be considered unreliable.
- return self._config.num_units * len(self._config.inputs)
+ if self._output_is_tuple:
+ self._output_size = tuple(self._cells[0].output_size
+ for _ in self._config.outputs)
+ else:
+ self._output_size = self._cells[0].output_size * len(self._config.outputs)
+
+ if self._state_is_tuple:
+ self._state_size = tuple(self._cells[0].state_size
+ for _ in self._config.recurrents)
+ else:
+ self._state_size = self._cell_state_size() * len(self._config.recurrents)
@property
def output_size(self):
- return self._cells[0].output_size * len(self._config.outputs)
+ return self._output_size
@property
def state_size(self):
- return self._cells[0].state_size * len(self._config.recurrents)
+ return self._state_size
def __call__(self, inputs, state, scope=None):
"""Run one step of GridRNN.
@@ -145,76 +174,148 @@ class GridRNNCell(rnn.RNNCell):
- A 2D, batch x state_size, Tensor representing the new state of the cell
after reading "inputs" when previous state was "state".
"""
- state_sz = state.get_shape().as_list()[1]
- if self.state_size != state_sz:
- raise ValueError(
- 'Actual state size not same as specified: {} vs {}.'.format(
- state_sz, self.state_size))
-
conf = self._config
- dtype = inputs.dtype if inputs is not None else state.dtype
+ dtype = inputs.dtype
- # c_prev is `m`, and m_prev is `h` in the paper.
- # Keep c and m here for consistency with the codebase
- c_prev = [None] * self._config.num_dims
- m_prev = [None] * self._config.num_dims
- cell_output_size = self._cells[0].state_size - conf.num_units
-
- # for LSTM : state = memory cell + output, hence cell_output_size > 0
- # for GRU/RNN: state = output (whose size is equal to _num_units),
- # hence cell_output_size = 0
- for recurrent_dim, start_idx in zip(self._config.recurrents, range(
- 0, self.state_size, self._cells[0].state_size)):
- if cell_output_size > 0:
- c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
- [-1, conf.num_units])
- m_prev[recurrent_dim] = array_ops.slice(
- state, [0, start_idx + conf.num_units], [-1, cell_output_size])
- else:
- m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
- [-1, conf.num_units])
+ c_prev, m_prev, cell_output_size = self._extract_states(state)
new_output = [None] * conf.num_dims
new_state = [None] * conf.num_dims
with vs.variable_scope(scope or type(self).__name__): # GridRNNCell
+ # project input, populate c_prev and m_prev
+ self._project_input(inputs, c_prev, m_prev, cell_output_size > 0)
- # project input
- if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len(
- conf.inputs) > 0:
- input_splits = array_ops.split(
- value=inputs, num_or_size_splits=len(conf.inputs), axis=1)
- input_sz = input_splits[0].get_shape().as_list()[1]
-
- for i, j in enumerate(conf.inputs):
- input_project_m = vs.get_variable(
- 'project_m_{}'.format(j), [input_sz, conf.num_units], dtype=dtype)
- m_prev[j] = math_ops.matmul(input_splits[i], input_project_m)
-
- if cell_output_size > 0:
- input_project_c = vs.get_variable(
- 'project_c_{}'.format(j), [input_sz, conf.num_units],
- dtype=dtype)
- c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
-
+ # propagate along dimensions, first for non-priority dimensions
+ # then priority dimensions
_propagate(conf.non_priority, conf, self._cells, c_prev, m_prev,
new_output, new_state, True)
_propagate(conf.priority, conf, self._cells,
c_prev, m_prev, new_output, new_state, False)
+ # collect outputs and states
output_tensors = [new_output[i] for i in self._config.outputs]
- output = array_ops.zeros(
- [0, 0], dtype) if len(output_tensors) == 0 else array_ops.concat(
- output_tensors, 1)
+ if self._output_is_tuple:
+ output = tuple(output_tensors)
+ else:
+ if output_tensors:
+ output = array_ops.concat(output_tensors, 1)
+ else:
+ output = array_ops.zeros([0, 0], dtype)
- state_tensors = [new_state[i] for i in self._config.recurrents]
- states = array_ops.zeros(
- [0, 0],
- dtype) if len(state_tensors) == 0 else array_ops.concat(state_tensors,
- 1)
+ if self._state_is_tuple:
+ states = tuple(new_state[i] for i in self._config.recurrents)
+ else:
+ # concat each state first, then flatten the whole thing
+ state_tensors = [
+ x for i in self._config.recurrents for x in new_state[i]
+ ]
+ if state_tensors:
+ states = array_ops.concat(state_tensors, 1)
+ else:
+ states = array_ops.zeros([0, 0], dtype)
return output, states
+ def _extract_states(self, state):
+ """Extract the cell and previous output tensors from the given state.
+
+ Args:
+ state: The RNN state.
+
+ Returns:
+ Tuple of the cell value, previous output, and cell_output_size.
+
+ Raises:
+ ValueError: If len(self._config.recurrents) != len(state).
+ """
+ conf = self._config
+
+ # c_prev is `m` (cell value), and
+ # m_prev is `h` (previous output) in the paper.
+ # Keeping c and m here for consistency with the codebase
+ c_prev = [None] * conf.num_dims
+ m_prev = [None] * conf.num_dims
+
+ # for LSTM : state = memory cell + output, hence cell_output_size > 0
+ # for GRU/RNN: state = output (whose size is equal to _num_units),
+ # hence cell_output_size = 0
+ total_cell_state_size = self._cell_state_size()
+ cell_output_size = total_cell_state_size - conf.num_units
+
+ if self._state_is_tuple:
+ if len(conf.recurrents) != len(state):
+ raise ValueError('Expected state as a tuple of {} '
+ 'element'.format(len(conf.recurrents)))
+
+ for recurrent_dim, recurrent_state in zip(conf.recurrents, state):
+ if cell_output_size > 0:
+ c_prev[recurrent_dim], m_prev[recurrent_dim] = recurrent_state
+ else:
+ m_prev[recurrent_dim] = recurrent_state
+ else:
+ for recurrent_dim, start_idx in zip(conf.recurrents,
+ range(0, self.state_size,
+ total_cell_state_size)):
+ if cell_output_size > 0:
+ c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
+ [-1, conf.num_units])
+ m_prev[recurrent_dim] = array_ops.slice(
+ state, [0, start_idx + conf.num_units], [-1, cell_output_size])
+ else:
+ m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
+ [-1, conf.num_units])
+ return c_prev, m_prev, cell_output_size
+
+ def _project_input(self, inputs, c_prev, m_prev, with_c):
+ """Fills in c_prev and m_prev with projected input, for input dimensions.
+
+ Args:
+ inputs: inputs tensor
+ c_prev: cell value
+ m_prev: previous output
+ with_c: boolean; whether to include project_c.
+
+ Raises:
+ ValueError: if len(self._config.input) != len(inputs)
+ """
+ conf = self._config
+
+ if (inputs is not None and inputs.get_shape().with_rank(2)[1].value > 0 and
+ conf.inputs):
+ if isinstance(inputs, tuple):
+ if len(conf.inputs) != len(inputs):
+ raise ValueError('Expect inputs as a tuple of {} '
+ 'tensors'.format(len(conf.inputs)))
+ input_splits = inputs
+ else:
+ input_splits = array_ops.split(
+ value=inputs, num_or_size_splits=len(conf.inputs), axis=1)
+ input_sz = input_splits[0].get_shape().with_rank(2)[1].value
+
+ for i, j in enumerate(conf.inputs):
+ input_project_m = vs.get_variable(
+ 'project_m_{}'.format(j), [input_sz, conf.num_units],
+ dtype=inputs.dtype)
+ m_prev[j] = math_ops.matmul(input_splits[i], input_project_m)
+
+ if with_c:
+ input_project_c = vs.get_variable(
+ 'project_c_{}'.format(j), [input_sz, conf.num_units],
+ dtype=inputs.dtype)
+ c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
+
+ def _cell_state_size(self):
+ """Total size of the state of the inner cell used in this grid.
+
+ Returns:
+ Total size of the state of the inner cell.
+ """
+ state_sizes = self._cells[0].state_size
+ if isinstance(state_sizes, tuple):
+ return sum(state_sizes)
+ return state_sizes
+
"""Specialized cells, for convenience
"""
@@ -223,11 +324,17 @@ class GridRNNCell(rnn.RNNCell):
class Grid1BasicRNNCell(GridRNNCell):
"""1D BasicRNN cell"""
- def __init__(self, num_units):
+ def __init__(self, num_units, state_is_tuple=True, output_is_tuple=True):
super(Grid1BasicRNNCell, self).__init__(
- num_units=num_units, num_dims=1,
- input_dims=0, output_dims=0, priority_dims=0, tied=False,
- cell_fn=lambda n, i: rnn.BasicRNNCell(num_units=n, input_size=i))
+ num_units=num_units,
+ num_dims=1,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=False,
+ cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid2BasicRNNCell(GridRNNCell):
@@ -240,71 +347,112 @@ class Grid2BasicRNNCell(GridRNNCell):
specified.
"""
- def __init__(self, num_units, tied=False, non_recurrent_fn=None):
+ def __init__(self,
+ num_units,
+ tied=False,
+ non_recurrent_fn=None,
+ state_is_tuple=True,
+ output_is_tuple=True):
super(Grid2BasicRNNCell, self).__init__(
- num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ num_units=num_units,
+ num_dims=2,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn.BasicRNNCell(num_units=n, input_size=i),
- non_recurrent_fn=non_recurrent_fn)
+ cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid1BasicLSTMCell(GridRNNCell):
- """1D BasicLSTM cell"""
+ """1D BasicLSTM cell."""
- def __init__(self, num_units, forget_bias=1):
+ def __init__(self,
+ num_units,
+ forget_bias=1,
+ state_is_tuple=True,
+ output_is_tuple=True):
+ def cell_fn(n):
+ return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
super(Grid1BasicLSTMCell, self).__init__(
- num_units=num_units, num_dims=1,
- input_dims=0, output_dims=0, priority_dims=0, tied=False,
- cell_fn=lambda n, i: rnn.BasicLSTMCell(
- num_units=n,
- forget_bias=forget_bias, input_size=i,
- state_is_tuple=False))
+ num_units=num_units,
+ num_dims=1,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=False,
+ cell_fn=cell_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid2BasicLSTMCell(GridRNNCell):
- """2D BasicLSTM cell
+ """2D BasicLSTM cell.
- This creates a 2D cell which receives input and gives output in the first
- dimension.
+ This creates a 2D cell which receives input and gives output in the first
+ dimension.
- The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
- specified.
+ The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
+ specified.
"""
def __init__(self,
num_units,
tied=False,
non_recurrent_fn=None,
- forget_bias=1):
+ forget_bias=1,
+ state_is_tuple=True,
+ output_is_tuple=True):
+ def cell_fn(n):
+ return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
super(Grid2BasicLSTMCell, self).__init__(
- num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ num_units=num_units,
+ num_dims=2,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn.BasicLSTMCell(
- num_units=n, forget_bias=forget_bias, input_size=i,
- state_is_tuple=False),
- non_recurrent_fn=non_recurrent_fn)
+ cell_fn=cell_fn,
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid1LSTMCell(GridRNNCell):
- """1D LSTM cell
+ """1D LSTM cell.
- This is different from Grid1BasicLSTMCell because it gives options to
- specify the forget bias and enabling peepholes
+ This is different from Grid1BasicLSTMCell because it gives options to
+ specify the forget bias and enabling peepholes.
"""
- def __init__(self, num_units, use_peepholes=False, forget_bias=1.0):
+ def __init__(self,
+ num_units,
+ use_peepholes=False,
+ forget_bias=1.0,
+ state_is_tuple=True,
+ output_is_tuple=True):
+
+ def cell_fn(n):
+ return rnn.LSTMCell(
+ num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
+
super(Grid1LSTMCell, self).__init__(
- num_units=num_units, num_dims=1,
- input_dims=0, output_dims=0, priority_dims=0,
- cell_fn=lambda n, i: rnn.LSTMCell(
- num_units=n, input_size=i, use_peepholes=use_peepholes,
- forget_bias=forget_bias, state_is_tuple=False))
+ num_units=num_units,
+ num_dims=1,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ cell_fn=cell_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid2LSTMCell(GridRNNCell):
- """2D LSTM cell
+ """2D LSTM cell.
This creates a 2D cell which receives input and gives output in the first
dimension.
@@ -317,19 +465,30 @@ class Grid2LSTMCell(GridRNNCell):
tied=False,
non_recurrent_fn=None,
use_peepholes=False,
- forget_bias=1.0):
+ forget_bias=1.0,
+ state_is_tuple=True,
+ output_is_tuple=True):
+
+ def cell_fn(n):
+ return rnn.LSTMCell(
+ num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
+
super(Grid2LSTMCell, self).__init__(
- num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ num_units=num_units,
+ num_dims=2,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn.LSTMCell(
- num_units=n, input_size=i, forget_bias=forget_bias,
- use_peepholes=use_peepholes, state_is_tuple=False),
- non_recurrent_fn=non_recurrent_fn)
+ cell_fn=cell_fn,
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid3LSTMCell(GridRNNCell):
- """3D BasicLSTM cell
+ """3D BasicLSTM cell.
This creates a 2D cell which receives input and gives output in the first
dimension.
@@ -343,19 +502,30 @@ class Grid3LSTMCell(GridRNNCell):
tied=False,
non_recurrent_fn=None,
use_peepholes=False,
- forget_bias=1.0):
+ forget_bias=1.0,
+ state_is_tuple=True,
+ output_is_tuple=True):
+
+ def cell_fn(n):
+ return rnn.LSTMCell(
+ num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
+
super(Grid3LSTMCell, self).__init__(
- num_units=num_units, num_dims=3,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ num_units=num_units,
+ num_dims=3,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn.LSTMCell(
- num_units=n, input_size=i, forget_bias=forget_bias,
- use_peepholes=use_peepholes, state_is_tuple=False),
- non_recurrent_fn=non_recurrent_fn)
+ cell_fn=cell_fn,
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid2GRUCell(GridRNNCell):
- """2D LSTM cell
+ """2D LSTM cell.
This creates a 2D cell which receives input and gives output in the first
dimension.
@@ -363,21 +533,31 @@ class Grid2GRUCell(GridRNNCell):
specified.
"""
- def __init__(self, num_units, tied=False, non_recurrent_fn=None):
+ def __init__(self,
+ num_units,
+ tied=False,
+ non_recurrent_fn=None,
+ state_is_tuple=True,
+ output_is_tuple=True):
super(Grid2GRUCell, self).__init__(
- num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ num_units=num_units,
+ num_dims=2,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn.GRUCell(num_units=n, input_size=i),
- non_recurrent_fn=non_recurrent_fn)
+ cell_fn=lambda n: rnn.GRUCell(num_units=n),
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
-"""Helpers
-"""
+# Helpers
-_GridRNNDimension = namedtuple(
- '_GridRNNDimension',
- ['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'])
+_GridRNNDimension = namedtuple('_GridRNNDimension', [
+ 'idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'
+])
_GridRNNConfig = namedtuple('_GridRNNConfig',
['num_dims', 'dims', 'inputs', 'outputs',
@@ -387,7 +567,6 @@ _GridRNNConfig = namedtuple('_GridRNNConfig',
def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
ls_non_recurrent_dims, non_recurrent_fn, tied, num_units):
-
def check_dim_list(ls):
if ls is None:
ls = []
@@ -412,8 +591,8 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
is_input=(i in input_dims),
is_output=(i in output_dims),
is_priority=(i in priority_dims),
- non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else
- None))
+ non_recurrent_fn=non_recurrent_fn
+ if i in non_recurrent_dims else None))
return _GridRNNConfig(
num_dims=num_dims,
dims=rnn_dims,
@@ -440,34 +619,40 @@ def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
if conf.num_dims > 1:
ls_cell_inputs = [None] * (conf.num_dims - 1)
for d in conf.dims[:-1]:
- ls_cell_inputs[d.idx] = new_output[d.idx] if new_output[
- d.idx] is not None else m_prev[d.idx]
+ if new_output[d.idx] is None:
+ ls_cell_inputs[d.idx] = m_prev[d.idx]
+ else:
+ ls_cell_inputs[d.idx] = new_output[d.idx]
cell_inputs = array_ops.concat(ls_cell_inputs, 1)
else:
cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0],
m_prev[0].dtype)
- last_dim_output = new_output[-1] if new_output[-1] is not None else m_prev[-1]
+ last_dim_output = (new_output[-1]
+ if new_output[-1] is not None else m_prev[-1])
for i in dim_indices:
d = conf.dims[i]
if d.non_recurrent_fn:
- linear_args = array_ops.concat(
- [cell_inputs, last_dim_output],
- 1) if conf.num_dims > 1 else last_dim_output
+ if conf.num_dims > 1:
+ linear_args = array_ops.concat([cell_inputs, last_dim_output], 1)
+ else:
+ linear_args = last_dim_output
with vs.variable_scope('non_recurrent' if conf.tied else
'non_recurrent/cell_{}'.format(i)):
if conf.tied and not (first_call and i == dim_indices[0]):
vs.get_variable_scope().reuse_variables()
- new_output[d.idx] = layers.legacy_fully_connected(
+
+ new_output[d.idx] = layers.fully_connected(
linear_args,
- num_output_units=conf.num_units,
+ num_outputs=conf.num_units,
activation_fn=d.non_recurrent_fn,
- weight_init=vs.get_variable_scope().initializer or
- layers.initializers.xavier_initializer)
+ weights_initializer=(vs.get_variable_scope().initializer or
+ layers.initializers.xavier_initializer),
+ weights_regularizer=vs.get_variable_scope().regularizer)
else:
if c_prev[i] is not None:
- cell_state = array_ops.concat([c_prev[i], last_dim_output], 1)
+ cell_state = (c_prev[i], last_dim_output)
else:
# for GRU/RNN, the state is just the previous output
cell_state = last_dim_output