aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Scott Zhu <scottzhu@google.com>2018-08-09 15:08:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 15:13:14 -0700
commit9e5157c9340def527ba5eeff293c5f183cdbbfc0 (patch)
tree94a7d0abb1bd1307643c343fbf008a2a862a65e5
parent0980e844c115bdffbbdd7d993355633c48a6100e (diff)
Consolidate the RNN cell interface between Keras and TF RNN.
PiperOrigin-RevId: 208118371
-rw-r--r--tensorflow/python/keras/layers/recurrent.py43
-rw-r--r--tensorflow/python/keras/layers/recurrent_test.py41
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt4
3 files changed, 75 insertions, 13 deletions
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index acc4ba37c0..66c68e2085 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -93,6 +93,13 @@ class StackedRNNCells(Layer):
state_size.append(cell.state_size)
return tuple(state_size)
+ @property
+ def output_size(self):
+ if hasattr(self.cells[-1], 'output_size'):
+ return self.cells[-1].output_size
+ else:
+ return self.state_size[0]
+
def call(self, inputs, states, constants=None, **kwargs):
# Recover per-cell states.
nested_states = []
@@ -244,13 +251,16 @@ class RNN(Layer):
cell can also take the optional argument `constants`, see
section "Note on passing external constants" below.
- a `state_size` attribute. This can be a single integer
- (single state) in which case it is the size of the recurrent state
- (which should be the same as the size of the cell output).
- This can also be a list/tuple of integers (one size per state).
- In this case, the first entry (`state_size[0]`) should be the same
- as the size of the cell output.
+ (single state) in which case it is the size of the recurrent
+ state. This can also be a list/tuple of integers (one size per
+ state).
The `state_size` can also be TensorShape or tuple/list of
TensorShape, to represent high dimension state.
+ - a `output_size` attribute. This can be a single integer or a
+ TensorShape, which represent the shape of the output. For backward
+ compatible reason, if this attribute is not available for the
+ cell, the value will be inferred by the first element of the
+ `state_size`.
In the case that `cell` is a list of RNN cell instances, the cells
will be stacked on after the other in the RNN, implementing an
efficient stacked RNN.
@@ -289,13 +299,13 @@ class RNN(Layer):
Output shape:
- if `return_state`: a list of tensors. The first tensor is
the output. The remaining tensors are the last states,
- each with shape `(batch_size, ...)`, where `...` is in the shape of
- `state_size`.
+ each with shape `(batch_size, state_size)`, where `state_size` could
+ be a high dimension tensor shape.
- if `return_sequences`: N-D tensor with shape
- `(batch_size, timesteps, ...)`, where `...` is in the shape of output
- size.
- - else, N-D tensor with shape `(batch_size, ...)`, where `...` is in the
- shape of output size.
+ `(batch_size, timesteps, output_size)`, where `output_size` could
+ be a high dimension tensor shape.
+ - else, N-D tensor with shape `(batch_size, output_size)`, where
+ `output_size` could be a high dimension tensor shape.
# Masking
This layer supports masking for input data with a variable number
@@ -442,8 +452,12 @@ class RNN(Layer):
state_size = self.cell.state_size
else:
state_size = [self.cell.state_size]
- # Note that state_size[0] could be a tensor_shape or int.
- output_dim = tensor_shape.as_shape(state_size[0]).as_list()
+
+ if hasattr(self.cell, 'output_size'):
+ output_dim = tensor_shape.as_shape(self.cell.output_size).as_list()
+ else:
+ # Note that state_size[0] could be a tensor_shape or int.
+ output_dim = tensor_shape.as_shape(state_size[0]).as_list()
if self.return_sequences:
output_shape = tuple([input_shape[0], input_shape[1]] + output_dim)
@@ -893,6 +907,7 @@ class SimpleRNNCell(Layer):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.state_size = self.units
+ self.output_size = self.units
self._dropout_mask = None
self._recurrent_dropout_mask = None
@@ -1296,6 +1311,7 @@ class GRUCell(Layer):
self.implementation = implementation
self.reset_after = reset_after
self.state_size = self.units
+ self.output_size = self.units
self._dropout_mask = None
self._recurrent_dropout_mask = None
@@ -1841,6 +1857,7 @@ class LSTMCell(Layer):
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.implementation = implementation
self.state_size = (self.units, self.units)
+ self.output_size = self.units
self._dropout_mask = None
self._recurrent_dropout_mask = None
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index 9be439ea14..13bd070528 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -654,6 +654,30 @@ class RNNTest(test.TestCase):
'however `cell.state_size` is'):
layer(x, initial_state=s)
+ def test_inconsistent_output_state_size(self):
+ with self.test_session():
+ batch = 32
+ time_step = 4
+ state_size = 5
+ input_size = 6
+ cell = PlusOneRNNCell(state_size)
+ x = keras.Input((None, input_size))
+ layer = keras.layers.RNN(cell)
+ y = layer(x)
+
+ self.assertEqual(cell.state_size, state_size)
+ init_state = layer.get_initial_state(x)
+ self.assertEqual(len(init_state), 1)
+ self.assertEqual(init_state[0].get_shape().as_list(),
+ [None, state_size])
+
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ np.zeros((batch, time_step, input_size)),
+ np.zeros((batch, input_size)))
+ self.assertEqual(model.output_shape, (None, input_size))
+
class Minimal2DRNNCell(keras.layers.Layer):
"""The minimal 2D RNN cell is a simple combination of 2 1-D RNN cell.
@@ -666,6 +690,7 @@ class Minimal2DRNNCell(keras.layers.Layer):
self.unit_a = unit_a
self.unit_b = unit_b
self.state_size = tensor_shape.as_shape([unit_a, unit_b])
+ self.output_size = tensor_shape.as_shape([unit_a, unit_b])
super(Minimal2DRNNCell, self).__init__(**kwargs)
def build(self, input_shape):
@@ -692,5 +717,21 @@ class Minimal2DRNNCell(keras.layers.Layer):
return output, [output]
+class PlusOneRNNCell(keras.layers.Layer):
+ """Add one to the input and state.
+
+ This cell is used for testing state_size and output_size."""
+
+ def __init__(self, num_unit, **kwargs):
+ self.state_size = num_unit
+ super(PlusOneRNNCell, self).__init__(**kwargs)
+
+ def build(self, input_shape):
+ self.output_size = input_shape[-1]
+
+ def call(self, inputs, states):
+ return inputs + 1, [states[0] + 1]
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
index 1160d2840f..6718e36dc6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -61,6 +61,10 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "output_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "state_size"
mtype: "<type \'property\'>"
}