aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend.py
diff options
context:
space:
mode:
authorGravatar Scott Zhu <scottzhu@google.com>2018-10-02 16:27:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 16:32:19 -0700
commit6663959a8a2dd93a4dab9b049767d64761a00adc (patch)
treebbc84022e57498347247647be27fe19d82118282 /tensorflow/python/keras/backend.py
parent7c0c0abab5b07528bae982d69257ebf4a8c077cb (diff)
Update Keras RNN layer to support time major input.
PiperOrigin-RevId: 215479788
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r--tensorflow/python/keras/backend.py25
1 files changed, 18 insertions, 7 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 584facc859..0d6877e4a1 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -3058,7 +3058,8 @@ def rnn(step_function,
mask=None,
constants=None,
unroll=False,
- input_length=None):
+ input_length=None,
+ time_major=False):
"""Iterates over the time dimension of a tensor.
Arguments:
@@ -3087,6 +3088,13 @@ def rnn(step_function,
constants: List of constant values passed at each step.
unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
input_length: If specified, assume time dimension is of this length.
+ time_major: Boolean. If true, the inputs and outputs will be in shape
+ `(timesteps, batch, ...)`, whereas in the False case, it will be
+ `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
+ efficient because it avoids transposes at the beginning and end of the
+ RNN calculation. However, most TensorFlow data is batch-major, so by
+ default this function accepts input and emits output in batch-major
+ form.
Returns:
A tuple, `(last_output, outputs, new_states)`.
@@ -3108,15 +3116,17 @@ def rnn(step_function,
if ndim < 3:
raise ValueError('Input should be at least 3D.')
inputs_shape = inputs.shape
- axes = [1, 0] + list(range(2, ndim))
- inputs = array_ops.transpose(inputs, (axes))
+ if not time_major:
+ axes = [1, 0] + list(range(2, ndim))
+ inputs = array_ops.transpose(inputs, axes)
if mask is not None:
if mask.dtype != dtypes_module.bool:
mask = math_ops.cast(mask, dtypes_module.bool)
if len(mask.shape) == ndim - 1:
mask = expand_dims(mask)
- mask = array_ops.transpose(mask, axes)
+ if not time_major:
+ mask = array_ops.transpose(mask, axes)
if constants is None:
constants = []
@@ -3297,10 +3307,11 @@ def rnn(step_function,
outputs = output_ta.stack()
last_output = output_ta.read(last_time - 1)
- axes = [1, 0] + list(range(2, len(outputs.shape)))
- outputs = array_ops.transpose(outputs, axes)
+ if not time_major:
+ axes = [1, 0] + list(range(2, len(outputs.shape)))
+ outputs = array_ops.transpose(outputs, axes)
- # Static shape inference: (samples, time, ...)
+ # Static shape inference: (samples, time, ...) or (time, sample, ...)
outputs_shape = outputs.shape.as_list()
outputs_shape[0] = inputs_shape[0]
outputs_shape[1] = inputs_shape[1]