aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Scott Zhu <scottzhu@google.com>2018-10-02 16:27:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 16:32:19 -0700
commit6663959a8a2dd93a4dab9b049767d64761a00adc (patch)
treebbc84022e57498347247647be27fe19d82118282
parent7c0c0abab5b07528bae982d69257ebf4a8c077cb (diff)
Update Keras RNN layer to support time major input.
PiperOrigin-RevId: 215479788
-rw-r--r--tensorflow/python/keras/backend.py25
-rw-r--r--tensorflow/python/keras/layers/cudnn_recurrent.py24
-rw-r--r--tensorflow/python/keras/layers/cudnn_recurrent_test.py27
-rw-r--r--tensorflow/python/keras/layers/recurrent.py65
-rw-r--r--tensorflow/python/keras/layers/recurrent_test.py90
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt2
9 files changed, 207 insertions, 32 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 584facc859..0d6877e4a1 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -3058,7 +3058,8 @@ def rnn(step_function,
mask=None,
constants=None,
unroll=False,
- input_length=None):
+ input_length=None,
+ time_major=False):
"""Iterates over the time dimension of a tensor.
Arguments:
@@ -3087,6 +3088,13 @@ def rnn(step_function,
constants: List of constant values passed at each step.
unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
input_length: If specified, assume time dimension is of this length.
+ time_major: Boolean. If true, the inputs and outputs will be in shape
+ `(timesteps, batch, ...)`, whereas in the False case, it will be
+ `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
+ efficient because it avoids transposes at the beginning and end of the
+ RNN calculation. However, most TensorFlow data is batch-major, so by
+ default this function accepts input and emits output in batch-major
+ form.
Returns:
A tuple, `(last_output, outputs, new_states)`.
@@ -3108,15 +3116,17 @@ def rnn(step_function,
if ndim < 3:
raise ValueError('Input should be at least 3D.')
inputs_shape = inputs.shape
- axes = [1, 0] + list(range(2, ndim))
- inputs = array_ops.transpose(inputs, (axes))
+ if not time_major:
+ axes = [1, 0] + list(range(2, ndim))
+ inputs = array_ops.transpose(inputs, axes)
if mask is not None:
if mask.dtype != dtypes_module.bool:
mask = math_ops.cast(mask, dtypes_module.bool)
if len(mask.shape) == ndim - 1:
mask = expand_dims(mask)
- mask = array_ops.transpose(mask, axes)
+ if not time_major:
+ mask = array_ops.transpose(mask, axes)
if constants is None:
constants = []
@@ -3297,10 +3307,11 @@ def rnn(step_function,
outputs = output_ta.stack()
last_output = output_ta.read(last_time - 1)
- axes = [1, 0] + list(range(2, len(outputs.shape)))
- outputs = array_ops.transpose(outputs, axes)
+ if not time_major:
+ axes = [1, 0] + list(range(2, len(outputs.shape)))
+ outputs = array_ops.transpose(outputs, axes)
- # Static shape inference: (samples, time, ...)
+ # Static shape inference: (samples, time, ...) or (time, sample, ...)
outputs_shape = outputs.shape.as_list()
outputs_shape[0] = inputs_shape[0]
outputs_shape[1] = inputs_shape[1]
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent.py b/tensorflow/python/keras/layers/cudnn_recurrent.py
index cf2b0c476c..29a09a3d71 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent.py
@@ -47,6 +47,9 @@ class _CuDNNRNN(RNN):
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
+ time_major: Boolean (default False). If true, the inputs and outputs will be
+ in shape `(timesteps, batch, ...)`, whereas in the False case, it will
+ be `(batch, timesteps, ...)`.
"""
def __init__(self,
@@ -54,6 +57,7 @@ class _CuDNNRNN(RNN):
return_state=False,
go_backwards=False,
stateful=False,
+ time_major=False,
**kwargs):
# We invoke the base layer's initializer directly here because we do not
# want to create RNN cell instance.
@@ -62,6 +66,7 @@ class _CuDNNRNN(RNN):
self.return_state = return_state
self.go_backwards = go_backwards
self.stateful = stateful
+ self.time_major = time_major
self.supports_masking = False
self.input_spec = [InputSpec(ndim=3)]
if hasattr(self.cell.state_size, '__len__'):
@@ -124,7 +129,8 @@ class _CuDNNRNN(RNN):
'return_sequences': self.return_sequences,
'return_state': self.return_state,
'go_backwards': self.go_backwards,
- 'stateful': self.stateful
+ 'stateful': self.stateful,
+ 'time_major': self.time_major,
}
base_config = super( # pylint: disable=bad-super-call
RNN, self).get_config()
@@ -267,7 +273,8 @@ class CuDNNGRU(_CuDNNRNN):
self.built = True
def _process_batch(self, inputs, initial_state):
- inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
+ if not self.time_major:
+ inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
input_h = initial_state[0]
input_h = array_ops.expand_dims(input_h, axis=0)
@@ -301,7 +308,10 @@ class CuDNNGRU(_CuDNNRNN):
if self.stateful or self.return_state:
h = h[0]
if self.return_sequences:
- output = array_ops.transpose(outputs, perm=(1, 0, 2))
+ if self.time_major:
+ output = outputs
+ else:
+ output = array_ops.transpose(outputs, perm=(1, 0, 2))
else:
output = outputs[-1]
return output, [h]
@@ -456,7 +466,8 @@ class CuDNNLSTM(_CuDNNRNN):
self.built = True
def _process_batch(self, inputs, initial_state):
- inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
+ if not self.time_major:
+ inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
input_h = initial_state[0]
input_c = initial_state[1]
input_h = array_ops.expand_dims(input_h, axis=0)
@@ -496,7 +507,10 @@ class CuDNNLSTM(_CuDNNRNN):
h = h[0]
c = c[0]
if self.return_sequences:
- output = array_ops.transpose(outputs, perm=(1, 0, 2))
+ if self.time_major:
+ output = outputs
+ else:
+ output = array_ops.transpose(outputs, perm=(1, 0, 2))
else:
output = outputs[-1]
return output, [h, c]
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
index 2ed0aa8f26..7becbfede1 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent_test.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
@@ -26,6 +26,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.training.rmsprop import RMSPropOptimizer
@@ -142,6 +143,32 @@ class CuDNNTest(test.TestCase, parameterized.TestCase):
('cudnngru', keras.layers.CuDNNGRU),
('cudnnlstm', keras.layers.CuDNNLSTM),
)
+ def test_time_major_input(self, layer_class):
+ if test.is_gpu_available(cuda_only=True):
+ with self.test_session(use_gpu=True):
+ input_size = 10
+ timesteps = 6
+ units = 2
+ num_samples = 32
+
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2])))
+ layer = layer_class(units, time_major=True, return_sequences=True)
+ model.add(layer)
+ model.add(
+ keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2])))
+ model.compile(loss='categorical_crossentropy', optimizer='adam')
+ model.fit(
+ np.ones((num_samples, timesteps, input_size)),
+ np.ones((num_samples, timesteps, units)))
+ out = model.predict(np.ones((num_samples, timesteps, input_size)))
+ self.assertEqual(out.shape, (num_samples, timesteps, units))
+
+ @parameterized.named_parameters(
+ ('cudnngru', keras.layers.CuDNNGRU),
+ ('cudnnlstm', keras.layers.CuDNNLSTM),
+ )
def test_specify_initial_state_keras_tensor(self, layer_class):
if test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True):
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index ba7498e7e6..b07ec71178 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -336,9 +336,18 @@ class RNN(Layer):
in your model, you would need to specify the input length
at the level of the first layer
(e.g. via the `input_shape` argument)
+ time_major: The shape format of the `inputs` and `outputs` tensors.
+ If True, the inputs and outputs will be in shape
+ `(timesteps, batch, ...)`, whereas in the False case, it will be
+ `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
+ efficient because it avoids transposes at the beginning and end of the
+ RNN calculation. However, most TensorFlow data is batch-major, so by
+ default this function accepts input and emits output in batch-major
+ form.
Input shape:
- N-D tensor with shape `(batch_size, timesteps, ...)`.
+ N-D tensor with shape `(batch_size, timesteps, ...)` or
+ `(timesteps, batch_size, ...)` when time_major is True.
Output shape:
- if `return_state`: a list of tensors. The first tensor is
@@ -347,7 +356,8 @@ class RNN(Layer):
be a high dimension tensor shape.
- if `return_sequences`: N-D tensor with shape
`(batch_size, timesteps, output_size)`, where `output_size` could
- be a high dimension tensor shape.
+ be a high dimension tensor shape, or
+ `(timesteps, batch_size, output_size)` when `time_major` is True.
- else, N-D tensor with shape `(batch_size, output_size)`, where
`output_size` could be a high dimension tensor shape.
@@ -448,6 +458,7 @@ class RNN(Layer):
go_backwards=False,
stateful=False,
unroll=False,
+ time_major=False,
**kwargs):
if isinstance(cell, (list, tuple)):
cell = StackedRNNCells(cell)
@@ -468,6 +479,7 @@ class RNN(Layer):
self.go_backwards = go_backwards
self.stateful = stateful
self.unroll = unroll
+ self.time_major = time_major
self.supports_masking = True
self.input_spec = [None] # The input shape is unknown yet, at least rank 3.
@@ -503,14 +515,21 @@ class RNN(Layer):
# Note that state_size[0] could be a tensor_shape or int.
output_dim = tensor_shape.as_shape(state_size[0]).as_list()
+ batch = input_shape[0]
+ time_step = input_shape[1]
+ if self.time_major:
+ batch, time_step = time_step, batch
if self.return_sequences:
- output_shape = tuple([input_shape[0], input_shape[1]] + output_dim)
+ if self.time_major:
+ output_shape = tuple([time_step, batch] + output_dim)
+ else:
+ output_shape = tuple([batch, time_step] + output_dim)
else:
- output_shape = tuple([input_shape[0]] + output_dim)
+ output_shape = tuple([batch] + output_dim)
if self.return_state:
state_shape = [
- tuple([input_shape[0]] + tensor_shape.as_shape(dim).as_list())
+ tuple([batch] + tensor_shape.as_shape(dim).as_list())
for dim in state_size
]
return [output_shape] + state_shape
@@ -539,13 +558,18 @@ class RNN(Layer):
if isinstance(input_shape, list):
input_shape = input_shape[0]
- batch_size = input_shape[0] if self.stateful else None
- input_dim = input_shape[2:]
- self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_dim)
+ input_spec_shape = list(input_shape)
+ batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
+ if not self.stateful:
+ input_spec_shape[batch_index] = None
+ input_spec_shape[time_step_index] = None
+ self.input_spec[0] = InputSpec(shape=tuple(input_spec_shape))
+ batch = input_shape[batch_index]
+ input_dim = input_shape[2:]
+ step_input_shape = (batch,) + input_dim
# allow cell (if layer) to build before we set or validate state_spec
if isinstance(self.cell, Layer):
- step_input_shape = (input_shape[0],) + input_dim
if constants_shape is not None:
self.cell.build([step_input_shape] + constants_shape)
else:
@@ -598,12 +622,16 @@ class RNN(Layer):
def get_initial_state(self, inputs):
get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
+
+ input_shape = array_ops.shape(inputs)
+ batch_size = input_shape[1] if self.time_major else input_shape[0]
+ dtype = inputs.dtype
if get_initial_state_fn:
init_state = get_initial_state_fn(
- inputs=inputs, batch_size=None, dtype=None)
+ inputs=None, batch_size=batch_size, dtype=dtype)
else:
- init_state = _generate_zero_filled_state(
- array_ops.shape(inputs)[0], self.cell.state_size, inputs.dtype)
+ init_state = _generate_zero_filled_state(batch_size, self.cell.state_size,
+ dtype)
# Keras RNN expect the states in a list, even if it's a single state tensor.
if not nest.is_sequence(init_state):
init_state = [init_state]
@@ -696,7 +724,7 @@ class RNN(Layer):
'Layer has ' + str(len(self.states)) + ' states but was passed ' +
str(len(initial_state)) + ' initial states.')
input_shape = K.int_shape(inputs)
- timesteps = input_shape[1]
+ timesteps = input_shape[0] if self.time_major else input_shape[1]
if self.unroll and timesteps in [None, 1]:
raise ValueError('Cannot unroll a RNN if the '
'time dimension is undefined or equal to 1. \n'
@@ -747,7 +775,8 @@ class RNN(Layer):
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
- input_length=timesteps)
+ input_length=timesteps,
+ time_major=self.time_major)
if self.stateful:
updates = []
for i in range(len(states)):
@@ -777,7 +806,10 @@ class RNN(Layer):
def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
- batch_size = self.input_spec[0].shape[0]
+ if self.time_major:
+ batch_size = self.input_spec[0].shape[1]
+ else:
+ batch_size = self.input_spec[0].shape[0]
if not batch_size:
raise ValueError('If a RNN is stateful, it needs to know '
'its batch size. Specify the batch size '
@@ -839,7 +871,8 @@ class RNN(Layer):
'return_state': self.return_state,
'go_backwards': self.go_backwards,
'stateful': self.stateful,
- 'unroll': self.unroll
+ 'unroll': self.unroll,
+ 'time_major': self.time_major
}
if self._num_constants is not None:
config['num_constants'] = self._num_constants
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index b9e90095e4..d246be6b45 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -186,6 +186,96 @@ class RNNTest(test.TestCase):
y_np_2 = model.predict(x_np)
self.assertAllClose(y_np, y_np_2, atol=1e-4)
+ def test_rnn_with_time_major(self):
+ batch = 10
+ time_step = 5
+ embedding_dim = 4
+ units = 3
+
+ with self.cached_session():
+ # Test basic case.
+ x = keras.Input((time_step, embedding_dim))
+ time_major_x = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(x)
+ layer = keras.layers.SimpleRNN(
+ units, time_major=True, return_sequences=True)
+ self.assertEqual(
+ layer.compute_output_shape((time_step, None,
+ embedding_dim)).as_list(),
+ [time_step, None, units])
+ y = layer(time_major_x)
+ self.assertEqual(layer.output_shape, (time_step, None, units))
+
+ y = keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2]))(y)
+
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ np.zeros((batch, time_step, embedding_dim)),
+ np.zeros((batch, time_step, units)))
+
+ with self.cached_session():
+ # Test stacking.
+ x = keras.Input((time_step, embedding_dim))
+ time_major_x = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(x)
+ cell_units = [10, 8, 6]
+ cells = [keras.layers.SimpleRNNCell(cell_units[i]) for i in range(3)]
+ layer = keras.layers.RNN(cells, time_major=True, return_sequences=True)
+ y = layer(time_major_x)
+ self.assertEqual(layer.output_shape, (time_step, None, cell_units[-1]))
+
+ y = keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2]))(y)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ np.zeros((batch, time_step, embedding_dim)),
+ np.zeros((batch, time_step, cell_units[-1])))
+
+ with self.cached_session():
+ # Test masking.
+ x = keras.Input((time_step, embedding_dim))
+ time_major = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(x)
+ mask = keras.layers.Masking()(time_major)
+ rnn = keras.layers.SimpleRNN(
+ units, time_major=True, return_sequences=True)(mask)
+ y = keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2]))(rnn)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ np.zeros((batch, time_step, embedding_dim)),
+ np.zeros((batch, time_step, units)))
+
+ with self.cached_session():
+ # Test layer output
+ x = keras.Input((time_step, embedding_dim))
+ rnn_1 = keras.layers.SimpleRNN(units, return_sequences=True)
+ y = rnn_1(x)
+
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ np.zeros((batch, time_step, embedding_dim)),
+ np.zeros((batch, time_step, units)))
+
+ x_np = np.random.random((batch, time_step, embedding_dim))
+ y_np_1 = model.predict(x_np)
+
+ time_major = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(x)
+ rnn_2 = keras.layers.SimpleRNN(
+ units, time_major=True, return_sequences=True)
+ y_2 = rnn_2(time_major)
+ y_2 = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(y_2)
+
+ model_2 = keras.models.Model(x, y_2)
+ rnn_2.set_weights(rnn_1.get_weights())
+
+ y_np_2 = model_2.predict(x_np)
+ self.assertAllClose(y_np_1, y_np_2, atol=1e-4)
+
def test_rnn_cell_with_constants_layer(self):
class RNNCellWithConstants(keras.layers.Layer):
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt
index 126ce8db6a..a71a59e269 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt
@@ -398,7 +398,7 @@ tf_module {
}
member_method {
name: "rnn"
- argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\'], "
+ argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "round"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt
index 2b6e8af11d..68b6678d48 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -86,7 +86,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\'], "
+ argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'time_major\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
index 126ce8db6a..a71a59e269 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
@@ -398,7 +398,7 @@ tf_module {
}
member_method {
name: "rnn"
- argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\'], "
+ argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "round"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt
index 2b6e8af11d..68b6678d48 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -86,7 +86,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\'], "
+ argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'time_major\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\'], "
}
member_method {
name: "add_loss"