diff options
author | Scott Zhu <scottzhu@google.com> | 2018-10-02 16:27:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 16:32:19 -0700 |
commit | 6663959a8a2dd93a4dab9b049767d64761a00adc (patch) | |
tree | bbc84022e57498347247647be27fe19d82118282 | |
parent | 7c0c0abab5b07528bae982d69257ebf4a8c077cb (diff) |
Update Keras RNN layer to support time major input.
PiperOrigin-RevId: 215479788
9 files changed, 207 insertions, 32 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 584facc859..0d6877e4a1 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -3058,7 +3058,8 @@ def rnn(step_function, mask=None, constants=None, unroll=False, - input_length=None): + input_length=None, + time_major=False): """Iterates over the time dimension of a tensor. Arguments: @@ -3087,6 +3088,13 @@ def rnn(step_function, constants: List of constant values passed at each step. unroll: Whether to unroll the RNN or to use a symbolic `while_loop`. input_length: If specified, assume time dimension is of this length. + time_major: Boolean. If true, the inputs and outputs will be in shape + `(timesteps, batch, ...)`, whereas in the False case, it will be + `(batch, timesteps, ...)`. Using `time_major = True` is a bit more + efficient because it avoids transposes at the beginning and end of the + RNN calculation. However, most TensorFlow data is batch-major, so by + default this function accepts input and emits output in batch-major + form. Returns: A tuple, `(last_output, outputs, new_states)`. @@ -3108,15 +3116,17 @@ def rnn(step_function, if ndim < 3: raise ValueError('Input should be at least 3D.') inputs_shape = inputs.shape - axes = [1, 0] + list(range(2, ndim)) - inputs = array_ops.transpose(inputs, (axes)) + if not time_major: + axes = [1, 0] + list(range(2, ndim)) + inputs = array_ops.transpose(inputs, axes) if mask is not None: if mask.dtype != dtypes_module.bool: mask = math_ops.cast(mask, dtypes_module.bool) if len(mask.shape) == ndim - 1: mask = expand_dims(mask) - mask = array_ops.transpose(mask, axes) + if not time_major: + mask = array_ops.transpose(mask, axes) if constants is None: constants = [] @@ -3297,10 +3307,11 @@ def rnn(step_function, outputs = output_ta.stack() last_output = output_ta.read(last_time - 1) - axes = [1, 0] + list(range(2, len(outputs.shape))) - outputs = array_ops.transpose(outputs, axes) + if not time_major: + axes = [1, 0] + list(range(2, len(outputs.shape))) + outputs = array_ops.transpose(outputs, axes) - # Static shape inference: (samples, time, ...) + # Static shape inference: (samples, time, ...) or (time, sample, ...) outputs_shape = outputs.shape.as_list() outputs_shape[0] = inputs_shape[0] outputs_shape[1] = inputs_shape[1] diff --git a/tensorflow/python/keras/layers/cudnn_recurrent.py b/tensorflow/python/keras/layers/cudnn_recurrent.py index cf2b0c476c..29a09a3d71 100644 --- a/tensorflow/python/keras/layers/cudnn_recurrent.py +++ b/tensorflow/python/keras/layers/cudnn_recurrent.py @@ -47,6 +47,9 @@ class _CuDNNRNN(RNN): stateful: Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch. + time_major: Boolean (default False). If true, the inputs and outputs will be + in shape `(timesteps, batch, ...)`, whereas in the False case, it will + be `(batch, timesteps, ...)`. """ def __init__(self, @@ -54,6 +57,7 @@ class _CuDNNRNN(RNN): return_state=False, go_backwards=False, stateful=False, + time_major=False, **kwargs): # We invoke the base layer's initializer directly here because we do not # want to create RNN cell instance. @@ -62,6 +66,7 @@ class _CuDNNRNN(RNN): self.return_state = return_state self.go_backwards = go_backwards self.stateful = stateful + self.time_major = time_major self.supports_masking = False self.input_spec = [InputSpec(ndim=3)] if hasattr(self.cell.state_size, '__len__'): @@ -124,7 +129,8 @@ class _CuDNNRNN(RNN): 'return_sequences': self.return_sequences, 'return_state': self.return_state, 'go_backwards': self.go_backwards, - 'stateful': self.stateful + 'stateful': self.stateful, + 'time_major': self.time_major, } base_config = super( # pylint: disable=bad-super-call RNN, self).get_config() @@ -267,7 +273,8 @@ class CuDNNGRU(_CuDNNRNN): self.built = True def _process_batch(self, inputs, initial_state): - inputs = array_ops.transpose(inputs, perm=(1, 0, 2)) + if not self.time_major: + inputs = array_ops.transpose(inputs, perm=(1, 0, 2)) input_h = initial_state[0] input_h = array_ops.expand_dims(input_h, axis=0) @@ -301,7 +308,10 @@ class CuDNNGRU(_CuDNNRNN): if self.stateful or self.return_state: h = h[0] if self.return_sequences: - output = array_ops.transpose(outputs, perm=(1, 0, 2)) + if self.time_major: + output = outputs + else: + output = array_ops.transpose(outputs, perm=(1, 0, 2)) else: output = outputs[-1] return output, [h] @@ -456,7 +466,8 @@ class CuDNNLSTM(_CuDNNRNN): self.built = True def _process_batch(self, inputs, initial_state): - inputs = array_ops.transpose(inputs, perm=(1, 0, 2)) + if not self.time_major: + inputs = array_ops.transpose(inputs, perm=(1, 0, 2)) input_h = initial_state[0] input_c = initial_state[1] input_h = array_ops.expand_dims(input_h, axis=0) @@ -496,7 +507,10 @@ class CuDNNLSTM(_CuDNNRNN): h = h[0] c = c[0] if self.return_sequences: - output = array_ops.transpose(outputs, perm=(1, 0, 2)) + if self.time_major: + output = outputs + else: + output = array_ops.transpose(outputs, perm=(1, 0, 2)) else: output = outputs[-1] return output, [h, c] diff --git a/tensorflow/python/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/layers/cudnn_recurrent_test.py index 2ed0aa8f26..7becbfede1 100644 --- a/tensorflow/python/keras/layers/cudnn_recurrent_test.py +++ b/tensorflow/python/keras/layers/cudnn_recurrent_test.py @@ -26,6 +26,7 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test from tensorflow.python.training.rmsprop import RMSPropOptimizer @@ -142,6 +143,32 @@ class CuDNNTest(test.TestCase, parameterized.TestCase): ('cudnngru', keras.layers.CuDNNGRU), ('cudnnlstm', keras.layers.CuDNNLSTM), ) + def test_time_major_input(self, layer_class): + if test.is_gpu_available(cuda_only=True): + with self.test_session(use_gpu=True): + input_size = 10 + timesteps = 6 + units = 2 + num_samples = 32 + + model = keras.models.Sequential() + model.add( + keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2]))) + layer = layer_class(units, time_major=True, return_sequences=True) + model.add(layer) + model.add( + keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2]))) + model.compile(loss='categorical_crossentropy', optimizer='adam') + model.fit( + np.ones((num_samples, timesteps, input_size)), + np.ones((num_samples, timesteps, units))) + out = model.predict(np.ones((num_samples, timesteps, input_size))) + self.assertEqual(out.shape, (num_samples, timesteps, units)) + + @parameterized.named_parameters( + ('cudnngru', keras.layers.CuDNNGRU), + ('cudnnlstm', keras.layers.CuDNNLSTM), + ) def test_specify_initial_state_keras_tensor(self, layer_class): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index ba7498e7e6..b07ec71178 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -336,9 +336,18 @@ class RNN(Layer): in your model, you would need to specify the input length at the level of the first layer (e.g. via the `input_shape` argument) + time_major: The shape format of the `inputs` and `outputs` tensors. + If True, the inputs and outputs will be in shape + `(timesteps, batch, ...)`, whereas in the False case, it will be + `(batch, timesteps, ...)`. Using `time_major = True` is a bit more + efficient because it avoids transposes at the beginning and end of the + RNN calculation. However, most TensorFlow data is batch-major, so by + default this function accepts input and emits output in batch-major + form. Input shape: - N-D tensor with shape `(batch_size, timesteps, ...)`. + N-D tensor with shape `(batch_size, timesteps, ...)` or + `(timesteps, batch_size, ...)` when time_major is True. Output shape: - if `return_state`: a list of tensors. The first tensor is @@ -347,7 +356,8 @@ class RNN(Layer): be a high dimension tensor shape. - if `return_sequences`: N-D tensor with shape `(batch_size, timesteps, output_size)`, where `output_size` could - be a high dimension tensor shape. + be a high dimension tensor shape, or + `(timesteps, batch_size, output_size)` when `time_major` is True. - else, N-D tensor with shape `(batch_size, output_size)`, where `output_size` could be a high dimension tensor shape. @@ -448,6 +458,7 @@ class RNN(Layer): go_backwards=False, stateful=False, unroll=False, + time_major=False, **kwargs): if isinstance(cell, (list, tuple)): cell = StackedRNNCells(cell) @@ -468,6 +479,7 @@ class RNN(Layer): self.go_backwards = go_backwards self.stateful = stateful self.unroll = unroll + self.time_major = time_major self.supports_masking = True self.input_spec = [None] # The input shape is unknown yet, at least rank 3. @@ -503,14 +515,21 @@ class RNN(Layer): # Note that state_size[0] could be a tensor_shape or int. output_dim = tensor_shape.as_shape(state_size[0]).as_list() + batch = input_shape[0] + time_step = input_shape[1] + if self.time_major: + batch, time_step = time_step, batch if self.return_sequences: - output_shape = tuple([input_shape[0], input_shape[1]] + output_dim) + if self.time_major: + output_shape = tuple([time_step, batch] + output_dim) + else: + output_shape = tuple([batch, time_step] + output_dim) else: - output_shape = tuple([input_shape[0]] + output_dim) + output_shape = tuple([batch] + output_dim) if self.return_state: state_shape = [ - tuple([input_shape[0]] + tensor_shape.as_shape(dim).as_list()) + tuple([batch] + tensor_shape.as_shape(dim).as_list()) for dim in state_size ] return [output_shape] + state_shape @@ -539,13 +558,18 @@ class RNN(Layer): if isinstance(input_shape, list): input_shape = input_shape[0] - batch_size = input_shape[0] if self.stateful else None - input_dim = input_shape[2:] - self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_dim) + input_spec_shape = list(input_shape) + batch_index, time_step_index = (1, 0) if self.time_major else (0, 1) + if not self.stateful: + input_spec_shape[batch_index] = None + input_spec_shape[time_step_index] = None + self.input_spec[0] = InputSpec(shape=tuple(input_spec_shape)) + batch = input_shape[batch_index] + input_dim = input_shape[2:] + step_input_shape = (batch,) + input_dim # allow cell (if layer) to build before we set or validate state_spec if isinstance(self.cell, Layer): - step_input_shape = (input_shape[0],) + input_dim if constants_shape is not None: self.cell.build([step_input_shape] + constants_shape) else: @@ -598,12 +622,16 @@ class RNN(Layer): def get_initial_state(self, inputs): get_initial_state_fn = getattr(self.cell, 'get_initial_state', None) + + input_shape = array_ops.shape(inputs) + batch_size = input_shape[1] if self.time_major else input_shape[0] + dtype = inputs.dtype if get_initial_state_fn: init_state = get_initial_state_fn( - inputs=inputs, batch_size=None, dtype=None) + inputs=None, batch_size=batch_size, dtype=dtype) else: - init_state = _generate_zero_filled_state( - array_ops.shape(inputs)[0], self.cell.state_size, inputs.dtype) + init_state = _generate_zero_filled_state(batch_size, self.cell.state_size, + dtype) # Keras RNN expect the states in a list, even if it's a single state tensor. if not nest.is_sequence(init_state): init_state = [init_state] @@ -696,7 +724,7 @@ class RNN(Layer): 'Layer has ' + str(len(self.states)) + ' states but was passed ' + str(len(initial_state)) + ' initial states.') input_shape = K.int_shape(inputs) - timesteps = input_shape[1] + timesteps = input_shape[0] if self.time_major else input_shape[1] if self.unroll and timesteps in [None, 1]: raise ValueError('Cannot unroll a RNN if the ' 'time dimension is undefined or equal to 1. \n' @@ -747,7 +775,8 @@ class RNN(Layer): go_backwards=self.go_backwards, mask=mask, unroll=self.unroll, - input_length=timesteps) + input_length=timesteps, + time_major=self.time_major) if self.stateful: updates = [] for i in range(len(states)): @@ -777,7 +806,10 @@ class RNN(Layer): def reset_states(self, states=None): if not self.stateful: raise AttributeError('Layer must be stateful.') - batch_size = self.input_spec[0].shape[0] + if self.time_major: + batch_size = self.input_spec[0].shape[1] + else: + batch_size = self.input_spec[0].shape[0] if not batch_size: raise ValueError('If a RNN is stateful, it needs to know ' 'its batch size. Specify the batch size ' @@ -839,7 +871,8 @@ class RNN(Layer): 'return_state': self.return_state, 'go_backwards': self.go_backwards, 'stateful': self.stateful, - 'unroll': self.unroll + 'unroll': self.unroll, + 'time_major': self.time_major } if self._num_constants is not None: config['num_constants'] = self._num_constants diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py index b9e90095e4..d246be6b45 100644 --- a/tensorflow/python/keras/layers/recurrent_test.py +++ b/tensorflow/python/keras/layers/recurrent_test.py @@ -186,6 +186,96 @@ class RNNTest(test.TestCase): y_np_2 = model.predict(x_np) self.assertAllClose(y_np, y_np_2, atol=1e-4) + def test_rnn_with_time_major(self): + batch = 10 + time_step = 5 + embedding_dim = 4 + units = 3 + + with self.cached_session(): + # Test basic case. + x = keras.Input((time_step, embedding_dim)) + time_major_x = keras.layers.Lambda( + lambda t: array_ops.transpose(t, [1, 0, 2]))(x) + layer = keras.layers.SimpleRNN( + units, time_major=True, return_sequences=True) + self.assertEqual( + layer.compute_output_shape((time_step, None, + embedding_dim)).as_list(), + [time_step, None, units]) + y = layer(time_major_x) + self.assertEqual(layer.output_shape, (time_step, None, units)) + + y = keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2]))(y) + + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + np.zeros((batch, time_step, embedding_dim)), + np.zeros((batch, time_step, units))) + + with self.cached_session(): + # Test stacking. + x = keras.Input((time_step, embedding_dim)) + time_major_x = keras.layers.Lambda( + lambda t: array_ops.transpose(t, [1, 0, 2]))(x) + cell_units = [10, 8, 6] + cells = [keras.layers.SimpleRNNCell(cell_units[i]) for i in range(3)] + layer = keras.layers.RNN(cells, time_major=True, return_sequences=True) + y = layer(time_major_x) + self.assertEqual(layer.output_shape, (time_step, None, cell_units[-1])) + + y = keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2]))(y) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + np.zeros((batch, time_step, embedding_dim)), + np.zeros((batch, time_step, cell_units[-1]))) + + with self.cached_session(): + # Test masking. + x = keras.Input((time_step, embedding_dim)) + time_major = keras.layers.Lambda( + lambda t: array_ops.transpose(t, [1, 0, 2]))(x) + mask = keras.layers.Masking()(time_major) + rnn = keras.layers.SimpleRNN( + units, time_major=True, return_sequences=True)(mask) + y = keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2]))(rnn) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + np.zeros((batch, time_step, embedding_dim)), + np.zeros((batch, time_step, units))) + + with self.cached_session(): + # Test layer output + x = keras.Input((time_step, embedding_dim)) + rnn_1 = keras.layers.SimpleRNN(units, return_sequences=True) + y = rnn_1(x) + + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + np.zeros((batch, time_step, embedding_dim)), + np.zeros((batch, time_step, units))) + + x_np = np.random.random((batch, time_step, embedding_dim)) + y_np_1 = model.predict(x_np) + + time_major = keras.layers.Lambda( + lambda t: array_ops.transpose(t, [1, 0, 2]))(x) + rnn_2 = keras.layers.SimpleRNN( + units, time_major=True, return_sequences=True) + y_2 = rnn_2(time_major) + y_2 = keras.layers.Lambda( + lambda t: array_ops.transpose(t, [1, 0, 2]))(y_2) + + model_2 = keras.models.Model(x, y_2) + rnn_2.set_weights(rnn_1.get_weights()) + + y_np_2 = model_2.predict(x_np) + self.assertAllClose(y_np_1, y_np_2, atol=1e-4) + def test_rnn_cell_with_constants_layer(self): class RNNCellWithConstants(keras.layers.Layer): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt index 126ce8db6a..a71a59e269 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt @@ -398,7 +398,7 @@ tf_module { } member_method { name: "rnn" - argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\'], " + argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\', \'False\'], " } member_method { name: "round" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt index 2b6e8af11d..68b6678d48 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt @@ -86,7 +86,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\'], " + argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'time_major\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt index 126ce8db6a..a71a59e269 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt @@ -398,7 +398,7 @@ tf_module { } member_method { name: "rnn" - argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\'], " + argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\', \'False\'], " } member_method { name: "round" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt index 2b6e8af11d..68b6678d48 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt @@ -86,7 +86,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\'], " + argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'time_major\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\'], " } member_method { name: "add_loss" |