diff options
author | 2017-02-15 12:00:42 -0800 | |
---|---|---|
committer | 2017-02-15 12:08:43 -0800 | |
commit | 57fba1aa11a9196dc24e5ee5116dd54f048297de (patch) | |
tree | 4d3f3d8f6ce037f9f33e4067717f887811ff4762 | |
parent | 9b8b1213a3276ebb8a9bbd2e6fd23b5499869f4f (diff) |
Refactor state_saving_rnn_estimator.py to use inheritance from Estimator.
Change: 147626256
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py | 354 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py | 172 |
2 files changed, 339 insertions, 187 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py index 8d77684726..f15c717c91 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py @@ -23,12 +23,14 @@ import functools from tensorflow.contrib import layers from tensorflow.contrib import metrics from tensorflow.contrib import rnn as rnn_cell -from tensorflow.contrib.framework.python.framework import experimental +from tensorflow.contrib.framework.python.framework import deprecated from tensorflow.contrib.layers.python.layers import feature_column_ops from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec +from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.rnn.python.ops import core_rnn from tensorflow.contrib.training.python.training import sequence_queueing_state_saver as sqss from tensorflow.python.framework import dtypes @@ -40,19 +42,21 @@ from tensorflow.python.training import momentum as momentum_opt from tensorflow.python.util import nest -class ProblemType(object): - REGRESSION = 1 - CLASSIFICATION = 2 - - +# NOTE(jtbates): As of February 10, 2017, some of the `RNNKeys` have been +# removed and replaced with values from `prediction_key.PredictionKey`. The key +# `RNNKeys.PREDICTIONS_KEY` has been replaced by +# `prediction_key.PredictionKey.SCORES` for regression and +# `prediction_key.PredictionKey.CLASSES` for classification. The key +# `RNNKeys.PROBABILITIES_KEY` has been replaced by +# `prediction_key.PredictionKey.PROBABILITIES`. class RNNKeys(object): - PREDICTIONS_KEY = 'predictions' - PROBABILITIES_KEY = 'probabilities' FINAL_STATE_KEY = 'final_state' LABELS_KEY = '__labels__' STATE_PREFIX = 'rnn_cell_state' +# TODO(b/34272579): mask_activations_and_labels is shared with +# dynamic_rnn_estimator.py. Move it to a common library. def mask_activations_and_labels(activations, labels, sequence_lengths): """Remove entries outside `sequence_lengths` and returned flattened results. @@ -134,6 +138,8 @@ def construct_state_saving_rnn(cell, return activations, final_state +# TODO(jtbates): As per cl/14156248, remove this function and switch from +# MetricSpec to metric ops. def _mask_multivalue(sequence_length, metric): """Wrapper function that masks values by `sequence_length`. @@ -160,7 +166,8 @@ def _get_default_metrics(problem_type, sequence_length): """Returns default `MetricSpec`s for `problem_type`. Args: - problem_type: `ProblemType.CLASSIFICATION` or`ProblemType.REGRESSION`. + problem_type: `ProblemType.CLASSIFICATION` or + `ProblemType.LINEAR_REGRESSION`. sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32` containing the length of each sequence in the batch. If `None`, sequences are assumed to be unpadded. @@ -168,22 +175,29 @@ def _get_default_metrics(problem_type, sequence_length): A `dict` mapping strings to `MetricSpec`s. """ default_metrics = {} - if problem_type == ProblemType.CLASSIFICATION: + if problem_type == constants.ProblemType.CLASSIFICATION: default_metrics['accuracy'] = metric_spec.MetricSpec( metric_fn=_mask_multivalue(sequence_length, metrics.streaming_accuracy), - prediction_key=RNNKeys.PREDICTIONS_KEY) - elif problem_type == ProblemType.REGRESSION: + prediction_key=prediction_key.PredictionKey.CLASSES) + elif problem_type == constants.ProblemType.LINEAR_REGRESSION: pass return default_metrics +# TODO(b/34272579): _multi_value_predictions is shared with +# dynamic_rnn_estimator.py. Move it to a common library. def _multi_value_predictions( - activations, target_column, predict_probabilities): + activations, target_column, problem_type, predict_probabilities): """Maps `activations` from the RNN to predictions for multi value models. If `predict_probabilities` is `False`, this function returns a `dict` - containing single entry with key `PREDICTIONS_KEY`. If `predict_probabilities` - is `True`, it will contain a second entry with key `PROBABILITIES_KEY`. The + containing single entry with key `prediction_key.PredictionKey.CLASSES` for + `problem_type` `ProblemType.CLASSIFICATION` or + `prediction_key.PredictionKey.SCORE` for `problem_type` + `ProblemType.LINEAR_REGRESSION`. + + If `predict_probabilities` is `True`, it will contain a second entry with key + `prediction_key.PredictionKey.PROBABILITIES`. The value of this entry is a `Tensor` of probabilities with shape `[batch_size, padded_length, num_classes]`. @@ -195,6 +209,8 @@ def _multi_value_predictions( activations: Output from an RNN. Should have dtype `float32` and shape `[batch_size, padded_length, ?]`. target_column: An initialized `TargetColumn`, calculate predictions. + problem_type: Either `ProblemType.CLASSIFICATION` or + `ProblemType.LINEAR_REGRESSION`. predict_probabilities: A Python boolean, indicating whether probabilities should be returned. Should only be set to `True` for classification/logistic regression problems. @@ -215,15 +231,20 @@ def _multi_value_predictions( else: probability_shape = activations_shape probabilities = array_ops.reshape( - flat_probabilities, probability_shape, name=RNNKeys.PROBABILITIES_KEY) - prediction_dict[RNNKeys.PROBABILITIES_KEY] = probabilities + flat_probabilities, probability_shape, + name=prediction_key.PredictionKey.PROBABILITIES) + prediction_dict[ + prediction_key.PredictionKey.PROBABILITIES] = probabilities else: flat_predictions = target_column.logits_to_predictions( flattened_activations, proba=False) + predictions_name = (prediction_key.PredictionKey.CLASSES + if problem_type == constants.ProblemType.CLASSIFICATION + else prediction_key.PredictionKey.SCORES) predictions = array_ops.reshape( flat_predictions, [activations_shape[0], activations_shape[1]], - name=RNNKeys.PREDICTIONS_KEY) - prediction_dict[RNNKeys.PREDICTIONS_KEY] = predictions + name=predictions_name) + prediction_dict[predictions_name] = predictions return prediction_dict @@ -267,7 +288,7 @@ def _get_name_or_parent_names(column): return [column.name] -def _prepare_features_for_sqss(features, labels, mode, input_key_column_name, +def _prepare_features_for_sqss(features, labels, mode, sequence_feature_columns, context_feature_columns): """Prepares features for batching by the SQSS. @@ -283,9 +304,6 @@ def _prepare_features_for_sqss(features, labels, mode, input_key_column_name, labels: An iterable of `Tensor`. mode: Defines whether this is training, evaluation or prediction. See `ModeKeys`. - input_key_column_name: Python string, the name of the feature column - containing a string scalar `Tensor` that serves as a unique key to - identify the input sequence across minibatches. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances of classes derived from `FeatureColumn`. @@ -295,25 +313,15 @@ def _prepare_features_for_sqss(features, labels, mode, input_key_column_name, `FeatureColumn`. Returns: - input_key: The string scalar `Tensor` that serves as a unique key to - identify the input sequence across minibatches. sequence_features: A dict mapping feature names to sequence features. context_features: A dict mapping feature names to context features. Raises: - ValueError: If `features` does not contain a value for - `input_key_column_name`. ValueError: If `features` does not contain a value for every key in `sequence_feature_columns` or `context_feature_columns`. """ - # Pop the input key from the features dict. - input_key = features.pop(input_key_column_name, None) - if input_key is None: - raise ValueError('No key in features for input_key_column_name: ' + - input_key_column_name) # Extract sequence features. - feature_column_ops._check_supported_sequence_columns(sequence_feature_columns) # pylint: disable=protected-access sequence_features = {} for column in sequence_feature_columns: @@ -337,7 +345,7 @@ def _prepare_features_for_sqss(features, labels, mode, input_key_column_name, if mode != model_fn.ModeKeys.INFER: sequence_features[RNNKeys.LABELS_KEY] = labels - return input_key, sequence_features, context_features + return sequence_features, context_features def _read_batch(cell, @@ -347,11 +355,11 @@ def _read_batch(cell, num_unroll, num_layers, batch_size, - input_key_column_name, sequence_feature_columns, context_feature_columns=None, num_threads=3, - queue_capacity=1000): + queue_capacity=1000, + seed=None): """Reads a batch from a state saving sequence queue. Args: @@ -367,9 +375,6 @@ def _read_batch(cell, many segments. num_layers: Python integer, number of layers in the RNN. batch_size: Python integer, the size of the minibatch produced by the SQSS. - input_key_column_name: Python string, the name of the feature column - containing a string scalar `Tensor` that serves as a unique key to - identify input sequence across minibatches. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances of classes derived from `FeatureColumn`. @@ -383,6 +388,7 @@ def _read_batch(cell, Needs to be at least `batch_size`. Defaults to 1000. When iterating over the same input example multiple times reusing their keys the `queue_capacity` must be smaller than the number of examples. + seed: Fixes the random seed used for generating input keys by the SQSS. Returns: batch: A `NextQueuedSequenceBatch` containing batch_size `SequenceExample` @@ -398,12 +404,12 @@ def _read_batch(cell, states[state_names[i][0]] = array_ops.squeeze(values[i][0], axis=0) states[state_names[i][1]] = array_ops.squeeze(values[i][1], axis=0) - input_key, sequences, context = _prepare_features_for_sqss( - features, labels, mode, input_key_column_name, sequence_feature_columns, + sequences, context = _prepare_features_for_sqss( + features, labels, mode, sequence_feature_columns, context_feature_columns) return sqss.batch_sequences_with_states( - input_key=input_key, + input_key='key', input_sequences=sequences, input_context=context, input_length=None, # infer sequence lengths @@ -411,6 +417,8 @@ def _read_batch(cell, num_unroll=num_unroll, batch_size=batch_size, pad=True, # pad to a multiple of num_unroll + make_keys_unique=True, + make_keys_unique_seed=seed, num_threads=num_threads, capacity=queue_capacity) @@ -532,7 +540,6 @@ def _get_rnn_model_fn(cell, num_threads, queue_capacity, batch_size, - input_key_column_name, sequence_feature_columns, context_feature_columns=None, predict_probabilities=False, @@ -540,14 +547,16 @@ def _get_rnn_model_fn(cell, gradient_clipping_norm=None, input_keep_probability=None, output_keep_probability=None, - name='StateSavingRNNModel'): + name='StateSavingRNNModel', + seed=None): """Creates a state saving RNN model function for an `Estimator`. Args: cell: An initialized `RNNCell` to be used in the RNN. target_column: An initialized `TargetColumn`, used to calculate prediction and loss. - problem_type: `ProblemType.CLASSIFICATION` or`ProblemType.REGRESSION`. + problem_type: `ProblemType.CLASSIFICATION` or + `ProblemType.LINEAR_REGRESSION`. optimizer: A subclass of `Optimizer`, an instance of an `Optimizer` or a string. num_unroll: Python integer, how many time steps to unroll at a time. @@ -561,9 +570,6 @@ def _get_rnn_model_fn(cell, example multiple times reusing their keys the `queue_capacity` must be smaller than the number of examples. batch_size: Python integer, the size of the minibatch produced by the SQSS. - input_key_column_name: Python string, the name of the feature column - containing a string scalar `Tensor` that serves as a unique key to - identify input sequence across minibatches. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances of classes derived from `FeatureColumn`. @@ -572,7 +578,8 @@ def _get_rnn_model_fn(cell, steps. All items in the set should be instances of classes derived from `FeatureColumn`. predict_probabilities: A boolean indicating whether to predict probabilities - for all classes. Must only be used with `ProblemType.CLASSIFICATION`. + for all classes. + Must only be used with `ProblemType.CLASSIFICATION`. learning_rate: Learning rate used for optimization. This argument has no effect if `optimizer` is an instance of an `Optimizer`. gradient_clipping_norm: A float. Gradients will be clipped to this value. @@ -581,31 +588,32 @@ def _get_rnn_model_fn(cell, output_keep_probability: Probability to keep outputs of `cell`. If `None`, no dropout is applied. name: A string that will be used to create a scope for the RNN. + seed: Fixes the random seed used for generating input keys by the SQSS. Returns: A model function to be passed to an `Estimator`. Raises: - ValueError: `problem_type` is not one of `ProblemType.REGRESSION` or - `ProblemType.CLASSIFICATION`. + ValueError: `problem_type` is not one of + `ProblemType.LINEAR_REGRESSION` + or `ProblemType.CLASSIFICATION`. ValueError: `predict_probabilities` is `True` for `problem_type` other than `ProblemType.CLASSIFICATION`. ValueError: `num_unroll` is not positive. - ValueError: `input_key_column_name` is empty. """ - if problem_type not in (ProblemType.CLASSIFICATION, ProblemType.REGRESSION): + if problem_type not in (constants.ProblemType.CLASSIFICATION, + constants.ProblemType.LINEAR_REGRESSION): raise ValueError( - 'problem_type must be ProblemType.REGRESSION or ' + 'problem_type must be ProblemType.LINEAR_REGRESSION or ' 'ProblemType.CLASSIFICATION; got {}'. format(problem_type)) - if problem_type != ProblemType.CLASSIFICATION and predict_probabilities: + if (problem_type != constants.ProblemType.CLASSIFICATION and + predict_probabilities): raise ValueError( 'predict_probabilities can only be set to True for problem_type' ' ProblemType.CLASSIFICATION; got {}.'.format(problem_type)) if num_unroll <= 0: raise ValueError('num_unroll must be positive; got {}.'.format(num_unroll)) - if not input_key_column_name: - raise ValueError('input_key_column_name must not be empty') def _rnn_model_fn(features, labels, mode): """The model to be passed to an `Estimator`.""" @@ -624,11 +632,11 @@ def _get_rnn_model_fn(cell, num_unroll=num_unroll, num_layers=num_layers, batch_size=batch_size, - input_key_column_name=input_key_column_name, sequence_feature_columns=sequence_feature_columns, context_feature_columns=context_feature_columns, num_threads=num_threads, - queue_capacity=queue_capacity) + queue_capacity=queue_capacity, + seed=seed) sequence_features = batch.sequences context_features = batch.context if mode != model_fn.ModeKeys.INFER: @@ -644,7 +652,9 @@ def _get_rnn_model_fn(cell, state_name=state_name) loss = None # Created below for modes TRAIN and EVAL. - prediction_dict = _multi_value_predictions(rnn_activations, target_column, + prediction_dict = _multi_value_predictions(rnn_activations, + target_column, + problem_type, predict_probabilities) if mode != model_fn.ModeKeys.INFER: loss = _multi_value_loss(rnn_activations, labels, batch.length, @@ -707,11 +717,140 @@ def lstm_cell(num_units, num_layers): ]) -@experimental +class StateSavingRnnEstimator(estimator.Estimator): + + def __init__(self, + problem_type, + num_units, + num_unroll, + batch_size, + sequence_feature_columns, + context_feature_columns=None, + num_classes=None, + num_rnn_layers=1, + optimizer_type='SGD', + learning_rate=0.1, + predict_probabilities=False, + momentum=None, + gradient_clipping_norm=5.0, + # TODO(jtbates): Support lists of input_keep_probability. + input_keep_probability=None, + output_keep_probability=None, + model_dir=None, + config=None, + feature_engineering_fn=None, + num_threads=3, + queue_capacity=1000, + seed=None): + """Initializes a StateSavingRnnEstimator. + + Args: + problem_type: `ProblemType.CLASSIFICATION` or + `ProblemType.LINEAR_REGRESSION`. + num_units: The size of the RNN cells. + num_unroll: Python integer, how many time steps to unroll at a time. + The input sequences of length `k` are then split into `k / num_unroll` + many segments. + batch_size: Python integer, the size of the minibatch. + sequence_feature_columns: An iterable containing all the feature columns + describing sequence features. All items in the set should be instances + of classes derived from `FeatureColumn`. + context_feature_columns: An iterable containing all the feature columns + describing context features, i.e., features that apply accross all time + steps. All items in the set should be instances of classes derived from + `FeatureColumn`. + num_classes: The number of classes for categorization. Used only and + required if `problem_type` is `ProblemType.CLASSIFICATION` + num_rnn_layers: Number of RNN layers. + optimizer_type: The type of optimizer to use. Either a subclass of + `Optimizer`, an instance of an `Optimizer` or a string. Strings must be + one of 'Adagrad', 'Adam', 'Ftrl', Momentum', 'RMSProp', or 'SGD'. + learning_rate: Learning rate. This argument has no effect if `optimizer` + is an instance of an `Optimizer`. + predict_probabilities: A boolean indicating whether to predict + probabilities for all classes. Used only if `problem_type` is + `ProblemType.CLASSIFICATION`. + momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. + gradient_clipping_norm: Parameter used for gradient clipping. If `None`, + then no clipping is performed. + input_keep_probability: Probability to keep inputs to `cell`. If `None`, + no dropout is applied. + output_keep_probability: Probability to keep outputs of `cell`. If `None`, + no dropout is applied. + model_dir: The directory in which to save and restore the model graph, + parameters, etc. + config: A `RunConfig` instance. + feature_engineering_fn: Takes features and labels which are the output of + `input_fn` and returns features and labels which will be fed into + `model_fn`. Please check `model_fn` for a definition of features and + labels. + num_threads: The Python integer number of threads enqueuing input examples + into a queue. Defaults to 3. + queue_capacity: The max capacity of the queue in number of examples. + Needs to be at least `batch_size`. Defaults to 1000. When iterating + over the same input example multiple times reusing their keys the + `queue_capacity` must be smaller than the number of examples. + seed: Fixes the random seed used for generating input keys by the SQSS. + + Raises: + ValueError: `problem_type` is not one of + `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`. + ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but + `num_classes` is not specified. + """ + + name = 'MultiValueStateSavingRNN' + if problem_type == constants.ProblemType.LINEAR_REGRESSION: + name += 'Regressor' + target_column = layers.regression_target() + elif problem_type == constants.ProblemType.CLASSIFICATION: + if not num_classes: + raise ValueError('For CLASSIFICATION problem_type, num_classes must be ' + 'specified.') + target_column = layers.multi_class_target(n_classes=num_classes) + name += 'Classifier' + else: + raise ValueError( + 'problem_type must be either ProblemType.LINEAR_REGRESSION ' + 'or ProblemType.CLASSIFICATION; got {}'.format( + problem_type)) + + if optimizer_type == 'Momentum': + optimizer_type = momentum_opt.MomentumOptimizer(learning_rate, momentum) + + cell = lstm_cell(num_units, num_rnn_layers) + rnn_model_fn = _get_rnn_model_fn( + cell=cell, + target_column=target_column, + problem_type=problem_type, + optimizer=optimizer_type, + num_unroll=num_unroll, + num_layers=num_rnn_layers, + num_threads=num_threads, + queue_capacity=queue_capacity, + batch_size=batch_size, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + predict_probabilities=predict_probabilities, + learning_rate=learning_rate, + gradient_clipping_norm=gradient_clipping_norm, + input_keep_probability=input_keep_probability, + output_keep_probability=output_keep_probability, + name=name, + seed=seed) + + super(StateSavingRnnEstimator, self).__init__( + model_fn=rnn_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) + + +@deprecated('2017-04-01', 'multi_value_rnn_regressor is deprecated. ' + 'Please construct a StateSavingRnnEstimator directly.') def multi_value_rnn_regressor(num_units, num_unroll, batch_size, - input_key_column_name, sequence_feature_columns, context_feature_columns=None, num_rnn_layers=1, @@ -725,7 +864,8 @@ def multi_value_rnn_regressor(num_units, config=None, feature_engineering_fn=None, num_threads=3, - queue_capacity=1000): + queue_capacity=1000, + seed=None): """Creates a RNN `Estimator` that predicts sequences of values. Args: @@ -734,9 +874,6 @@ def multi_value_rnn_regressor(num_units, The input sequences of length `k` are then split into `k / num_unroll` many segments. batch_size: Python integer, the size of the minibatch. - input_key_column_name: Python string, the name of the feature column - containing a string scalar `Tensor` that serves as a unique key to - identify input sequence across minibatches. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances of classes derived from `FeatureColumn`. @@ -771,45 +908,40 @@ def multi_value_rnn_regressor(num_units, Needs to be at least `batch_size`. Defaults to 1000. When iterating over the same input example multiple times reusing their keys the `queue_capacity` must be smaller than the number of examples. + seed: Fixes the random seed used for generating input keys by the SQSS. Returns: An initialized `Estimator`. """ - cell = lstm_cell(num_units, num_rnn_layers) - target_column = layers.regression_target() - if optimizer_type == 'Momentum': - optimizer_type = momentum_opt.MomentumOptimizer(learning_rate, momentum) - rnn_model_fn = _get_rnn_model_fn( - cell=cell, - target_column=target_column, - problem_type=ProblemType.REGRESSION, - optimizer=optimizer_type, - num_unroll=num_unroll, - num_layers=num_rnn_layers, - num_threads=num_threads, - queue_capacity=queue_capacity, - batch_size=batch_size, - input_key_column_name=input_key_column_name, - sequence_feature_columns=sequence_feature_columns, + return StateSavingRnnEstimator( + constants.ProblemType.LINEAR_REGRESSION, + num_units, + num_unroll, + batch_size, + sequence_feature_columns, context_feature_columns=context_feature_columns, + num_classes=None, + num_rnn_layers=num_rnn_layers, + optimizer_type=optimizer_type, learning_rate=learning_rate, + predict_probabilities=False, + momentum=momentum, gradient_clipping_norm=gradient_clipping_norm, input_keep_probability=input_keep_probability, output_keep_probability=output_keep_probability, - name='MultiValueRnnRegressor') - - return estimator.Estimator( - model_fn=rnn_model_fn, model_dir=model_dir, config=config, - feature_engineering_fn=feature_engineering_fn) + feature_engineering_fn=feature_engineering_fn, + num_threads=num_threads, + queue_capacity=queue_capacity, + seed=seed) -@experimental +@deprecated('2017-04-01', 'multi_value_rnn_classifier is deprecated. ' + 'Please construct a StateSavingRnnEstimator directly.') def multi_value_rnn_classifier(num_classes, num_units, num_unroll, batch_size, - input_key_column_name, sequence_feature_columns, context_feature_columns=None, num_rnn_layers=1, @@ -824,7 +956,8 @@ def multi_value_rnn_classifier(num_classes, config=None, feature_engineering_fn=None, num_threads=3, - queue_capacity=1000): + queue_capacity=1000, + seed=None): """Creates a RNN `Estimator` that predicts sequences of labels. Args: @@ -834,9 +967,6 @@ def multi_value_rnn_classifier(num_classes, The input sequences of length `k` are then split into `k / num_unroll` many segments. batch_size: Python integer, the size of the minibatch. - input_key_column_name: Python string, the name of the feature column - containing a string scalar `Tensor` that serves as a unique key to - identify input sequence across minibatches. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances of classes derived from `FeatureColumn`. @@ -872,35 +1002,29 @@ def multi_value_rnn_classifier(num_classes, Needs to be at least `batch_size`. Defaults to 1000. When iterating over the same input example multiple times reusing their keys the `queue_capacity` must be smaller than the number of examples. + seed: Fixes the random seed used for generating input keys by the SQSS. Returns: An initialized `Estimator`. """ - cell = lstm_cell(num_units, num_rnn_layers) - target_column = layers.multi_class_target(n_classes=num_classes) - if optimizer_type == 'Momentum': - optimizer_type = momentum_opt.MomentumOptimizer(learning_rate, momentum) - rnn_model_fn = _get_rnn_model_fn( - cell=cell, - target_column=target_column, - problem_type=ProblemType.CLASSIFICATION, - optimizer=optimizer_type, - num_unroll=num_unroll, - num_layers=num_rnn_layers, - num_threads=num_threads, - queue_capacity=queue_capacity, - batch_size=batch_size, - input_key_column_name=input_key_column_name, - sequence_feature_columns=sequence_feature_columns, + return StateSavingRnnEstimator( + constants.ProblemType.CLASSIFICATION, + num_units, + num_unroll, + batch_size, + sequence_feature_columns, context_feature_columns=context_feature_columns, - predict_probabilities=predict_probabilities, + num_classes=num_classes, + num_rnn_layers=num_rnn_layers, + optimizer_type=optimizer_type, learning_rate=learning_rate, + predict_probabilities=predict_probabilities, + momentum=momentum, gradient_clipping_norm=gradient_clipping_norm, input_keep_probability=input_keep_probability, output_keep_probability=output_keep_probability, - name='MultiValueRnnClassifier') - - return estimator.Estimator( - model_fn=rnn_model_fn, model_dir=model_dir, config=config, - feature_engineering_fn=feature_engineering_fn) + feature_engineering_fn=feature_engineering_fn, + num_threads=num_threads, + queue_capacity=queue_capacity, + seed=seed) diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py index 3e8bc302d6..4ad6c01fee 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py @@ -31,7 +31,9 @@ import numpy as np from tensorflow.contrib import lookup from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import target_column as target_column_lib +from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib +from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.learn.python.learn.estimators import state_saving_rnn_estimator as ssre from tensorflow.python.framework import constant_op @@ -42,7 +44,6 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -249,13 +250,10 @@ class StateSavingRnnEstimatorTest(test.TestCase): seq_feature_name = 'seq_feature' sparse_seq_feature_name = 'wire_cast' ctx_feature_name = 'ctx_feature' - input_key_column_name = 'input_key_column' sequence_length = 4 embedding_dimension = 8 features = { - input_key_column_name: - constant_op.constant('input0'), sparse_seq_feature_name: sparse_tensor.SparseTensor( indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0], [1, 1, 1], @@ -289,8 +287,6 @@ class StateSavingRnnEstimatorTest(test.TestCase): ctx_feature_name, dimension=1) ] - expected_input_key = b'input0' - expected_sequence = { ssre.RNNKeys.LABELS_KEY: np.array([5., 5., 5., 5.]), @@ -309,8 +305,8 @@ class StateSavingRnnEstimatorTest(test.TestCase): expected_context = {ctx_feature_name: 2.} - input_key, sequence, context = ssre._prepare_features_for_sqss( - features, labels, mode, input_key_column_name, sequence_feature_columns, + sequence, context = ssre._prepare_features_for_sqss( + features, labels, mode, sequence_feature_columns, context_feature_columns) def assert_equal(expected, got): @@ -326,9 +322,8 @@ class StateSavingRnnEstimatorTest(test.TestCase): with self.test_session() as sess: sess.run(variables.global_variables_initializer()) sess.run(data_flow_ops.initialize_all_tables()) - actual_input_key, actual_sequence, actual_context = sess.run( - [input_key, sequence, context]) - self.assertEqual(expected_input_key, actual_input_key) + actual_sequence, actual_context = sess.run( + [sequence, context]) assert_equal(expected_sequence, actual_sequence) assert_equal(expected_context, actual_context) @@ -396,7 +391,6 @@ class StateSavingRnnEstimatorTest(test.TestCase): ] features = { 'inputs': constant_op.constant([1., 2., 3.]), - 'input_key_column': constant_op.constant('input0') } labels = constant_op.constant([1., 0., 1.]) model_fn = ssre._get_rnn_model_fn( @@ -408,9 +402,8 @@ class StateSavingRnnEstimatorTest(test.TestCase): num_threads=1, queue_capacity=10, batch_size=1, - input_key_column_name='input_key_column', # Only CLASSIFICATION yields eval metrics to test for. - problem_type=ssre.ProblemType.CLASSIFICATION, + problem_type=constants.ProblemType.CLASSIFICATION, sequence_feature_columns=seq_columns, context_feature_columns=None, learning_rate=0.1) @@ -444,7 +437,6 @@ class StateSavingRnnEstimatorTest(test.TestCase): self.assertFalse(model_fn_ops.eval_metric_ops) def testExport(self): - input_key_column_name = 'input0' input_feature_key = 'magic_input_feature_key' batch_size = 8 cell_size = 4 @@ -460,22 +452,13 @@ class StateSavingRnnEstimatorTest(test.TestCase): def get_input_fn(mode, seed): def input_fn(): - input_key = string_ops.string_join([ - 'key_', string_ops.as_string( - random_ops.random_uniform( - (), - minval=0, - maxval=10000000, - dtype=dtypes.int32, - seed=seed)) - ]) features = {} random_sequence = random_ops.random_uniform( [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) labels = array_ops.slice(random_sequence, [0], [sequence_length]) inputs = math_ops.to_float( array_ops.slice(random_sequence, [1], [sequence_length])) - features = {'inputs': inputs, input_key_column_name: input_key} + features = {'inputs': inputs} if mode == model_fn_lib.ModeKeys.INFER: input_examples = array_ops.placeholder(dtypes.string) @@ -488,16 +471,17 @@ class StateSavingRnnEstimatorTest(test.TestCase): model_dir = tempfile.mkdtemp() def estimator_fn(): - return ssre.multi_value_rnn_classifier( - num_classes=num_classes, + return ssre.StateSavingRnnEstimator( + constants.ProblemType.CLASSIFICATION, num_units=cell_size, num_unroll=num_unroll, batch_size=batch_size, - input_key_column_name=input_key_column_name, sequence_feature_columns=seq_columns, + num_classes=num_classes, predict_probabilities=True, model_dir=model_dir, - queue_capacity=2 + batch_size) + queue_capacity=2 + batch_size, + seed=1234) # Train a bit to create an exportable checkpoint. estimator_fn().fit(input_fn=get_input_fn( @@ -516,6 +500,72 @@ class StateSavingRnnEstimatorTest(test.TestCase): input_feature_key=input_feature_key) +# Smoke tests to ensure deprecated constructor functions still work. +class LegacyConstructorTest(test.TestCase): + + def _get_input_fn(self, + sequence_length, + seed=None): + def input_fn(): + random_sequence = random_ops.random_uniform( + [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) + labels = array_ops.slice(random_sequence, [0], [sequence_length]) + inputs = math_ops.to_float( + array_ops.slice(random_sequence, [1], [sequence_length])) + return {'inputs': inputs}, labels + return input_fn + + def testClassifierConstructor(self): + batch_size = 16 + num_classes = 2 + num_unroll = 32 + sequence_length = 32 + num_units = 4 + learning_rate = 0.5 + steps = 100 + input_fn = self._get_input_fn(sequence_length, + seed=1234) + model_dir = tempfile.mkdtemp() + seq_columns = [ + feature_column.real_valued_column( + 'inputs', dimension=num_units) + ] + estimator = ssre.multi_value_rnn_classifier(num_classes, + num_units, + num_unroll, + batch_size, + seq_columns, + learning_rate=learning_rate, + model_dir=model_dir, + queue_capacity=batch_size+2, + seed=1234) + estimator.fit(input_fn=input_fn, steps=steps) + + def testRegressorConstructor(self): + batch_size = 16 + num_unroll = 32 + sequence_length = 32 + num_units = 4 + learning_rate = 0.5 + steps = 100 + input_fn = self._get_input_fn(sequence_length, + seed=4321) + model_dir = tempfile.mkdtemp() + seq_columns = [ + feature_column.real_valued_column( + 'inputs', dimension=num_units) + ] + estimator = ssre.multi_value_rnn_regressor(num_units, + num_unroll, + batch_size, + seq_columns, + learning_rate=learning_rate, + model_dir=model_dir, + queue_capacity=batch_size+2, + seed=1234) + estimator.fit(input_fn=input_fn, steps=steps) + + # TODO(jtbates): move all tests below to a benchmark test. class StateSavingRNNEstimatorLearningTest(test.TestCase): """Learning tests for state saving RNN Estimators.""" @@ -530,7 +580,6 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): num_units = 4 learning_rate = 0.3 loss_threshold = 0.035 - input_key_column_name = 'input_key_column' def get_sin_input_fn(sequence_length, increment, seed=None): @@ -542,11 +591,7 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): sequence_length + 1)) inputs = array_ops.slice(sin_curves, [0], [sequence_length]) labels = array_ops.slice(sin_curves, [1], [sequence_length]) - input_key = string_ops.string_join([ - 'key_', - string_ops.as_string(math_ops.cast(10000 * start, dtypes.int32)) - ]) - return {'inputs': inputs, input_key_column_name: input_key}, labels + return {'inputs': inputs}, labels return input_fn @@ -555,17 +600,18 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): 'inputs', dimension=num_units) ] config = run_config.RunConfig(tf_random_seed=1234) - sequence_estimator = ssre.multi_value_rnn_regressor( + sequence_estimator = ssre.StateSavingRnnEstimator( + constants.ProblemType.LINEAR_REGRESSION, num_units=num_units, num_unroll=num_unroll, batch_size=batch_size, - input_key_column_name=input_key_column_name, sequence_feature_columns=seq_columns, learning_rate=learning_rate, input_keep_probability=0.9, output_keep_probability=0.9, config=config, - queue_capacity=2 * batch_size) + queue_capacity=2 * batch_size, + seed=1234) train_input_fn = get_sin_input_fn(sequence_length, np.pi / 32, seed=1234) eval_input_fn = get_sin_input_fn(sequence_length, np.pi / 32, seed=4321) @@ -592,7 +638,6 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): num_units = 4 learning_rate = 0.5 accuracy_threshold = 0.9 - input_key_column_name = 'input_key_column' def get_shift_input_fn(sequence_length, seed=None): @@ -602,16 +647,7 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): labels = array_ops.slice(random_sequence, [0], [sequence_length]) inputs = math_ops.to_float( array_ops.slice(random_sequence, [1], [sequence_length])) - input_key = string_ops.string_join([ - 'key_', string_ops.as_string( - random_ops.random_uniform( - (), - minval=0, - maxval=10000000, - dtype=dtypes.int32, - seed=seed)) - ]) - return {'inputs': inputs, input_key_column_name: input_key}, labels + return {'inputs': inputs}, labels return input_fn @@ -620,17 +656,18 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): 'inputs', dimension=num_units) ] config = run_config.RunConfig(tf_random_seed=21212) - sequence_estimator = ssre.multi_value_rnn_classifier( - num_classes=num_classes, + sequence_estimator = ssre.StateSavingRnnEstimator( + constants.ProblemType.CLASSIFICATION, num_units=num_units, num_unroll=num_unroll, batch_size=batch_size, - input_key_column_name=input_key_column_name, sequence_feature_columns=seq_columns, + num_classes=num_classes, learning_rate=learning_rate, config=config, predict_probabilities=True, - queue_capacity=2 + batch_size) + queue_capacity=2 + batch_size, + seed=1234) train_input_fn = get_shift_input_fn(sequence_length, seed=12321) eval_input_fn = get_shift_input_fn(sequence_length, seed=32123) @@ -650,11 +687,11 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): self.assertListEqual( sorted(list(prediction_dict.keys())), sorted([ - ssre.RNNKeys.PREDICTIONS_KEY, ssre.RNNKeys.PROBABILITIES_KEY, - ssre._get_state_name(0) + prediction_key.PredictionKey.CLASSES, + prediction_key.PredictionKey.PROBABILITIES, ssre._get_state_name(0) ])) - predictions = prediction_dict[ssre.RNNKeys.PREDICTIONS_KEY] - probabilities = prediction_dict[ssre.RNNKeys.PROBABILITIES_KEY] + predictions = prediction_dict[prediction_key.PredictionKey.CLASSES] + probabilities = prediction_dict[prediction_key.PredictionKey.PROBABILITIES] self.assertListEqual(list(predictions.shape), [batch_size, sequence_length]) self.assertListEqual( list(probabilities.shape), [batch_size, sequence_length, 2]) @@ -672,7 +709,6 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): num_units = 4 learning_rate = 0.4 accuracy_threshold = 0.70 - input_key_column_name = 'input_key_column' def get_lyrics_input_fn(seed): @@ -692,16 +728,7 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): mapping=list(vocab), default_value=-1, name='lookup') labels = table.lookup( array_ops.slice(lyrics_list_concat, [start + 1], [sequence_length])) - input_key = string_ops.string_join([ - 'key_', string_ops.as_string( - random_ops.random_uniform( - (), - minval=0, - maxval=10000000, - dtype=dtypes.int32, - seed=seed)) - ]) - return {'lyrics': inputs, input_key_column_name: input_key}, labels + return {'lyrics': inputs}, labels return input_fn @@ -711,17 +738,18 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): dimension=8) ] config = run_config.RunConfig(tf_random_seed=21212) - sequence_estimator = ssre.multi_value_rnn_classifier( - num_classes=num_classes, + sequence_estimator = ssre.StateSavingRnnEstimator( + constants.ProblemType.CLASSIFICATION, num_units=num_units, num_unroll=num_unroll, batch_size=batch_size, - input_key_column_name=input_key_column_name, sequence_feature_columns=sequence_feature_columns, + num_classes=num_classes, learning_rate=learning_rate, config=config, predict_probabilities=True, - queue_capacity=2 + batch_size) + queue_capacity=2 + batch_size, + seed=1234) train_input_fn = get_lyrics_input_fn(seed=12321) eval_input_fn = get_lyrics_input_fn(seed=32123) |