diff options
author | 2016-11-17 07:53:47 -0800 | |
---|---|---|
committer | 2016-11-17 08:03:57 -0800 | |
commit | 20b211ffef7dc6e7aecda361962503d018450621 (patch) | |
tree | 88ab83a2ad95235780dddd03b87eed042884d9b5 | |
parent | abc6089dfb547bc52b967ef1007e61f1d191f81d (diff) |
Add dropout to dynamic_rnn_estimator.
Change: 139460368
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py | 97 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py | 4 |
2 files changed, 87 insertions, 14 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index 7326ec3e10..b054883252 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -45,6 +45,7 @@ class PredictionType(object): SINGLE_VALUE = 1 MULTIPLE_VALUE = 2 + class RNNKeys(object): SEQUENCE_LENGTH_KEY = 'sequence_length' INITIAL_STATE_KEY = 'initial_state' @@ -191,10 +192,10 @@ def _concatenate_context_input(sequence_input, context_input): def build_sequence_input(features, - sequence_feature_columns, - context_feature_columns, - weight_collections=None, - scope=None): + sequence_feature_columns, + context_feature_columns, + weight_collections=None, + scope=None): """Combine sequence and context features into input for an RNN. Args: @@ -229,12 +230,12 @@ def build_sequence_input(features, def construct_rnn(initial_state, - sequence_input, - cell, - num_label_columns, - dtype=dtypes.float32, - parallel_iterations=32, - swap_memory=False): + sequence_input, + cell, + num_label_columns, + dtype=dtypes.float32, + parallel_iterations=32, + swap_memory=False): """Build an RNN and apply a fully connected layer to get the desired output. Args: @@ -464,6 +465,33 @@ def _single_value_loss( return target_column.loss(last_activations, labels, features) +def apply_dropout( + cell, input_keep_probability, output_keep_probability, random_seed=None): + """Apply dropout to the outputs and inputs of `cell`. + + Args: + cell: An `RNNCell`. + input_keep_probability: Probability to keep inputs to `cell`. If `None`, + no dropout is applied. + output_keep_probability: Probability to keep outputs to `cell`. If `None`, + no dropout is applied. + random_seed: Seed for random dropout. + + Returns: + An `RNNCell`, the result of applying the supplied dropouts to `cell`. + """ + input_prob_none = input_keep_probability is None + output_prob_none = output_keep_probability is None + if input_prob_none and output_prob_none: + return cell + if input_prob_none: + input_keep_probability = 1.0 + if output_prob_none: + output_keep_probability = 1.0 + return rnn_cell.DropoutWrapper( + cell, input_keep_probability, output_keep_probability, random_seed) + + def _get_dynamic_rnn_model_fn(cell, target_column, problem_type, @@ -474,6 +502,8 @@ def _get_dynamic_rnn_model_fn(cell, predict_probabilities=False, learning_rate=None, gradient_clipping_norm=None, + input_keep_probability=None, + output_keep_probability=None, sequence_length_key=RNNKeys.SEQUENCE_LENGTH_KEY, initial_state_key=RNNKeys.INITIAL_STATE_KEY, dtype=dtypes.float32, @@ -504,6 +534,10 @@ def _get_dynamic_rnn_model_fn(cell, learning_rate: Learning rate used for optimization. This argument has no effect if `optimizer` is an instance of an `Optimizer`. gradient_clipping_norm: A float. Gradients will be clipped to this value. + input_keep_probability: Probability to keep inputs to `cell`. If `None`, + no dropout is applied. + output_keep_probability: Probability to keep outputs to `cell`. If `None`, + no dropout is applied. sequence_length_key: The key that will be used to look up sequence length in the `features` dict. initial_state_key: The key that will be used to look up initial_state in @@ -551,12 +585,17 @@ def _get_dynamic_rnn_model_fn(cell, initial_state = features.get(initial_state_key) sequence_length = features.get(sequence_length_key) sequence_input = build_sequence_input(features, - sequence_feature_columns, - context_feature_columns) + sequence_feature_columns, + context_feature_columns) + if mode == estimator.ModeKeys.TRAIN: + cell_for_mode = apply_dropout( + cell, input_keep_probability, output_keep_probability) + else: + cell_for_mode = cell rnn_activations, final_state = construct_rnn( initial_state, sequence_input, - cell, + cell_for_mode, target_column.num_label_columns, dtype=dtype, parallel_iterations=parallel_iterations, @@ -635,6 +674,8 @@ def multi_value_rnn_regressor(num_units, learning_rate=0.1, momentum=None, gradient_clipping_norm=10.0, + input_keep_probability=None, + output_keep_probability=None, model_dir=None, config=None, params=None, @@ -661,6 +702,10 @@ def multi_value_rnn_regressor(num_units, momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. gradient_clipping_norm: Parameter used for gradient clipping. If `None`, then no clipping is performed. + input_keep_probability: Probability to keep inputs to `cell`. If `None`, + no dropout is applied. + output_keep_probability: Probability to keep outputs to `cell`. If `None`, + no dropout is applied. model_dir: Directory to use for The directory in which to save and restore the model graph, parameters, etc. config: A `RunConfig` instance. @@ -686,6 +731,8 @@ def multi_value_rnn_regressor(num_units, context_feature_columns=context_feature_columns, learning_rate=learning_rate, gradient_clipping_norm=gradient_clipping_norm, + input_keep_probability=input_keep_probability, + output_keep_probability=output_keep_probability, name='MultiValueRnnRegressor') return estimator.Estimator(model_fn=dynamic_rnn_model_fn, @@ -707,6 +754,8 @@ def multi_value_rnn_classifier(num_classes, predict_probabilities=False, momentum=None, gradient_clipping_norm=10.0, + input_keep_probability=None, + output_keep_probability=None, model_dir=None, config=None, params=None, @@ -735,6 +784,10 @@ def multi_value_rnn_classifier(num_classes, momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. gradient_clipping_norm: Parameter used for gradient clipping. If `None`, then no clipping is performed. + input_keep_probability: Probability to keep inputs to `cell`. If `None`, + no dropout is applied. + output_keep_probability: Probability to keep outputs to `cell`. If `None`, + no dropout is applied. model_dir: Directory to use for The directory in which to save and restore the model graph, parameters, etc. config: A `RunConfig` instance. @@ -761,6 +814,8 @@ def multi_value_rnn_classifier(num_classes, predict_probabilities=predict_probabilities, learning_rate=learning_rate, gradient_clipping_norm=gradient_clipping_norm, + input_keep_probability=input_keep_probability, + output_keep_probability=output_keep_probability, name='MultiValueRnnClassifier') return estimator.Estimator(model_fn=dynamic_rnn_model_fn, @@ -780,6 +835,8 @@ def single_value_rnn_regressor(num_units, learning_rate=0.1, momentum=None, gradient_clipping_norm=10.0, + input_keep_probability=None, + output_keep_probability=None, model_dir=None, config=None, params=None, @@ -805,6 +862,10 @@ def single_value_rnn_regressor(num_units, momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. gradient_clipping_norm: Parameter used for gradient clipping. If `None`, then no clipping is performed. + input_keep_probability: Probability to keep inputs to `cell`. If `None`, + no dropout is applied. + output_keep_probability: Probability to keep outputs to `cell`. If `None`, + no dropout is applied. model_dir: Directory to use for The directory in which to save and restore the model graph, parameters, etc. config: A `RunConfig` instance. @@ -830,6 +891,8 @@ def single_value_rnn_regressor(num_units, context_feature_columns=context_feature_columns, learning_rate=learning_rate, gradient_clipping_norm=gradient_clipping_norm, + input_keep_probability=input_keep_probability, + output_keep_probability=output_keep_probability, name='SingleValueRnnRegressor') return estimator.Estimator(model_fn=dynamic_rnn_model_fn, @@ -851,6 +914,8 @@ def single_value_rnn_classifier(num_classes, predict_probabilities=False, momentum=None, gradient_clipping_norm=10.0, + input_keep_probability=None, + output_keep_probability=None, model_dir=None, config=None, params=None, @@ -879,6 +944,10 @@ def single_value_rnn_classifier(num_classes, momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. gradient_clipping_norm: Parameter used for gradient clipping. If `None`, then no clipping is performed. + input_keep_probability: Probability to keep inputs to `cell`. If `None`, + no dropout is applied. + output_keep_probability: Probability to keep outputs to `cell`. If `None`, + no dropout is applied. model_dir: Directory to use for The directory in which to save and restore the model graph, parameters, etc. config: A `RunConfig` instance. @@ -905,6 +974,8 @@ def single_value_rnn_classifier(num_classes, predict_probabilities=predict_probabilities, learning_rate=learning_rate, gradient_clipping_norm=gradient_clipping_norm, + input_keep_probability=input_keep_probability, + output_keep_probability=output_keep_probability, name='SingleValueRnnClassifier') return estimator.Estimator(model_fn=dynamic_rnn_model_fn, diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index 4cdd7c29f3..a2df6de6fd 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -275,7 +275,7 @@ class DynamicRnnEstimatorTest(tf.test.TestCase): ' Expected {}; got {}.'.format(i, expected_activations, actual_activations)) - +# TODO(jamieas): move all tests below to a benchmark test. class DynamicRNNEstimatorLearningTest(tf.test.TestCase): """Learning tests for dymanic RNN Estimators.""" @@ -313,6 +313,8 @@ class DynamicRNNEstimatorLearningTest(tf.test.TestCase): num_units=cell_size, sequence_feature_columns=seq_columns, learning_rate=learning_rate, + input_keep_probability=0.9, + output_keep_probability=0.9, config=config) train_input_fn = get_sin_input_fn( |