diff options
author | Scott Zhu <scottzhu@google.com> | 2018-08-09 15:08:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-09 15:13:14 -0700 |
commit | 9e5157c9340def527ba5eeff293c5f183cdbbfc0 (patch) | |
tree | 94a7d0abb1bd1307643c343fbf008a2a862a65e5 /tensorflow/python/keras/layers/recurrent.py | |
parent | 0980e844c115bdffbbdd7d993355633c48a6100e (diff) |
Consolidate the RNN cell interface between Keras and TF RNN.
PiperOrigin-RevId: 208118371
Diffstat (limited to 'tensorflow/python/keras/layers/recurrent.py')
-rw-r--r-- | tensorflow/python/keras/layers/recurrent.py | 43 |
1 files changed, 30 insertions, 13 deletions
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index acc4ba37c0..66c68e2085 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -93,6 +93,13 @@ class StackedRNNCells(Layer): state_size.append(cell.state_size) return tuple(state_size) + @property + def output_size(self): + if hasattr(self.cells[-1], 'output_size'): + return self.cells[-1].output_size + else: + return self.state_size[0] + def call(self, inputs, states, constants=None, **kwargs): # Recover per-cell states. nested_states = [] @@ -244,13 +251,16 @@ class RNN(Layer): cell can also take the optional argument `constants`, see section "Note on passing external constants" below. - a `state_size` attribute. This can be a single integer - (single state) in which case it is the size of the recurrent state - (which should be the same as the size of the cell output). - This can also be a list/tuple of integers (one size per state). - In this case, the first entry (`state_size[0]`) should be the same - as the size of the cell output. + (single state) in which case it is the size of the recurrent + state. This can also be a list/tuple of integers (one size per + state). The `state_size` can also be TensorShape or tuple/list of TensorShape, to represent high dimension state. + - a `output_size` attribute. This can be a single integer or a + TensorShape, which represent the shape of the output. For backward + compatible reason, if this attribute is not available for the + cell, the value will be inferred by the first element of the + `state_size`. In the case that `cell` is a list of RNN cell instances, the cells will be stacked on after the other in the RNN, implementing an efficient stacked RNN. @@ -289,13 +299,13 @@ class RNN(Layer): Output shape: - if `return_state`: a list of tensors. The first tensor is the output. The remaining tensors are the last states, - each with shape `(batch_size, ...)`, where `...` is in the shape of - `state_size`. + each with shape `(batch_size, state_size)`, where `state_size` could + be a high dimension tensor shape. - if `return_sequences`: N-D tensor with shape - `(batch_size, timesteps, ...)`, where `...` is in the shape of output - size. - - else, N-D tensor with shape `(batch_size, ...)`, where `...` is in the - shape of output size. + `(batch_size, timesteps, output_size)`, where `output_size` could + be a high dimension tensor shape. + - else, N-D tensor with shape `(batch_size, output_size)`, where + `output_size` could be a high dimension tensor shape. # Masking This layer supports masking for input data with a variable number @@ -442,8 +452,12 @@ class RNN(Layer): state_size = self.cell.state_size else: state_size = [self.cell.state_size] - # Note that state_size[0] could be a tensor_shape or int. - output_dim = tensor_shape.as_shape(state_size[0]).as_list() + + if hasattr(self.cell, 'output_size'): + output_dim = tensor_shape.as_shape(self.cell.output_size).as_list() + else: + # Note that state_size[0] could be a tensor_shape or int. + output_dim = tensor_shape.as_shape(state_size[0]).as_list() if self.return_sequences: output_shape = tuple([input_shape[0], input_shape[1]] + output_dim) @@ -893,6 +907,7 @@ class SimpleRNNCell(Layer): self.dropout = min(1., max(0., dropout)) self.recurrent_dropout = min(1., max(0., recurrent_dropout)) self.state_size = self.units + self.output_size = self.units self._dropout_mask = None self._recurrent_dropout_mask = None @@ -1296,6 +1311,7 @@ class GRUCell(Layer): self.implementation = implementation self.reset_after = reset_after self.state_size = self.units + self.output_size = self.units self._dropout_mask = None self._recurrent_dropout_mask = None @@ -1841,6 +1857,7 @@ class LSTMCell(Layer): self.recurrent_dropout = min(1., max(0., recurrent_dropout)) self.implementation = implementation self.state_size = (self.units, self.units) + self.output_size = self.units self._dropout_mask = None self._recurrent_dropout_mask = None |