diff options
author | Scott Zhu <scottzhu@google.com> | 2018-10-02 16:27:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 16:32:19 -0700 |
commit | 6663959a8a2dd93a4dab9b049767d64761a00adc (patch) | |
tree | bbc84022e57498347247647be27fe19d82118282 /tensorflow/python/keras/backend.py | |
parent | 7c0c0abab5b07528bae982d69257ebf4a8c077cb (diff) |
Update Keras RNN layer to support time major input.
PiperOrigin-RevId: 215479788
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r-- | tensorflow/python/keras/backend.py | 25 |
1 files changed, 18 insertions, 7 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 584facc859..0d6877e4a1 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -3058,7 +3058,8 @@ def rnn(step_function, mask=None, constants=None, unroll=False, - input_length=None): + input_length=None, + time_major=False): """Iterates over the time dimension of a tensor. Arguments: @@ -3087,6 +3088,13 @@ def rnn(step_function, constants: List of constant values passed at each step. unroll: Whether to unroll the RNN or to use a symbolic `while_loop`. input_length: If specified, assume time dimension is of this length. + time_major: Boolean. If true, the inputs and outputs will be in shape + `(timesteps, batch, ...)`, whereas in the False case, it will be + `(batch, timesteps, ...)`. Using `time_major = True` is a bit more + efficient because it avoids transposes at the beginning and end of the + RNN calculation. However, most TensorFlow data is batch-major, so by + default this function accepts input and emits output in batch-major + form. Returns: A tuple, `(last_output, outputs, new_states)`. @@ -3108,15 +3116,17 @@ def rnn(step_function, if ndim < 3: raise ValueError('Input should be at least 3D.') inputs_shape = inputs.shape - axes = [1, 0] + list(range(2, ndim)) - inputs = array_ops.transpose(inputs, (axes)) + if not time_major: + axes = [1, 0] + list(range(2, ndim)) + inputs = array_ops.transpose(inputs, axes) if mask is not None: if mask.dtype != dtypes_module.bool: mask = math_ops.cast(mask, dtypes_module.bool) if len(mask.shape) == ndim - 1: mask = expand_dims(mask) - mask = array_ops.transpose(mask, axes) + if not time_major: + mask = array_ops.transpose(mask, axes) if constants is None: constants = [] @@ -3297,10 +3307,11 @@ def rnn(step_function, outputs = output_ta.stack() last_output = output_ta.read(last_time - 1) - axes = [1, 0] + list(range(2, len(outputs.shape))) - outputs = array_ops.transpose(outputs, axes) + if not time_major: + axes = [1, 0] + list(range(2, len(outputs.shape))) + outputs = array_ops.transpose(outputs, axes) - # Static shape inference: (samples, time, ...) + # Static shape inference: (samples, time, ...) or (time, sample, ...) outputs_shape = outputs.shape.as_list() outputs_shape[0] = inputs_shape[0] outputs_shape[1] = inputs_shape[1] |