aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-24 11:12:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-24 11:26:55 -0800
commitf772b5dc5c047660a54c9f6bf54db13c7f72c012 (patch)
treef644893f1697b61c72a3b49ab72876d3153435e7
parent7cac355fda28ae7d0139089afe4a95df9d81130b (diff)
Moved common code out of RNN Estimators.
Change: 148480357
-rw-r--r--tensorflow/contrib/learn/BUILD26
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py231
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py92
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/rnn_common.py199
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py116
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py280
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py73
7 files changed, 453 insertions, 564 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 00be830686..188e29d500 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -76,6 +76,16 @@ py_library(
],
)
+# Exposes constants without having to build the entire :learn target.
+py_library(
+ name = "estimator_constants_py",
+ srcs = [
+ "python/learn/estimators/constants.py",
+ "python/learn/estimators/prediction_key.py",
+ ],
+ srcs_version = "PY2AND3",
+)
+
py_test(
name = "data_feeder_test",
size = "small",
@@ -865,6 +875,22 @@ py_test(
)
py_test(
+ name = "rnn_common_test",
+ size = "medium",
+ srcs = ["python/learn/estimators/rnn_common_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":learn",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ ],
+)
+
+py_test(
name = "ops_test",
size = "small",
srcs = ["python/learn/ops/ops_test.py"],
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
index 6b5b9a6dd9..dcb9048633 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
@@ -27,6 +27,7 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
+from tensorflow.contrib.learn.python.learn.estimators import rnn_common
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -42,90 +43,14 @@ class PredictionType(object):
MULTIPLE_VALUE = 2
-# NOTE(jamieas): As of February 7, 2017, some of the `RNNKeys` have been removed
-# and replaced with values from `prediction_key.PredictionKey`. The key
-# `RNNKeys.PREDICTIONS_KEY` has been replaced by
-# `prediction_key.PredictionKey.SCORES` for regression and
-# `prediction_key.PredictionKey.CLASSES` for classification. The key
-# `RNNKeys.PROBABILITIES_KEY` has been replaced by
-# `prediction_key.PredictionKey.PROBABILITIES`.
-class RNNKeys(object):
- SEQUENCE_LENGTH_KEY = 'sequence_length'
- STATE_PREFIX = 'rnn_cell_state'
-
_CELL_TYPES = {'basic_rnn': contrib_rnn.BasicRNNCell,
'lstm': contrib_rnn.LSTMCell,
'gru': contrib_rnn.GRUCell,}
-def mask_activations_and_labels(activations, labels, sequence_lengths):
- """Remove entries outside `sequence_lengths` and returned flattened results.
-
- Args:
- activations: Output of the RNN, shape `[batch_size, padded_length, k]`.
- labels: Label values, shape `[batch_size, padded_length]`.
- sequence_lengths: A `Tensor` of shape `[batch_size]` with the unpadded
- length of each sequence. If `None`, then each sequence is unpadded.
-
- Returns:
- activations_masked: `logit` values with those beyond `sequence_lengths`
- removed for each batch. Batches are then concatenated. Shape
- `[tf.sum(sequence_lengths), k]` if `sequence_lengths` is not `None` and
- shape `[batch_size * padded_length, k]` otherwise.
- labels_masked: Label values after removing unneeded entries. Shape
- `[tf.sum(sequence_lengths)]` if `sequence_lengths` is not `None` and shape
- `[batch_size * padded_length]` otherwise.
- """
- with ops.name_scope('mask_activations_and_labels',
- values=[activations, labels, sequence_lengths]):
- labels_shape = array_ops.shape(labels)
- batch_size = labels_shape[0]
- padded_length = labels_shape[1]
- if sequence_lengths is None:
- flattened_dimension = padded_length * batch_size
- activations_masked = array_ops.reshape(activations,
- [flattened_dimension, -1])
- labels_masked = array_ops.reshape(labels, [flattened_dimension])
- else:
- mask = array_ops.sequence_mask(sequence_lengths, padded_length)
- activations_masked = array_ops.boolean_mask(activations, mask)
- labels_masked = array_ops.boolean_mask(labels, mask)
- return activations_masked, labels_masked
-
-
-def select_last_activations(activations, sequence_lengths):
- """Selects the nth set of activations for each n in `sequence_length`.
-
- Reuturns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not
- `None`, then `output[i, :] = activations[i, sequence_length[i], :]`. If
- `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`.
-
- Args:
- activations: A `Tensor` with shape `[batch_size, padded_length, k]`.
- sequence_lengths: A `Tensor` with shape `[batch_size]` or `None`.
- Returns:
- A `Tensor` of shape `[batch_size, k]`.
- """
- with ops.name_scope('select_last_activations',
- values=[activations, sequence_lengths]):
- activations_shape = array_ops.shape(activations)
- batch_size = activations_shape[0]
- padded_length = activations_shape[1]
- num_label_columns = activations_shape[2]
- if sequence_lengths is None:
- sequence_lengths = padded_length
- reshaped_activations = array_ops.reshape(activations,
- [-1, num_label_columns])
- indices = math_ops.range(batch_size) * padded_length + sequence_lengths - 1
- last_activations = array_ops.gather(reshaped_activations, indices)
- last_activations.set_shape(
- [activations.get_shape()[0], activations.get_shape()[2]])
- return last_activations
-
-
def _get_state_name(i):
"""Constructs the name string for state component `i`."""
- return '{}_{}'.format(RNNKeys.STATE_PREFIX, i)
+ return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i)
def state_tuple_to_dict(state):
@@ -351,13 +276,11 @@ def _get_eval_metric_ops(problem_type, prediction_type, sequence_length,
if problem_type == constants.ProblemType.CLASSIFICATION:
# Multi value classification
if prediction_type == PredictionType.MULTIPLE_VALUE:
- masked_predictions, masked_labels = mask_activations_and_labels(
- prediction_dict[prediction_key.PredictionKey.CLASSES],
- labels,
+ mask_predictions, mask_labels = rnn_common.mask_activations_and_labels(
+ prediction_dict[prediction_key.PredictionKey.CLASSES], labels,
sequence_length)
eval_metric_ops['accuracy'] = metrics.streaming_accuracy(
- predictions=masked_predictions,
- labels=masked_labels)
+ predictions=mask_predictions, labels=mask_labels)
# Single value classification
elif prediction_type == PredictionType.SINGLE_VALUE:
eval_metric_ops['accuracy'] = metrics.streaming_accuracy(
@@ -373,64 +296,6 @@ def _get_eval_metric_ops(problem_type, prediction_type, sequence_length,
return eval_metric_ops
-def _multi_value_predictions(
- activations, target_column, problem_type, predict_probabilities):
- """Maps `activations` from the RNN to predictions for multi value models.
-
- If `predict_probabilities` is `False`, this function returns a `dict`
- containing single entry with key `PREDICTIONS_KEY`. If `predict_probabilities`
- is `True`, it will contain a second entry with key `PROBABILITIES_KEY`. The
- value of this entry is a `Tensor` of probabilities with shape
- `[batch_size, padded_length, num_classes]`.
-
- Note that variable length inputs will yield some predictions that don't have
- meaning. For example, if `sequence_length = [3, 2]`, then prediction `[1, 2]`
- has no meaningful interpretation.
-
- Args:
- activations: Output from an RNN. Should have dtype `float32` and shape
- `[batch_size, padded_length, ?]`.
- target_column: An initialized `TargetColumn`, calculate predictions.
- problem_type: Either `ProblemType.CLASSIFICATION` or
- `ProblemType.LINEAR_REGRESSION`.
- predict_probabilities: A Python boolean, indicating whether probabilities
- should be returned. Should only be set to `True` for
- classification/logistic regression problems.
- Returns:
- A `dict` mapping strings to `Tensors`.
- """
- with ops.name_scope('MultiValuePrediction'):
- activations_shape = array_ops.shape(activations)
- flattened_activations = array_ops.reshape(activations,
- [-1, activations_shape[2]])
- prediction_dict = {}
- if predict_probabilities:
- flat_probabilities = target_column.logits_to_predictions(
- flattened_activations, proba=True)
- flat_predictions = math_ops.argmax(flat_probabilities, 1)
- if target_column.num_label_columns == 1:
- probability_shape = array_ops.concat([activations_shape[:2], [2]], 0)
- else:
- probability_shape = activations_shape
- probabilities = array_ops.reshape(
- flat_probabilities,
- probability_shape,
- name=prediction_key.PredictionKey.PROBABILITIES)
- prediction_dict[
- prediction_key.PredictionKey.PROBABILITIES] = probabilities
- else:
- flat_predictions = target_column.logits_to_predictions(
- flattened_activations, proba=False)
- predictions_name = (prediction_key.PredictionKey.CLASSES
- if problem_type == constants.ProblemType.CLASSIFICATION
- else prediction_key.PredictionKey.SCORES)
- predictions = array_ops.reshape(
- flat_predictions, [activations_shape[0], activations_shape[1]],
- name=predictions_name)
- prediction_dict[predictions_name] = predictions
- return prediction_dict
-
-
def _single_value_predictions(activations,
sequence_length,
target_column,
@@ -460,7 +325,8 @@ def _single_value_predictions(activations,
A `dict` mapping strings to `Tensors`.
"""
with ops.name_scope('SingleValuePrediction'):
- last_activations = select_last_activations(activations, sequence_length)
+ last_activations = rnn_common.select_last_activations(
+ activations, sequence_length)
predictions_name = (prediction_key.PredictionKey.CLASSES
if problem_type == constants.ProblemType.CLASSIFICATION
else prediction_key.PredictionKey.SCORES)
@@ -495,7 +361,7 @@ def _multi_value_loss(
A scalar `Tensor` containing the loss.
"""
with ops.name_scope('MultiValueLoss'):
- activations_masked, labels_masked = mask_activations_and_labels(
+ activations_masked, labels_masked = rnn_common.mask_activations_and_labels(
activations, labels, sequence_length)
return target_column.loss(activations_masked, labels_masked, features)
@@ -519,7 +385,8 @@ def _single_value_loss(
"""
with ops.name_scope('SingleValueLoss'):
- last_activations = select_last_activations(activations, sequence_length)
+ last_activations = rnn_common.select_last_activations(
+ activations, sequence_length)
return target_column.loss(last_activations, labels, features)
@@ -544,29 +411,33 @@ def _get_output_alternatives(prediction_type,
if prediction_type == PredictionType.MULTIPLE_VALUE:
return None
if prediction_type == PredictionType.SINGLE_VALUE:
- prediction_dict_no_state = {k: v for k, v in prediction_dict.items()
- if RNNKeys.STATE_PREFIX not in k}
+ prediction_dict_no_state = {
+ k: v
+ for k, v in prediction_dict.items()
+ if rnn_common.RNNKeys.STATE_PREFIX not in k
+ }
return {'dynamic_rnn_output': (problem_type, prediction_dict_no_state)}
raise ValueError('Unrecognized prediction_type: {}'.format(prediction_type))
-def _get_dynamic_rnn_model_fn(cell_type,
- num_units,
- target_column,
- problem_type,
- prediction_type,
- optimizer,
- sequence_feature_columns,
- context_feature_columns=None,
- predict_probabilities=False,
- learning_rate=None,
- gradient_clipping_norm=None,
- dropout_keep_probabilities=None,
- sequence_length_key=RNNKeys.SEQUENCE_LENGTH_KEY,
- dtype=dtypes.float32,
- parallel_iterations=None,
- swap_memory=True,
- name='DynamicRNNModel'):
+def _get_dynamic_rnn_model_fn(
+ cell_type,
+ num_units,
+ target_column,
+ problem_type,
+ prediction_type,
+ optimizer,
+ sequence_feature_columns,
+ context_feature_columns=None,
+ predict_probabilities=False,
+ learning_rate=None,
+ gradient_clipping_norm=None,
+ dropout_keep_probabilities=None,
+ sequence_length_key=rnn_common.RNNKeys.SEQUENCE_LENGTH_KEY,
+ dtype=dtypes.float32,
+ parallel_iterations=None,
+ swap_memory=True,
+ name='DynamicRNNModel'):
"""Creates an RNN model function for an `Estimator`.
The model function returns an instance of `ModelFnOps`. When
@@ -667,7 +538,7 @@ def _get_dynamic_rnn_model_fn(cell_type,
loss = None # Created below for modes TRAIN and EVAL.
if prediction_type == PredictionType.MULTIPLE_VALUE:
- prediction_dict = _multi_value_predictions(
+ prediction_dict = rnn_common.multi_value_predictions(
rnn_activations, target_column, problem_type, predict_probabilities)
if mode != model_fn.ModeKeys.INFER:
loss = _multi_value_loss(
@@ -711,38 +582,6 @@ def _get_dynamic_rnn_model_fn(cell_type,
return _dynamic_rnn_model_fn
-def _apply_dropout(
- cells, dropout_keep_probabilities, random_seed=None):
- """Applies dropout to the outputs and inputs of `cell`.
-
- Args:
- cells: A list of `RNNCell`s.
- dropout_keep_probabilities: a list whose elements are either floats in
- `[0.0, 1.0]` or `None`. It must have length one greater than `cells`.
- random_seed: Seed for random dropout.
-
- Returns:
- A list of `RNNCell`s, the result of applying the supplied dropouts.
-
- Raises:
- ValueError: If `len(dropout_keep_probabilities) != len(cells) + 1`.
- """
- if len(dropout_keep_probabilities) != len(cells) + 1:
- raise ValueError(
- 'The number of dropout probabilites must be one greater than the '
- 'number of cells. Got {} cells and {} dropout probabilities.'.format(
- len(cells), len(dropout_keep_probabilities)))
- wrapped_cells = [
- contrib_rnn.DropoutWrapper(cell, prob, 1.0, random_seed)
- for cell, prob in zip(cells[:-1], dropout_keep_probabilities[:-2])
- ]
- wrapped_cells.append(contrib_rnn.DropoutWrapper(
- cells[-1],
- dropout_keep_probabilities[-2],
- dropout_keep_probabilities[-1]))
- return wrapped_cells
-
-
def _get_single_cell(cell_type, num_units):
"""Constructs and return an single `RNNCell`.
@@ -789,7 +628,7 @@ def _construct_rnn_cell(cell_type, num_units, dropout_keep_probabilities):
cells = [_get_single_cell(cell_type, n) for n in num_units]
if dropout_keep_probabilities:
- cells = _apply_dropout(cells, dropout_keep_probabilities)
+ cells = rnn_common.apply_dropout(cells, dropout_keep_probabilities)
if len(cells) == 1:
return cells[0]
return contrib_rnn.MultiRNNCell(cells)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
index d7d606deed..443d336214 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
@@ -35,6 +35,7 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import dynamic_rnn_estimator
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
+from tensorflow.contrib.learn.python.learn.estimators import rnn_common
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.python.client import session
@@ -191,93 +192,6 @@ class DynamicRnnEstimatorTest(test.TestCase):
expected_state_shape = np.array([3, self.NUM_RNN_CELL_UNITS])
self.assertAllEqual(expected_state_shape, final_state.shape)
- def testMaskActivationsAndLabels(self):
- """Test `mask_activations_and_labels`."""
- batch_size = 4
- padded_length = 6
- num_classes = 4
- np.random.seed(1234)
- sequence_length = np.random.randint(0, padded_length + 1, batch_size)
- activations = np.random.rand(batch_size, padded_length, num_classes)
- labels = np.random.randint(0, num_classes, [batch_size, padded_length])
- (activations_masked_t,
- labels_masked_t) = dynamic_rnn_estimator.mask_activations_and_labels(
- constant_op.constant(
- activations, dtype=dtypes.float32),
- constant_op.constant(
- labels, dtype=dtypes.int32),
- constant_op.constant(
- sequence_length, dtype=dtypes.int32))
-
- with session.Session() as sess:
- activations_masked, labels_masked = sess.run(
- [activations_masked_t, labels_masked_t])
-
- expected_activations_shape = [sum(sequence_length), num_classes]
- np.testing.assert_equal(
- expected_activations_shape, activations_masked.shape,
- 'Wrong activations shape. Expected {}; got {}.'.format(
- expected_activations_shape, activations_masked.shape))
-
- expected_labels_shape = [sum(sequence_length)]
- np.testing.assert_equal(expected_labels_shape, labels_masked.shape,
- 'Wrong labels shape. Expected {}; got {}.'.format(
- expected_labels_shape, labels_masked.shape))
- masked_index = 0
- for i in range(batch_size):
- for j in range(sequence_length[i]):
- actual_activations = activations_masked[masked_index]
- expected_activations = activations[i, j, :]
- np.testing.assert_almost_equal(
- expected_activations,
- actual_activations,
- err_msg='Unexpected logit value at index [{}, {}, :].'
- ' Expected {}; got {}.'.format(i, j, expected_activations,
- actual_activations))
-
- actual_labels = labels_masked[masked_index]
- expected_labels = labels[i, j]
- np.testing.assert_almost_equal(
- expected_labels,
- actual_labels,
- err_msg='Unexpected logit value at index [{}, {}].'
- ' Expected {}; got {}.'.format(i, j, expected_labels,
- actual_labels))
- masked_index += 1
-
- def testSelectLastActivations(self):
- """Test `select_last_activations`."""
- batch_size = 4
- padded_length = 6
- num_classes = 4
- np.random.seed(4444)
- sequence_length = np.random.randint(0, padded_length + 1, batch_size)
- activations = np.random.rand(batch_size, padded_length, num_classes)
- last_activations_t = dynamic_rnn_estimator.select_last_activations(
- constant_op.constant(
- activations, dtype=dtypes.float32),
- constant_op.constant(
- sequence_length, dtype=dtypes.int32))
-
- with session.Session() as sess:
- last_activations = sess.run(last_activations_t)
-
- expected_activations_shape = [batch_size, num_classes]
- np.testing.assert_equal(
- expected_activations_shape, last_activations.shape,
- 'Wrong activations shape. Expected {}; got {}.'.format(
- expected_activations_shape, last_activations.shape))
-
- for i in range(batch_size):
- actual_activations = last_activations[i, :]
- expected_activations = activations[i, sequence_length[i] - 1, :]
- np.testing.assert_almost_equal(
- expected_activations,
- actual_activations,
- err_msg='Unexpected logit value at index [{}, :].'
- ' Expected {}; got {}.'.format(i, expected_activations,
- actual_activations))
-
def testGetOutputAlternatives(self):
test_cases = (
(dynamic_rnn_estimator.PredictionType.SINGLE_VALUE,
@@ -618,7 +532,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
incremental_state_dict = {
k: v
for (k, v) in prediction_dict.items()
- if k.startswith(dynamic_rnn_estimator.RNNKeys.STATE_PREFIX)
+ if k.startswith(rnn_common.RNNKeys.STATE_PREFIX)
}
return prediction_dict
@@ -636,7 +550,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
err_msg='Mismatch on last {} predictions.'.format(prediction_steps[-1]))
# Check that final states are identical.
for k, v in pred_all_at_once.items():
- if k.startswith(dynamic_rnn_estimator.RNNKeys.STATE_PREFIX):
+ if k.startswith(rnn_common.RNNKeys.STATE_PREFIX):
np.testing.assert_array_equal(
v, pred_step_by_step[k], err_msg='Mismatch on state {}.'.format(k))
diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py
new file mode 100644
index 0000000000..b3aea01ffa
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py
@@ -0,0 +1,199 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Common operations for RNN Estimators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import rnn as contrib_rnn
+from tensorflow.contrib.learn.python.learn.estimators import constants
+from tensorflow.contrib.learn.python.learn.estimators import prediction_key
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+
+# NOTE(jtbates): As of February 10, 2017, some of the `RNNKeys` have been
+# removed and replaced with values from `prediction_key.PredictionKey`. The key
+# `RNNKeys.PREDICTIONS_KEY` has been replaced by
+# `prediction_key.PredictionKey.SCORES` for regression and
+# `prediction_key.PredictionKey.CLASSES` for classification. The key
+# `RNNKeys.PROBABILITIES_KEY` has been replaced by
+# `prediction_key.PredictionKey.PROBABILITIES`.
+class RNNKeys(object):
+ FINAL_STATE_KEY = 'final_state'
+ LABELS_KEY = '__labels__'
+ SEQUENCE_LENGTH_KEY = 'sequence_length'
+ STATE_PREFIX = 'rnn_cell_state'
+
+
+def apply_dropout(cells, dropout_keep_probabilities, random_seed=None):
+ """Applies dropout to the outputs and inputs of `cell`.
+
+ Args:
+ cells: A list of `RNNCell`s.
+ dropout_keep_probabilities: a list whose elements are either floats in
+ `[0.0, 1.0]` or `None`. It must have length one greater than `cells`.
+ random_seed: Seed for random dropout.
+
+ Returns:
+ A list of `RNNCell`s, the result of applying the supplied dropouts.
+
+ Raises:
+ ValueError: If `len(dropout_keep_probabilities) != len(cells) + 1`.
+ """
+ if len(dropout_keep_probabilities) != len(cells) + 1:
+ raise ValueError(
+ 'The number of dropout probabilites must be one greater than the '
+ 'number of cells. Got {} cells and {} dropout probabilities.'.format(
+ len(cells), len(dropout_keep_probabilities)))
+ wrapped_cells = [
+ contrib_rnn.DropoutWrapper(cell, prob, 1.0, random_seed)
+ for cell, prob in zip(cells[:-1], dropout_keep_probabilities[:-2])
+ ]
+ wrapped_cells.append(
+ contrib_rnn.DropoutWrapper(cells[-1], dropout_keep_probabilities[-2],
+ dropout_keep_probabilities[-1]))
+ return wrapped_cells
+
+
+def select_last_activations(activations, sequence_lengths):
+ """Selects the nth set of activations for each n in `sequence_length`.
+
+ Reuturns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not
+ `None`, then `output[i, :] = activations[i, sequence_length[i], :]`. If
+ `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`.
+
+ Args:
+ activations: A `Tensor` with shape `[batch_size, padded_length, k]`.
+ sequence_lengths: A `Tensor` with shape `[batch_size]` or `None`.
+ Returns:
+ A `Tensor` of shape `[batch_size, k]`.
+ """
+ with ops.name_scope(
+ 'select_last_activations', values=[activations, sequence_lengths]):
+ activations_shape = array_ops.shape(activations)
+ batch_size = activations_shape[0]
+ padded_length = activations_shape[1]
+ num_label_columns = activations_shape[2]
+ if sequence_lengths is None:
+ sequence_lengths = padded_length
+ reshaped_activations = array_ops.reshape(activations,
+ [-1, num_label_columns])
+ indices = math_ops.range(batch_size) * padded_length + sequence_lengths - 1
+ last_activations = array_ops.gather(reshaped_activations, indices)
+ last_activations.set_shape(
+ [activations.get_shape()[0], activations.get_shape()[2]])
+ return last_activations
+
+
+def mask_activations_and_labels(activations, labels, sequence_lengths):
+ """Remove entries outside `sequence_lengths` and returned flattened results.
+
+ Args:
+ activations: Output of the RNN, shape `[batch_size, padded_length, k]`.
+ labels: Label values, shape `[batch_size, padded_length]`.
+ sequence_lengths: A `Tensor` of shape `[batch_size]` with the unpadded
+ length of each sequence. If `None`, then each sequence is unpadded.
+
+ Returns:
+ activations_masked: `logit` values with those beyond `sequence_lengths`
+ removed for each batch. Batches are then concatenated. Shape
+ `[tf.sum(sequence_lengths), k]` if `sequence_lengths` is not `None` and
+ shape `[batch_size * padded_length, k]` otherwise.
+ labels_masked: Label values after removing unneeded entries. Shape
+ `[tf.sum(sequence_lengths)]` if `sequence_lengths` is not `None` and shape
+ `[batch_size * padded_length]` otherwise.
+ """
+ with ops.name_scope(
+ 'mask_activations_and_labels',
+ values=[activations, labels, sequence_lengths]):
+ labels_shape = array_ops.shape(labels)
+ batch_size = labels_shape[0]
+ padded_length = labels_shape[1]
+ if sequence_lengths is None:
+ flattened_dimension = padded_length * batch_size
+ activations_masked = array_ops.reshape(activations,
+ [flattened_dimension, -1])
+ labels_masked = array_ops.reshape(labels, [flattened_dimension])
+ else:
+ mask = array_ops.sequence_mask(sequence_lengths, padded_length)
+ activations_masked = array_ops.boolean_mask(activations, mask)
+ labels_masked = array_ops.boolean_mask(labels, mask)
+ return activations_masked, labels_masked
+
+
+def multi_value_predictions(activations, target_column, problem_type,
+ predict_probabilities):
+ """Maps `activations` from the RNN to predictions for multi value models.
+
+ If `predict_probabilities` is `False`, this function returns a `dict`
+ containing single entry with key `prediction_key.PredictionKey.CLASSES` for
+ `problem_type` `ProblemType.CLASSIFICATION` or
+ `prediction_key.PredictionKey.SCORE` for `problem_type`
+ `ProblemType.LINEAR_REGRESSION`.
+
+ If `predict_probabilities` is `True`, it will contain a second entry with key
+ `prediction_key.PredictionKey.PROBABILITIES`. The
+ value of this entry is a `Tensor` of probabilities with shape
+ `[batch_size, padded_length, num_classes]`.
+
+ Note that variable length inputs will yield some predictions that don't have
+ meaning. For example, if `sequence_length = [3, 2]`, then prediction `[1, 2]`
+ has no meaningful interpretation.
+
+ Args:
+ activations: Output from an RNN. Should have dtype `float32` and shape
+ `[batch_size, padded_length, ?]`.
+ target_column: An initialized `TargetColumn`, calculate predictions.
+ problem_type: Either `ProblemType.CLASSIFICATION` or
+ `ProblemType.LINEAR_REGRESSION`.
+ predict_probabilities: A Python boolean, indicating whether probabilities
+ should be returned. Should only be set to `True` for
+ classification/logistic regression problems.
+ Returns:
+ A `dict` mapping strings to `Tensors`.
+ """
+ with ops.name_scope('MultiValuePrediction'):
+ activations_shape = array_ops.shape(activations)
+ flattened_activations = array_ops.reshape(activations,
+ [-1, activations_shape[2]])
+ prediction_dict = {}
+ if predict_probabilities:
+ flat_probabilities = target_column.logits_to_predictions(
+ flattened_activations, proba=True)
+ flat_predictions = math_ops.argmax(flat_probabilities, 1)
+ if target_column.num_label_columns == 1:
+ probability_shape = array_ops.concat([activations_shape[:2], [2]], 0)
+ else:
+ probability_shape = activations_shape
+ probabilities = array_ops.reshape(
+ flat_probabilities,
+ probability_shape,
+ name=prediction_key.PredictionKey.PROBABILITIES)
+ prediction_dict[
+ prediction_key.PredictionKey.PROBABILITIES] = probabilities
+ else:
+ flat_predictions = target_column.logits_to_predictions(
+ flattened_activations, proba=False)
+ predictions_name = (prediction_key.PredictionKey.CLASSES
+ if problem_type == constants.ProblemType.CLASSIFICATION
+ else prediction_key.PredictionKey.SCORES)
+ predictions = array_ops.reshape(
+ flat_predictions, [activations_shape[0], activations_shape[1]],
+ name=predictions_name)
+ prediction_dict[predictions_name] = predictions
+ return prediction_dict
diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py
new file mode 100644
index 0000000000..82563141cc
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py
@@ -0,0 +1,116 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for layers.rnn_common."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.learn.python.learn.estimators import rnn_common
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.platform import test
+
+
+class RnnCommonTest(test.TestCase):
+
+ def testMaskActivationsAndLabels(self):
+ """Test `mask_activations_and_labels`."""
+ batch_size = 4
+ padded_length = 6
+ num_classes = 4
+ np.random.seed(1234)
+ sequence_length = np.random.randint(0, padded_length + 1, batch_size)
+ activations = np.random.rand(batch_size, padded_length, num_classes)
+ labels = np.random.randint(0, num_classes, [batch_size, padded_length])
+ (activations_masked_t,
+ labels_masked_t) = rnn_common.mask_activations_and_labels(
+ constant_op.constant(activations, dtype=dtypes.float32),
+ constant_op.constant(labels, dtype=dtypes.int32),
+ constant_op.constant(sequence_length, dtype=dtypes.int32))
+
+ with self.test_session() as sess:
+ activations_masked, labels_masked = sess.run(
+ [activations_masked_t, labels_masked_t])
+
+ expected_activations_shape = [sum(sequence_length), num_classes]
+ np.testing.assert_equal(
+ expected_activations_shape, activations_masked.shape,
+ 'Wrong activations shape. Expected {}; got {}.'.format(
+ expected_activations_shape, activations_masked.shape))
+
+ expected_labels_shape = [sum(sequence_length)]
+ np.testing.assert_equal(expected_labels_shape, labels_masked.shape,
+ 'Wrong labels shape. Expected {}; got {}.'.format(
+ expected_labels_shape, labels_masked.shape))
+ masked_index = 0
+ for i in range(batch_size):
+ for j in range(sequence_length[i]):
+ actual_activations = activations_masked[masked_index]
+ expected_activations = activations[i, j, :]
+ np.testing.assert_almost_equal(
+ expected_activations,
+ actual_activations,
+ err_msg='Unexpected logit value at index [{}, {}, :].'
+ ' Expected {}; got {}.'.format(i, j, expected_activations,
+ actual_activations))
+
+ actual_labels = labels_masked[masked_index]
+ expected_labels = labels[i, j]
+ np.testing.assert_almost_equal(
+ expected_labels,
+ actual_labels,
+ err_msg='Unexpected logit value at index [{}, {}].'
+ ' Expected {}; got {}.'.format(i, j, expected_labels,
+ actual_labels))
+ masked_index += 1
+
+ def testSelectLastActivations(self):
+ """Test `select_last_activations`."""
+ batch_size = 4
+ padded_length = 6
+ num_classes = 4
+ np.random.seed(4444)
+ sequence_length = np.random.randint(0, padded_length + 1, batch_size)
+ activations = np.random.rand(batch_size, padded_length, num_classes)
+ last_activations_t = rnn_common.select_last_activations(
+ constant_op.constant(activations, dtype=dtypes.float32),
+ constant_op.constant(sequence_length, dtype=dtypes.int32))
+
+ with session.Session() as sess:
+ last_activations = sess.run(last_activations_t)
+
+ expected_activations_shape = [batch_size, num_classes]
+ np.testing.assert_equal(
+ expected_activations_shape, last_activations.shape,
+ 'Wrong activations shape. Expected {}; got {}.'.format(
+ expected_activations_shape, last_activations.shape))
+
+ for i in range(batch_size):
+ actual_activations = last_activations[i, :]
+ expected_activations = activations[i, sequence_length[i] - 1, :]
+ np.testing.assert_almost_equal(
+ expected_activations,
+ actual_activations,
+ err_msg='Unexpected logit value at index [{}, :].'
+ ' Expected {}; got {}.'.format(i, expected_activations,
+ actual_activations))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
index f15c717c91..48ad279cc0 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
@@ -31,67 +31,17 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
+from tensorflow.contrib.learn.python.learn.estimators import rnn_common
from tensorflow.contrib.rnn.python.ops import core_rnn
from tensorflow.contrib.training.python.training import sequence_queueing_state_saver as sqss
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.training import momentum as momentum_opt
from tensorflow.python.util import nest
-# NOTE(jtbates): As of February 10, 2017, some of the `RNNKeys` have been
-# removed and replaced with values from `prediction_key.PredictionKey`. The key
-# `RNNKeys.PREDICTIONS_KEY` has been replaced by
-# `prediction_key.PredictionKey.SCORES` for regression and
-# `prediction_key.PredictionKey.CLASSES` for classification. The key
-# `RNNKeys.PROBABILITIES_KEY` has been replaced by
-# `prediction_key.PredictionKey.PROBABILITIES`.
-class RNNKeys(object):
- FINAL_STATE_KEY = 'final_state'
- LABELS_KEY = '__labels__'
- STATE_PREFIX = 'rnn_cell_state'
-
-
-# TODO(b/34272579): mask_activations_and_labels is shared with
-# dynamic_rnn_estimator.py. Move it to a common library.
-def mask_activations_and_labels(activations, labels, sequence_lengths):
- """Remove entries outside `sequence_lengths` and returned flattened results.
-
- Args:
- activations: Output of the RNN, shape `[batch_size, padded_length, k]`.
- labels: Label values, shape `[batch_size, padded_length]`.
- sequence_lengths: A `Tensor` of shape `[batch_size]` with the unpadded
- length of each sequence. If `None`, then each sequence is unpadded.
-
- Returns:
- activations_masked: `logit` values with those beyond `sequence_lengths`
- removed for each batch. Batches are then concatenated. Shape
- `[tf.sum(sequence_lengths), k]` if `sequence_lengths` is not `None` and
- shape `[batch_size * padded_length, k]` otherwise.
- labels_masked: Label values after removing unneeded entries. Shape
- `[tf.sum(sequence_lengths)]` if `sequence_lengths` is not `None` and shape
- `[batch_size * padded_length]` otherwise.
- """
- with ops.name_scope('mask_activations_and_labels',
- values=[activations, labels, sequence_lengths]):
- labels_shape = array_ops.shape(labels)
- batch_size = labels_shape[0]
- padded_length = labels_shape[1]
- if sequence_lengths is None:
- flattened_dimension = padded_length * batch_size
- activations_masked = array_ops.reshape(activations,
- [flattened_dimension, -1])
- labels_masked = array_ops.reshape(labels, [flattened_dimension])
- else:
- mask = array_ops.sequence_mask(sequence_lengths, padded_length)
- activations_masked = array_ops.boolean_mask(activations, mask)
- labels_masked = array_ops.boolean_mask(labels, mask)
- return activations_masked, labels_masked
-
-
def construct_state_saving_rnn(cell,
inputs,
num_label_columns,
@@ -134,7 +84,8 @@ def construct_state_saving_rnn(cell,
activation_fn=None,
trainable=True)
# Use `identity` to rename `final_state`.
- final_state = array_ops.identity(final_state, name=RNNKeys.FINAL_STATE_KEY)
+ final_state = array_ops.identity(
+ final_state, name=rnn_common.RNNKeys.FINAL_STATE_KEY)
return activations, final_state
@@ -156,7 +107,7 @@ def _mask_multivalue(sequence_length, metric):
"""
@functools.wraps(metric)
def _metric(predictions, labels, *args, **kwargs):
- predictions, labels = mask_activations_and_labels(
+ predictions, labels = rnn_common.mask_activations_and_labels(
predictions, labels, sequence_length)
return metric(predictions, labels, *args, **kwargs)
return _metric
@@ -184,70 +135,6 @@ def _get_default_metrics(problem_type, sequence_length):
return default_metrics
-# TODO(b/34272579): _multi_value_predictions is shared with
-# dynamic_rnn_estimator.py. Move it to a common library.
-def _multi_value_predictions(
- activations, target_column, problem_type, predict_probabilities):
- """Maps `activations` from the RNN to predictions for multi value models.
-
- If `predict_probabilities` is `False`, this function returns a `dict`
- containing single entry with key `prediction_key.PredictionKey.CLASSES` for
- `problem_type` `ProblemType.CLASSIFICATION` or
- `prediction_key.PredictionKey.SCORE` for `problem_type`
- `ProblemType.LINEAR_REGRESSION`.
-
- If `predict_probabilities` is `True`, it will contain a second entry with key
- `prediction_key.PredictionKey.PROBABILITIES`. The
- value of this entry is a `Tensor` of probabilities with shape
- `[batch_size, padded_length, num_classes]`.
-
- Note that variable length inputs will yield some predictions that don't have
- meaning. For example, if `sequence_length = [3, 2]`, then prediction `[1, 2]`
- has no meaningful interpretation.
-
- Args:
- activations: Output from an RNN. Should have dtype `float32` and shape
- `[batch_size, padded_length, ?]`.
- target_column: An initialized `TargetColumn`, calculate predictions.
- problem_type: Either `ProblemType.CLASSIFICATION` or
- `ProblemType.LINEAR_REGRESSION`.
- predict_probabilities: A Python boolean, indicating whether probabilities
- should be returned. Should only be set to `True` for
- classification/logistic regression problems.
- Returns:
- A `dict` mapping strings to `Tensors`.
- """
- with ops.name_scope('MultiValuePrediction'):
- activations_shape = array_ops.shape(activations)
- flattened_activations = array_ops.reshape(activations,
- [-1, activations_shape[2]])
- prediction_dict = {}
- if predict_probabilities:
- flat_probabilities = target_column.logits_to_predictions(
- flattened_activations, proba=True)
- flat_predictions = math_ops.argmax(flat_probabilities, 1)
- if target_column.num_label_columns == 1:
- probability_shape = array_ops.concat([activations_shape[:2], [2]], 0)
- else:
- probability_shape = activations_shape
- probabilities = array_ops.reshape(
- flat_probabilities, probability_shape,
- name=prediction_key.PredictionKey.PROBABILITIES)
- prediction_dict[
- prediction_key.PredictionKey.PROBABILITIES] = probabilities
- else:
- flat_predictions = target_column.logits_to_predictions(
- flattened_activations, proba=False)
- predictions_name = (prediction_key.PredictionKey.CLASSES
- if problem_type == constants.ProblemType.CLASSIFICATION
- else prediction_key.PredictionKey.SCORES)
- predictions = array_ops.reshape(
- flat_predictions, [activations_shape[0], activations_shape[1]],
- name=predictions_name)
- prediction_dict[predictions_name] = predictions
- return prediction_dict
-
-
def _multi_value_loss(
activations, labels, sequence_length, target_column, features):
"""Maps `activations` from the RNN to loss for multi value models.
@@ -266,7 +153,7 @@ def _multi_value_loss(
A scalar `Tensor` containing the loss.
"""
with ops.name_scope('MultiValueLoss'):
- activations_masked, labels_masked = mask_activations_and_labels(
+ activations_masked, labels_masked = rnn_common.mask_activations_and_labels(
activations, labels, sequence_length)
return target_column.loss(activations_masked, labels_masked, features)
@@ -343,7 +230,7 @@ def _prepare_features_for_sqss(features, labels, mode,
# Add labels to the resulting sequence features dict.
if mode != model_fn.ModeKeys.INFER:
- sequence_features[RNNKeys.LABELS_KEY] = labels
+ sequence_features[rnn_common.RNNKeys.LABELS_KEY] = labels
return sequence_features, context_features
@@ -353,7 +240,7 @@ def _read_batch(cell,
labels,
mode,
num_unroll,
- num_layers,
+ num_rnn_layers,
batch_size,
sequence_feature_columns,
context_feature_columns=None,
@@ -373,7 +260,7 @@ def _read_batch(cell,
num_unroll: Python integer, how many time steps to unroll at a time.
The input sequences of length `k` are then split into `k / num_unroll`
many segments.
- num_layers: Python integer, number of layers in the RNN.
+ num_rnn_layers: Python integer, number of layers in the RNN.
batch_size: Python integer, the size of the minibatch produced by the SQSS.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
@@ -399,8 +286,8 @@ def _read_batch(cell,
# Set up stateful queue reader.
states = {}
- state_names = _get_lstm_state_names(num_layers)
- for i in range(num_layers):
+ state_names = _get_lstm_state_names(num_rnn_layers)
+ for i in range(num_rnn_layers):
states[state_names[i][0]] = array_ops.squeeze(values[i][0], axis=0)
states[state_names[i][1]] = array_ops.squeeze(values[i][1], axis=0)
@@ -423,36 +310,9 @@ def _read_batch(cell,
capacity=queue_capacity)
-def apply_dropout(
- cell, input_keep_probability, output_keep_probability, random_seed=None):
- """Apply dropout to the outputs and inputs of `cell`.
-
- Args:
- cell: An `RNNCell`.
- input_keep_probability: Probability to keep inputs to `cell`. If `None`,
- no dropout is applied.
- output_keep_probability: Probability to keep outputs of `cell`. If `None`,
- no dropout is applied.
- random_seed: Seed for random dropout.
-
- Returns:
- An `RNNCell`, the result of applying the supplied dropouts to `cell`.
- """
- input_prob_none = input_keep_probability is None
- output_prob_none = output_keep_probability is None
- if input_prob_none and output_prob_none:
- return cell
- if input_prob_none:
- input_keep_probability = 1.0
- if output_prob_none:
- output_keep_probability = 1.0
- return rnn_cell.DropoutWrapper(
- cell, input_keep_probability, output_keep_probability, random_seed)
-
-
def _get_state_name(i):
"""Constructs the name string for state component `i`."""
- return '{}_{}'.format(RNNKeys.STATE_PREFIX, i)
+ return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i)
def state_tuple_to_dict(state):
@@ -531,12 +391,12 @@ def _prepare_inputs_for_rnn(sequence_features, context_features,
axis=1)
-def _get_rnn_model_fn(cell,
- target_column,
+def _get_rnn_model_fn(target_column,
problem_type,
optimizer,
num_unroll,
- num_layers,
+ num_units,
+ num_rnn_layers,
num_threads,
queue_capacity,
batch_size,
@@ -545,14 +405,12 @@ def _get_rnn_model_fn(cell,
predict_probabilities=False,
learning_rate=None,
gradient_clipping_norm=None,
- input_keep_probability=None,
- output_keep_probability=None,
+ dropout_keep_probabilities=None,
name='StateSavingRNNModel',
seed=None):
"""Creates a state saving RNN model function for an `Estimator`.
Args:
- cell: An initialized `RNNCell` to be used in the RNN.
target_column: An initialized `TargetColumn`, used to calculate prediction
and loss.
problem_type: `ProblemType.CLASSIFICATION` or
@@ -562,7 +420,8 @@ def _get_rnn_model_fn(cell,
num_unroll: Python integer, how many time steps to unroll at a time.
The input sequences of length `k` are then split into `k / num_unroll`
many segments.
- num_layers: Python integer, number of layers in the RNN.
+ num_units: The number of units in the `RNNCell`.
+ num_rnn_layers: Python integer, number of layers in the RNN.
num_threads: The Python integer number of threads enqueuing input examples
into a queue.
queue_capacity: The max capacity of the queue in number of examples.
@@ -583,10 +442,8 @@ def _get_rnn_model_fn(cell,
learning_rate: Learning rate used for optimization. This argument has no
effect if `optimizer` is an instance of an `Optimizer`.
gradient_clipping_norm: A float. Gradients will be clipped to this value.
- input_keep_probability: Probability to keep inputs to `cell`. If `None`,
- no dropout is applied.
- output_keep_probability: Probability to keep outputs of `cell`. If `None`,
- no dropout is applied.
+ dropout_keep_probabilities: a list of dropout keep probabilities or `None`.
+ If given a list, it must have length `num_rnn_layers + 1`.
name: A string that will be used to create a scope for the RNN.
seed: Fixes the random seed used for generating input keys by the SQSS.
@@ -618,19 +475,18 @@ def _get_rnn_model_fn(cell,
def _rnn_model_fn(features, labels, mode):
"""The model to be passed to an `Estimator`."""
with ops.name_scope(name):
- if mode == model_fn.ModeKeys.TRAIN:
- cell_for_mode = apply_dropout(
- cell, input_keep_probability, output_keep_probability)
- else:
- cell_for_mode = cell
+ dropout = (dropout_keep_probabilities
+ if mode == model_fn.ModeKeys.TRAIN
+ else None)
+ cell = lstm_cell(num_units, num_rnn_layers, dropout)
batch = _read_batch(
- cell=cell_for_mode,
+ cell=cell,
features=features,
labels=labels,
mode=mode,
num_unroll=num_unroll,
- num_layers=num_layers,
+ num_rnn_layers=num_rnn_layers,
batch_size=batch_size,
sequence_feature_columns=sequence_feature_columns,
context_feature_columns=context_feature_columns,
@@ -640,22 +496,20 @@ def _get_rnn_model_fn(cell,
sequence_features = batch.sequences
context_features = batch.context
if mode != model_fn.ModeKeys.INFER:
- labels = sequence_features.pop(RNNKeys.LABELS_KEY)
+ labels = sequence_features.pop(rnn_common.RNNKeys.LABELS_KEY)
inputs = _prepare_inputs_for_rnn(sequence_features, context_features,
sequence_feature_columns, num_unroll)
- state_name = _get_lstm_state_names(num_layers)
+ state_name = _get_lstm_state_names(num_rnn_layers)
rnn_activations, final_state = construct_state_saving_rnn(
- cell=cell_for_mode,
+ cell=cell,
inputs=inputs,
num_label_columns=target_column.num_label_columns,
state_saver=batch,
state_name=state_name)
loss = None # Created below for modes TRAIN and EVAL.
- prediction_dict = _multi_value_predictions(rnn_activations,
- target_column,
- problem_type,
- predict_probabilities)
+ prediction_dict = rnn_common.multi_value_predictions(
+ rnn_activations, target_column, problem_type, predict_probabilities)
if mode != model_fn.ModeKeys.INFER:
loss = _multi_value_loss(rnn_activations, labels, batch.length,
target_column, features)
@@ -686,35 +540,41 @@ def _get_rnn_model_fn(cell,
return _rnn_model_fn
-def _get_lstm_state_names(num_layers):
- """Returns a num_layers long list of lstm state name pairs.
+def _get_lstm_state_names(num_rnn_layers):
+ """Returns a num_rnn_layers long list of lstm state name pairs.
Args:
- num_layers: The number of layers in the RNN.
+ num_rnn_layers: The number of layers in the RNN.
Returns:
- A num_layers long list of lstm state name pairs of the form:
- ['lstm_state_cN', 'lstm_state_mN'] for all N from 0 to num_layers.
+ A num_rnn_layers long list of lstm state name pairs of the form:
+ ['lstm_state_cN', 'lstm_state_mN'] for all N from 0 to num_rnn_layers.
"""
return [['lstm_state_c' + str(i), 'lstm_state_m' + str(i)]
- for i in range(num_layers)]
+ for i in range(num_rnn_layers)]
# TODO(jtbates): Allow users to specify cell types other than LSTM.
-def lstm_cell(num_units, num_layers):
- """Constructs a `MultiRNNCell` with num_layers `BasicLSTMCell`s.
+def lstm_cell(num_units, num_rnn_layers, dropout_keep_probabilities):
+ """Constructs a `MultiRNNCell` with num_rnn_layers `BasicLSTMCell`s.
Args:
num_units: The number of units in the `RNNCell`.
- num_layers: The number of layers in the RNN.
+ num_rnn_layers: The number of layers in the RNN.
+ dropout_keep_probabilities: a list whose elements are either floats in
+ `[0.0, 1.0]` or `None`. It must have length `num_rnn_layers + 1`.
Returns:
An intiialized `MultiRNNCell`.
"""
- return rnn_cell.MultiRNNCell([
- rnn_cell.BasicLSTMCell(
- num_units=num_units, state_is_tuple=True) for _ in range(num_layers)
- ])
+
+ cells = [
+ rnn_cell.BasicLSTMCell(num_units=num_units, state_is_tuple=True)
+ for _ in range(num_rnn_layers)
+ ]
+ if dropout_keep_probabilities:
+ cells = rnn_common.apply_dropout(cells, dropout_keep_probabilities)
+ return rnn_cell.MultiRNNCell(cells)
class StateSavingRnnEstimator(estimator.Estimator):
@@ -733,9 +593,7 @@ class StateSavingRnnEstimator(estimator.Estimator):
predict_probabilities=False,
momentum=None,
gradient_clipping_norm=5.0,
- # TODO(jtbates): Support lists of input_keep_probability.
- input_keep_probability=None,
- output_keep_probability=None,
+ dropout_keep_probabilities=None,
model_dir=None,
config=None,
feature_engineering_fn=None,
@@ -773,10 +631,8 @@ class StateSavingRnnEstimator(estimator.Estimator):
momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
gradient_clipping_norm: Parameter used for gradient clipping. If `None`,
then no clipping is performed.
- input_keep_probability: Probability to keep inputs to `cell`. If `None`,
- no dropout is applied.
- output_keep_probability: Probability to keep outputs of `cell`. If `None`,
- no dropout is applied.
+ dropout_keep_probabilities: a list of dropout keep probabilities or
+ `None`. If given a list, it must have length `num_rnn_layers + 1`.
model_dir: The directory in which to save and restore the model graph,
parameters, etc.
config: A `RunConfig` instance.
@@ -818,14 +674,13 @@ class StateSavingRnnEstimator(estimator.Estimator):
if optimizer_type == 'Momentum':
optimizer_type = momentum_opt.MomentumOptimizer(learning_rate, momentum)
- cell = lstm_cell(num_units, num_rnn_layers)
rnn_model_fn = _get_rnn_model_fn(
- cell=cell,
target_column=target_column,
problem_type=problem_type,
optimizer=optimizer_type,
num_unroll=num_unroll,
- num_layers=num_rnn_layers,
+ num_units=num_units,
+ num_rnn_layers=num_rnn_layers,
num_threads=num_threads,
queue_capacity=queue_capacity,
batch_size=batch_size,
@@ -834,8 +689,7 @@ class StateSavingRnnEstimator(estimator.Estimator):
predict_probabilities=predict_probabilities,
learning_rate=learning_rate,
gradient_clipping_norm=gradient_clipping_norm,
- input_keep_probability=input_keep_probability,
- output_keep_probability=output_keep_probability,
+ dropout_keep_probabilities=dropout_keep_probabilities,
name=name,
seed=seed)
@@ -858,8 +712,7 @@ def multi_value_rnn_regressor(num_units,
learning_rate=0.1,
momentum=None,
gradient_clipping_norm=5.0,
- input_keep_probability=None,
- output_keep_probability=None,
+ dropout_keep_probabilities=None,
model_dir=None,
config=None,
feature_engineering_fn=None,
@@ -891,10 +744,8 @@ def multi_value_rnn_regressor(num_units,
momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
gradient_clipping_norm: Parameter used for gradient clipping. If `None`,
then no clipping is performed.
- input_keep_probability: Probability to keep inputs to `cell`. If `None`,
- no dropout is applied.
- output_keep_probability: Probability to keep outputs of `cell`. If `None`,
- no dropout is applied.
+ dropout_keep_probabilities: a list of dropout keep probabilities or `None`.
+ If given a list, it must have length `num_rnn_layers + 1`.
model_dir: The directory in which to save and restore the model graph,
parameters, etc.
config: A `RunConfig` instance.
@@ -926,8 +777,7 @@ def multi_value_rnn_regressor(num_units,
predict_probabilities=False,
momentum=momentum,
gradient_clipping_norm=gradient_clipping_norm,
- input_keep_probability=input_keep_probability,
- output_keep_probability=output_keep_probability,
+ dropout_keep_probabilities=dropout_keep_probabilities,
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn,
@@ -950,8 +800,7 @@ def multi_value_rnn_classifier(num_classes,
predict_probabilities=False,
momentum=None,
gradient_clipping_norm=5.0,
- input_keep_probability=None,
- output_keep_probability=None,
+ dropout_keep_probabilities=None,
model_dir=None,
config=None,
feature_engineering_fn=None,
@@ -985,10 +834,8 @@ def multi_value_rnn_classifier(num_classes,
momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
gradient_clipping_norm: Parameter used for gradient clipping. If `None`,
then no clipping is performed.
- input_keep_probability: Probability to keep inputs to `cell`. If `None`,
- no dropout is applied.
- output_keep_probability: Probability to keep outputs of `cell`. If `None`,
- no dropout is applied.
+ dropout_keep_probabilities: a list of dropout keep probabilities or `None`.
+ If given a list, it must have length `num_rnn_layers + 1`.
model_dir: The directory in which to save and restore the model graph,
parameters, etc.
config: A `RunConfig` instance.
@@ -1020,8 +867,7 @@ def multi_value_rnn_classifier(num_classes,
predict_probabilities=predict_probabilities,
momentum=momentum,
gradient_clipping_norm=gradient_clipping_norm,
- input_keep_probability=input_keep_probability,
- output_keep_probability=output_keep_probability,
+ dropout_keep_probabilities=dropout_keep_probabilities,
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn,
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
index 4ad6c01fee..05fc6a89a4 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
@@ -34,6 +34,7 @@ from tensorflow.contrib.layers.python.layers import target_column as target_colu
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
+from tensorflow.contrib.learn.python.learn.estimators import rnn_common
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.learn.python.learn.estimators import state_saving_rnn_estimator as ssre
from tensorflow.python.framework import constant_op
@@ -288,7 +289,7 @@ class StateSavingRnnEstimatorTest(test.TestCase):
]
expected_sequence = {
- ssre.RNNKeys.LABELS_KEY:
+ rnn_common.RNNKeys.LABELS_KEY:
np.array([5., 5., 5., 5.]),
seq_feature_name:
np.array([1., 1., 1., 1.]),
@@ -327,78 +328,24 @@ class StateSavingRnnEstimatorTest(test.TestCase):
assert_equal(expected_sequence, actual_sequence)
assert_equal(expected_context, actual_context)
- def testMaskActivationsAndLabels(self):
- """Test `mask_activations_and_labels`."""
- batch_size = 4
- padded_length = 6
- num_classes = 4
- np.random.seed(1234)
- sequence_length = np.random.randint(0, padded_length + 1, batch_size)
- activations = np.random.rand(batch_size, padded_length, num_classes)
- labels = np.random.randint(0, num_classes, [batch_size, padded_length])
- (activations_masked_t, labels_masked_t) = ssre.mask_activations_and_labels(
- constant_op.constant(
- activations, dtype=dtypes.float32),
- constant_op.constant(
- labels, dtype=dtypes.int32),
- constant_op.constant(
- sequence_length, dtype=dtypes.int32))
-
- with self.test_session() as sess:
- activations_masked, labels_masked = sess.run(
- [activations_masked_t, labels_masked_t])
-
- expected_activations_shape = [sum(sequence_length), num_classes]
- np.testing.assert_equal(
- expected_activations_shape, activations_masked.shape,
- 'Wrong activations shape. Expected {}; got {}.'.format(
- expected_activations_shape, activations_masked.shape))
-
- expected_labels_shape = [sum(sequence_length)]
- np.testing.assert_equal(expected_labels_shape, labels_masked.shape,
- 'Wrong labels shape. Expected {}; got {}.'.format(
- expected_labels_shape, labels_masked.shape))
- masked_index = 0
- for i in range(batch_size):
- for j in range(sequence_length[i]):
- actual_activations = activations_masked[masked_index]
- expected_activations = activations[i, j, :]
- np.testing.assert_almost_equal(
- expected_activations,
- actual_activations,
- err_msg='Unexpected logit value at index [{}, {}, :].'
- ' Expected {}; got {}.'.format(i, j, expected_activations,
- actual_activations))
-
- actual_labels = labels_masked[masked_index]
- expected_labels = labels[i, j]
- np.testing.assert_almost_equal(
- expected_labels,
- actual_labels,
- err_msg='Unexpected logit value at index [{}, {}].'
- ' Expected {}; got {}.'.format(i, j, expected_labels,
- actual_labels))
- masked_index += 1
-
def _getModelFnOpsForMode(self, mode):
"""Helper for testGetRnnModelFn{Train,Eval,Infer}()."""
- cell_size = 4
- num_layers = 1
- cell = ssre.lstm_cell(cell_size, num_layers)
+ num_units = 4
+ num_rnn_layers = 1
seq_columns = [
feature_column.real_valued_column(
- 'inputs', dimension=cell_size)
+ 'inputs', dimension=num_units)
]
features = {
'inputs': constant_op.constant([1., 2., 3.]),
}
labels = constant_op.constant([1., 0., 1.])
model_fn = ssre._get_rnn_model_fn(
- cell=cell,
target_column=target_column_lib.multi_class_target(n_classes=2),
optimizer='SGD',
num_unroll=2,
- num_layers=num_layers,
+ num_units=num_units,
+ num_rnn_layers=num_rnn_layers,
num_threads=1,
queue_capacity=10,
batch_size=1,
@@ -578,6 +525,7 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
train_steps = 250
eval_steps = 20
num_units = 4
+ num_rnn_layers = 1
learning_rate = 0.3
loss_threshold = 0.035
@@ -600,15 +548,16 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
'inputs', dimension=num_units)
]
config = run_config.RunConfig(tf_random_seed=1234)
+ dropout_keep_probabilities = [0.9] * (num_rnn_layers + 1)
sequence_estimator = ssre.StateSavingRnnEstimator(
constants.ProblemType.LINEAR_REGRESSION,
num_units=num_units,
+ num_rnn_layers=num_rnn_layers,
num_unroll=num_unroll,
batch_size=batch_size,
sequence_feature_columns=seq_columns,
learning_rate=learning_rate,
- input_keep_probability=0.9,
- output_keep_probability=0.9,
+ dropout_keep_probabilities=dropout_keep_probabilities,
config=config,
queue_capacity=2 * batch_size,
seed=1234)