aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/recurrent.py
diff options
context:
space:
mode:
authorGravatar Scott Zhu <scottzhu@google.com>2018-08-09 15:08:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 15:13:14 -0700
commit9e5157c9340def527ba5eeff293c5f183cdbbfc0 (patch)
tree94a7d0abb1bd1307643c343fbf008a2a862a65e5 /tensorflow/python/keras/layers/recurrent.py
parent0980e844c115bdffbbdd7d993355633c48a6100e (diff)
Consolidate the RNN cell interface between Keras and TF RNN.
PiperOrigin-RevId: 208118371
Diffstat (limited to 'tensorflow/python/keras/layers/recurrent.py')
-rw-r--r--tensorflow/python/keras/layers/recurrent.py43
1 files changed, 30 insertions, 13 deletions
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index acc4ba37c0..66c68e2085 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -93,6 +93,13 @@ class StackedRNNCells(Layer):
state_size.append(cell.state_size)
return tuple(state_size)
+ @property
+ def output_size(self):
+ if hasattr(self.cells[-1], 'output_size'):
+ return self.cells[-1].output_size
+ else:
+ return self.state_size[0]
+
def call(self, inputs, states, constants=None, **kwargs):
# Recover per-cell states.
nested_states = []
@@ -244,13 +251,16 @@ class RNN(Layer):
cell can also take the optional argument `constants`, see
section "Note on passing external constants" below.
- a `state_size` attribute. This can be a single integer
- (single state) in which case it is the size of the recurrent state
- (which should be the same as the size of the cell output).
- This can also be a list/tuple of integers (one size per state).
- In this case, the first entry (`state_size[0]`) should be the same
- as the size of the cell output.
+ (single state) in which case it is the size of the recurrent
+ state. This can also be a list/tuple of integers (one size per
+ state).
The `state_size` can also be TensorShape or tuple/list of
TensorShape, to represent high dimension state.
+ - a `output_size` attribute. This can be a single integer or a
+ TensorShape, which represent the shape of the output. For backward
+ compatible reason, if this attribute is not available for the
+ cell, the value will be inferred by the first element of the
+ `state_size`.
In the case that `cell` is a list of RNN cell instances, the cells
will be stacked on after the other in the RNN, implementing an
efficient stacked RNN.
@@ -289,13 +299,13 @@ class RNN(Layer):
Output shape:
- if `return_state`: a list of tensors. The first tensor is
the output. The remaining tensors are the last states,
- each with shape `(batch_size, ...)`, where `...` is in the shape of
- `state_size`.
+ each with shape `(batch_size, state_size)`, where `state_size` could
+ be a high dimension tensor shape.
- if `return_sequences`: N-D tensor with shape
- `(batch_size, timesteps, ...)`, where `...` is in the shape of output
- size.
- - else, N-D tensor with shape `(batch_size, ...)`, where `...` is in the
- shape of output size.
+ `(batch_size, timesteps, output_size)`, where `output_size` could
+ be a high dimension tensor shape.
+ - else, N-D tensor with shape `(batch_size, output_size)`, where
+ `output_size` could be a high dimension tensor shape.
# Masking
This layer supports masking for input data with a variable number
@@ -442,8 +452,12 @@ class RNN(Layer):
state_size = self.cell.state_size
else:
state_size = [self.cell.state_size]
- # Note that state_size[0] could be a tensor_shape or int.
- output_dim = tensor_shape.as_shape(state_size[0]).as_list()
+
+ if hasattr(self.cell, 'output_size'):
+ output_dim = tensor_shape.as_shape(self.cell.output_size).as_list()
+ else:
+ # Note that state_size[0] could be a tensor_shape or int.
+ output_dim = tensor_shape.as_shape(state_size[0]).as_list()
if self.return_sequences:
output_shape = tuple([input_shape[0], input_shape[1]] + output_dim)
@@ -893,6 +907,7 @@ class SimpleRNNCell(Layer):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.state_size = self.units
+ self.output_size = self.units
self._dropout_mask = None
self._recurrent_dropout_mask = None
@@ -1296,6 +1311,7 @@ class GRUCell(Layer):
self.implementation = implementation
self.reset_after = reset_after
self.state_size = self.units
+ self.output_size = self.units
self._dropout_mask = None
self._recurrent_dropout_mask = None
@@ -1841,6 +1857,7 @@ class LSTMCell(Layer):
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.implementation = implementation
self.state_size = (self.units, self.units)
+ self.output_size = self.units
self._dropout_mask = None
self._recurrent_dropout_mask = None