From f772b5dc5c047660a54c9f6bf54db13c7f72c012 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 24 Feb 2017 11:12:22 -0800 Subject: Moved common code out of RNN Estimators. Change: 148480357 --- tensorflow/contrib/learn/BUILD | 26 ++ .../learn/estimators/dynamic_rnn_estimator.py | 231 +++-------------- .../learn/estimators/dynamic_rnn_estimator_test.py | 92 +------ .../learn/python/learn/estimators/rnn_common.py | 199 +++++++++++++++ .../python/learn/estimators/rnn_common_test.py | 116 +++++++++ .../learn/estimators/state_saving_rnn_estimator.py | 280 +++++---------------- .../estimators/state_saving_rnn_estimator_test.py | 73 +----- 7 files changed, 453 insertions(+), 564 deletions(-) create mode 100644 tensorflow/contrib/learn/python/learn/estimators/rnn_common.py create mode 100644 tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 00be830686..188e29d500 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -76,6 +76,16 @@ py_library( ], ) +# Exposes constants without having to build the entire :learn target. +py_library( + name = "estimator_constants_py", + srcs = [ + "python/learn/estimators/constants.py", + "python/learn/estimators/prediction_key.py", + ], + srcs_version = "PY2AND3", +) + py_test( name = "data_feeder_test", size = "small", @@ -864,6 +874,22 @@ py_test( ], ) +py_test( + name = "rnn_common_test", + size = "medium", + srcs = ["python/learn/estimators/rnn_common_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":learn", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + ], +) + py_test( name = "ops_test", size = "small", diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index 6b5b9a6dd9..dcb9048633 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -27,6 +27,7 @@ from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.contrib.learn.python.learn.estimators import rnn_common from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -42,90 +43,14 @@ class PredictionType(object): MULTIPLE_VALUE = 2 -# NOTE(jamieas): As of February 7, 2017, some of the `RNNKeys` have been removed -# and replaced with values from `prediction_key.PredictionKey`. The key -# `RNNKeys.PREDICTIONS_KEY` has been replaced by -# `prediction_key.PredictionKey.SCORES` for regression and -# `prediction_key.PredictionKey.CLASSES` for classification. The key -# `RNNKeys.PROBABILITIES_KEY` has been replaced by -# `prediction_key.PredictionKey.PROBABILITIES`. -class RNNKeys(object): - SEQUENCE_LENGTH_KEY = 'sequence_length' - STATE_PREFIX = 'rnn_cell_state' - _CELL_TYPES = {'basic_rnn': contrib_rnn.BasicRNNCell, 'lstm': contrib_rnn.LSTMCell, 'gru': contrib_rnn.GRUCell,} -def mask_activations_and_labels(activations, labels, sequence_lengths): - """Remove entries outside `sequence_lengths` and returned flattened results. - - Args: - activations: Output of the RNN, shape `[batch_size, padded_length, k]`. - labels: Label values, shape `[batch_size, padded_length]`. - sequence_lengths: A `Tensor` of shape `[batch_size]` with the unpadded - length of each sequence. If `None`, then each sequence is unpadded. - - Returns: - activations_masked: `logit` values with those beyond `sequence_lengths` - removed for each batch. Batches are then concatenated. Shape - `[tf.sum(sequence_lengths), k]` if `sequence_lengths` is not `None` and - shape `[batch_size * padded_length, k]` otherwise. - labels_masked: Label values after removing unneeded entries. Shape - `[tf.sum(sequence_lengths)]` if `sequence_lengths` is not `None` and shape - `[batch_size * padded_length]` otherwise. - """ - with ops.name_scope('mask_activations_and_labels', - values=[activations, labels, sequence_lengths]): - labels_shape = array_ops.shape(labels) - batch_size = labels_shape[0] - padded_length = labels_shape[1] - if sequence_lengths is None: - flattened_dimension = padded_length * batch_size - activations_masked = array_ops.reshape(activations, - [flattened_dimension, -1]) - labels_masked = array_ops.reshape(labels, [flattened_dimension]) - else: - mask = array_ops.sequence_mask(sequence_lengths, padded_length) - activations_masked = array_ops.boolean_mask(activations, mask) - labels_masked = array_ops.boolean_mask(labels, mask) - return activations_masked, labels_masked - - -def select_last_activations(activations, sequence_lengths): - """Selects the nth set of activations for each n in `sequence_length`. - - Reuturns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not - `None`, then `output[i, :] = activations[i, sequence_length[i], :]`. If - `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`. - - Args: - activations: A `Tensor` with shape `[batch_size, padded_length, k]`. - sequence_lengths: A `Tensor` with shape `[batch_size]` or `None`. - Returns: - A `Tensor` of shape `[batch_size, k]`. - """ - with ops.name_scope('select_last_activations', - values=[activations, sequence_lengths]): - activations_shape = array_ops.shape(activations) - batch_size = activations_shape[0] - padded_length = activations_shape[1] - num_label_columns = activations_shape[2] - if sequence_lengths is None: - sequence_lengths = padded_length - reshaped_activations = array_ops.reshape(activations, - [-1, num_label_columns]) - indices = math_ops.range(batch_size) * padded_length + sequence_lengths - 1 - last_activations = array_ops.gather(reshaped_activations, indices) - last_activations.set_shape( - [activations.get_shape()[0], activations.get_shape()[2]]) - return last_activations - - def _get_state_name(i): """Constructs the name string for state component `i`.""" - return '{}_{}'.format(RNNKeys.STATE_PREFIX, i) + return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i) def state_tuple_to_dict(state): @@ -351,13 +276,11 @@ def _get_eval_metric_ops(problem_type, prediction_type, sequence_length, if problem_type == constants.ProblemType.CLASSIFICATION: # Multi value classification if prediction_type == PredictionType.MULTIPLE_VALUE: - masked_predictions, masked_labels = mask_activations_and_labels( - prediction_dict[prediction_key.PredictionKey.CLASSES], - labels, + mask_predictions, mask_labels = rnn_common.mask_activations_and_labels( + prediction_dict[prediction_key.PredictionKey.CLASSES], labels, sequence_length) eval_metric_ops['accuracy'] = metrics.streaming_accuracy( - predictions=masked_predictions, - labels=masked_labels) + predictions=mask_predictions, labels=mask_labels) # Single value classification elif prediction_type == PredictionType.SINGLE_VALUE: eval_metric_ops['accuracy'] = metrics.streaming_accuracy( @@ -373,64 +296,6 @@ def _get_eval_metric_ops(problem_type, prediction_type, sequence_length, return eval_metric_ops -def _multi_value_predictions( - activations, target_column, problem_type, predict_probabilities): - """Maps `activations` from the RNN to predictions for multi value models. - - If `predict_probabilities` is `False`, this function returns a `dict` - containing single entry with key `PREDICTIONS_KEY`. If `predict_probabilities` - is `True`, it will contain a second entry with key `PROBABILITIES_KEY`. The - value of this entry is a `Tensor` of probabilities with shape - `[batch_size, padded_length, num_classes]`. - - Note that variable length inputs will yield some predictions that don't have - meaning. For example, if `sequence_length = [3, 2]`, then prediction `[1, 2]` - has no meaningful interpretation. - - Args: - activations: Output from an RNN. Should have dtype `float32` and shape - `[batch_size, padded_length, ?]`. - target_column: An initialized `TargetColumn`, calculate predictions. - problem_type: Either `ProblemType.CLASSIFICATION` or - `ProblemType.LINEAR_REGRESSION`. - predict_probabilities: A Python boolean, indicating whether probabilities - should be returned. Should only be set to `True` for - classification/logistic regression problems. - Returns: - A `dict` mapping strings to `Tensors`. - """ - with ops.name_scope('MultiValuePrediction'): - activations_shape = array_ops.shape(activations) - flattened_activations = array_ops.reshape(activations, - [-1, activations_shape[2]]) - prediction_dict = {} - if predict_probabilities: - flat_probabilities = target_column.logits_to_predictions( - flattened_activations, proba=True) - flat_predictions = math_ops.argmax(flat_probabilities, 1) - if target_column.num_label_columns == 1: - probability_shape = array_ops.concat([activations_shape[:2], [2]], 0) - else: - probability_shape = activations_shape - probabilities = array_ops.reshape( - flat_probabilities, - probability_shape, - name=prediction_key.PredictionKey.PROBABILITIES) - prediction_dict[ - prediction_key.PredictionKey.PROBABILITIES] = probabilities - else: - flat_predictions = target_column.logits_to_predictions( - flattened_activations, proba=False) - predictions_name = (prediction_key.PredictionKey.CLASSES - if problem_type == constants.ProblemType.CLASSIFICATION - else prediction_key.PredictionKey.SCORES) - predictions = array_ops.reshape( - flat_predictions, [activations_shape[0], activations_shape[1]], - name=predictions_name) - prediction_dict[predictions_name] = predictions - return prediction_dict - - def _single_value_predictions(activations, sequence_length, target_column, @@ -460,7 +325,8 @@ def _single_value_predictions(activations, A `dict` mapping strings to `Tensors`. """ with ops.name_scope('SingleValuePrediction'): - last_activations = select_last_activations(activations, sequence_length) + last_activations = rnn_common.select_last_activations( + activations, sequence_length) predictions_name = (prediction_key.PredictionKey.CLASSES if problem_type == constants.ProblemType.CLASSIFICATION else prediction_key.PredictionKey.SCORES) @@ -495,7 +361,7 @@ def _multi_value_loss( A scalar `Tensor` containing the loss. """ with ops.name_scope('MultiValueLoss'): - activations_masked, labels_masked = mask_activations_and_labels( + activations_masked, labels_masked = rnn_common.mask_activations_and_labels( activations, labels, sequence_length) return target_column.loss(activations_masked, labels_masked, features) @@ -519,7 +385,8 @@ def _single_value_loss( """ with ops.name_scope('SingleValueLoss'): - last_activations = select_last_activations(activations, sequence_length) + last_activations = rnn_common.select_last_activations( + activations, sequence_length) return target_column.loss(last_activations, labels, features) @@ -544,29 +411,33 @@ def _get_output_alternatives(prediction_type, if prediction_type == PredictionType.MULTIPLE_VALUE: return None if prediction_type == PredictionType.SINGLE_VALUE: - prediction_dict_no_state = {k: v for k, v in prediction_dict.items() - if RNNKeys.STATE_PREFIX not in k} + prediction_dict_no_state = { + k: v + for k, v in prediction_dict.items() + if rnn_common.RNNKeys.STATE_PREFIX not in k + } return {'dynamic_rnn_output': (problem_type, prediction_dict_no_state)} raise ValueError('Unrecognized prediction_type: {}'.format(prediction_type)) -def _get_dynamic_rnn_model_fn(cell_type, - num_units, - target_column, - problem_type, - prediction_type, - optimizer, - sequence_feature_columns, - context_feature_columns=None, - predict_probabilities=False, - learning_rate=None, - gradient_clipping_norm=None, - dropout_keep_probabilities=None, - sequence_length_key=RNNKeys.SEQUENCE_LENGTH_KEY, - dtype=dtypes.float32, - parallel_iterations=None, - swap_memory=True, - name='DynamicRNNModel'): +def _get_dynamic_rnn_model_fn( + cell_type, + num_units, + target_column, + problem_type, + prediction_type, + optimizer, + sequence_feature_columns, + context_feature_columns=None, + predict_probabilities=False, + learning_rate=None, + gradient_clipping_norm=None, + dropout_keep_probabilities=None, + sequence_length_key=rnn_common.RNNKeys.SEQUENCE_LENGTH_KEY, + dtype=dtypes.float32, + parallel_iterations=None, + swap_memory=True, + name='DynamicRNNModel'): """Creates an RNN model function for an `Estimator`. The model function returns an instance of `ModelFnOps`. When @@ -667,7 +538,7 @@ def _get_dynamic_rnn_model_fn(cell_type, loss = None # Created below for modes TRAIN and EVAL. if prediction_type == PredictionType.MULTIPLE_VALUE: - prediction_dict = _multi_value_predictions( + prediction_dict = rnn_common.multi_value_predictions( rnn_activations, target_column, problem_type, predict_probabilities) if mode != model_fn.ModeKeys.INFER: loss = _multi_value_loss( @@ -711,38 +582,6 @@ def _get_dynamic_rnn_model_fn(cell_type, return _dynamic_rnn_model_fn -def _apply_dropout( - cells, dropout_keep_probabilities, random_seed=None): - """Applies dropout to the outputs and inputs of `cell`. - - Args: - cells: A list of `RNNCell`s. - dropout_keep_probabilities: a list whose elements are either floats in - `[0.0, 1.0]` or `None`. It must have length one greater than `cells`. - random_seed: Seed for random dropout. - - Returns: - A list of `RNNCell`s, the result of applying the supplied dropouts. - - Raises: - ValueError: If `len(dropout_keep_probabilities) != len(cells) + 1`. - """ - if len(dropout_keep_probabilities) != len(cells) + 1: - raise ValueError( - 'The number of dropout probabilites must be one greater than the ' - 'number of cells. Got {} cells and {} dropout probabilities.'.format( - len(cells), len(dropout_keep_probabilities))) - wrapped_cells = [ - contrib_rnn.DropoutWrapper(cell, prob, 1.0, random_seed) - for cell, prob in zip(cells[:-1], dropout_keep_probabilities[:-2]) - ] - wrapped_cells.append(contrib_rnn.DropoutWrapper( - cells[-1], - dropout_keep_probabilities[-2], - dropout_keep_probabilities[-1])) - return wrapped_cells - - def _get_single_cell(cell_type, num_units): """Constructs and return an single `RNNCell`. @@ -789,7 +628,7 @@ def _construct_rnn_cell(cell_type, num_units, dropout_keep_probabilities): cells = [_get_single_cell(cell_type, n) for n in num_units] if dropout_keep_probabilities: - cells = _apply_dropout(cells, dropout_keep_probabilities) + cells = rnn_common.apply_dropout(cells, dropout_keep_probabilities) if len(cells) == 1: return cells[0] return contrib_rnn.MultiRNNCell(cells) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index d7d606deed..443d336214 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -35,6 +35,7 @@ from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import dynamic_rnn_estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.contrib.learn.python.learn.estimators import rnn_common from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl from tensorflow.python.client import session @@ -191,93 +192,6 @@ class DynamicRnnEstimatorTest(test.TestCase): expected_state_shape = np.array([3, self.NUM_RNN_CELL_UNITS]) self.assertAllEqual(expected_state_shape, final_state.shape) - def testMaskActivationsAndLabels(self): - """Test `mask_activations_and_labels`.""" - batch_size = 4 - padded_length = 6 - num_classes = 4 - np.random.seed(1234) - sequence_length = np.random.randint(0, padded_length + 1, batch_size) - activations = np.random.rand(batch_size, padded_length, num_classes) - labels = np.random.randint(0, num_classes, [batch_size, padded_length]) - (activations_masked_t, - labels_masked_t) = dynamic_rnn_estimator.mask_activations_and_labels( - constant_op.constant( - activations, dtype=dtypes.float32), - constant_op.constant( - labels, dtype=dtypes.int32), - constant_op.constant( - sequence_length, dtype=dtypes.int32)) - - with session.Session() as sess: - activations_masked, labels_masked = sess.run( - [activations_masked_t, labels_masked_t]) - - expected_activations_shape = [sum(sequence_length), num_classes] - np.testing.assert_equal( - expected_activations_shape, activations_masked.shape, - 'Wrong activations shape. Expected {}; got {}.'.format( - expected_activations_shape, activations_masked.shape)) - - expected_labels_shape = [sum(sequence_length)] - np.testing.assert_equal(expected_labels_shape, labels_masked.shape, - 'Wrong labels shape. Expected {}; got {}.'.format( - expected_labels_shape, labels_masked.shape)) - masked_index = 0 - for i in range(batch_size): - for j in range(sequence_length[i]): - actual_activations = activations_masked[masked_index] - expected_activations = activations[i, j, :] - np.testing.assert_almost_equal( - expected_activations, - actual_activations, - err_msg='Unexpected logit value at index [{}, {}, :].' - ' Expected {}; got {}.'.format(i, j, expected_activations, - actual_activations)) - - actual_labels = labels_masked[masked_index] - expected_labels = labels[i, j] - np.testing.assert_almost_equal( - expected_labels, - actual_labels, - err_msg='Unexpected logit value at index [{}, {}].' - ' Expected {}; got {}.'.format(i, j, expected_labels, - actual_labels)) - masked_index += 1 - - def testSelectLastActivations(self): - """Test `select_last_activations`.""" - batch_size = 4 - padded_length = 6 - num_classes = 4 - np.random.seed(4444) - sequence_length = np.random.randint(0, padded_length + 1, batch_size) - activations = np.random.rand(batch_size, padded_length, num_classes) - last_activations_t = dynamic_rnn_estimator.select_last_activations( - constant_op.constant( - activations, dtype=dtypes.float32), - constant_op.constant( - sequence_length, dtype=dtypes.int32)) - - with session.Session() as sess: - last_activations = sess.run(last_activations_t) - - expected_activations_shape = [batch_size, num_classes] - np.testing.assert_equal( - expected_activations_shape, last_activations.shape, - 'Wrong activations shape. Expected {}; got {}.'.format( - expected_activations_shape, last_activations.shape)) - - for i in range(batch_size): - actual_activations = last_activations[i, :] - expected_activations = activations[i, sequence_length[i] - 1, :] - np.testing.assert_almost_equal( - expected_activations, - actual_activations, - err_msg='Unexpected logit value at index [{}, :].' - ' Expected {}; got {}.'.format(i, expected_activations, - actual_activations)) - def testGetOutputAlternatives(self): test_cases = ( (dynamic_rnn_estimator.PredictionType.SINGLE_VALUE, @@ -618,7 +532,7 @@ class DynamicRnnEstimatorTest(test.TestCase): incremental_state_dict = { k: v for (k, v) in prediction_dict.items() - if k.startswith(dynamic_rnn_estimator.RNNKeys.STATE_PREFIX) + if k.startswith(rnn_common.RNNKeys.STATE_PREFIX) } return prediction_dict @@ -636,7 +550,7 @@ class DynamicRnnEstimatorTest(test.TestCase): err_msg='Mismatch on last {} predictions.'.format(prediction_steps[-1])) # Check that final states are identical. for k, v in pred_all_at_once.items(): - if k.startswith(dynamic_rnn_estimator.RNNKeys.STATE_PREFIX): + if k.startswith(rnn_common.RNNKeys.STATE_PREFIX): np.testing.assert_array_equal( v, pred_step_by_step[k], err_msg='Mismatch on state {}.'.format(k)) diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py new file mode 100644 index 0000000000..b3aea01ffa --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py @@ -0,0 +1,199 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common operations for RNN Estimators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import rnn as contrib_rnn +from tensorflow.contrib.learn.python.learn.estimators import constants +from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +# NOTE(jtbates): As of February 10, 2017, some of the `RNNKeys` have been +# removed and replaced with values from `prediction_key.PredictionKey`. The key +# `RNNKeys.PREDICTIONS_KEY` has been replaced by +# `prediction_key.PredictionKey.SCORES` for regression and +# `prediction_key.PredictionKey.CLASSES` for classification. The key +# `RNNKeys.PROBABILITIES_KEY` has been replaced by +# `prediction_key.PredictionKey.PROBABILITIES`. +class RNNKeys(object): + FINAL_STATE_KEY = 'final_state' + LABELS_KEY = '__labels__' + SEQUENCE_LENGTH_KEY = 'sequence_length' + STATE_PREFIX = 'rnn_cell_state' + + +def apply_dropout(cells, dropout_keep_probabilities, random_seed=None): + """Applies dropout to the outputs and inputs of `cell`. + + Args: + cells: A list of `RNNCell`s. + dropout_keep_probabilities: a list whose elements are either floats in + `[0.0, 1.0]` or `None`. It must have length one greater than `cells`. + random_seed: Seed for random dropout. + + Returns: + A list of `RNNCell`s, the result of applying the supplied dropouts. + + Raises: + ValueError: If `len(dropout_keep_probabilities) != len(cells) + 1`. + """ + if len(dropout_keep_probabilities) != len(cells) + 1: + raise ValueError( + 'The number of dropout probabilites must be one greater than the ' + 'number of cells. Got {} cells and {} dropout probabilities.'.format( + len(cells), len(dropout_keep_probabilities))) + wrapped_cells = [ + contrib_rnn.DropoutWrapper(cell, prob, 1.0, random_seed) + for cell, prob in zip(cells[:-1], dropout_keep_probabilities[:-2]) + ] + wrapped_cells.append( + contrib_rnn.DropoutWrapper(cells[-1], dropout_keep_probabilities[-2], + dropout_keep_probabilities[-1])) + return wrapped_cells + + +def select_last_activations(activations, sequence_lengths): + """Selects the nth set of activations for each n in `sequence_length`. + + Reuturns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not + `None`, then `output[i, :] = activations[i, sequence_length[i], :]`. If + `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`. + + Args: + activations: A `Tensor` with shape `[batch_size, padded_length, k]`. + sequence_lengths: A `Tensor` with shape `[batch_size]` or `None`. + Returns: + A `Tensor` of shape `[batch_size, k]`. + """ + with ops.name_scope( + 'select_last_activations', values=[activations, sequence_lengths]): + activations_shape = array_ops.shape(activations) + batch_size = activations_shape[0] + padded_length = activations_shape[1] + num_label_columns = activations_shape[2] + if sequence_lengths is None: + sequence_lengths = padded_length + reshaped_activations = array_ops.reshape(activations, + [-1, num_label_columns]) + indices = math_ops.range(batch_size) * padded_length + sequence_lengths - 1 + last_activations = array_ops.gather(reshaped_activations, indices) + last_activations.set_shape( + [activations.get_shape()[0], activations.get_shape()[2]]) + return last_activations + + +def mask_activations_and_labels(activations, labels, sequence_lengths): + """Remove entries outside `sequence_lengths` and returned flattened results. + + Args: + activations: Output of the RNN, shape `[batch_size, padded_length, k]`. + labels: Label values, shape `[batch_size, padded_length]`. + sequence_lengths: A `Tensor` of shape `[batch_size]` with the unpadded + length of each sequence. If `None`, then each sequence is unpadded. + + Returns: + activations_masked: `logit` values with those beyond `sequence_lengths` + removed for each batch. Batches are then concatenated. Shape + `[tf.sum(sequence_lengths), k]` if `sequence_lengths` is not `None` and + shape `[batch_size * padded_length, k]` otherwise. + labels_masked: Label values after removing unneeded entries. Shape + `[tf.sum(sequence_lengths)]` if `sequence_lengths` is not `None` and shape + `[batch_size * padded_length]` otherwise. + """ + with ops.name_scope( + 'mask_activations_and_labels', + values=[activations, labels, sequence_lengths]): + labels_shape = array_ops.shape(labels) + batch_size = labels_shape[0] + padded_length = labels_shape[1] + if sequence_lengths is None: + flattened_dimension = padded_length * batch_size + activations_masked = array_ops.reshape(activations, + [flattened_dimension, -1]) + labels_masked = array_ops.reshape(labels, [flattened_dimension]) + else: + mask = array_ops.sequence_mask(sequence_lengths, padded_length) + activations_masked = array_ops.boolean_mask(activations, mask) + labels_masked = array_ops.boolean_mask(labels, mask) + return activations_masked, labels_masked + + +def multi_value_predictions(activations, target_column, problem_type, + predict_probabilities): + """Maps `activations` from the RNN to predictions for multi value models. + + If `predict_probabilities` is `False`, this function returns a `dict` + containing single entry with key `prediction_key.PredictionKey.CLASSES` for + `problem_type` `ProblemType.CLASSIFICATION` or + `prediction_key.PredictionKey.SCORE` for `problem_type` + `ProblemType.LINEAR_REGRESSION`. + + If `predict_probabilities` is `True`, it will contain a second entry with key + `prediction_key.PredictionKey.PROBABILITIES`. The + value of this entry is a `Tensor` of probabilities with shape + `[batch_size, padded_length, num_classes]`. + + Note that variable length inputs will yield some predictions that don't have + meaning. For example, if `sequence_length = [3, 2]`, then prediction `[1, 2]` + has no meaningful interpretation. + + Args: + activations: Output from an RNN. Should have dtype `float32` and shape + `[batch_size, padded_length, ?]`. + target_column: An initialized `TargetColumn`, calculate predictions. + problem_type: Either `ProblemType.CLASSIFICATION` or + `ProblemType.LINEAR_REGRESSION`. + predict_probabilities: A Python boolean, indicating whether probabilities + should be returned. Should only be set to `True` for + classification/logistic regression problems. + Returns: + A `dict` mapping strings to `Tensors`. + """ + with ops.name_scope('MultiValuePrediction'): + activations_shape = array_ops.shape(activations) + flattened_activations = array_ops.reshape(activations, + [-1, activations_shape[2]]) + prediction_dict = {} + if predict_probabilities: + flat_probabilities = target_column.logits_to_predictions( + flattened_activations, proba=True) + flat_predictions = math_ops.argmax(flat_probabilities, 1) + if target_column.num_label_columns == 1: + probability_shape = array_ops.concat([activations_shape[:2], [2]], 0) + else: + probability_shape = activations_shape + probabilities = array_ops.reshape( + flat_probabilities, + probability_shape, + name=prediction_key.PredictionKey.PROBABILITIES) + prediction_dict[ + prediction_key.PredictionKey.PROBABILITIES] = probabilities + else: + flat_predictions = target_column.logits_to_predictions( + flattened_activations, proba=False) + predictions_name = (prediction_key.PredictionKey.CLASSES + if problem_type == constants.ProblemType.CLASSIFICATION + else prediction_key.PredictionKey.SCORES) + predictions = array_ops.reshape( + flat_predictions, [activations_shape[0], activations_shape[1]], + name=predictions_name) + prediction_dict[predictions_name] = predictions + return prediction_dict diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py new file mode 100644 index 0000000000..82563141cc --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py @@ -0,0 +1,116 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for layers.rnn_common.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.learn.python.learn.estimators import rnn_common +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test + + +class RnnCommonTest(test.TestCase): + + def testMaskActivationsAndLabels(self): + """Test `mask_activations_and_labels`.""" + batch_size = 4 + padded_length = 6 + num_classes = 4 + np.random.seed(1234) + sequence_length = np.random.randint(0, padded_length + 1, batch_size) + activations = np.random.rand(batch_size, padded_length, num_classes) + labels = np.random.randint(0, num_classes, [batch_size, padded_length]) + (activations_masked_t, + labels_masked_t) = rnn_common.mask_activations_and_labels( + constant_op.constant(activations, dtype=dtypes.float32), + constant_op.constant(labels, dtype=dtypes.int32), + constant_op.constant(sequence_length, dtype=dtypes.int32)) + + with self.test_session() as sess: + activations_masked, labels_masked = sess.run( + [activations_masked_t, labels_masked_t]) + + expected_activations_shape = [sum(sequence_length), num_classes] + np.testing.assert_equal( + expected_activations_shape, activations_masked.shape, + 'Wrong activations shape. Expected {}; got {}.'.format( + expected_activations_shape, activations_masked.shape)) + + expected_labels_shape = [sum(sequence_length)] + np.testing.assert_equal(expected_labels_shape, labels_masked.shape, + 'Wrong labels shape. Expected {}; got {}.'.format( + expected_labels_shape, labels_masked.shape)) + masked_index = 0 + for i in range(batch_size): + for j in range(sequence_length[i]): + actual_activations = activations_masked[masked_index] + expected_activations = activations[i, j, :] + np.testing.assert_almost_equal( + expected_activations, + actual_activations, + err_msg='Unexpected logit value at index [{}, {}, :].' + ' Expected {}; got {}.'.format(i, j, expected_activations, + actual_activations)) + + actual_labels = labels_masked[masked_index] + expected_labels = labels[i, j] + np.testing.assert_almost_equal( + expected_labels, + actual_labels, + err_msg='Unexpected logit value at index [{}, {}].' + ' Expected {}; got {}.'.format(i, j, expected_labels, + actual_labels)) + masked_index += 1 + + def testSelectLastActivations(self): + """Test `select_last_activations`.""" + batch_size = 4 + padded_length = 6 + num_classes = 4 + np.random.seed(4444) + sequence_length = np.random.randint(0, padded_length + 1, batch_size) + activations = np.random.rand(batch_size, padded_length, num_classes) + last_activations_t = rnn_common.select_last_activations( + constant_op.constant(activations, dtype=dtypes.float32), + constant_op.constant(sequence_length, dtype=dtypes.int32)) + + with session.Session() as sess: + last_activations = sess.run(last_activations_t) + + expected_activations_shape = [batch_size, num_classes] + np.testing.assert_equal( + expected_activations_shape, last_activations.shape, + 'Wrong activations shape. Expected {}; got {}.'.format( + expected_activations_shape, last_activations.shape)) + + for i in range(batch_size): + actual_activations = last_activations[i, :] + expected_activations = activations[i, sequence_length[i] - 1, :] + np.testing.assert_almost_equal( + expected_activations, + actual_activations, + err_msg='Unexpected logit value at index [{}, :].' + ' Expected {}; got {}.'.format(i, expected_activations, + actual_activations)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py index f15c717c91..48ad279cc0 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py @@ -31,67 +31,17 @@ from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.contrib.learn.python.learn.estimators import rnn_common from tensorflow.contrib.rnn.python.ops import core_rnn from tensorflow.contrib.training.python.training import sequence_queueing_state_saver as sqss from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops from tensorflow.python.training import momentum as momentum_opt from tensorflow.python.util import nest -# NOTE(jtbates): As of February 10, 2017, some of the `RNNKeys` have been -# removed and replaced with values from `prediction_key.PredictionKey`. The key -# `RNNKeys.PREDICTIONS_KEY` has been replaced by -# `prediction_key.PredictionKey.SCORES` for regression and -# `prediction_key.PredictionKey.CLASSES` for classification. The key -# `RNNKeys.PROBABILITIES_KEY` has been replaced by -# `prediction_key.PredictionKey.PROBABILITIES`. -class RNNKeys(object): - FINAL_STATE_KEY = 'final_state' - LABELS_KEY = '__labels__' - STATE_PREFIX = 'rnn_cell_state' - - -# TODO(b/34272579): mask_activations_and_labels is shared with -# dynamic_rnn_estimator.py. Move it to a common library. -def mask_activations_and_labels(activations, labels, sequence_lengths): - """Remove entries outside `sequence_lengths` and returned flattened results. - - Args: - activations: Output of the RNN, shape `[batch_size, padded_length, k]`. - labels: Label values, shape `[batch_size, padded_length]`. - sequence_lengths: A `Tensor` of shape `[batch_size]` with the unpadded - length of each sequence. If `None`, then each sequence is unpadded. - - Returns: - activations_masked: `logit` values with those beyond `sequence_lengths` - removed for each batch. Batches are then concatenated. Shape - `[tf.sum(sequence_lengths), k]` if `sequence_lengths` is not `None` and - shape `[batch_size * padded_length, k]` otherwise. - labels_masked: Label values after removing unneeded entries. Shape - `[tf.sum(sequence_lengths)]` if `sequence_lengths` is not `None` and shape - `[batch_size * padded_length]` otherwise. - """ - with ops.name_scope('mask_activations_and_labels', - values=[activations, labels, sequence_lengths]): - labels_shape = array_ops.shape(labels) - batch_size = labels_shape[0] - padded_length = labels_shape[1] - if sequence_lengths is None: - flattened_dimension = padded_length * batch_size - activations_masked = array_ops.reshape(activations, - [flattened_dimension, -1]) - labels_masked = array_ops.reshape(labels, [flattened_dimension]) - else: - mask = array_ops.sequence_mask(sequence_lengths, padded_length) - activations_masked = array_ops.boolean_mask(activations, mask) - labels_masked = array_ops.boolean_mask(labels, mask) - return activations_masked, labels_masked - - def construct_state_saving_rnn(cell, inputs, num_label_columns, @@ -134,7 +84,8 @@ def construct_state_saving_rnn(cell, activation_fn=None, trainable=True) # Use `identity` to rename `final_state`. - final_state = array_ops.identity(final_state, name=RNNKeys.FINAL_STATE_KEY) + final_state = array_ops.identity( + final_state, name=rnn_common.RNNKeys.FINAL_STATE_KEY) return activations, final_state @@ -156,7 +107,7 @@ def _mask_multivalue(sequence_length, metric): """ @functools.wraps(metric) def _metric(predictions, labels, *args, **kwargs): - predictions, labels = mask_activations_and_labels( + predictions, labels = rnn_common.mask_activations_and_labels( predictions, labels, sequence_length) return metric(predictions, labels, *args, **kwargs) return _metric @@ -184,70 +135,6 @@ def _get_default_metrics(problem_type, sequence_length): return default_metrics -# TODO(b/34272579): _multi_value_predictions is shared with -# dynamic_rnn_estimator.py. Move it to a common library. -def _multi_value_predictions( - activations, target_column, problem_type, predict_probabilities): - """Maps `activations` from the RNN to predictions for multi value models. - - If `predict_probabilities` is `False`, this function returns a `dict` - containing single entry with key `prediction_key.PredictionKey.CLASSES` for - `problem_type` `ProblemType.CLASSIFICATION` or - `prediction_key.PredictionKey.SCORE` for `problem_type` - `ProblemType.LINEAR_REGRESSION`. - - If `predict_probabilities` is `True`, it will contain a second entry with key - `prediction_key.PredictionKey.PROBABILITIES`. The - value of this entry is a `Tensor` of probabilities with shape - `[batch_size, padded_length, num_classes]`. - - Note that variable length inputs will yield some predictions that don't have - meaning. For example, if `sequence_length = [3, 2]`, then prediction `[1, 2]` - has no meaningful interpretation. - - Args: - activations: Output from an RNN. Should have dtype `float32` and shape - `[batch_size, padded_length, ?]`. - target_column: An initialized `TargetColumn`, calculate predictions. - problem_type: Either `ProblemType.CLASSIFICATION` or - `ProblemType.LINEAR_REGRESSION`. - predict_probabilities: A Python boolean, indicating whether probabilities - should be returned. Should only be set to `True` for - classification/logistic regression problems. - Returns: - A `dict` mapping strings to `Tensors`. - """ - with ops.name_scope('MultiValuePrediction'): - activations_shape = array_ops.shape(activations) - flattened_activations = array_ops.reshape(activations, - [-1, activations_shape[2]]) - prediction_dict = {} - if predict_probabilities: - flat_probabilities = target_column.logits_to_predictions( - flattened_activations, proba=True) - flat_predictions = math_ops.argmax(flat_probabilities, 1) - if target_column.num_label_columns == 1: - probability_shape = array_ops.concat([activations_shape[:2], [2]], 0) - else: - probability_shape = activations_shape - probabilities = array_ops.reshape( - flat_probabilities, probability_shape, - name=prediction_key.PredictionKey.PROBABILITIES) - prediction_dict[ - prediction_key.PredictionKey.PROBABILITIES] = probabilities - else: - flat_predictions = target_column.logits_to_predictions( - flattened_activations, proba=False) - predictions_name = (prediction_key.PredictionKey.CLASSES - if problem_type == constants.ProblemType.CLASSIFICATION - else prediction_key.PredictionKey.SCORES) - predictions = array_ops.reshape( - flat_predictions, [activations_shape[0], activations_shape[1]], - name=predictions_name) - prediction_dict[predictions_name] = predictions - return prediction_dict - - def _multi_value_loss( activations, labels, sequence_length, target_column, features): """Maps `activations` from the RNN to loss for multi value models. @@ -266,7 +153,7 @@ def _multi_value_loss( A scalar `Tensor` containing the loss. """ with ops.name_scope('MultiValueLoss'): - activations_masked, labels_masked = mask_activations_and_labels( + activations_masked, labels_masked = rnn_common.mask_activations_and_labels( activations, labels, sequence_length) return target_column.loss(activations_masked, labels_masked, features) @@ -343,7 +230,7 @@ def _prepare_features_for_sqss(features, labels, mode, # Add labels to the resulting sequence features dict. if mode != model_fn.ModeKeys.INFER: - sequence_features[RNNKeys.LABELS_KEY] = labels + sequence_features[rnn_common.RNNKeys.LABELS_KEY] = labels return sequence_features, context_features @@ -353,7 +240,7 @@ def _read_batch(cell, labels, mode, num_unroll, - num_layers, + num_rnn_layers, batch_size, sequence_feature_columns, context_feature_columns=None, @@ -373,7 +260,7 @@ def _read_batch(cell, num_unroll: Python integer, how many time steps to unroll at a time. The input sequences of length `k` are then split into `k / num_unroll` many segments. - num_layers: Python integer, number of layers in the RNN. + num_rnn_layers: Python integer, number of layers in the RNN. batch_size: Python integer, the size of the minibatch produced by the SQSS. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances @@ -399,8 +286,8 @@ def _read_batch(cell, # Set up stateful queue reader. states = {} - state_names = _get_lstm_state_names(num_layers) - for i in range(num_layers): + state_names = _get_lstm_state_names(num_rnn_layers) + for i in range(num_rnn_layers): states[state_names[i][0]] = array_ops.squeeze(values[i][0], axis=0) states[state_names[i][1]] = array_ops.squeeze(values[i][1], axis=0) @@ -423,36 +310,9 @@ def _read_batch(cell, capacity=queue_capacity) -def apply_dropout( - cell, input_keep_probability, output_keep_probability, random_seed=None): - """Apply dropout to the outputs and inputs of `cell`. - - Args: - cell: An `RNNCell`. - input_keep_probability: Probability to keep inputs to `cell`. If `None`, - no dropout is applied. - output_keep_probability: Probability to keep outputs of `cell`. If `None`, - no dropout is applied. - random_seed: Seed for random dropout. - - Returns: - An `RNNCell`, the result of applying the supplied dropouts to `cell`. - """ - input_prob_none = input_keep_probability is None - output_prob_none = output_keep_probability is None - if input_prob_none and output_prob_none: - return cell - if input_prob_none: - input_keep_probability = 1.0 - if output_prob_none: - output_keep_probability = 1.0 - return rnn_cell.DropoutWrapper( - cell, input_keep_probability, output_keep_probability, random_seed) - - def _get_state_name(i): """Constructs the name string for state component `i`.""" - return '{}_{}'.format(RNNKeys.STATE_PREFIX, i) + return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i) def state_tuple_to_dict(state): @@ -531,12 +391,12 @@ def _prepare_inputs_for_rnn(sequence_features, context_features, axis=1) -def _get_rnn_model_fn(cell, - target_column, +def _get_rnn_model_fn(target_column, problem_type, optimizer, num_unroll, - num_layers, + num_units, + num_rnn_layers, num_threads, queue_capacity, batch_size, @@ -545,14 +405,12 @@ def _get_rnn_model_fn(cell, predict_probabilities=False, learning_rate=None, gradient_clipping_norm=None, - input_keep_probability=None, - output_keep_probability=None, + dropout_keep_probabilities=None, name='StateSavingRNNModel', seed=None): """Creates a state saving RNN model function for an `Estimator`. Args: - cell: An initialized `RNNCell` to be used in the RNN. target_column: An initialized `TargetColumn`, used to calculate prediction and loss. problem_type: `ProblemType.CLASSIFICATION` or @@ -562,7 +420,8 @@ def _get_rnn_model_fn(cell, num_unroll: Python integer, how many time steps to unroll at a time. The input sequences of length `k` are then split into `k / num_unroll` many segments. - num_layers: Python integer, number of layers in the RNN. + num_units: The number of units in the `RNNCell`. + num_rnn_layers: Python integer, number of layers in the RNN. num_threads: The Python integer number of threads enqueuing input examples into a queue. queue_capacity: The max capacity of the queue in number of examples. @@ -583,10 +442,8 @@ def _get_rnn_model_fn(cell, learning_rate: Learning rate used for optimization. This argument has no effect if `optimizer` is an instance of an `Optimizer`. gradient_clipping_norm: A float. Gradients will be clipped to this value. - input_keep_probability: Probability to keep inputs to `cell`. If `None`, - no dropout is applied. - output_keep_probability: Probability to keep outputs of `cell`. If `None`, - no dropout is applied. + dropout_keep_probabilities: a list of dropout keep probabilities or `None`. + If given a list, it must have length `num_rnn_layers + 1`. name: A string that will be used to create a scope for the RNN. seed: Fixes the random seed used for generating input keys by the SQSS. @@ -618,19 +475,18 @@ def _get_rnn_model_fn(cell, def _rnn_model_fn(features, labels, mode): """The model to be passed to an `Estimator`.""" with ops.name_scope(name): - if mode == model_fn.ModeKeys.TRAIN: - cell_for_mode = apply_dropout( - cell, input_keep_probability, output_keep_probability) - else: - cell_for_mode = cell + dropout = (dropout_keep_probabilities + if mode == model_fn.ModeKeys.TRAIN + else None) + cell = lstm_cell(num_units, num_rnn_layers, dropout) batch = _read_batch( - cell=cell_for_mode, + cell=cell, features=features, labels=labels, mode=mode, num_unroll=num_unroll, - num_layers=num_layers, + num_rnn_layers=num_rnn_layers, batch_size=batch_size, sequence_feature_columns=sequence_feature_columns, context_feature_columns=context_feature_columns, @@ -640,22 +496,20 @@ def _get_rnn_model_fn(cell, sequence_features = batch.sequences context_features = batch.context if mode != model_fn.ModeKeys.INFER: - labels = sequence_features.pop(RNNKeys.LABELS_KEY) + labels = sequence_features.pop(rnn_common.RNNKeys.LABELS_KEY) inputs = _prepare_inputs_for_rnn(sequence_features, context_features, sequence_feature_columns, num_unroll) - state_name = _get_lstm_state_names(num_layers) + state_name = _get_lstm_state_names(num_rnn_layers) rnn_activations, final_state = construct_state_saving_rnn( - cell=cell_for_mode, + cell=cell, inputs=inputs, num_label_columns=target_column.num_label_columns, state_saver=batch, state_name=state_name) loss = None # Created below for modes TRAIN and EVAL. - prediction_dict = _multi_value_predictions(rnn_activations, - target_column, - problem_type, - predict_probabilities) + prediction_dict = rnn_common.multi_value_predictions( + rnn_activations, target_column, problem_type, predict_probabilities) if mode != model_fn.ModeKeys.INFER: loss = _multi_value_loss(rnn_activations, labels, batch.length, target_column, features) @@ -686,35 +540,41 @@ def _get_rnn_model_fn(cell, return _rnn_model_fn -def _get_lstm_state_names(num_layers): - """Returns a num_layers long list of lstm state name pairs. +def _get_lstm_state_names(num_rnn_layers): + """Returns a num_rnn_layers long list of lstm state name pairs. Args: - num_layers: The number of layers in the RNN. + num_rnn_layers: The number of layers in the RNN. Returns: - A num_layers long list of lstm state name pairs of the form: - ['lstm_state_cN', 'lstm_state_mN'] for all N from 0 to num_layers. + A num_rnn_layers long list of lstm state name pairs of the form: + ['lstm_state_cN', 'lstm_state_mN'] for all N from 0 to num_rnn_layers. """ return [['lstm_state_c' + str(i), 'lstm_state_m' + str(i)] - for i in range(num_layers)] + for i in range(num_rnn_layers)] # TODO(jtbates): Allow users to specify cell types other than LSTM. -def lstm_cell(num_units, num_layers): - """Constructs a `MultiRNNCell` with num_layers `BasicLSTMCell`s. +def lstm_cell(num_units, num_rnn_layers, dropout_keep_probabilities): + """Constructs a `MultiRNNCell` with num_rnn_layers `BasicLSTMCell`s. Args: num_units: The number of units in the `RNNCell`. - num_layers: The number of layers in the RNN. + num_rnn_layers: The number of layers in the RNN. + dropout_keep_probabilities: a list whose elements are either floats in + `[0.0, 1.0]` or `None`. It must have length `num_rnn_layers + 1`. Returns: An intiialized `MultiRNNCell`. """ - return rnn_cell.MultiRNNCell([ - rnn_cell.BasicLSTMCell( - num_units=num_units, state_is_tuple=True) for _ in range(num_layers) - ]) + + cells = [ + rnn_cell.BasicLSTMCell(num_units=num_units, state_is_tuple=True) + for _ in range(num_rnn_layers) + ] + if dropout_keep_probabilities: + cells = rnn_common.apply_dropout(cells, dropout_keep_probabilities) + return rnn_cell.MultiRNNCell(cells) class StateSavingRnnEstimator(estimator.Estimator): @@ -733,9 +593,7 @@ class StateSavingRnnEstimator(estimator.Estimator): predict_probabilities=False, momentum=None, gradient_clipping_norm=5.0, - # TODO(jtbates): Support lists of input_keep_probability. - input_keep_probability=None, - output_keep_probability=None, + dropout_keep_probabilities=None, model_dir=None, config=None, feature_engineering_fn=None, @@ -773,10 +631,8 @@ class StateSavingRnnEstimator(estimator.Estimator): momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. gradient_clipping_norm: Parameter used for gradient clipping. If `None`, then no clipping is performed. - input_keep_probability: Probability to keep inputs to `cell`. If `None`, - no dropout is applied. - output_keep_probability: Probability to keep outputs of `cell`. If `None`, - no dropout is applied. + dropout_keep_probabilities: a list of dropout keep probabilities or + `None`. If given a list, it must have length `num_rnn_layers + 1`. model_dir: The directory in which to save and restore the model graph, parameters, etc. config: A `RunConfig` instance. @@ -818,14 +674,13 @@ class StateSavingRnnEstimator(estimator.Estimator): if optimizer_type == 'Momentum': optimizer_type = momentum_opt.MomentumOptimizer(learning_rate, momentum) - cell = lstm_cell(num_units, num_rnn_layers) rnn_model_fn = _get_rnn_model_fn( - cell=cell, target_column=target_column, problem_type=problem_type, optimizer=optimizer_type, num_unroll=num_unroll, - num_layers=num_rnn_layers, + num_units=num_units, + num_rnn_layers=num_rnn_layers, num_threads=num_threads, queue_capacity=queue_capacity, batch_size=batch_size, @@ -834,8 +689,7 @@ class StateSavingRnnEstimator(estimator.Estimator): predict_probabilities=predict_probabilities, learning_rate=learning_rate, gradient_clipping_norm=gradient_clipping_norm, - input_keep_probability=input_keep_probability, - output_keep_probability=output_keep_probability, + dropout_keep_probabilities=dropout_keep_probabilities, name=name, seed=seed) @@ -858,8 +712,7 @@ def multi_value_rnn_regressor(num_units, learning_rate=0.1, momentum=None, gradient_clipping_norm=5.0, - input_keep_probability=None, - output_keep_probability=None, + dropout_keep_probabilities=None, model_dir=None, config=None, feature_engineering_fn=None, @@ -891,10 +744,8 @@ def multi_value_rnn_regressor(num_units, momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. gradient_clipping_norm: Parameter used for gradient clipping. If `None`, then no clipping is performed. - input_keep_probability: Probability to keep inputs to `cell`. If `None`, - no dropout is applied. - output_keep_probability: Probability to keep outputs of `cell`. If `None`, - no dropout is applied. + dropout_keep_probabilities: a list of dropout keep probabilities or `None`. + If given a list, it must have length `num_rnn_layers + 1`. model_dir: The directory in which to save and restore the model graph, parameters, etc. config: A `RunConfig` instance. @@ -926,8 +777,7 @@ def multi_value_rnn_regressor(num_units, predict_probabilities=False, momentum=momentum, gradient_clipping_norm=gradient_clipping_norm, - input_keep_probability=input_keep_probability, - output_keep_probability=output_keep_probability, + dropout_keep_probabilities=dropout_keep_probabilities, model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn, @@ -950,8 +800,7 @@ def multi_value_rnn_classifier(num_classes, predict_probabilities=False, momentum=None, gradient_clipping_norm=5.0, - input_keep_probability=None, - output_keep_probability=None, + dropout_keep_probabilities=None, model_dir=None, config=None, feature_engineering_fn=None, @@ -985,10 +834,8 @@ def multi_value_rnn_classifier(num_classes, momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. gradient_clipping_norm: Parameter used for gradient clipping. If `None`, then no clipping is performed. - input_keep_probability: Probability to keep inputs to `cell`. If `None`, - no dropout is applied. - output_keep_probability: Probability to keep outputs of `cell`. If `None`, - no dropout is applied. + dropout_keep_probabilities: a list of dropout keep probabilities or `None`. + If given a list, it must have length `num_rnn_layers + 1`. model_dir: The directory in which to save and restore the model graph, parameters, etc. config: A `RunConfig` instance. @@ -1020,8 +867,7 @@ def multi_value_rnn_classifier(num_classes, predict_probabilities=predict_probabilities, momentum=momentum, gradient_clipping_norm=gradient_clipping_norm, - input_keep_probability=input_keep_probability, - output_keep_probability=output_keep_probability, + dropout_keep_probabilities=dropout_keep_probabilities, model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn, diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py index 4ad6c01fee..05fc6a89a4 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py @@ -34,6 +34,7 @@ from tensorflow.contrib.layers.python.layers import target_column as target_colu from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.contrib.learn.python.learn.estimators import rnn_common from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.learn.python.learn.estimators import state_saving_rnn_estimator as ssre from tensorflow.python.framework import constant_op @@ -288,7 +289,7 @@ class StateSavingRnnEstimatorTest(test.TestCase): ] expected_sequence = { - ssre.RNNKeys.LABELS_KEY: + rnn_common.RNNKeys.LABELS_KEY: np.array([5., 5., 5., 5.]), seq_feature_name: np.array([1., 1., 1., 1.]), @@ -327,78 +328,24 @@ class StateSavingRnnEstimatorTest(test.TestCase): assert_equal(expected_sequence, actual_sequence) assert_equal(expected_context, actual_context) - def testMaskActivationsAndLabels(self): - """Test `mask_activations_and_labels`.""" - batch_size = 4 - padded_length = 6 - num_classes = 4 - np.random.seed(1234) - sequence_length = np.random.randint(0, padded_length + 1, batch_size) - activations = np.random.rand(batch_size, padded_length, num_classes) - labels = np.random.randint(0, num_classes, [batch_size, padded_length]) - (activations_masked_t, labels_masked_t) = ssre.mask_activations_and_labels( - constant_op.constant( - activations, dtype=dtypes.float32), - constant_op.constant( - labels, dtype=dtypes.int32), - constant_op.constant( - sequence_length, dtype=dtypes.int32)) - - with self.test_session() as sess: - activations_masked, labels_masked = sess.run( - [activations_masked_t, labels_masked_t]) - - expected_activations_shape = [sum(sequence_length), num_classes] - np.testing.assert_equal( - expected_activations_shape, activations_masked.shape, - 'Wrong activations shape. Expected {}; got {}.'.format( - expected_activations_shape, activations_masked.shape)) - - expected_labels_shape = [sum(sequence_length)] - np.testing.assert_equal(expected_labels_shape, labels_masked.shape, - 'Wrong labels shape. Expected {}; got {}.'.format( - expected_labels_shape, labels_masked.shape)) - masked_index = 0 - for i in range(batch_size): - for j in range(sequence_length[i]): - actual_activations = activations_masked[masked_index] - expected_activations = activations[i, j, :] - np.testing.assert_almost_equal( - expected_activations, - actual_activations, - err_msg='Unexpected logit value at index [{}, {}, :].' - ' Expected {}; got {}.'.format(i, j, expected_activations, - actual_activations)) - - actual_labels = labels_masked[masked_index] - expected_labels = labels[i, j] - np.testing.assert_almost_equal( - expected_labels, - actual_labels, - err_msg='Unexpected logit value at index [{}, {}].' - ' Expected {}; got {}.'.format(i, j, expected_labels, - actual_labels)) - masked_index += 1 - def _getModelFnOpsForMode(self, mode): """Helper for testGetRnnModelFn{Train,Eval,Infer}().""" - cell_size = 4 - num_layers = 1 - cell = ssre.lstm_cell(cell_size, num_layers) + num_units = 4 + num_rnn_layers = 1 seq_columns = [ feature_column.real_valued_column( - 'inputs', dimension=cell_size) + 'inputs', dimension=num_units) ] features = { 'inputs': constant_op.constant([1., 2., 3.]), } labels = constant_op.constant([1., 0., 1.]) model_fn = ssre._get_rnn_model_fn( - cell=cell, target_column=target_column_lib.multi_class_target(n_classes=2), optimizer='SGD', num_unroll=2, - num_layers=num_layers, + num_units=num_units, + num_rnn_layers=num_rnn_layers, num_threads=1, queue_capacity=10, batch_size=1, @@ -578,6 +525,7 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): train_steps = 250 eval_steps = 20 num_units = 4 + num_rnn_layers = 1 learning_rate = 0.3 loss_threshold = 0.035 @@ -600,15 +548,16 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): 'inputs', dimension=num_units) ] config = run_config.RunConfig(tf_random_seed=1234) + dropout_keep_probabilities = [0.9] * (num_rnn_layers + 1) sequence_estimator = ssre.StateSavingRnnEstimator( constants.ProblemType.LINEAR_REGRESSION, num_units=num_units, + num_rnn_layers=num_rnn_layers, num_unroll=num_unroll, batch_size=batch_size, sequence_feature_columns=seq_columns, learning_rate=learning_rate, - input_keep_probability=0.9, - output_keep_probability=0.9, + dropout_keep_probabilities=dropout_keep_probabilities, config=config, queue_capacity=2 * batch_size, seed=1234) -- cgit v1.2.3