aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/recurrent.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/layers/recurrent.py')
-rw-r--r--tensorflow/python/keras/layers/recurrent.py65
1 files changed, 49 insertions, 16 deletions
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index ba7498e7e6..b07ec71178 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -336,9 +336,18 @@ class RNN(Layer):
in your model, you would need to specify the input length
at the level of the first layer
(e.g. via the `input_shape` argument)
+ time_major: The shape format of the `inputs` and `outputs` tensors.
+ If True, the inputs and outputs will be in shape
+ `(timesteps, batch, ...)`, whereas in the False case, it will be
+ `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
+ efficient because it avoids transposes at the beginning and end of the
+ RNN calculation. However, most TensorFlow data is batch-major, so by
+ default this function accepts input and emits output in batch-major
+ form.
Input shape:
- N-D tensor with shape `(batch_size, timesteps, ...)`.
+ N-D tensor with shape `(batch_size, timesteps, ...)` or
+ `(timesteps, batch_size, ...)` when time_major is True.
Output shape:
- if `return_state`: a list of tensors. The first tensor is
@@ -347,7 +356,8 @@ class RNN(Layer):
be a high dimension tensor shape.
- if `return_sequences`: N-D tensor with shape
`(batch_size, timesteps, output_size)`, where `output_size` could
- be a high dimension tensor shape.
+ be a high dimension tensor shape, or
+ `(timesteps, batch_size, output_size)` when `time_major` is True.
- else, N-D tensor with shape `(batch_size, output_size)`, where
`output_size` could be a high dimension tensor shape.
@@ -448,6 +458,7 @@ class RNN(Layer):
go_backwards=False,
stateful=False,
unroll=False,
+ time_major=False,
**kwargs):
if isinstance(cell, (list, tuple)):
cell = StackedRNNCells(cell)
@@ -468,6 +479,7 @@ class RNN(Layer):
self.go_backwards = go_backwards
self.stateful = stateful
self.unroll = unroll
+ self.time_major = time_major
self.supports_masking = True
self.input_spec = [None] # The input shape is unknown yet, at least rank 3.
@@ -503,14 +515,21 @@ class RNN(Layer):
# Note that state_size[0] could be a tensor_shape or int.
output_dim = tensor_shape.as_shape(state_size[0]).as_list()
+ batch = input_shape[0]
+ time_step = input_shape[1]
+ if self.time_major:
+ batch, time_step = time_step, batch
if self.return_sequences:
- output_shape = tuple([input_shape[0], input_shape[1]] + output_dim)
+ if self.time_major:
+ output_shape = tuple([time_step, batch] + output_dim)
+ else:
+ output_shape = tuple([batch, time_step] + output_dim)
else:
- output_shape = tuple([input_shape[0]] + output_dim)
+ output_shape = tuple([batch] + output_dim)
if self.return_state:
state_shape = [
- tuple([input_shape[0]] + tensor_shape.as_shape(dim).as_list())
+ tuple([batch] + tensor_shape.as_shape(dim).as_list())
for dim in state_size
]
return [output_shape] + state_shape
@@ -539,13 +558,18 @@ class RNN(Layer):
if isinstance(input_shape, list):
input_shape = input_shape[0]
- batch_size = input_shape[0] if self.stateful else None
- input_dim = input_shape[2:]
- self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_dim)
+ input_spec_shape = list(input_shape)
+ batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
+ if not self.stateful:
+ input_spec_shape[batch_index] = None
+ input_spec_shape[time_step_index] = None
+ self.input_spec[0] = InputSpec(shape=tuple(input_spec_shape))
+ batch = input_shape[batch_index]
+ input_dim = input_shape[2:]
+ step_input_shape = (batch,) + input_dim
# allow cell (if layer) to build before we set or validate state_spec
if isinstance(self.cell, Layer):
- step_input_shape = (input_shape[0],) + input_dim
if constants_shape is not None:
self.cell.build([step_input_shape] + constants_shape)
else:
@@ -598,12 +622,16 @@ class RNN(Layer):
def get_initial_state(self, inputs):
get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
+
+ input_shape = array_ops.shape(inputs)
+ batch_size = input_shape[1] if self.time_major else input_shape[0]
+ dtype = inputs.dtype
if get_initial_state_fn:
init_state = get_initial_state_fn(
- inputs=inputs, batch_size=None, dtype=None)
+ inputs=None, batch_size=batch_size, dtype=dtype)
else:
- init_state = _generate_zero_filled_state(
- array_ops.shape(inputs)[0], self.cell.state_size, inputs.dtype)
+ init_state = _generate_zero_filled_state(batch_size, self.cell.state_size,
+ dtype)
# Keras RNN expect the states in a list, even if it's a single state tensor.
if not nest.is_sequence(init_state):
init_state = [init_state]
@@ -696,7 +724,7 @@ class RNN(Layer):
'Layer has ' + str(len(self.states)) + ' states but was passed ' +
str(len(initial_state)) + ' initial states.')
input_shape = K.int_shape(inputs)
- timesteps = input_shape[1]
+ timesteps = input_shape[0] if self.time_major else input_shape[1]
if self.unroll and timesteps in [None, 1]:
raise ValueError('Cannot unroll a RNN if the '
'time dimension is undefined or equal to 1. \n'
@@ -747,7 +775,8 @@ class RNN(Layer):
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
- input_length=timesteps)
+ input_length=timesteps,
+ time_major=self.time_major)
if self.stateful:
updates = []
for i in range(len(states)):
@@ -777,7 +806,10 @@ class RNN(Layer):
def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
- batch_size = self.input_spec[0].shape[0]
+ if self.time_major:
+ batch_size = self.input_spec[0].shape[1]
+ else:
+ batch_size = self.input_spec[0].shape[0]
if not batch_size:
raise ValueError('If a RNN is stateful, it needs to know '
'its batch size. Specify the batch size '
@@ -839,7 +871,8 @@ class RNN(Layer):
'return_state': self.return_state,
'go_backwards': self.go_backwards,
'stateful': self.stateful,
- 'unroll': self.unroll
+ 'unroll': self.unroll,
+ 'time_major': self.time_major
}
if self._num_constants is not None:
config['num_constants'] = self._num_constants