diff options
author | 2018-08-20 16:12:40 -0700 | |
---|---|---|
committer | 2018-08-20 16:22:16 -0700 | |
commit | fc4504edb1ab419ae59b0ebb9ff8d943beb61117 (patch) | |
tree | fbbe6c131f46ace2a084fac3280e858800135723 | |
parent | 65b9ed5a83319830db02504d4c69e98bd07665b6 (diff) |
Unify RNN Cell interface between TF and Keras.
PiperOrigin-RevId: 209503416
31 files changed, 509 insertions, 32 deletions
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 12c82a53f6..65171acfb6 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -33,7 +33,6 @@ from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable @@ -95,11 +94,27 @@ class StackedRNNCells(Layer): @property def output_size(self): - if hasattr(self.cells[-1], 'output_size'): + if getattr(self.cells[-1], 'output_size', None) is not None: return self.cells[-1].output_size else: return self.state_size[0] + def get_initial_state(self, inputs=None, batch_size=None, dtype=None): + # The init state is in reverse order of cell's initial state since the + # state_size is in reverse order. It is flattened into a list also because + # the state_size is a flattened list. + initial_states = [] + for cell in self.cells[::-1]: + get_initial_state_fn = getattr(cell, 'get_initial_state', None) + if get_initial_state_fn: + initial_states.append(get_initial_state_fn( + inputs=inputs, batch_size=batch_size, dtype=dtype)) + else: + initial_states.append(_generate_zero_filled_state_for_cell( + cell, inputs, batch_size, dtype)) + + return nest.flatten(initial_states) + def call(self, inputs, states, constants=None, **kwargs): # Recover per-cell states. nested_states = [] @@ -261,6 +276,22 @@ class RNN(Layer): compatible reason, if this attribute is not available for the cell, the value will be inferred by the first element of the `state_size`. + - a `get_initial_state(inputs=None, batch_size=None, dtype=None)` + method that creates a tensor meant to be fed to `call()` as the + initial state, if user didn't specify any initial state via other + means. The returned initial state should be in shape of + [batch, cell.state_size]. Cell might choose to create zero filled + tensor, or with other values based on the cell implementations. + `inputs` is the input tensor to the RNN layer, which should + contain the batch size as its shape[0], and also dtype. Note that + the shape[0] might be None during the graph construction. Either + the `inputs` or the pair of `batch` and `dtype `are provided. + `batch` is a scalar tensor that represent the batch size + of the input. `dtype` is `tf.dtype` that represent the dtype of + the input. + For backward compatible reason, if this method is not implemented + by the cell, RNN layer will create a zero filled tensors with the + size of [batch, cell.state_size]. In the case that `cell` is a list of RNN cell instances, the cells will be stacked on after the other in the RNN, implementing an efficient stacked RNN. @@ -453,7 +484,7 @@ class RNN(Layer): else: state_size = [self.cell.state_size] - if hasattr(self.cell, 'output_size'): + if getattr(self.cell, 'output_size', None) is not None: output_dim = tensor_shape.as_shape(self.cell.output_size).as_list() else: # Note that state_size[0] could be a tensor_shape or int. @@ -553,26 +584,18 @@ class RNN(Layer): raise validation_error def get_initial_state(self, inputs): - # build an all-zero tensor of shape (batch, cell.state_size) - initial_state = array_ops.zeros_like(inputs) - # shape of initial_state = (batch, timesteps, ...) - initial_state = math_ops.reduce_sum( - initial_state, axis=list(range(1, len(inputs.shape)))) - # shape of initial_state = (batch,) - if _is_multiple_state(self.cell.state_size): - states = [] - for dims in self.cell.state_size: - state = initial_state - flat_dims = tensor_shape.as_shape(dims).as_list() - # reshape the state to (batch, 1, 1, ....) and then expand each state. - state = array_ops.reshape(state, [-1,] + [1] * len(flat_dims)) - states.append(K.tile(state, [1] + flat_dims)) - return states + get_initial_state_fn = getattr(self.cell, 'get_initial_state', None) + if get_initial_state_fn: + init_state = get_initial_state_fn( + inputs=inputs, batch_size=None, dtype=None) else: - flat_dims = tensor_shape.as_shape(self.cell.state_size).as_list() - initial_state = array_ops.reshape( - initial_state, [-1] + [1] * len(flat_dims)) - return [K.tile(initial_state, [1] + flat_dims)] + init_state = _generate_zero_filled_state( + array_ops.shape(inputs)[0], self.cell.state_size, inputs.dtype) + # Keras RNN expect the states in a list, even if it's a single state tensor. + if not nest.is_sequence(init_state): + init_state = [init_state] + # Force the state to be a list in case it is a namedtuple eg LSTMStateTuple. + return list(init_state) def __call__(self, inputs, initial_state=None, constants=None, **kwargs): inputs, initial_state, constants = _standardize_args(inputs, @@ -986,6 +1009,9 @@ class SimpleRNNCell(Layer): output._uses_learning_phase = True return output, [output] + def get_initial_state(self, inputs=None, batch_size=None, dtype=None): + return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) + def get_config(self): config = { 'units': @@ -1517,6 +1543,9 @@ class GRUCell(Layer): base_config = super(GRUCell, self).get_config() return dict(list(base_config.items()) + list(config.items())) + def get_initial_state(self, inputs=None, batch_size=None, dtype=None): + return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) + @tf_export('keras.layers.GRU') class GRU(RNN): @@ -2042,6 +2071,9 @@ class LSTMCell(Layer): base_config = super(LSTMCell, self).get_config() return dict(list(base_config.items()) + list(config.items())) + def get_initial_state(self, inputs=None, batch_size=None, dtype=None): + return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) + @tf_export('keras.layers.LSTM') class LSTM(RNN): @@ -2354,3 +2386,30 @@ def _is_multiple_state(state_size): """Check whether the state_size contains multiple states.""" return (hasattr(state_size, '__len__') and not isinstance(state_size, tensor_shape.TensorShape)) + + +def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype): + if inputs is not None: + batch_size = array_ops.shape(inputs)[0] + dtype = inputs.dtype + return _generate_zero_filled_state(batch_size, cell.state_size, dtype) + + +def _generate_zero_filled_state(batch_size_tensor, state_size, dtype): + """Generate a zero filled tensor with shape [batch_size, state_size].""" + if None in [batch_size_tensor, dtype]: + raise ValueError( + 'batch_size and dtype cannot be None while constructing initial state: ' + 'batch_size={}, dtype={}'.format(batch_size_tensor, dtype)) + if _is_multiple_state(state_size): + states = [] + for dims in state_size: + flat_dims = tensor_shape.as_shape(dims).as_list() + init_state_size = [batch_size_tensor] + flat_dims + init_state = array_ops.zeros(init_state_size, dtype=dtype) + states.append(init_state) + return states + else: + flat_dims = tensor_shape.as_shape(state_size).as_list() + init_state_size = [batch_size_tensor] + flat_dims + return array_ops.zeros(init_state_size, dtype=dtype) diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py index 13bd070528..f14b36e7e1 100644 --- a/tensorflow/python/keras/layers/recurrent_test.py +++ b/tensorflow/python/keras/layers/recurrent_test.py @@ -678,6 +678,23 @@ class RNNTest(test.TestCase): np.zeros((batch, input_size))) self.assertEqual(model.output_shape, (None, input_size)) + def test_get_initial_state(self): + cell = keras.layers.SimpleRNNCell(5) + with self.assertRaisesRegexp(ValueError, + 'batch_size and dtype cannot be None'): + cell.get_initial_state(None, None, None) + + inputs = keras.Input((None, 2, 10)) + initial_state = cell.get_initial_state(inputs, None, None) + self.assertEqual(initial_state.shape.as_list(), [None, 5]) + self.assertEqual(initial_state.dtype, inputs.dtype) + + batch = array_ops.shape(inputs)[0] + dtype = inputs.dtype + initial_state = cell.get_initial_state(None, batch, dtype) + self.assertEqual(initial_state.shape.as_list(), [None, 5]) + self.assertEqual(initial_state.dtype, inputs.dtype) + class Minimal2DRNNCell(keras.layers.Layer): """The minimal 2D RNN cell is a simple combination of 2 1-D RNN cell. diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index c72ada11da..c4f200a22e 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -35,6 +35,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util +from tensorflow.python.keras import testing_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl @@ -44,11 +45,13 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables as variables_lib import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import +from tensorflow.python.ops.losses import losses import tensorflow.python.ops.nn_grad # pylint: disable=unused-import import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import from tensorflow.python.platform import test from tensorflow.python.training import saver +from tensorflow.python.training import training class Plus1RNNCell(rnn_cell_impl.RNNCell): @@ -250,12 +253,44 @@ class RNNTest(test.TestCase): self.assertAllEqual(4, state[0]) self.assertAllEqual([[[1]], [[2]], [[3]], [[4]]], state[1]) + def testCellGetInitialState(self): + cell = rnn_cell_impl.BasicRNNCell(5) + with self.assertRaisesRegexp( + ValueError, "batch_size and dtype cannot be None"): + cell.get_initial_state(None, None, None) + + inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 1)) + with self.assertRaisesRegexp( + ValueError, "batch size from input tensor is different from"): + cell.get_initial_state(inputs=inputs, batch_size=50, dtype=None) + + with self.assertRaisesRegexp( + ValueError, "batch size from input tensor is different from"): + cell.get_initial_state( + inputs=inputs, batch_size=constant_op.constant(50), dtype=None) + + with self.assertRaisesRegexp( + ValueError, "dtype from input tensor is different from"): + cell.get_initial_state(inputs=inputs, batch_size=None, dtype=dtypes.int16) + + initial_state = cell.get_initial_state( + inputs=inputs, batch_size=None, dtype=None) + self.assertEqual(initial_state.shape.as_list(), [None, 5]) + self.assertEqual(initial_state.dtype, inputs.dtype) + + batch = array_ops.shape(inputs)[0] + dtype = inputs.dtype + initial_state = cell.get_initial_state(None, batch, dtype) + self.assertEqual(initial_state.shape.as_list(), [None, 5]) + self.assertEqual(initial_state.dtype, inputs.dtype) + def _assert_cell_builds(self, cell_class, dtype, batch_size, in_size, out_size): cell = cell_class(out_size, dtype=dtype) in_shape = tensor_shape.TensorShape((batch_size, in_size)) cell.build(in_shape) - state_output = cell.zero_state(batch_size, dtype) + state_output = cell.get_initial_state( + inputs=None, batch_size=batch_size, dtype=dtype) cell_output, _ = cell(array_ops.zeros(in_shape, dtype), state_output) self.assertAllEqual([batch_size, out_size], cell_output.shape.as_list()) @@ -278,12 +313,228 @@ class RNNTest(test.TestCase): self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f32, 5, 7, 3) self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f64, 5, 7, 3) + def testRNNWithKerasSimpleRNNCell(self): + with self.test_session() as sess: + input_shape = 10 + output_shape = 5 + timestep = 4 + batch = 100 + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=batch, + test_samples=0, + input_shape=(timestep, input_shape), + num_classes=output_shape) + y_train = keras.utils.to_categorical(y_train) + cell = keras.layers.SimpleRNNCell(output_shape) + + inputs = array_ops.placeholder( + dtypes.float32, shape=(None, timestep, input_shape)) + predict = array_ops.placeholder( + dtypes.float32, shape=(None, output_shape)) + + outputs, state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32) + self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape]) + self.assertEqual(state.shape.as_list(), [None, output_shape]) + loss = losses.softmax_cross_entropy(predict, state) + train_op = training.GradientDescentOptimizer(0.001).minimize(loss) + + sess.run([variables_lib.global_variables_initializer()]) + _, outputs, state = sess.run( + [train_op, outputs, state], {inputs: x_train, predict: y_train}) + + self.assertEqual(len(outputs), batch) + self.assertEqual(len(state), batch) + + def testRNNWithKerasGRUCell(self): + with self.test_session() as sess: + input_shape = 10 + output_shape = 5 + timestep = 4 + batch = 100 + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=batch, + test_samples=0, + input_shape=(timestep, input_shape), + num_classes=output_shape) + y_train = keras.utils.to_categorical(y_train) + cell = keras.layers.GRUCell(output_shape) + + inputs = array_ops.placeholder( + dtypes.float32, shape=(None, timestep, input_shape)) + predict = array_ops.placeholder( + dtypes.float32, shape=(None, output_shape)) + + outputs, state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32) + self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape]) + self.assertEqual(state.shape.as_list(), [None, output_shape]) + loss = losses.softmax_cross_entropy(predict, state) + train_op = training.GradientDescentOptimizer(0.001).minimize(loss) + + sess.run([variables_lib.global_variables_initializer()]) + _, outputs, state = sess.run( + [train_op, outputs, state], {inputs: x_train, predict: y_train}) + + self.assertEqual(len(outputs), batch) + self.assertEqual(len(state), batch) + + def testRNNWithKerasLSTMCell(self): + with self.test_session() as sess: + input_shape = 10 + output_shape = 5 + timestep = 4 + batch = 100 + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=batch, + test_samples=0, + input_shape=(timestep, input_shape), + num_classes=output_shape) + y_train = keras.utils.to_categorical(y_train) + cell = keras.layers.LSTMCell(output_shape) + + inputs = array_ops.placeholder( + dtypes.float32, shape=(None, timestep, input_shape)) + predict = array_ops.placeholder( + dtypes.float32, shape=(None, output_shape)) + + outputs, state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32) + self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape]) + self.assertEqual(len(state), 2) + self.assertEqual(state[0].shape.as_list(), [None, output_shape]) + self.assertEqual(state[1].shape.as_list(), [None, output_shape]) + loss = losses.softmax_cross_entropy(predict, state[0]) + train_op = training.GradientDescentOptimizer(0.001).minimize(loss) + + sess.run([variables_lib.global_variables_initializer()]) + _, outputs, state = sess.run( + [train_op, outputs, state], {inputs: x_train, predict: y_train}) + + self.assertEqual(len(outputs), batch) + self.assertEqual(len(state), 2) + self.assertEqual(len(state[0]), batch) + self.assertEqual(len(state[1]), batch) + + def testRNNWithStackKerasCell(self): + with self.test_session() as sess: + input_shape = 10 + output_shape = 5 + timestep = 4 + batch = 100 + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=batch, + test_samples=0, + input_shape=(timestep, input_shape), + num_classes=output_shape) + y_train = keras.utils.to_categorical(y_train) + cell = keras.layers.StackedRNNCells( + [keras.layers.LSTMCell(2 * output_shape), + keras.layers.LSTMCell(output_shape)]) + + inputs = array_ops.placeholder( + dtypes.float32, shape=(None, timestep, input_shape)) + predict = array_ops.placeholder( + dtypes.float32, shape=(None, output_shape)) + + outputs, state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32) + self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape]) + self.assertEqual(len(state), 4) + self.assertEqual(state[0].shape.as_list(), [None, output_shape]) + self.assertEqual(state[1].shape.as_list(), [None, output_shape]) + self.assertEqual(state[2].shape.as_list(), [None, 2 * output_shape]) + self.assertEqual(state[3].shape.as_list(), [None, 2 * output_shape]) + loss = losses.softmax_cross_entropy(predict, state[0]) + train_op = training.GradientDescentOptimizer(0.001).minimize(loss) + + sess.run([variables_lib.global_variables_initializer()]) + _, outputs, state = sess.run( + [train_op, outputs, state], {inputs: x_train, predict: y_train}) + + self.assertEqual(len(outputs), batch) + self.assertEqual(len(state), 4) + for s in state: + self.assertEqual(len(s), batch) + + def testStaticRNNWithKerasSimpleRNNCell(self): + with self.test_session() as sess: + input_shape = 10 + output_shape = 5 + timestep = 4 + batch = 100 + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=batch, + test_samples=0, + input_shape=(timestep, input_shape), + num_classes=output_shape) + x_train = np.transpose(x_train, (1, 0, 2)) + y_train = keras.utils.to_categorical(y_train) + cell = keras.layers.SimpleRNNCell(output_shape) + + inputs = [array_ops.placeholder( + dtypes.float32, shape=(None, input_shape))] * timestep + predict = array_ops.placeholder( + dtypes.float32, shape=(None, output_shape)) + + outputs, state = rnn.static_rnn( + cell, inputs, dtype=dtypes.float32) + self.assertEqual(len(outputs), timestep) + self.assertEqual(outputs[0].shape.as_list(), [None, output_shape]) + self.assertEqual(state.shape.as_list(), [None, output_shape]) + loss = losses.softmax_cross_entropy(predict, state) + train_op = training.GradientDescentOptimizer(0.001).minimize(loss) + + sess.run([variables_lib.global_variables_initializer()]) + feed_dict = {i: d for i, d in zip(inputs, x_train)} + feed_dict[predict] = y_train + _, outputs, state = sess.run( + [train_op, outputs, state], feed_dict) + + self.assertEqual(len(outputs), timestep) + self.assertEqual(len(outputs[0]), batch) + self.assertEqual(len(state), batch) + + def testKerasAndTFRNNLayerOutputComparison(self): + input_shape = 10 + output_shape = 5 + timestep = 4 + batch = 20 + (x_train, _), _ = testing_utils.get_test_data( + train_samples=batch, + test_samples=0, + input_shape=(timestep, input_shape), + num_classes=output_shape) + fix_weights_generator = keras.layers.SimpleRNNCell(output_shape) + fix_weights_generator.build((None, input_shape)) + weights = fix_weights_generator.get_weights() + + with self.test_session(graph=ops_lib.Graph()) as sess: + inputs = array_ops.placeholder( + dtypes.float32, shape=(None, timestep, input_shape)) + cell = keras.layers.SimpleRNNCell(output_shape) + tf_out, tf_state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32) + cell.set_weights(weights) + [tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train}) + with self.test_session(graph=ops_lib.Graph()) as sess: + k_input = keras.Input(shape=(timestep, input_shape), + dtype=dtypes.float32) + cell = keras.layers.SimpleRNNCell(output_shape) + layer = keras.layers.RNN(cell, return_sequences=True, return_state=True) + keras_out = layer(k_input) + cell.set_weights(weights) + k_out, k_state = sess.run(keras_out, {k_input: x_train}) + self.assertAllClose(tf_out, k_out) + self.assertAllClose(tf_state, k_state) + def testBasicLSTMCellInterchangeWithLSTMCell(self): with self.test_session(graph=ops_lib.Graph()) as sess: basic_cell = rnn_cell_impl.BasicLSTMCell(1) basic_cell(array_ops.ones([1, 1]), - state=basic_cell.zero_state(batch_size=1, - dtype=dtypes.float32)) + state=basic_cell.get_initial_state(inputs=None, + batch_size=1, + dtype=dtypes.float32)) self.evaluate([v.initializer for v in basic_cell.variables]) self.evaluate(basic_cell._bias.assign([10.] * 4)) save = saver.Saver() @@ -293,8 +544,9 @@ class RNNTest(test.TestCase): with self.test_session(graph=ops_lib.Graph()) as sess: lstm_cell = rnn_cell_impl.LSTMCell(1, name="basic_lstm_cell") lstm_cell(array_ops.ones([1, 1]), - state=lstm_cell.zero_state(batch_size=1, - dtype=dtypes.float32)) + state=lstm_cell.get_initial_state(inputs=None, + batch_size=1, + dtype=dtypes.float32)) self.evaluate([v.initializer for v in lstm_cell.variables]) save = saver.Saver() save.restore(sess, save_path) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 7b6ab20975..38336b64db 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -608,7 +608,8 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, else: if not dtype: raise ValueError("If there is no initial_state, you must give a dtype.") - state = cell.zero_state(batch_size, dtype) + state = cell.get_initial_state( + inputs=None, batch_size=batch_size, dtype=dtype) def _assert_has_shape(x, shape): x_shape = array_ops.shape(x) @@ -788,6 +789,10 @@ def _dynamic_rnn_loop(cell, input_t = tuple(ta[time.numpy()] for ta in input_ta) input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t) + # Keras RNN cells only accept state as list, even if it's a single tensor. + is_keras_rnn_cell = not isinstance(cell, rnn_cell_impl.RNNCell) + if is_keras_rnn_cell and not nest.is_sequence(state): + state = [state] call_cell = lambda: cell(input_t, state) if sequence_length is not None: @@ -804,6 +809,9 @@ def _dynamic_rnn_loop(cell, else: (output, new_state) = call_cell() + # Keras cells always wrap state as list, even if it's a single tensor. + if is_keras_rnn_cell and len(new_state) == 1: + new_state = new_state[0] # Pack state if using state tuples output = nest.flatten(output) @@ -1286,7 +1294,8 @@ def static_rnn(cell, if not dtype: raise ValueError("If no initial_state is provided, " "dtype must be specified") - state = cell.zero_state(batch_size, dtype) + state = cell.get_initial_state( + inputs=None, batch_size=batch_size, dtype=dtype) if sequence_length is not None: # Prepare variables sequence_length = ops.convert_to_tensor( @@ -1315,6 +1324,10 @@ def static_rnn(cell, min_sequence_length = math_ops.reduce_min(sequence_length) max_sequence_length = math_ops.reduce_max(sequence_length) + # Keras RNN cells only accept state as list, even if it's a single tensor. + is_keras_rnn_cell = not isinstance(cell, rnn_cell_impl.RNNCell) + if is_keras_rnn_cell and not nest.is_sequence(state): + state = [state] for time, input_ in enumerate(inputs): if time > 0: varscope.reuse_variables() @@ -1333,8 +1346,10 @@ def static_rnn(cell, state_size=cell.state_size) else: (output, state) = call_cell() - outputs.append(output) + # Keras RNN cells only return state as list, even if it's a single tensor. + if is_keras_rnn_cell and len(state) == 1: + state = state[0] return (outputs, state) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 85a6a2233c..dcc9c6a4d2 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -80,13 +80,13 @@ def assert_like_rnncell(cell_name, cell): conditions = [ hasattr(cell, "output_size"), hasattr(cell, "state_size"), - hasattr(cell, "zero_state"), + hasattr(cell, "get_initial_state"), callable(cell), ] errors = [ "'output_size' property is missing", "'state_size' property is missing", - "'zero_state' method is missing", + "'get_initial_state' method is missing", "is not callable" ] @@ -266,6 +266,36 @@ class RNNCell(base_layer.Layer): # self.add_variable() inside the call() method. pass + def get_initial_state(self, inputs=None, batch_size=None, dtype=None): + if inputs is not None: + # Validate the given batch_size and dtype against inputs if provided. + inputs = ops.convert_to_tensor(inputs, name="inputs") + if batch_size is not None: + if tensor_util.is_tensor(batch_size): + static_batch_size = tensor_util.constant_value( + batch_size, partial=True) + else: + static_batch_size = batch_size + if inputs.shape[0].value != static_batch_size: + raise ValueError( + "batch size from input tensor is different from the " + "input param. Input tensor batch: {}, batch_size: {}".format( + inputs.shape[0].value, batch_size)) + + if dtype is not None and inputs.dtype != dtype: + raise ValueError( + "dtype from input tensor is different from the " + "input param. Input tensor dtype: {}, dtype: {}".format( + inputs.dtype, dtype)) + + batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] + dtype = inputs.dtype + if None in [batch_size, dtype]: + raise ValueError( + "batch_size and dtype cannot be None while constructing initial " + "state: batch_size={}, dtype={}".format(batch_size, dtype)) + return self.zero_state(batch_size, dtype) + def zero_state(self, batch_size, dtype): """Return zero-filled state tensor(s). diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt index 4ba21a25cd..3b4d703aea 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -133,6 +133,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt index 511456e740..9fe35571b9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -133,6 +133,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt index 9dfda96fc8..82b6d2015b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -133,6 +133,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt index 6718e36dc6..671fc66db0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -141,6 +141,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt index e606eab919..88b8f37c4f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt @@ -152,6 +152,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt index 5deb02d569..a4483fefa2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt @@ -152,6 +152,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt index 8a63b49180..381c4975d7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -151,6 +151,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt index db1aae2757..912365a28b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -155,6 +155,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt index 32fa151a8e..a4bb3219c7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt @@ -152,6 +152,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt index 30c6c2ce3b..715bfd5fc7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt @@ -152,6 +152,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt index 72b40cc9f7..b66c0f89cc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -151,6 +151,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt index a5c2b4aefd..faeb4f3513 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -150,6 +150,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt index 61d5f04b22..caa2e60080 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -151,6 +151,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt index 4ba21a25cd..3b4d703aea 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -133,6 +133,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt index 511456e740..9fe35571b9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -133,6 +133,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt index 9dfda96fc8..82b6d2015b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -133,6 +133,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt index 6718e36dc6..671fc66db0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -141,6 +141,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt index e606eab919..88b8f37c4f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt @@ -152,6 +152,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt index 5deb02d569..a4483fefa2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt @@ -152,6 +152,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt index 8a63b49180..381c4975d7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -151,6 +151,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt index db1aae2757..912365a28b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -155,6 +155,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt index 32fa151a8e..a4bb3219c7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt @@ -152,6 +152,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt index 30c6c2ce3b..715bfd5fc7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt @@ -152,6 +152,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt index 72b40cc9f7..b66c0f89cc 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -151,6 +151,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt index a5c2b4aefd..faeb4f3513 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -150,6 +150,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt index 61d5f04b22..caa2e60080 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -151,6 +151,10 @@ tf_class { argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { name: "get_input_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } |