aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-17 07:53:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-17 08:03:57 -0800
commit20b211ffef7dc6e7aecda361962503d018450621 (patch)
tree88ab83a2ad95235780dddd03b87eed042884d9b5
parentabc6089dfb547bc52b967ef1007e61f1d191f81d (diff)
Add dropout to dynamic_rnn_estimator.
Change: 139460368
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py97
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py4
2 files changed, 87 insertions, 14 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
index 7326ec3e10..b054883252 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
@@ -45,6 +45,7 @@ class PredictionType(object):
SINGLE_VALUE = 1
MULTIPLE_VALUE = 2
+
class RNNKeys(object):
SEQUENCE_LENGTH_KEY = 'sequence_length'
INITIAL_STATE_KEY = 'initial_state'
@@ -191,10 +192,10 @@ def _concatenate_context_input(sequence_input, context_input):
def build_sequence_input(features,
- sequence_feature_columns,
- context_feature_columns,
- weight_collections=None,
- scope=None):
+ sequence_feature_columns,
+ context_feature_columns,
+ weight_collections=None,
+ scope=None):
"""Combine sequence and context features into input for an RNN.
Args:
@@ -229,12 +230,12 @@ def build_sequence_input(features,
def construct_rnn(initial_state,
- sequence_input,
- cell,
- num_label_columns,
- dtype=dtypes.float32,
- parallel_iterations=32,
- swap_memory=False):
+ sequence_input,
+ cell,
+ num_label_columns,
+ dtype=dtypes.float32,
+ parallel_iterations=32,
+ swap_memory=False):
"""Build an RNN and apply a fully connected layer to get the desired output.
Args:
@@ -464,6 +465,33 @@ def _single_value_loss(
return target_column.loss(last_activations, labels, features)
+def apply_dropout(
+ cell, input_keep_probability, output_keep_probability, random_seed=None):
+ """Apply dropout to the outputs and inputs of `cell`.
+
+ Args:
+ cell: An `RNNCell`.
+ input_keep_probability: Probability to keep inputs to `cell`. If `None`,
+ no dropout is applied.
+ output_keep_probability: Probability to keep outputs to `cell`. If `None`,
+ no dropout is applied.
+ random_seed: Seed for random dropout.
+
+ Returns:
+ An `RNNCell`, the result of applying the supplied dropouts to `cell`.
+ """
+ input_prob_none = input_keep_probability is None
+ output_prob_none = output_keep_probability is None
+ if input_prob_none and output_prob_none:
+ return cell
+ if input_prob_none:
+ input_keep_probability = 1.0
+ if output_prob_none:
+ output_keep_probability = 1.0
+ return rnn_cell.DropoutWrapper(
+ cell, input_keep_probability, output_keep_probability, random_seed)
+
+
def _get_dynamic_rnn_model_fn(cell,
target_column,
problem_type,
@@ -474,6 +502,8 @@ def _get_dynamic_rnn_model_fn(cell,
predict_probabilities=False,
learning_rate=None,
gradient_clipping_norm=None,
+ input_keep_probability=None,
+ output_keep_probability=None,
sequence_length_key=RNNKeys.SEQUENCE_LENGTH_KEY,
initial_state_key=RNNKeys.INITIAL_STATE_KEY,
dtype=dtypes.float32,
@@ -504,6 +534,10 @@ def _get_dynamic_rnn_model_fn(cell,
learning_rate: Learning rate used for optimization. This argument has no
effect if `optimizer` is an instance of an `Optimizer`.
gradient_clipping_norm: A float. Gradients will be clipped to this value.
+ input_keep_probability: Probability to keep inputs to `cell`. If `None`,
+ no dropout is applied.
+ output_keep_probability: Probability to keep outputs to `cell`. If `None`,
+ no dropout is applied.
sequence_length_key: The key that will be used to look up sequence length in
the `features` dict.
initial_state_key: The key that will be used to look up initial_state in
@@ -551,12 +585,17 @@ def _get_dynamic_rnn_model_fn(cell,
initial_state = features.get(initial_state_key)
sequence_length = features.get(sequence_length_key)
sequence_input = build_sequence_input(features,
- sequence_feature_columns,
- context_feature_columns)
+ sequence_feature_columns,
+ context_feature_columns)
+ if mode == estimator.ModeKeys.TRAIN:
+ cell_for_mode = apply_dropout(
+ cell, input_keep_probability, output_keep_probability)
+ else:
+ cell_for_mode = cell
rnn_activations, final_state = construct_rnn(
initial_state,
sequence_input,
- cell,
+ cell_for_mode,
target_column.num_label_columns,
dtype=dtype,
parallel_iterations=parallel_iterations,
@@ -635,6 +674,8 @@ def multi_value_rnn_regressor(num_units,
learning_rate=0.1,
momentum=None,
gradient_clipping_norm=10.0,
+ input_keep_probability=None,
+ output_keep_probability=None,
model_dir=None,
config=None,
params=None,
@@ -661,6 +702,10 @@ def multi_value_rnn_regressor(num_units,
momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
gradient_clipping_norm: Parameter used for gradient clipping. If `None`,
then no clipping is performed.
+ input_keep_probability: Probability to keep inputs to `cell`. If `None`,
+ no dropout is applied.
+ output_keep_probability: Probability to keep outputs to `cell`. If `None`,
+ no dropout is applied.
model_dir: Directory to use for The directory in which to save and restore
the model graph, parameters, etc.
config: A `RunConfig` instance.
@@ -686,6 +731,8 @@ def multi_value_rnn_regressor(num_units,
context_feature_columns=context_feature_columns,
learning_rate=learning_rate,
gradient_clipping_norm=gradient_clipping_norm,
+ input_keep_probability=input_keep_probability,
+ output_keep_probability=output_keep_probability,
name='MultiValueRnnRegressor')
return estimator.Estimator(model_fn=dynamic_rnn_model_fn,
@@ -707,6 +754,8 @@ def multi_value_rnn_classifier(num_classes,
predict_probabilities=False,
momentum=None,
gradient_clipping_norm=10.0,
+ input_keep_probability=None,
+ output_keep_probability=None,
model_dir=None,
config=None,
params=None,
@@ -735,6 +784,10 @@ def multi_value_rnn_classifier(num_classes,
momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
gradient_clipping_norm: Parameter used for gradient clipping. If `None`,
then no clipping is performed.
+ input_keep_probability: Probability to keep inputs to `cell`. If `None`,
+ no dropout is applied.
+ output_keep_probability: Probability to keep outputs to `cell`. If `None`,
+ no dropout is applied.
model_dir: Directory to use for The directory in which to save and restore
the model graph, parameters, etc.
config: A `RunConfig` instance.
@@ -761,6 +814,8 @@ def multi_value_rnn_classifier(num_classes,
predict_probabilities=predict_probabilities,
learning_rate=learning_rate,
gradient_clipping_norm=gradient_clipping_norm,
+ input_keep_probability=input_keep_probability,
+ output_keep_probability=output_keep_probability,
name='MultiValueRnnClassifier')
return estimator.Estimator(model_fn=dynamic_rnn_model_fn,
@@ -780,6 +835,8 @@ def single_value_rnn_regressor(num_units,
learning_rate=0.1,
momentum=None,
gradient_clipping_norm=10.0,
+ input_keep_probability=None,
+ output_keep_probability=None,
model_dir=None,
config=None,
params=None,
@@ -805,6 +862,10 @@ def single_value_rnn_regressor(num_units,
momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
gradient_clipping_norm: Parameter used for gradient clipping. If `None`,
then no clipping is performed.
+ input_keep_probability: Probability to keep inputs to `cell`. If `None`,
+ no dropout is applied.
+ output_keep_probability: Probability to keep outputs to `cell`. If `None`,
+ no dropout is applied.
model_dir: Directory to use for The directory in which to save and restore
the model graph, parameters, etc.
config: A `RunConfig` instance.
@@ -830,6 +891,8 @@ def single_value_rnn_regressor(num_units,
context_feature_columns=context_feature_columns,
learning_rate=learning_rate,
gradient_clipping_norm=gradient_clipping_norm,
+ input_keep_probability=input_keep_probability,
+ output_keep_probability=output_keep_probability,
name='SingleValueRnnRegressor')
return estimator.Estimator(model_fn=dynamic_rnn_model_fn,
@@ -851,6 +914,8 @@ def single_value_rnn_classifier(num_classes,
predict_probabilities=False,
momentum=None,
gradient_clipping_norm=10.0,
+ input_keep_probability=None,
+ output_keep_probability=None,
model_dir=None,
config=None,
params=None,
@@ -879,6 +944,10 @@ def single_value_rnn_classifier(num_classes,
momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
gradient_clipping_norm: Parameter used for gradient clipping. If `None`,
then no clipping is performed.
+ input_keep_probability: Probability to keep inputs to `cell`. If `None`,
+ no dropout is applied.
+ output_keep_probability: Probability to keep outputs to `cell`. If `None`,
+ no dropout is applied.
model_dir: Directory to use for The directory in which to save and restore
the model graph, parameters, etc.
config: A `RunConfig` instance.
@@ -905,6 +974,8 @@ def single_value_rnn_classifier(num_classes,
predict_probabilities=predict_probabilities,
learning_rate=learning_rate,
gradient_clipping_norm=gradient_clipping_norm,
+ input_keep_probability=input_keep_probability,
+ output_keep_probability=output_keep_probability,
name='SingleValueRnnClassifier')
return estimator.Estimator(model_fn=dynamic_rnn_model_fn,
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
index 4cdd7c29f3..a2df6de6fd 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
@@ -275,7 +275,7 @@ class DynamicRnnEstimatorTest(tf.test.TestCase):
' Expected {}; got {}.'.format(i, expected_activations,
actual_activations))
-
+# TODO(jamieas): move all tests below to a benchmark test.
class DynamicRNNEstimatorLearningTest(tf.test.TestCase):
"""Learning tests for dymanic RNN Estimators."""
@@ -313,6 +313,8 @@ class DynamicRNNEstimatorLearningTest(tf.test.TestCase):
num_units=cell_size,
sequence_feature_columns=seq_columns,
learning_rate=learning_rate,
+ input_keep_probability=0.9,
+ output_keep_probability=0.9,
config=config)
train_input_fn = get_sin_input_fn(