diff options
Diffstat (limited to 'tensorflow/python/keras/layers/recurrent.py')
-rw-r--r-- | tensorflow/python/keras/layers/recurrent.py | 65 |
1 files changed, 49 insertions, 16 deletions
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index ba7498e7e6..b07ec71178 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -336,9 +336,18 @@ class RNN(Layer): in your model, you would need to specify the input length at the level of the first layer (e.g. via the `input_shape` argument) + time_major: The shape format of the `inputs` and `outputs` tensors. + If True, the inputs and outputs will be in shape + `(timesteps, batch, ...)`, whereas in the False case, it will be + `(batch, timesteps, ...)`. Using `time_major = True` is a bit more + efficient because it avoids transposes at the beginning and end of the + RNN calculation. However, most TensorFlow data is batch-major, so by + default this function accepts input and emits output in batch-major + form. Input shape: - N-D tensor with shape `(batch_size, timesteps, ...)`. + N-D tensor with shape `(batch_size, timesteps, ...)` or + `(timesteps, batch_size, ...)` when time_major is True. Output shape: - if `return_state`: a list of tensors. The first tensor is @@ -347,7 +356,8 @@ class RNN(Layer): be a high dimension tensor shape. - if `return_sequences`: N-D tensor with shape `(batch_size, timesteps, output_size)`, where `output_size` could - be a high dimension tensor shape. + be a high dimension tensor shape, or + `(timesteps, batch_size, output_size)` when `time_major` is True. - else, N-D tensor with shape `(batch_size, output_size)`, where `output_size` could be a high dimension tensor shape. @@ -448,6 +458,7 @@ class RNN(Layer): go_backwards=False, stateful=False, unroll=False, + time_major=False, **kwargs): if isinstance(cell, (list, tuple)): cell = StackedRNNCells(cell) @@ -468,6 +479,7 @@ class RNN(Layer): self.go_backwards = go_backwards self.stateful = stateful self.unroll = unroll + self.time_major = time_major self.supports_masking = True self.input_spec = [None] # The input shape is unknown yet, at least rank 3. @@ -503,14 +515,21 @@ class RNN(Layer): # Note that state_size[0] could be a tensor_shape or int. output_dim = tensor_shape.as_shape(state_size[0]).as_list() + batch = input_shape[0] + time_step = input_shape[1] + if self.time_major: + batch, time_step = time_step, batch if self.return_sequences: - output_shape = tuple([input_shape[0], input_shape[1]] + output_dim) + if self.time_major: + output_shape = tuple([time_step, batch] + output_dim) + else: + output_shape = tuple([batch, time_step] + output_dim) else: - output_shape = tuple([input_shape[0]] + output_dim) + output_shape = tuple([batch] + output_dim) if self.return_state: state_shape = [ - tuple([input_shape[0]] + tensor_shape.as_shape(dim).as_list()) + tuple([batch] + tensor_shape.as_shape(dim).as_list()) for dim in state_size ] return [output_shape] + state_shape @@ -539,13 +558,18 @@ class RNN(Layer): if isinstance(input_shape, list): input_shape = input_shape[0] - batch_size = input_shape[0] if self.stateful else None - input_dim = input_shape[2:] - self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_dim) + input_spec_shape = list(input_shape) + batch_index, time_step_index = (1, 0) if self.time_major else (0, 1) + if not self.stateful: + input_spec_shape[batch_index] = None + input_spec_shape[time_step_index] = None + self.input_spec[0] = InputSpec(shape=tuple(input_spec_shape)) + batch = input_shape[batch_index] + input_dim = input_shape[2:] + step_input_shape = (batch,) + input_dim # allow cell (if layer) to build before we set or validate state_spec if isinstance(self.cell, Layer): - step_input_shape = (input_shape[0],) + input_dim if constants_shape is not None: self.cell.build([step_input_shape] + constants_shape) else: @@ -598,12 +622,16 @@ class RNN(Layer): def get_initial_state(self, inputs): get_initial_state_fn = getattr(self.cell, 'get_initial_state', None) + + input_shape = array_ops.shape(inputs) + batch_size = input_shape[1] if self.time_major else input_shape[0] + dtype = inputs.dtype if get_initial_state_fn: init_state = get_initial_state_fn( - inputs=inputs, batch_size=None, dtype=None) + inputs=None, batch_size=batch_size, dtype=dtype) else: - init_state = _generate_zero_filled_state( - array_ops.shape(inputs)[0], self.cell.state_size, inputs.dtype) + init_state = _generate_zero_filled_state(batch_size, self.cell.state_size, + dtype) # Keras RNN expect the states in a list, even if it's a single state tensor. if not nest.is_sequence(init_state): init_state = [init_state] @@ -696,7 +724,7 @@ class RNN(Layer): 'Layer has ' + str(len(self.states)) + ' states but was passed ' + str(len(initial_state)) + ' initial states.') input_shape = K.int_shape(inputs) - timesteps = input_shape[1] + timesteps = input_shape[0] if self.time_major else input_shape[1] if self.unroll and timesteps in [None, 1]: raise ValueError('Cannot unroll a RNN if the ' 'time dimension is undefined or equal to 1. \n' @@ -747,7 +775,8 @@ class RNN(Layer): go_backwards=self.go_backwards, mask=mask, unroll=self.unroll, - input_length=timesteps) + input_length=timesteps, + time_major=self.time_major) if self.stateful: updates = [] for i in range(len(states)): @@ -777,7 +806,10 @@ class RNN(Layer): def reset_states(self, states=None): if not self.stateful: raise AttributeError('Layer must be stateful.') - batch_size = self.input_spec[0].shape[0] + if self.time_major: + batch_size = self.input_spec[0].shape[1] + else: + batch_size = self.input_spec[0].shape[0] if not batch_size: raise ValueError('If a RNN is stateful, it needs to know ' 'its batch size. Specify the batch size ' @@ -839,7 +871,8 @@ class RNN(Layer): 'return_state': self.return_state, 'go_backwards': self.go_backwards, 'stateful': self.stateful, - 'unroll': self.unroll + 'unroll': self.unroll, + 'time_major': self.time_major } if self._num_constants is not None: config['num_constants'] = self._num_constants |