diff options
author | Scott Zhu <scottzhu@google.com> | 2018-08-09 15:08:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-09 15:13:14 -0700 |
commit | 9e5157c9340def527ba5eeff293c5f183cdbbfc0 (patch) | |
tree | 94a7d0abb1bd1307643c343fbf008a2a862a65e5 | |
parent | 0980e844c115bdffbbdd7d993355633c48a6100e (diff) |
Consolidate the RNN cell interface between Keras and TF RNN.
PiperOrigin-RevId: 208118371
3 files changed, 75 insertions, 13 deletions
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index acc4ba37c0..66c68e2085 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -93,6 +93,13 @@ class StackedRNNCells(Layer): state_size.append(cell.state_size) return tuple(state_size) + @property + def output_size(self): + if hasattr(self.cells[-1], 'output_size'): + return self.cells[-1].output_size + else: + return self.state_size[0] + def call(self, inputs, states, constants=None, **kwargs): # Recover per-cell states. nested_states = [] @@ -244,13 +251,16 @@ class RNN(Layer): cell can also take the optional argument `constants`, see section "Note on passing external constants" below. - a `state_size` attribute. This can be a single integer - (single state) in which case it is the size of the recurrent state - (which should be the same as the size of the cell output). - This can also be a list/tuple of integers (one size per state). - In this case, the first entry (`state_size[0]`) should be the same - as the size of the cell output. + (single state) in which case it is the size of the recurrent + state. This can also be a list/tuple of integers (one size per + state). The `state_size` can also be TensorShape or tuple/list of TensorShape, to represent high dimension state. + - a `output_size` attribute. This can be a single integer or a + TensorShape, which represent the shape of the output. For backward + compatible reason, if this attribute is not available for the + cell, the value will be inferred by the first element of the + `state_size`. In the case that `cell` is a list of RNN cell instances, the cells will be stacked on after the other in the RNN, implementing an efficient stacked RNN. @@ -289,13 +299,13 @@ class RNN(Layer): Output shape: - if `return_state`: a list of tensors. The first tensor is the output. The remaining tensors are the last states, - each with shape `(batch_size, ...)`, where `...` is in the shape of - `state_size`. + each with shape `(batch_size, state_size)`, where `state_size` could + be a high dimension tensor shape. - if `return_sequences`: N-D tensor with shape - `(batch_size, timesteps, ...)`, where `...` is in the shape of output - size. - - else, N-D tensor with shape `(batch_size, ...)`, where `...` is in the - shape of output size. + `(batch_size, timesteps, output_size)`, where `output_size` could + be a high dimension tensor shape. + - else, N-D tensor with shape `(batch_size, output_size)`, where + `output_size` could be a high dimension tensor shape. # Masking This layer supports masking for input data with a variable number @@ -442,8 +452,12 @@ class RNN(Layer): state_size = self.cell.state_size else: state_size = [self.cell.state_size] - # Note that state_size[0] could be a tensor_shape or int. - output_dim = tensor_shape.as_shape(state_size[0]).as_list() + + if hasattr(self.cell, 'output_size'): + output_dim = tensor_shape.as_shape(self.cell.output_size).as_list() + else: + # Note that state_size[0] could be a tensor_shape or int. + output_dim = tensor_shape.as_shape(state_size[0]).as_list() if self.return_sequences: output_shape = tuple([input_shape[0], input_shape[1]] + output_dim) @@ -893,6 +907,7 @@ class SimpleRNNCell(Layer): self.dropout = min(1., max(0., dropout)) self.recurrent_dropout = min(1., max(0., recurrent_dropout)) self.state_size = self.units + self.output_size = self.units self._dropout_mask = None self._recurrent_dropout_mask = None @@ -1296,6 +1311,7 @@ class GRUCell(Layer): self.implementation = implementation self.reset_after = reset_after self.state_size = self.units + self.output_size = self.units self._dropout_mask = None self._recurrent_dropout_mask = None @@ -1841,6 +1857,7 @@ class LSTMCell(Layer): self.recurrent_dropout = min(1., max(0., recurrent_dropout)) self.implementation = implementation self.state_size = (self.units, self.units) + self.output_size = self.units self._dropout_mask = None self._recurrent_dropout_mask = None diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py index 9be439ea14..13bd070528 100644 --- a/tensorflow/python/keras/layers/recurrent_test.py +++ b/tensorflow/python/keras/layers/recurrent_test.py @@ -654,6 +654,30 @@ class RNNTest(test.TestCase): 'however `cell.state_size` is'): layer(x, initial_state=s) + def test_inconsistent_output_state_size(self): + with self.test_session(): + batch = 32 + time_step = 4 + state_size = 5 + input_size = 6 + cell = PlusOneRNNCell(state_size) + x = keras.Input((None, input_size)) + layer = keras.layers.RNN(cell) + y = layer(x) + + self.assertEqual(cell.state_size, state_size) + init_state = layer.get_initial_state(x) + self.assertEqual(len(init_state), 1) + self.assertEqual(init_state[0].get_shape().as_list(), + [None, state_size]) + + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + np.zeros((batch, time_step, input_size)), + np.zeros((batch, input_size))) + self.assertEqual(model.output_shape, (None, input_size)) + class Minimal2DRNNCell(keras.layers.Layer): """The minimal 2D RNN cell is a simple combination of 2 1-D RNN cell. @@ -666,6 +690,7 @@ class Minimal2DRNNCell(keras.layers.Layer): self.unit_a = unit_a self.unit_b = unit_b self.state_size = tensor_shape.as_shape([unit_a, unit_b]) + self.output_size = tensor_shape.as_shape([unit_a, unit_b]) super(Minimal2DRNNCell, self).__init__(**kwargs) def build(self, input_shape): @@ -692,5 +717,21 @@ class Minimal2DRNNCell(keras.layers.Layer): return output, [output] +class PlusOneRNNCell(keras.layers.Layer): + """Add one to the input and state. + + This cell is used for testing state_size and output_size.""" + + def __init__(self, num_unit, **kwargs): + self.state_size = num_unit + super(PlusOneRNNCell, self).__init__(**kwargs) + + def build(self, input_shape): + self.output_size = input_shape[-1] + + def call(self, inputs, states): + return inputs + 1, [states[0] + 1] + + if __name__ == '__main__': test.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt index 1160d2840f..6718e36dc6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -61,6 +61,10 @@ tf_class { mtype: "<type \'property\'>" } member { + name: "output_size" + mtype: "<type \'property\'>" + } + member { name: "state_size" mtype: "<type \'property\'>" } |