aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Scott Zhu <scottzhu@google.com>2018-08-20 16:12:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 16:22:16 -0700
commitfc4504edb1ab419ae59b0ebb9ff8d943beb61117 (patch)
treefbbe6c131f46ace2a084fac3280e858800135723
parent65b9ed5a83319830db02504d4c69e98bd07665b6 (diff)
Unify RNN Cell interface between TF and Keras.
PiperOrigin-RevId: 209503416
-rw-r--r--tensorflow/python/keras/layers/recurrent.py103
-rw-r--r--tensorflow/python/keras/layers/recurrent_test.py17
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py262
-rw-r--r--tensorflow/python/ops/rnn.py21
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py34
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt4
31 files changed, 509 insertions, 32 deletions
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 12c82a53f6..65171acfb6 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -33,7 +33,6 @@ from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -95,11 +94,27 @@ class StackedRNNCells(Layer):
@property
def output_size(self):
- if hasattr(self.cells[-1], 'output_size'):
+ if getattr(self.cells[-1], 'output_size', None) is not None:
return self.cells[-1].output_size
else:
return self.state_size[0]
+ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
+ # The init state is in reverse order of cell's initial state since the
+ # state_size is in reverse order. It is flattened into a list also because
+ # the state_size is a flattened list.
+ initial_states = []
+ for cell in self.cells[::-1]:
+ get_initial_state_fn = getattr(cell, 'get_initial_state', None)
+ if get_initial_state_fn:
+ initial_states.append(get_initial_state_fn(
+ inputs=inputs, batch_size=batch_size, dtype=dtype))
+ else:
+ initial_states.append(_generate_zero_filled_state_for_cell(
+ cell, inputs, batch_size, dtype))
+
+ return nest.flatten(initial_states)
+
def call(self, inputs, states, constants=None, **kwargs):
# Recover per-cell states.
nested_states = []
@@ -261,6 +276,22 @@ class RNN(Layer):
compatible reason, if this attribute is not available for the
cell, the value will be inferred by the first element of the
`state_size`.
+ - a `get_initial_state(inputs=None, batch_size=None, dtype=None)`
+ method that creates a tensor meant to be fed to `call()` as the
+ initial state, if user didn't specify any initial state via other
+ means. The returned initial state should be in shape of
+ [batch, cell.state_size]. Cell might choose to create zero filled
+ tensor, or with other values based on the cell implementations.
+ `inputs` is the input tensor to the RNN layer, which should
+ contain the batch size as its shape[0], and also dtype. Note that
+ the shape[0] might be None during the graph construction. Either
+ the `inputs` or the pair of `batch` and `dtype `are provided.
+ `batch` is a scalar tensor that represent the batch size
+ of the input. `dtype` is `tf.dtype` that represent the dtype of
+ the input.
+ For backward compatible reason, if this method is not implemented
+ by the cell, RNN layer will create a zero filled tensors with the
+ size of [batch, cell.state_size].
In the case that `cell` is a list of RNN cell instances, the cells
will be stacked on after the other in the RNN, implementing an
efficient stacked RNN.
@@ -453,7 +484,7 @@ class RNN(Layer):
else:
state_size = [self.cell.state_size]
- if hasattr(self.cell, 'output_size'):
+ if getattr(self.cell, 'output_size', None) is not None:
output_dim = tensor_shape.as_shape(self.cell.output_size).as_list()
else:
# Note that state_size[0] could be a tensor_shape or int.
@@ -553,26 +584,18 @@ class RNN(Layer):
raise validation_error
def get_initial_state(self, inputs):
- # build an all-zero tensor of shape (batch, cell.state_size)
- initial_state = array_ops.zeros_like(inputs)
- # shape of initial_state = (batch, timesteps, ...)
- initial_state = math_ops.reduce_sum(
- initial_state, axis=list(range(1, len(inputs.shape))))
- # shape of initial_state = (batch,)
- if _is_multiple_state(self.cell.state_size):
- states = []
- for dims in self.cell.state_size:
- state = initial_state
- flat_dims = tensor_shape.as_shape(dims).as_list()
- # reshape the state to (batch, 1, 1, ....) and then expand each state.
- state = array_ops.reshape(state, [-1,] + [1] * len(flat_dims))
- states.append(K.tile(state, [1] + flat_dims))
- return states
+ get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
+ if get_initial_state_fn:
+ init_state = get_initial_state_fn(
+ inputs=inputs, batch_size=None, dtype=None)
else:
- flat_dims = tensor_shape.as_shape(self.cell.state_size).as_list()
- initial_state = array_ops.reshape(
- initial_state, [-1] + [1] * len(flat_dims))
- return [K.tile(initial_state, [1] + flat_dims)]
+ init_state = _generate_zero_filled_state(
+ array_ops.shape(inputs)[0], self.cell.state_size, inputs.dtype)
+ # Keras RNN expect the states in a list, even if it's a single state tensor.
+ if not nest.is_sequence(init_state):
+ init_state = [init_state]
+ # Force the state to be a list in case it is a namedtuple eg LSTMStateTuple.
+ return list(init_state)
def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
inputs, initial_state, constants = _standardize_args(inputs,
@@ -986,6 +1009,9 @@ class SimpleRNNCell(Layer):
output._uses_learning_phase = True
return output, [output]
+ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
+ return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
+
def get_config(self):
config = {
'units':
@@ -1517,6 +1543,9 @@ class GRUCell(Layer):
base_config = super(GRUCell, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
+ return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
+
@tf_export('keras.layers.GRU')
class GRU(RNN):
@@ -2042,6 +2071,9 @@ class LSTMCell(Layer):
base_config = super(LSTMCell, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
+ return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
+
@tf_export('keras.layers.LSTM')
class LSTM(RNN):
@@ -2354,3 +2386,30 @@ def _is_multiple_state(state_size):
"""Check whether the state_size contains multiple states."""
return (hasattr(state_size, '__len__') and
not isinstance(state_size, tensor_shape.TensorShape))
+
+
+def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype):
+ if inputs is not None:
+ batch_size = array_ops.shape(inputs)[0]
+ dtype = inputs.dtype
+ return _generate_zero_filled_state(batch_size, cell.state_size, dtype)
+
+
+def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
+ """Generate a zero filled tensor with shape [batch_size, state_size]."""
+ if None in [batch_size_tensor, dtype]:
+ raise ValueError(
+ 'batch_size and dtype cannot be None while constructing initial state: '
+ 'batch_size={}, dtype={}'.format(batch_size_tensor, dtype))
+ if _is_multiple_state(state_size):
+ states = []
+ for dims in state_size:
+ flat_dims = tensor_shape.as_shape(dims).as_list()
+ init_state_size = [batch_size_tensor] + flat_dims
+ init_state = array_ops.zeros(init_state_size, dtype=dtype)
+ states.append(init_state)
+ return states
+ else:
+ flat_dims = tensor_shape.as_shape(state_size).as_list()
+ init_state_size = [batch_size_tensor] + flat_dims
+ return array_ops.zeros(init_state_size, dtype=dtype)
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index 13bd070528..f14b36e7e1 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -678,6 +678,23 @@ class RNNTest(test.TestCase):
np.zeros((batch, input_size)))
self.assertEqual(model.output_shape, (None, input_size))
+ def test_get_initial_state(self):
+ cell = keras.layers.SimpleRNNCell(5)
+ with self.assertRaisesRegexp(ValueError,
+ 'batch_size and dtype cannot be None'):
+ cell.get_initial_state(None, None, None)
+
+ inputs = keras.Input((None, 2, 10))
+ initial_state = cell.get_initial_state(inputs, None, None)
+ self.assertEqual(initial_state.shape.as_list(), [None, 5])
+ self.assertEqual(initial_state.dtype, inputs.dtype)
+
+ batch = array_ops.shape(inputs)[0]
+ dtype = inputs.dtype
+ initial_state = cell.get_initial_state(None, batch, dtype)
+ self.assertEqual(initial_state.shape.as_list(), [None, 5])
+ self.assertEqual(initial_state.dtype, inputs.dtype)
+
class Minimal2DRNNCell(keras.layers.Layer):
"""The minimal 2D RNN cell is a simple combination of 2 1-D RNN cell.
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index c72ada11da..c4f200a22e 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
+from tensorflow.python.keras import testing_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
@@ -44,11 +45,13 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables as variables_lib
import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import
+from tensorflow.python.ops.losses import losses
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import
import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
from tensorflow.python.training import saver
+from tensorflow.python.training import training
class Plus1RNNCell(rnn_cell_impl.RNNCell):
@@ -250,12 +253,44 @@ class RNNTest(test.TestCase):
self.assertAllEqual(4, state[0])
self.assertAllEqual([[[1]], [[2]], [[3]], [[4]]], state[1])
+ def testCellGetInitialState(self):
+ cell = rnn_cell_impl.BasicRNNCell(5)
+ with self.assertRaisesRegexp(
+ ValueError, "batch_size and dtype cannot be None"):
+ cell.get_initial_state(None, None, None)
+
+ inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 1))
+ with self.assertRaisesRegexp(
+ ValueError, "batch size from input tensor is different from"):
+ cell.get_initial_state(inputs=inputs, batch_size=50, dtype=None)
+
+ with self.assertRaisesRegexp(
+ ValueError, "batch size from input tensor is different from"):
+ cell.get_initial_state(
+ inputs=inputs, batch_size=constant_op.constant(50), dtype=None)
+
+ with self.assertRaisesRegexp(
+ ValueError, "dtype from input tensor is different from"):
+ cell.get_initial_state(inputs=inputs, batch_size=None, dtype=dtypes.int16)
+
+ initial_state = cell.get_initial_state(
+ inputs=inputs, batch_size=None, dtype=None)
+ self.assertEqual(initial_state.shape.as_list(), [None, 5])
+ self.assertEqual(initial_state.dtype, inputs.dtype)
+
+ batch = array_ops.shape(inputs)[0]
+ dtype = inputs.dtype
+ initial_state = cell.get_initial_state(None, batch, dtype)
+ self.assertEqual(initial_state.shape.as_list(), [None, 5])
+ self.assertEqual(initial_state.dtype, inputs.dtype)
+
def _assert_cell_builds(self, cell_class, dtype, batch_size, in_size,
out_size):
cell = cell_class(out_size, dtype=dtype)
in_shape = tensor_shape.TensorShape((batch_size, in_size))
cell.build(in_shape)
- state_output = cell.zero_state(batch_size, dtype)
+ state_output = cell.get_initial_state(
+ inputs=None, batch_size=batch_size, dtype=dtype)
cell_output, _ = cell(array_ops.zeros(in_shape, dtype), state_output)
self.assertAllEqual([batch_size, out_size], cell_output.shape.as_list())
@@ -278,12 +313,228 @@ class RNNTest(test.TestCase):
self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f32, 5, 7, 3)
self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f64, 5, 7, 3)
+ def testRNNWithKerasSimpleRNNCell(self):
+ with self.test_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ y_train = keras.utils.to_categorical(y_train)
+ cell = keras.layers.SimpleRNNCell(output_shape)
+
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
+ self.assertEqual(state.shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state)
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables_lib.global_variables_initializer()])
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], {inputs: x_train, predict: y_train})
+
+ self.assertEqual(len(outputs), batch)
+ self.assertEqual(len(state), batch)
+
+ def testRNNWithKerasGRUCell(self):
+ with self.test_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ y_train = keras.utils.to_categorical(y_train)
+ cell = keras.layers.GRUCell(output_shape)
+
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
+ self.assertEqual(state.shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state)
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables_lib.global_variables_initializer()])
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], {inputs: x_train, predict: y_train})
+
+ self.assertEqual(len(outputs), batch)
+ self.assertEqual(len(state), batch)
+
+ def testRNNWithKerasLSTMCell(self):
+ with self.test_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ y_train = keras.utils.to_categorical(y_train)
+ cell = keras.layers.LSTMCell(output_shape)
+
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
+ self.assertEqual(len(state), 2)
+ self.assertEqual(state[0].shape.as_list(), [None, output_shape])
+ self.assertEqual(state[1].shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state[0])
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables_lib.global_variables_initializer()])
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], {inputs: x_train, predict: y_train})
+
+ self.assertEqual(len(outputs), batch)
+ self.assertEqual(len(state), 2)
+ self.assertEqual(len(state[0]), batch)
+ self.assertEqual(len(state[1]), batch)
+
+ def testRNNWithStackKerasCell(self):
+ with self.test_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ y_train = keras.utils.to_categorical(y_train)
+ cell = keras.layers.StackedRNNCells(
+ [keras.layers.LSTMCell(2 * output_shape),
+ keras.layers.LSTMCell(output_shape)])
+
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
+ self.assertEqual(len(state), 4)
+ self.assertEqual(state[0].shape.as_list(), [None, output_shape])
+ self.assertEqual(state[1].shape.as_list(), [None, output_shape])
+ self.assertEqual(state[2].shape.as_list(), [None, 2 * output_shape])
+ self.assertEqual(state[3].shape.as_list(), [None, 2 * output_shape])
+ loss = losses.softmax_cross_entropy(predict, state[0])
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables_lib.global_variables_initializer()])
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], {inputs: x_train, predict: y_train})
+
+ self.assertEqual(len(outputs), batch)
+ self.assertEqual(len(state), 4)
+ for s in state:
+ self.assertEqual(len(s), batch)
+
+ def testStaticRNNWithKerasSimpleRNNCell(self):
+ with self.test_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ x_train = np.transpose(x_train, (1, 0, 2))
+ y_train = keras.utils.to_categorical(y_train)
+ cell = keras.layers.SimpleRNNCell(output_shape)
+
+ inputs = [array_ops.placeholder(
+ dtypes.float32, shape=(None, input_shape))] * timestep
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.static_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(len(outputs), timestep)
+ self.assertEqual(outputs[0].shape.as_list(), [None, output_shape])
+ self.assertEqual(state.shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state)
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables_lib.global_variables_initializer()])
+ feed_dict = {i: d for i, d in zip(inputs, x_train)}
+ feed_dict[predict] = y_train
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], feed_dict)
+
+ self.assertEqual(len(outputs), timestep)
+ self.assertEqual(len(outputs[0]), batch)
+ self.assertEqual(len(state), batch)
+
+ def testKerasAndTFRNNLayerOutputComparison(self):
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 20
+ (x_train, _), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ fix_weights_generator = keras.layers.SimpleRNNCell(output_shape)
+ fix_weights_generator.build((None, input_shape))
+ weights = fix_weights_generator.get_weights()
+
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ cell = keras.layers.SimpleRNNCell(output_shape)
+ tf_out, tf_state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ cell.set_weights(weights)
+ [tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train})
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ k_input = keras.Input(shape=(timestep, input_shape),
+ dtype=dtypes.float32)
+ cell = keras.layers.SimpleRNNCell(output_shape)
+ layer = keras.layers.RNN(cell, return_sequences=True, return_state=True)
+ keras_out = layer(k_input)
+ cell.set_weights(weights)
+ k_out, k_state = sess.run(keras_out, {k_input: x_train})
+ self.assertAllClose(tf_out, k_out)
+ self.assertAllClose(tf_state, k_state)
+
def testBasicLSTMCellInterchangeWithLSTMCell(self):
with self.test_session(graph=ops_lib.Graph()) as sess:
basic_cell = rnn_cell_impl.BasicLSTMCell(1)
basic_cell(array_ops.ones([1, 1]),
- state=basic_cell.zero_state(batch_size=1,
- dtype=dtypes.float32))
+ state=basic_cell.get_initial_state(inputs=None,
+ batch_size=1,
+ dtype=dtypes.float32))
self.evaluate([v.initializer for v in basic_cell.variables])
self.evaluate(basic_cell._bias.assign([10.] * 4))
save = saver.Saver()
@@ -293,8 +544,9 @@ class RNNTest(test.TestCase):
with self.test_session(graph=ops_lib.Graph()) as sess:
lstm_cell = rnn_cell_impl.LSTMCell(1, name="basic_lstm_cell")
lstm_cell(array_ops.ones([1, 1]),
- state=lstm_cell.zero_state(batch_size=1,
- dtype=dtypes.float32))
+ state=lstm_cell.get_initial_state(inputs=None,
+ batch_size=1,
+ dtype=dtypes.float32))
self.evaluate([v.initializer for v in lstm_cell.variables])
save = saver.Saver()
save.restore(sess, save_path)
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 7b6ab20975..38336b64db 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -608,7 +608,8 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
else:
if not dtype:
raise ValueError("If there is no initial_state, you must give a dtype.")
- state = cell.zero_state(batch_size, dtype)
+ state = cell.get_initial_state(
+ inputs=None, batch_size=batch_size, dtype=dtype)
def _assert_has_shape(x, shape):
x_shape = array_ops.shape(x)
@@ -788,6 +789,10 @@ def _dynamic_rnn_loop(cell,
input_t = tuple(ta[time.numpy()] for ta in input_ta)
input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)
+ # Keras RNN cells only accept state as list, even if it's a single tensor.
+ is_keras_rnn_cell = not isinstance(cell, rnn_cell_impl.RNNCell)
+ if is_keras_rnn_cell and not nest.is_sequence(state):
+ state = [state]
call_cell = lambda: cell(input_t, state)
if sequence_length is not None:
@@ -804,6 +809,9 @@ def _dynamic_rnn_loop(cell,
else:
(output, new_state) = call_cell()
+ # Keras cells always wrap state as list, even if it's a single tensor.
+ if is_keras_rnn_cell and len(new_state) == 1:
+ new_state = new_state[0]
# Pack state if using state tuples
output = nest.flatten(output)
@@ -1286,7 +1294,8 @@ def static_rnn(cell,
if not dtype:
raise ValueError("If no initial_state is provided, "
"dtype must be specified")
- state = cell.zero_state(batch_size, dtype)
+ state = cell.get_initial_state(
+ inputs=None, batch_size=batch_size, dtype=dtype)
if sequence_length is not None: # Prepare variables
sequence_length = ops.convert_to_tensor(
@@ -1315,6 +1324,10 @@ def static_rnn(cell,
min_sequence_length = math_ops.reduce_min(sequence_length)
max_sequence_length = math_ops.reduce_max(sequence_length)
+ # Keras RNN cells only accept state as list, even if it's a single tensor.
+ is_keras_rnn_cell = not isinstance(cell, rnn_cell_impl.RNNCell)
+ if is_keras_rnn_cell and not nest.is_sequence(state):
+ state = [state]
for time, input_ in enumerate(inputs):
if time > 0:
varscope.reuse_variables()
@@ -1333,8 +1346,10 @@ def static_rnn(cell,
state_size=cell.state_size)
else:
(output, state) = call_cell()
-
outputs.append(output)
+ # Keras RNN cells only return state as list, even if it's a single tensor.
+ if is_keras_rnn_cell and len(state) == 1:
+ state = state[0]
return (outputs, state)
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 85a6a2233c..dcc9c6a4d2 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -80,13 +80,13 @@ def assert_like_rnncell(cell_name, cell):
conditions = [
hasattr(cell, "output_size"),
hasattr(cell, "state_size"),
- hasattr(cell, "zero_state"),
+ hasattr(cell, "get_initial_state"),
callable(cell),
]
errors = [
"'output_size' property is missing",
"'state_size' property is missing",
- "'zero_state' method is missing",
+ "'get_initial_state' method is missing",
"is not callable"
]
@@ -266,6 +266,36 @@ class RNNCell(base_layer.Layer):
# self.add_variable() inside the call() method.
pass
+ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
+ if inputs is not None:
+ # Validate the given batch_size and dtype against inputs if provided.
+ inputs = ops.convert_to_tensor(inputs, name="inputs")
+ if batch_size is not None:
+ if tensor_util.is_tensor(batch_size):
+ static_batch_size = tensor_util.constant_value(
+ batch_size, partial=True)
+ else:
+ static_batch_size = batch_size
+ if inputs.shape[0].value != static_batch_size:
+ raise ValueError(
+ "batch size from input tensor is different from the "
+ "input param. Input tensor batch: {}, batch_size: {}".format(
+ inputs.shape[0].value, batch_size))
+
+ if dtype is not None and inputs.dtype != dtype:
+ raise ValueError(
+ "dtype from input tensor is different from the "
+ "input param. Input tensor dtype: {}, dtype: {}".format(
+ inputs.dtype, dtype))
+
+ batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
+ dtype = inputs.dtype
+ if None in [batch_size, dtype]:
+ raise ValueError(
+ "batch_size and dtype cannot be None while constructing initial "
+ "state: batch_size={}, dtype={}".format(batch_size, dtype))
+ return self.zero_state(batch_size, dtype)
+
def zero_state(self, batch_size, dtype):
"""Return zero-filled state tensor(s).
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt
index 4ba21a25cd..3b4d703aea 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -133,6 +133,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
index 511456e740..9fe35571b9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
@@ -133,6 +133,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
index 9dfda96fc8..82b6d2015b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
@@ -133,6 +133,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
index 6718e36dc6..671fc66db0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -141,6 +141,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
index e606eab919..88b8f37c4f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
index 5deb02d569..a4483fefa2 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
index 8a63b49180..381c4975d7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
@@ -151,6 +151,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
index db1aae2757..912365a28b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
@@ -155,6 +155,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
index 32fa151a8e..a4bb3219c7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
index 30c6c2ce3b..715bfd5fc7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
index 72b40cc9f7..b66c0f89cc 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
@@ -151,6 +151,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
index a5c2b4aefd..faeb4f3513 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
@@ -150,6 +150,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
index 61d5f04b22..caa2e60080 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
@@ -151,6 +151,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt
index 4ba21a25cd..3b4d703aea 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -133,6 +133,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
index 511456e740..9fe35571b9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
@@ -133,6 +133,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
index 9dfda96fc8..82b6d2015b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
@@ -133,6 +133,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
index 6718e36dc6..671fc66db0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -141,6 +141,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
index e606eab919..88b8f37c4f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
index 5deb02d569..a4483fefa2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
index 8a63b49180..381c4975d7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
@@ -151,6 +151,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
index db1aae2757..912365a28b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
@@ -155,6 +155,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
index 32fa151a8e..a4bb3219c7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
index 30c6c2ce3b..715bfd5fc7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
index 72b40cc9f7..b66c0f89cc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
@@ -151,6 +151,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
index a5c2b4aefd..faeb4f3513 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
@@ -150,6 +150,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
index 61d5f04b22..caa2e60080 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
@@ -151,6 +151,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}