aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar John Bates <jtbates@google.com>2017-03-23 12:56:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-23 14:12:56 -0700
commit2543f3b60d4b1dd3a5184221655022836bbca2a3 (patch)
treea95d6e6574d7f861e378447c484f390e8525d4ca
parentbe8e608423d4f916c25c7df1142ffbd0e1c2aebf (diff)
Allow users to specify the RNNCell type for the StateSavingRnnEstimator.
Change: 151049836
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py149
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py27
2 files changed, 95 insertions, 81 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
index a02ac4a55b..2707a347b7 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
@@ -32,6 +32,7 @@ from tensorflow.contrib.training.python.training import sequence_queueing_state_
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.training import momentum as momentum_opt
from tensorflow.python.util import nest
@@ -184,12 +185,59 @@ def _prepare_features_for_sqss(features, labels, mode,
return sequence_features, context_features
+def _get_state_names(cell):
+ """Gets the state names for an `RNNCell`.
+
+ Args:
+ cell: A `RNNCell` to be used in the RNN.
+
+ Returns:
+ State names in the form of a string, a list of strings, or a list of
+ string pairs, depending on the type of `cell.state_size`.
+
+ Raises:
+ TypeError: If cell.state_size is of type TensorShape.
+ """
+ state_size = cell.state_size
+ if isinstance(state_size, tensor_shape.TensorShape):
+ raise TypeError('cell.state_size of type TensorShape is not supported.')
+ if isinstance(state_size, int):
+ return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, 0)
+ if isinstance(state_size, rnn_cell.LSTMStateTuple):
+ return [
+ '{}_{}_c'.format(rnn_common.RNNKeys.STATE_PREFIX, 0),
+ '{}_{}_h'.format(rnn_common.RNNKeys.STATE_PREFIX, 0),
+ ]
+ if isinstance(state_size[0], rnn_cell.LSTMStateTuple):
+ return [[
+ '{}_{}_c'.format(rnn_common.RNNKeys.STATE_PREFIX, i),
+ '{}_{}_h'.format(rnn_common.RNNKeys.STATE_PREFIX, i),
+ ] for i in range(len(state_size))]
+ return [
+ '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i)
+ for i in range(len(state_size))]
+
+
+def _get_initial_states(cell):
+ """Gets the initial state of the `RNNCell` used in the RNN.
+
+ Args:
+ cell: A `RNNCell` to be used in the RNN.
+
+ Returns:
+ A Python dict mapping state names to the `RNNCell`'s initial state for
+ consumption by the SQSS.
+ """
+ names = nest.flatten(_get_state_names(cell))
+ values = nest.flatten(cell.zero_state(1, dtype=dtypes.float32))
+ return {n: array_ops.squeeze(v, axis=0) for [n, v] in zip(names, values)}
+
+
def _read_batch(cell,
features,
labels,
mode,
num_unroll,
- num_rnn_layers,
batch_size,
sequence_feature_columns,
context_feature_columns=None,
@@ -209,7 +257,6 @@ def _read_batch(cell,
num_unroll: Python integer, how many time steps to unroll at a time.
The input sequences of length `k` are then split into `k / num_unroll`
many segments.
- num_rnn_layers: Python integer, number of layers in the RNN.
batch_size: Python integer, the size of the minibatch produced by the SQSS.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
@@ -230,15 +277,7 @@ def _read_batch(cell,
batch: A `NextQueuedSequenceBatch` containing batch_size `SequenceExample`
values and their saved internal states.
"""
- # Set batch_size=1 to initialize SQSS with cell's zero state.
- values = cell.zero_state(batch_size=1, dtype=dtypes.float32)
-
- # Set up stateful queue reader.
- states = {}
- state_names = _get_lstm_state_names(num_rnn_layers)
- for i in range(num_rnn_layers):
- states[state_names[i][0]] = array_ops.squeeze(values[i][0], axis=0)
- states[state_names[i][1]] = array_ops.squeeze(values[i][1], axis=0)
+ states = _get_initial_states(cell)
sequences, context = _prepare_features_for_sqss(
features, labels, mode, sequence_feature_columns,
@@ -340,12 +379,12 @@ def _prepare_inputs_for_rnn(sequence_features, context_features,
axis=1)
-def _get_rnn_model_fn(target_column,
+def _get_rnn_model_fn(cell_type,
+ target_column,
problem_type,
optimizer,
num_unroll,
num_units,
- num_rnn_layers,
num_threads,
queue_capacity,
batch_size,
@@ -360,6 +399,8 @@ def _get_rnn_model_fn(target_column,
"""Creates a state saving RNN model function for an `Estimator`.
Args:
+ cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of
+ 'basic_rnn,' 'lstm' or 'gru'.
target_column: An initialized `TargetColumn`, used to calculate prediction
and loss.
problem_type: `ProblemType.CLASSIFICATION` or
@@ -370,7 +411,6 @@ def _get_rnn_model_fn(target_column,
The input sequences of length `k` are then split into `k / num_unroll`
many segments.
num_units: The number of units in the `RNNCell`.
- num_rnn_layers: Python integer, number of layers in the RNN.
num_threads: The Python integer number of threads enqueuing input examples
into a queue.
queue_capacity: The max capacity of the queue in number of examples.
@@ -392,7 +432,7 @@ def _get_rnn_model_fn(target_column,
effect if `optimizer` is an instance of an `Optimizer`.
gradient_clipping_norm: A float. Gradients will be clipped to this value.
dropout_keep_probabilities: a list of dropout keep probabilities or `None`.
- If given a list, it must have length `num_rnn_layers + 1`.
+ If given a list, it must have length `len(num_units) + 1`.
name: A string that will be used to create a scope for the RNN.
seed: Fixes the random seed used for generating input keys by the SQSS.
@@ -427,7 +467,7 @@ def _get_rnn_model_fn(target_column,
dropout = (dropout_keep_probabilities
if mode == model_fn.ModeKeys.TRAIN
else None)
- cell = lstm_cell(num_units, num_rnn_layers, dropout)
+ cell = rnn_common.construct_rnn_cell(num_units, cell_type, dropout)
batch = _read_batch(
cell=cell,
@@ -435,7 +475,6 @@ def _get_rnn_model_fn(target_column,
labels=labels,
mode=mode,
num_unroll=num_unroll,
- num_rnn_layers=num_rnn_layers,
batch_size=batch_size,
sequence_feature_columns=sequence_feature_columns,
context_feature_columns=context_feature_columns,
@@ -448,7 +487,7 @@ def _get_rnn_model_fn(target_column,
labels = sequence_features.pop(rnn_common.RNNKeys.LABELS_KEY)
inputs = _prepare_inputs_for_rnn(sequence_features, context_features,
sequence_feature_columns, num_unroll)
- state_name = _get_lstm_state_names(num_rnn_layers)
+ state_name = _get_state_names(cell)
rnn_activations, final_state = construct_state_saving_rnn(
cell=cell,
inputs=inputs,
@@ -490,54 +529,17 @@ def _get_rnn_model_fn(target_column,
return _rnn_model_fn
-def _get_lstm_state_names(num_rnn_layers):
- """Returns a num_rnn_layers long list of lstm state name pairs.
-
- Args:
- num_rnn_layers: The number of layers in the RNN.
-
- Returns:
- A num_rnn_layers long list of lstm state name pairs of the form:
- ['lstm_state_cN', 'lstm_state_mN'] for all N from 0 to num_rnn_layers.
- """
- return [['lstm_state_c' + str(i), 'lstm_state_m' + str(i)]
- for i in range(num_rnn_layers)]
-
-
-# TODO(jtbates): Allow users to specify cell types other than LSTM.
-def lstm_cell(num_units, num_rnn_layers, dropout_keep_probabilities):
- """Constructs a `MultiRNNCell` with num_rnn_layers `BasicLSTMCell`s.
-
- Args:
- num_units: The number of units in the `RNNCell`.
- num_rnn_layers: The number of layers in the RNN.
- dropout_keep_probabilities: a list whose elements are either floats in
- `[0.0, 1.0]` or `None`. It must have length `num_rnn_layers + 1`.
-
- Returns:
- An intiialized `MultiRNNCell`.
- """
-
- cells = [
- rnn_cell.BasicLSTMCell(num_units=num_units, state_is_tuple=True)
- for _ in range(num_rnn_layers)
- ]
- if dropout_keep_probabilities:
- cells = rnn_common.apply_dropout(cells, dropout_keep_probabilities)
- return rnn_cell.MultiRNNCell(cells)
-
-
class StateSavingRnnEstimator(estimator.Estimator):
def __init__(self,
problem_type,
- num_units,
num_unroll,
batch_size,
sequence_feature_columns,
context_feature_columns=None,
num_classes=None,
- num_rnn_layers=1,
+ num_units=None,
+ cell_type='basic_rnn',
optimizer_type='SGD',
learning_rate=0.1,
predict_probabilities=False,
@@ -555,7 +557,6 @@ class StateSavingRnnEstimator(estimator.Estimator):
Args:
problem_type: `ProblemType.CLASSIFICATION` or
`ProblemType.LINEAR_REGRESSION`.
- num_units: The size of the RNN cells.
num_unroll: Python integer, how many time steps to unroll at a time.
The input sequences of length `k` are then split into `k / num_unroll`
many segments.
@@ -568,8 +569,12 @@ class StateSavingRnnEstimator(estimator.Estimator):
steps. All items in the set should be instances of classes derived from
`FeatureColumn`.
num_classes: The number of classes for categorization. Used only and
- required if `problem_type` is `ProblemType.CLASSIFICATION`
- num_rnn_layers: Number of RNN layers.
+ required if `problem_type` is `ProblemType.CLASSIFICATION`.
+ num_units: A list of integers indicating the number of units in the
+ `RNNCell`s in each layer. Either `num_units` is specified or `cell_type`
+ is an instance of `RNNCell`.
+ cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of
+ 'basic_rnn,' 'lstm' or 'gru'.
optimizer_type: The type of optimizer to use. Either a subclass of
`Optimizer`, an instance of an `Optimizer` or a string. Strings must be
one of 'Adagrad', 'Adam', 'Ftrl', Momentum', 'RMSProp', or 'SGD'.
@@ -582,7 +587,7 @@ class StateSavingRnnEstimator(estimator.Estimator):
gradient_clipping_norm: Parameter used for gradient clipping. If `None`,
then no clipping is performed.
dropout_keep_probabilities: a list of dropout keep probabilities or
- `None`. If given a list, it must have length `num_rnn_layers + 1`.
+ `None`. If given a list, it must have length `len(num_units) + 1`.
model_dir: The directory in which to save and restore the model graph,
parameters, etc.
config: A `RunConfig` instance.
@@ -599,11 +604,18 @@ class StateSavingRnnEstimator(estimator.Estimator):
seed: Fixes the random seed used for generating input keys by the SQSS.
Raises:
+ ValueError: Both or neither of the following are true: (a) `num_units` is
+ specified and (b) `cell_type` is an instance of `RNNCell`.
ValueError: `problem_type` is not one of
`ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`.
ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but
`num_classes` is not specified.
"""
+ if (num_units is not None) == isinstance(cell_type, rnn_cell.RNNCell):
+ raise ValueError(
+ 'Either num_units is specified OR cell_type is an instance of '
+ 'RNNCell. Got num_units = {} and cell_type = {}.'.format(
+ num_units, cell_type))
name = 'MultiValueStateSavingRNN'
if problem_type == constants.ProblemType.LINEAR_REGRESSION:
@@ -625,12 +637,12 @@ class StateSavingRnnEstimator(estimator.Estimator):
optimizer_type = momentum_opt.MomentumOptimizer(learning_rate, momentum)
rnn_model_fn = _get_rnn_model_fn(
+ cell_type=cell_type,
target_column=target_column,
problem_type=problem_type,
optimizer=optimizer_type,
num_unroll=num_unroll,
num_units=num_units,
- num_rnn_layers=num_rnn_layers,
num_threads=num_threads,
queue_capacity=queue_capacity,
batch_size=batch_size,
@@ -684,8 +696,7 @@ def multi_value_rnn_regressor(num_units,
describing context features, i.e., features that apply accross all time
steps. All items in the set should be instances of classes derived from
`FeatureColumn`.
- num_rnn_layers: Number of RNN layers. Leave this at its default value 1
- if passing a `cell_type` that is already a MultiRNNCell.
+ num_rnn_layers: Number of RNN layers.
optimizer_type: The type of optimizer to use. Either a subclass of
`Optimizer`, an instance of an `Optimizer` or a string. Strings must be
one of 'Adagrad', 'Momentum' or 'SGD'.
@@ -713,15 +724,16 @@ def multi_value_rnn_regressor(num_units,
Returns:
An initialized `Estimator`.
"""
+ num_units = [num_units for _ in range(num_rnn_layers)]
return StateSavingRnnEstimator(
constants.ProblemType.LINEAR_REGRESSION,
- num_units,
num_unroll,
batch_size,
sequence_feature_columns,
context_feature_columns=context_feature_columns,
num_classes=None,
- num_rnn_layers=num_rnn_layers,
+ num_units=num_units,
+ cell_type='lstm',
optimizer_type=optimizer_type,
learning_rate=learning_rate,
predict_probabilities=False,
@@ -803,15 +815,16 @@ def multi_value_rnn_classifier(num_classes,
Returns:
An initialized `Estimator`.
"""
+ num_units = [num_units for _ in range(num_rnn_layers)]
return StateSavingRnnEstimator(
constants.ProblemType.CLASSIFICATION,
- num_units,
num_unroll,
batch_size,
sequence_feature_columns,
context_feature_columns=context_feature_columns,
num_classes=num_classes,
- num_rnn_layers=num_rnn_layers,
+ num_units=num_units,
+ cell_type='lstm',
optimizer_type=optimizer_type,
learning_rate=learning_rate,
predict_probabilities=predict_probabilities,
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
index 643a7d987b..f5bd03429c 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
@@ -324,22 +324,21 @@ class StateSavingRnnEstimatorTest(test.TestCase):
def _getModelFnOpsForMode(self, mode):
"""Helper for testGetRnnModelFn{Train,Eval,Infer}()."""
- num_units = 4
- num_rnn_layers = 1
+ num_units = [4]
seq_columns = [
feature_column.real_valued_column(
- 'inputs', dimension=num_units)
+ 'inputs', dimension=1)
]
features = {
'inputs': constant_op.constant([1., 2., 3.]),
}
labels = constant_op.constant([1., 0., 1.])
model_fn = ssre._get_rnn_model_fn(
+ cell_type='basic_rnn',
target_column=target_column_lib.multi_class_target(n_classes=2),
optimizer='SGD',
num_unroll=2,
num_units=num_units,
- num_rnn_layers=num_rnn_layers,
num_threads=1,
queue_capacity=10,
batch_size=1,
@@ -380,14 +379,14 @@ class StateSavingRnnEstimatorTest(test.TestCase):
def testExport(self):
input_feature_key = 'magic_input_feature_key'
batch_size = 8
- cell_size = 4
+ num_units = [4]
sequence_length = 10
num_unroll = 2
num_classes = 2
seq_columns = [
feature_column.real_valued_column(
- 'inputs', dimension=cell_size)
+ 'inputs', dimension=4)
]
def get_input_fn(mode, seed):
@@ -414,7 +413,7 @@ class StateSavingRnnEstimatorTest(test.TestCase):
def estimator_fn():
return ssre.StateSavingRnnEstimator(
constants.ProblemType.CLASSIFICATION,
- num_units=cell_size,
+ num_units=num_units,
num_unroll=num_unroll,
batch_size=batch_size,
sequence_feature_columns=seq_columns,
@@ -518,8 +517,8 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
sequence_length = 64
train_steps = 250
eval_steps = 20
- num_units = 4
num_rnn_layers = 1
+ num_units = [4] * num_rnn_layers
learning_rate = 0.3
loss_threshold = 0.035
@@ -539,14 +538,14 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
seq_columns = [
feature_column.real_valued_column(
- 'inputs', dimension=num_units)
+ 'inputs', dimension=1)
]
config = run_config.RunConfig(tf_random_seed=1234)
dropout_keep_probabilities = [0.9] * (num_rnn_layers + 1)
sequence_estimator = ssre.StateSavingRnnEstimator(
constants.ProblemType.LINEAR_REGRESSION,
num_units=num_units,
- num_rnn_layers=num_rnn_layers,
+ cell_type='lstm',
num_unroll=num_unroll,
batch_size=batch_size,
sequence_feature_columns=seq_columns,
@@ -578,7 +577,7 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
sequence_length = 32
train_steps = 200
eval_steps = 20
- num_units = 4
+ num_units = [4]
learning_rate = 0.5
accuracy_threshold = 0.9
@@ -596,12 +595,13 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
seq_columns = [
feature_column.real_valued_column(
- 'inputs', dimension=num_units)
+ 'inputs', dimension=1)
]
config = run_config.RunConfig(tf_random_seed=21212)
sequence_estimator = ssre.StateSavingRnnEstimator(
constants.ProblemType.CLASSIFICATION,
num_units=num_units,
+ cell_type='lstm',
num_unroll=num_unroll,
batch_size=batch_size,
sequence_feature_columns=seq_columns,
@@ -649,7 +649,7 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
num_unroll = 7 # not a divisor of sequence_length
train_steps = 350
eval_steps = 30
- num_units = 4
+ num_units = [4]
learning_rate = 0.4
accuracy_threshold = 0.65
@@ -684,6 +684,7 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
sequence_estimator = ssre.StateSavingRnnEstimator(
constants.ProblemType.CLASSIFICATION,
num_units=num_units,
+ cell_type='basic_rnn',
num_unroll=num_unroll,
batch_size=batch_size,
sequence_feature_columns=sequence_feature_columns,