diff options
author | John Bates <jtbates@google.com> | 2017-03-23 12:56:57 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-23 14:12:56 -0700 |
commit | 2543f3b60d4b1dd3a5184221655022836bbca2a3 (patch) | |
tree | a95d6e6574d7f861e378447c484f390e8525d4ca | |
parent | be8e608423d4f916c25c7df1142ffbd0e1c2aebf (diff) |
Allow users to specify the RNNCell type for the StateSavingRnnEstimator.
Change: 151049836
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py | 149 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py | 27 |
2 files changed, 95 insertions, 81 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py index a02ac4a55b..2707a347b7 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py @@ -32,6 +32,7 @@ from tensorflow.contrib.training.python.training import sequence_queueing_state_ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.training import momentum as momentum_opt from tensorflow.python.util import nest @@ -184,12 +185,59 @@ def _prepare_features_for_sqss(features, labels, mode, return sequence_features, context_features +def _get_state_names(cell): + """Gets the state names for an `RNNCell`. + + Args: + cell: A `RNNCell` to be used in the RNN. + + Returns: + State names in the form of a string, a list of strings, or a list of + string pairs, depending on the type of `cell.state_size`. + + Raises: + TypeError: If cell.state_size is of type TensorShape. + """ + state_size = cell.state_size + if isinstance(state_size, tensor_shape.TensorShape): + raise TypeError('cell.state_size of type TensorShape is not supported.') + if isinstance(state_size, int): + return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, 0) + if isinstance(state_size, rnn_cell.LSTMStateTuple): + return [ + '{}_{}_c'.format(rnn_common.RNNKeys.STATE_PREFIX, 0), + '{}_{}_h'.format(rnn_common.RNNKeys.STATE_PREFIX, 0), + ] + if isinstance(state_size[0], rnn_cell.LSTMStateTuple): + return [[ + '{}_{}_c'.format(rnn_common.RNNKeys.STATE_PREFIX, i), + '{}_{}_h'.format(rnn_common.RNNKeys.STATE_PREFIX, i), + ] for i in range(len(state_size))] + return [ + '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i) + for i in range(len(state_size))] + + +def _get_initial_states(cell): + """Gets the initial state of the `RNNCell` used in the RNN. + + Args: + cell: A `RNNCell` to be used in the RNN. + + Returns: + A Python dict mapping state names to the `RNNCell`'s initial state for + consumption by the SQSS. + """ + names = nest.flatten(_get_state_names(cell)) + values = nest.flatten(cell.zero_state(1, dtype=dtypes.float32)) + return {n: array_ops.squeeze(v, axis=0) for [n, v] in zip(names, values)} + + def _read_batch(cell, features, labels, mode, num_unroll, - num_rnn_layers, batch_size, sequence_feature_columns, context_feature_columns=None, @@ -209,7 +257,6 @@ def _read_batch(cell, num_unroll: Python integer, how many time steps to unroll at a time. The input sequences of length `k` are then split into `k / num_unroll` many segments. - num_rnn_layers: Python integer, number of layers in the RNN. batch_size: Python integer, the size of the minibatch produced by the SQSS. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances @@ -230,15 +277,7 @@ def _read_batch(cell, batch: A `NextQueuedSequenceBatch` containing batch_size `SequenceExample` values and their saved internal states. """ - # Set batch_size=1 to initialize SQSS with cell's zero state. - values = cell.zero_state(batch_size=1, dtype=dtypes.float32) - - # Set up stateful queue reader. - states = {} - state_names = _get_lstm_state_names(num_rnn_layers) - for i in range(num_rnn_layers): - states[state_names[i][0]] = array_ops.squeeze(values[i][0], axis=0) - states[state_names[i][1]] = array_ops.squeeze(values[i][1], axis=0) + states = _get_initial_states(cell) sequences, context = _prepare_features_for_sqss( features, labels, mode, sequence_feature_columns, @@ -340,12 +379,12 @@ def _prepare_inputs_for_rnn(sequence_features, context_features, axis=1) -def _get_rnn_model_fn(target_column, +def _get_rnn_model_fn(cell_type, + target_column, problem_type, optimizer, num_unroll, num_units, - num_rnn_layers, num_threads, queue_capacity, batch_size, @@ -360,6 +399,8 @@ def _get_rnn_model_fn(target_column, """Creates a state saving RNN model function for an `Estimator`. Args: + cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of + 'basic_rnn,' 'lstm' or 'gru'. target_column: An initialized `TargetColumn`, used to calculate prediction and loss. problem_type: `ProblemType.CLASSIFICATION` or @@ -370,7 +411,6 @@ def _get_rnn_model_fn(target_column, The input sequences of length `k` are then split into `k / num_unroll` many segments. num_units: The number of units in the `RNNCell`. - num_rnn_layers: Python integer, number of layers in the RNN. num_threads: The Python integer number of threads enqueuing input examples into a queue. queue_capacity: The max capacity of the queue in number of examples. @@ -392,7 +432,7 @@ def _get_rnn_model_fn(target_column, effect if `optimizer` is an instance of an `Optimizer`. gradient_clipping_norm: A float. Gradients will be clipped to this value. dropout_keep_probabilities: a list of dropout keep probabilities or `None`. - If given a list, it must have length `num_rnn_layers + 1`. + If given a list, it must have length `len(num_units) + 1`. name: A string that will be used to create a scope for the RNN. seed: Fixes the random seed used for generating input keys by the SQSS. @@ -427,7 +467,7 @@ def _get_rnn_model_fn(target_column, dropout = (dropout_keep_probabilities if mode == model_fn.ModeKeys.TRAIN else None) - cell = lstm_cell(num_units, num_rnn_layers, dropout) + cell = rnn_common.construct_rnn_cell(num_units, cell_type, dropout) batch = _read_batch( cell=cell, @@ -435,7 +475,6 @@ def _get_rnn_model_fn(target_column, labels=labels, mode=mode, num_unroll=num_unroll, - num_rnn_layers=num_rnn_layers, batch_size=batch_size, sequence_feature_columns=sequence_feature_columns, context_feature_columns=context_feature_columns, @@ -448,7 +487,7 @@ def _get_rnn_model_fn(target_column, labels = sequence_features.pop(rnn_common.RNNKeys.LABELS_KEY) inputs = _prepare_inputs_for_rnn(sequence_features, context_features, sequence_feature_columns, num_unroll) - state_name = _get_lstm_state_names(num_rnn_layers) + state_name = _get_state_names(cell) rnn_activations, final_state = construct_state_saving_rnn( cell=cell, inputs=inputs, @@ -490,54 +529,17 @@ def _get_rnn_model_fn(target_column, return _rnn_model_fn -def _get_lstm_state_names(num_rnn_layers): - """Returns a num_rnn_layers long list of lstm state name pairs. - - Args: - num_rnn_layers: The number of layers in the RNN. - - Returns: - A num_rnn_layers long list of lstm state name pairs of the form: - ['lstm_state_cN', 'lstm_state_mN'] for all N from 0 to num_rnn_layers. - """ - return [['lstm_state_c' + str(i), 'lstm_state_m' + str(i)] - for i in range(num_rnn_layers)] - - -# TODO(jtbates): Allow users to specify cell types other than LSTM. -def lstm_cell(num_units, num_rnn_layers, dropout_keep_probabilities): - """Constructs a `MultiRNNCell` with num_rnn_layers `BasicLSTMCell`s. - - Args: - num_units: The number of units in the `RNNCell`. - num_rnn_layers: The number of layers in the RNN. - dropout_keep_probabilities: a list whose elements are either floats in - `[0.0, 1.0]` or `None`. It must have length `num_rnn_layers + 1`. - - Returns: - An intiialized `MultiRNNCell`. - """ - - cells = [ - rnn_cell.BasicLSTMCell(num_units=num_units, state_is_tuple=True) - for _ in range(num_rnn_layers) - ] - if dropout_keep_probabilities: - cells = rnn_common.apply_dropout(cells, dropout_keep_probabilities) - return rnn_cell.MultiRNNCell(cells) - - class StateSavingRnnEstimator(estimator.Estimator): def __init__(self, problem_type, - num_units, num_unroll, batch_size, sequence_feature_columns, context_feature_columns=None, num_classes=None, - num_rnn_layers=1, + num_units=None, + cell_type='basic_rnn', optimizer_type='SGD', learning_rate=0.1, predict_probabilities=False, @@ -555,7 +557,6 @@ class StateSavingRnnEstimator(estimator.Estimator): Args: problem_type: `ProblemType.CLASSIFICATION` or `ProblemType.LINEAR_REGRESSION`. - num_units: The size of the RNN cells. num_unroll: Python integer, how many time steps to unroll at a time. The input sequences of length `k` are then split into `k / num_unroll` many segments. @@ -568,8 +569,12 @@ class StateSavingRnnEstimator(estimator.Estimator): steps. All items in the set should be instances of classes derived from `FeatureColumn`. num_classes: The number of classes for categorization. Used only and - required if `problem_type` is `ProblemType.CLASSIFICATION` - num_rnn_layers: Number of RNN layers. + required if `problem_type` is `ProblemType.CLASSIFICATION`. + num_units: A list of integers indicating the number of units in the + `RNNCell`s in each layer. Either `num_units` is specified or `cell_type` + is an instance of `RNNCell`. + cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of + 'basic_rnn,' 'lstm' or 'gru'. optimizer_type: The type of optimizer to use. Either a subclass of `Optimizer`, an instance of an `Optimizer` or a string. Strings must be one of 'Adagrad', 'Adam', 'Ftrl', Momentum', 'RMSProp', or 'SGD'. @@ -582,7 +587,7 @@ class StateSavingRnnEstimator(estimator.Estimator): gradient_clipping_norm: Parameter used for gradient clipping. If `None`, then no clipping is performed. dropout_keep_probabilities: a list of dropout keep probabilities or - `None`. If given a list, it must have length `num_rnn_layers + 1`. + `None`. If given a list, it must have length `len(num_units) + 1`. model_dir: The directory in which to save and restore the model graph, parameters, etc. config: A `RunConfig` instance. @@ -599,11 +604,18 @@ class StateSavingRnnEstimator(estimator.Estimator): seed: Fixes the random seed used for generating input keys by the SQSS. Raises: + ValueError: Both or neither of the following are true: (a) `num_units` is + specified and (b) `cell_type` is an instance of `RNNCell`. ValueError: `problem_type` is not one of `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`. ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but `num_classes` is not specified. """ + if (num_units is not None) == isinstance(cell_type, rnn_cell.RNNCell): + raise ValueError( + 'Either num_units is specified OR cell_type is an instance of ' + 'RNNCell. Got num_units = {} and cell_type = {}.'.format( + num_units, cell_type)) name = 'MultiValueStateSavingRNN' if problem_type == constants.ProblemType.LINEAR_REGRESSION: @@ -625,12 +637,12 @@ class StateSavingRnnEstimator(estimator.Estimator): optimizer_type = momentum_opt.MomentumOptimizer(learning_rate, momentum) rnn_model_fn = _get_rnn_model_fn( + cell_type=cell_type, target_column=target_column, problem_type=problem_type, optimizer=optimizer_type, num_unroll=num_unroll, num_units=num_units, - num_rnn_layers=num_rnn_layers, num_threads=num_threads, queue_capacity=queue_capacity, batch_size=batch_size, @@ -684,8 +696,7 @@ def multi_value_rnn_regressor(num_units, describing context features, i.e., features that apply accross all time steps. All items in the set should be instances of classes derived from `FeatureColumn`. - num_rnn_layers: Number of RNN layers. Leave this at its default value 1 - if passing a `cell_type` that is already a MultiRNNCell. + num_rnn_layers: Number of RNN layers. optimizer_type: The type of optimizer to use. Either a subclass of `Optimizer`, an instance of an `Optimizer` or a string. Strings must be one of 'Adagrad', 'Momentum' or 'SGD'. @@ -713,15 +724,16 @@ def multi_value_rnn_regressor(num_units, Returns: An initialized `Estimator`. """ + num_units = [num_units for _ in range(num_rnn_layers)] return StateSavingRnnEstimator( constants.ProblemType.LINEAR_REGRESSION, - num_units, num_unroll, batch_size, sequence_feature_columns, context_feature_columns=context_feature_columns, num_classes=None, - num_rnn_layers=num_rnn_layers, + num_units=num_units, + cell_type='lstm', optimizer_type=optimizer_type, learning_rate=learning_rate, predict_probabilities=False, @@ -803,15 +815,16 @@ def multi_value_rnn_classifier(num_classes, Returns: An initialized `Estimator`. """ + num_units = [num_units for _ in range(num_rnn_layers)] return StateSavingRnnEstimator( constants.ProblemType.CLASSIFICATION, - num_units, num_unroll, batch_size, sequence_feature_columns, context_feature_columns=context_feature_columns, num_classes=num_classes, - num_rnn_layers=num_rnn_layers, + num_units=num_units, + cell_type='lstm', optimizer_type=optimizer_type, learning_rate=learning_rate, predict_probabilities=predict_probabilities, diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py index 643a7d987b..f5bd03429c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py @@ -324,22 +324,21 @@ class StateSavingRnnEstimatorTest(test.TestCase): def _getModelFnOpsForMode(self, mode): """Helper for testGetRnnModelFn{Train,Eval,Infer}().""" - num_units = 4 - num_rnn_layers = 1 + num_units = [4] seq_columns = [ feature_column.real_valued_column( - 'inputs', dimension=num_units) + 'inputs', dimension=1) ] features = { 'inputs': constant_op.constant([1., 2., 3.]), } labels = constant_op.constant([1., 0., 1.]) model_fn = ssre._get_rnn_model_fn( + cell_type='basic_rnn', target_column=target_column_lib.multi_class_target(n_classes=2), optimizer='SGD', num_unroll=2, num_units=num_units, - num_rnn_layers=num_rnn_layers, num_threads=1, queue_capacity=10, batch_size=1, @@ -380,14 +379,14 @@ class StateSavingRnnEstimatorTest(test.TestCase): def testExport(self): input_feature_key = 'magic_input_feature_key' batch_size = 8 - cell_size = 4 + num_units = [4] sequence_length = 10 num_unroll = 2 num_classes = 2 seq_columns = [ feature_column.real_valued_column( - 'inputs', dimension=cell_size) + 'inputs', dimension=4) ] def get_input_fn(mode, seed): @@ -414,7 +413,7 @@ class StateSavingRnnEstimatorTest(test.TestCase): def estimator_fn(): return ssre.StateSavingRnnEstimator( constants.ProblemType.CLASSIFICATION, - num_units=cell_size, + num_units=num_units, num_unroll=num_unroll, batch_size=batch_size, sequence_feature_columns=seq_columns, @@ -518,8 +517,8 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): sequence_length = 64 train_steps = 250 eval_steps = 20 - num_units = 4 num_rnn_layers = 1 + num_units = [4] * num_rnn_layers learning_rate = 0.3 loss_threshold = 0.035 @@ -539,14 +538,14 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): seq_columns = [ feature_column.real_valued_column( - 'inputs', dimension=num_units) + 'inputs', dimension=1) ] config = run_config.RunConfig(tf_random_seed=1234) dropout_keep_probabilities = [0.9] * (num_rnn_layers + 1) sequence_estimator = ssre.StateSavingRnnEstimator( constants.ProblemType.LINEAR_REGRESSION, num_units=num_units, - num_rnn_layers=num_rnn_layers, + cell_type='lstm', num_unroll=num_unroll, batch_size=batch_size, sequence_feature_columns=seq_columns, @@ -578,7 +577,7 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): sequence_length = 32 train_steps = 200 eval_steps = 20 - num_units = 4 + num_units = [4] learning_rate = 0.5 accuracy_threshold = 0.9 @@ -596,12 +595,13 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): seq_columns = [ feature_column.real_valued_column( - 'inputs', dimension=num_units) + 'inputs', dimension=1) ] config = run_config.RunConfig(tf_random_seed=21212) sequence_estimator = ssre.StateSavingRnnEstimator( constants.ProblemType.CLASSIFICATION, num_units=num_units, + cell_type='lstm', num_unroll=num_unroll, batch_size=batch_size, sequence_feature_columns=seq_columns, @@ -649,7 +649,7 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): num_unroll = 7 # not a divisor of sequence_length train_steps = 350 eval_steps = 30 - num_units = 4 + num_units = [4] learning_rate = 0.4 accuracy_threshold = 0.65 @@ -684,6 +684,7 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): sequence_estimator = ssre.StateSavingRnnEstimator( constants.ProblemType.CLASSIFICATION, num_units=num_units, + cell_type='basic_rnn', num_unroll=num_unroll, batch_size=batch_size, sequence_feature_columns=sequence_feature_columns, |