aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-15 12:00:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-15 12:08:43 -0800
commit57fba1aa11a9196dc24e5ee5116dd54f048297de (patch)
tree4d3f3d8f6ce037f9f33e4067717f887811ff4762
parent9b8b1213a3276ebb8a9bbd2e6fd23b5499869f4f (diff)
Refactor state_saving_rnn_estimator.py to use inheritance from Estimator.
Change: 147626256
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py354
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py172
2 files changed, 339 insertions, 187 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
index 8d77684726..f15c717c91 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
@@ -23,12 +23,14 @@ import functools
from tensorflow.contrib import layers
from tensorflow.contrib import metrics
from tensorflow.contrib import rnn as rnn_cell
-from tensorflow.contrib.framework.python.framework import experimental
+from tensorflow.contrib.framework.python.framework import deprecated
from tensorflow.contrib.layers.python.layers import feature_column_ops
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn import metric_spec
+from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import model_fn
+from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.contrib.rnn.python.ops import core_rnn
from tensorflow.contrib.training.python.training import sequence_queueing_state_saver as sqss
from tensorflow.python.framework import dtypes
@@ -40,19 +42,21 @@ from tensorflow.python.training import momentum as momentum_opt
from tensorflow.python.util import nest
-class ProblemType(object):
- REGRESSION = 1
- CLASSIFICATION = 2
-
-
+# NOTE(jtbates): As of February 10, 2017, some of the `RNNKeys` have been
+# removed and replaced with values from `prediction_key.PredictionKey`. The key
+# `RNNKeys.PREDICTIONS_KEY` has been replaced by
+# `prediction_key.PredictionKey.SCORES` for regression and
+# `prediction_key.PredictionKey.CLASSES` for classification. The key
+# `RNNKeys.PROBABILITIES_KEY` has been replaced by
+# `prediction_key.PredictionKey.PROBABILITIES`.
class RNNKeys(object):
- PREDICTIONS_KEY = 'predictions'
- PROBABILITIES_KEY = 'probabilities'
FINAL_STATE_KEY = 'final_state'
LABELS_KEY = '__labels__'
STATE_PREFIX = 'rnn_cell_state'
+# TODO(b/34272579): mask_activations_and_labels is shared with
+# dynamic_rnn_estimator.py. Move it to a common library.
def mask_activations_and_labels(activations, labels, sequence_lengths):
"""Remove entries outside `sequence_lengths` and returned flattened results.
@@ -134,6 +138,8 @@ def construct_state_saving_rnn(cell,
return activations, final_state
+# TODO(jtbates): As per cl/14156248, remove this function and switch from
+# MetricSpec to metric ops.
def _mask_multivalue(sequence_length, metric):
"""Wrapper function that masks values by `sequence_length`.
@@ -160,7 +166,8 @@ def _get_default_metrics(problem_type, sequence_length):
"""Returns default `MetricSpec`s for `problem_type`.
Args:
- problem_type: `ProblemType.CLASSIFICATION` or`ProblemType.REGRESSION`.
+ problem_type: `ProblemType.CLASSIFICATION` or
+ `ProblemType.LINEAR_REGRESSION`.
sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32`
containing the length of each sequence in the batch. If `None`, sequences
are assumed to be unpadded.
@@ -168,22 +175,29 @@ def _get_default_metrics(problem_type, sequence_length):
A `dict` mapping strings to `MetricSpec`s.
"""
default_metrics = {}
- if problem_type == ProblemType.CLASSIFICATION:
+ if problem_type == constants.ProblemType.CLASSIFICATION:
default_metrics['accuracy'] = metric_spec.MetricSpec(
metric_fn=_mask_multivalue(sequence_length, metrics.streaming_accuracy),
- prediction_key=RNNKeys.PREDICTIONS_KEY)
- elif problem_type == ProblemType.REGRESSION:
+ prediction_key=prediction_key.PredictionKey.CLASSES)
+ elif problem_type == constants.ProblemType.LINEAR_REGRESSION:
pass
return default_metrics
+# TODO(b/34272579): _multi_value_predictions is shared with
+# dynamic_rnn_estimator.py. Move it to a common library.
def _multi_value_predictions(
- activations, target_column, predict_probabilities):
+ activations, target_column, problem_type, predict_probabilities):
"""Maps `activations` from the RNN to predictions for multi value models.
If `predict_probabilities` is `False`, this function returns a `dict`
- containing single entry with key `PREDICTIONS_KEY`. If `predict_probabilities`
- is `True`, it will contain a second entry with key `PROBABILITIES_KEY`. The
+ containing single entry with key `prediction_key.PredictionKey.CLASSES` for
+ `problem_type` `ProblemType.CLASSIFICATION` or
+ `prediction_key.PredictionKey.SCORE` for `problem_type`
+ `ProblemType.LINEAR_REGRESSION`.
+
+ If `predict_probabilities` is `True`, it will contain a second entry with key
+ `prediction_key.PredictionKey.PROBABILITIES`. The
value of this entry is a `Tensor` of probabilities with shape
`[batch_size, padded_length, num_classes]`.
@@ -195,6 +209,8 @@ def _multi_value_predictions(
activations: Output from an RNN. Should have dtype `float32` and shape
`[batch_size, padded_length, ?]`.
target_column: An initialized `TargetColumn`, calculate predictions.
+ problem_type: Either `ProblemType.CLASSIFICATION` or
+ `ProblemType.LINEAR_REGRESSION`.
predict_probabilities: A Python boolean, indicating whether probabilities
should be returned. Should only be set to `True` for
classification/logistic regression problems.
@@ -215,15 +231,20 @@ def _multi_value_predictions(
else:
probability_shape = activations_shape
probabilities = array_ops.reshape(
- flat_probabilities, probability_shape, name=RNNKeys.PROBABILITIES_KEY)
- prediction_dict[RNNKeys.PROBABILITIES_KEY] = probabilities
+ flat_probabilities, probability_shape,
+ name=prediction_key.PredictionKey.PROBABILITIES)
+ prediction_dict[
+ prediction_key.PredictionKey.PROBABILITIES] = probabilities
else:
flat_predictions = target_column.logits_to_predictions(
flattened_activations, proba=False)
+ predictions_name = (prediction_key.PredictionKey.CLASSES
+ if problem_type == constants.ProblemType.CLASSIFICATION
+ else prediction_key.PredictionKey.SCORES)
predictions = array_ops.reshape(
flat_predictions, [activations_shape[0], activations_shape[1]],
- name=RNNKeys.PREDICTIONS_KEY)
- prediction_dict[RNNKeys.PREDICTIONS_KEY] = predictions
+ name=predictions_name)
+ prediction_dict[predictions_name] = predictions
return prediction_dict
@@ -267,7 +288,7 @@ def _get_name_or_parent_names(column):
return [column.name]
-def _prepare_features_for_sqss(features, labels, mode, input_key_column_name,
+def _prepare_features_for_sqss(features, labels, mode,
sequence_feature_columns,
context_feature_columns):
"""Prepares features for batching by the SQSS.
@@ -283,9 +304,6 @@ def _prepare_features_for_sqss(features, labels, mode, input_key_column_name,
labels: An iterable of `Tensor`.
mode: Defines whether this is training, evaluation or prediction.
See `ModeKeys`.
- input_key_column_name: Python string, the name of the feature column
- containing a string scalar `Tensor` that serves as a unique key to
- identify the input sequence across minibatches.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
of classes derived from `FeatureColumn`.
@@ -295,25 +313,15 @@ def _prepare_features_for_sqss(features, labels, mode, input_key_column_name,
`FeatureColumn`.
Returns:
- input_key: The string scalar `Tensor` that serves as a unique key to
- identify the input sequence across minibatches.
sequence_features: A dict mapping feature names to sequence features.
context_features: A dict mapping feature names to context features.
Raises:
- ValueError: If `features` does not contain a value for
- `input_key_column_name`.
ValueError: If `features` does not contain a value for every key in
`sequence_feature_columns` or `context_feature_columns`.
"""
- # Pop the input key from the features dict.
- input_key = features.pop(input_key_column_name, None)
- if input_key is None:
- raise ValueError('No key in features for input_key_column_name: ' +
- input_key_column_name)
# Extract sequence features.
-
feature_column_ops._check_supported_sequence_columns(sequence_feature_columns) # pylint: disable=protected-access
sequence_features = {}
for column in sequence_feature_columns:
@@ -337,7 +345,7 @@ def _prepare_features_for_sqss(features, labels, mode, input_key_column_name,
if mode != model_fn.ModeKeys.INFER:
sequence_features[RNNKeys.LABELS_KEY] = labels
- return input_key, sequence_features, context_features
+ return sequence_features, context_features
def _read_batch(cell,
@@ -347,11 +355,11 @@ def _read_batch(cell,
num_unroll,
num_layers,
batch_size,
- input_key_column_name,
sequence_feature_columns,
context_feature_columns=None,
num_threads=3,
- queue_capacity=1000):
+ queue_capacity=1000,
+ seed=None):
"""Reads a batch from a state saving sequence queue.
Args:
@@ -367,9 +375,6 @@ def _read_batch(cell,
many segments.
num_layers: Python integer, number of layers in the RNN.
batch_size: Python integer, the size of the minibatch produced by the SQSS.
- input_key_column_name: Python string, the name of the feature column
- containing a string scalar `Tensor` that serves as a unique key to
- identify input sequence across minibatches.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
of classes derived from `FeatureColumn`.
@@ -383,6 +388,7 @@ def _read_batch(cell,
Needs to be at least `batch_size`. Defaults to 1000. When iterating
over the same input example multiple times reusing their keys the
`queue_capacity` must be smaller than the number of examples.
+ seed: Fixes the random seed used for generating input keys by the SQSS.
Returns:
batch: A `NextQueuedSequenceBatch` containing batch_size `SequenceExample`
@@ -398,12 +404,12 @@ def _read_batch(cell,
states[state_names[i][0]] = array_ops.squeeze(values[i][0], axis=0)
states[state_names[i][1]] = array_ops.squeeze(values[i][1], axis=0)
- input_key, sequences, context = _prepare_features_for_sqss(
- features, labels, mode, input_key_column_name, sequence_feature_columns,
+ sequences, context = _prepare_features_for_sqss(
+ features, labels, mode, sequence_feature_columns,
context_feature_columns)
return sqss.batch_sequences_with_states(
- input_key=input_key,
+ input_key='key',
input_sequences=sequences,
input_context=context,
input_length=None, # infer sequence lengths
@@ -411,6 +417,8 @@ def _read_batch(cell,
num_unroll=num_unroll,
batch_size=batch_size,
pad=True, # pad to a multiple of num_unroll
+ make_keys_unique=True,
+ make_keys_unique_seed=seed,
num_threads=num_threads,
capacity=queue_capacity)
@@ -532,7 +540,6 @@ def _get_rnn_model_fn(cell,
num_threads,
queue_capacity,
batch_size,
- input_key_column_name,
sequence_feature_columns,
context_feature_columns=None,
predict_probabilities=False,
@@ -540,14 +547,16 @@ def _get_rnn_model_fn(cell,
gradient_clipping_norm=None,
input_keep_probability=None,
output_keep_probability=None,
- name='StateSavingRNNModel'):
+ name='StateSavingRNNModel',
+ seed=None):
"""Creates a state saving RNN model function for an `Estimator`.
Args:
cell: An initialized `RNNCell` to be used in the RNN.
target_column: An initialized `TargetColumn`, used to calculate prediction
and loss.
- problem_type: `ProblemType.CLASSIFICATION` or`ProblemType.REGRESSION`.
+ problem_type: `ProblemType.CLASSIFICATION` or
+ `ProblemType.LINEAR_REGRESSION`.
optimizer: A subclass of `Optimizer`, an instance of an `Optimizer` or a
string.
num_unroll: Python integer, how many time steps to unroll at a time.
@@ -561,9 +570,6 @@ def _get_rnn_model_fn(cell,
example multiple times reusing their keys the `queue_capacity` must be
smaller than the number of examples.
batch_size: Python integer, the size of the minibatch produced by the SQSS.
- input_key_column_name: Python string, the name of the feature column
- containing a string scalar `Tensor` that serves as a unique key to
- identify input sequence across minibatches.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
of classes derived from `FeatureColumn`.
@@ -572,7 +578,8 @@ def _get_rnn_model_fn(cell,
steps. All items in the set should be instances of classes derived from
`FeatureColumn`.
predict_probabilities: A boolean indicating whether to predict probabilities
- for all classes. Must only be used with `ProblemType.CLASSIFICATION`.
+ for all classes.
+ Must only be used with `ProblemType.CLASSIFICATION`.
learning_rate: Learning rate used for optimization. This argument has no
effect if `optimizer` is an instance of an `Optimizer`.
gradient_clipping_norm: A float. Gradients will be clipped to this value.
@@ -581,31 +588,32 @@ def _get_rnn_model_fn(cell,
output_keep_probability: Probability to keep outputs of `cell`. If `None`,
no dropout is applied.
name: A string that will be used to create a scope for the RNN.
+ seed: Fixes the random seed used for generating input keys by the SQSS.
Returns:
A model function to be passed to an `Estimator`.
Raises:
- ValueError: `problem_type` is not one of `ProblemType.REGRESSION` or
- `ProblemType.CLASSIFICATION`.
+ ValueError: `problem_type` is not one of
+ `ProblemType.LINEAR_REGRESSION`
+ or `ProblemType.CLASSIFICATION`.
ValueError: `predict_probabilities` is `True` for `problem_type` other
than `ProblemType.CLASSIFICATION`.
ValueError: `num_unroll` is not positive.
- ValueError: `input_key_column_name` is empty.
"""
- if problem_type not in (ProblemType.CLASSIFICATION, ProblemType.REGRESSION):
+ if problem_type not in (constants.ProblemType.CLASSIFICATION,
+ constants.ProblemType.LINEAR_REGRESSION):
raise ValueError(
- 'problem_type must be ProblemType.REGRESSION or '
+ 'problem_type must be ProblemType.LINEAR_REGRESSION or '
'ProblemType.CLASSIFICATION; got {}'.
format(problem_type))
- if problem_type != ProblemType.CLASSIFICATION and predict_probabilities:
+ if (problem_type != constants.ProblemType.CLASSIFICATION and
+ predict_probabilities):
raise ValueError(
'predict_probabilities can only be set to True for problem_type'
' ProblemType.CLASSIFICATION; got {}.'.format(problem_type))
if num_unroll <= 0:
raise ValueError('num_unroll must be positive; got {}.'.format(num_unroll))
- if not input_key_column_name:
- raise ValueError('input_key_column_name must not be empty')
def _rnn_model_fn(features, labels, mode):
"""The model to be passed to an `Estimator`."""
@@ -624,11 +632,11 @@ def _get_rnn_model_fn(cell,
num_unroll=num_unroll,
num_layers=num_layers,
batch_size=batch_size,
- input_key_column_name=input_key_column_name,
sequence_feature_columns=sequence_feature_columns,
context_feature_columns=context_feature_columns,
num_threads=num_threads,
- queue_capacity=queue_capacity)
+ queue_capacity=queue_capacity,
+ seed=seed)
sequence_features = batch.sequences
context_features = batch.context
if mode != model_fn.ModeKeys.INFER:
@@ -644,7 +652,9 @@ def _get_rnn_model_fn(cell,
state_name=state_name)
loss = None # Created below for modes TRAIN and EVAL.
- prediction_dict = _multi_value_predictions(rnn_activations, target_column,
+ prediction_dict = _multi_value_predictions(rnn_activations,
+ target_column,
+ problem_type,
predict_probabilities)
if mode != model_fn.ModeKeys.INFER:
loss = _multi_value_loss(rnn_activations, labels, batch.length,
@@ -707,11 +717,140 @@ def lstm_cell(num_units, num_layers):
])
-@experimental
+class StateSavingRnnEstimator(estimator.Estimator):
+
+ def __init__(self,
+ problem_type,
+ num_units,
+ num_unroll,
+ batch_size,
+ sequence_feature_columns,
+ context_feature_columns=None,
+ num_classes=None,
+ num_rnn_layers=1,
+ optimizer_type='SGD',
+ learning_rate=0.1,
+ predict_probabilities=False,
+ momentum=None,
+ gradient_clipping_norm=5.0,
+ # TODO(jtbates): Support lists of input_keep_probability.
+ input_keep_probability=None,
+ output_keep_probability=None,
+ model_dir=None,
+ config=None,
+ feature_engineering_fn=None,
+ num_threads=3,
+ queue_capacity=1000,
+ seed=None):
+ """Initializes a StateSavingRnnEstimator.
+
+ Args:
+ problem_type: `ProblemType.CLASSIFICATION` or
+ `ProblemType.LINEAR_REGRESSION`.
+ num_units: The size of the RNN cells.
+ num_unroll: Python integer, how many time steps to unroll at a time.
+ The input sequences of length `k` are then split into `k / num_unroll`
+ many segments.
+ batch_size: Python integer, the size of the minibatch.
+ sequence_feature_columns: An iterable containing all the feature columns
+ describing sequence features. All items in the set should be instances
+ of classes derived from `FeatureColumn`.
+ context_feature_columns: An iterable containing all the feature columns
+ describing context features, i.e., features that apply accross all time
+ steps. All items in the set should be instances of classes derived from
+ `FeatureColumn`.
+ num_classes: The number of classes for categorization. Used only and
+ required if `problem_type` is `ProblemType.CLASSIFICATION`
+ num_rnn_layers: Number of RNN layers.
+ optimizer_type: The type of optimizer to use. Either a subclass of
+ `Optimizer`, an instance of an `Optimizer` or a string. Strings must be
+ one of 'Adagrad', 'Adam', 'Ftrl', Momentum', 'RMSProp', or 'SGD'.
+ learning_rate: Learning rate. This argument has no effect if `optimizer`
+ is an instance of an `Optimizer`.
+ predict_probabilities: A boolean indicating whether to predict
+ probabilities for all classes. Used only if `problem_type` is
+ `ProblemType.CLASSIFICATION`.
+ momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
+ gradient_clipping_norm: Parameter used for gradient clipping. If `None`,
+ then no clipping is performed.
+ input_keep_probability: Probability to keep inputs to `cell`. If `None`,
+ no dropout is applied.
+ output_keep_probability: Probability to keep outputs of `cell`. If `None`,
+ no dropout is applied.
+ model_dir: The directory in which to save and restore the model graph,
+ parameters, etc.
+ config: A `RunConfig` instance.
+ feature_engineering_fn: Takes features and labels which are the output of
+ `input_fn` and returns features and labels which will be fed into
+ `model_fn`. Please check `model_fn` for a definition of features and
+ labels.
+ num_threads: The Python integer number of threads enqueuing input examples
+ into a queue. Defaults to 3.
+ queue_capacity: The max capacity of the queue in number of examples.
+ Needs to be at least `batch_size`. Defaults to 1000. When iterating
+ over the same input example multiple times reusing their keys the
+ `queue_capacity` must be smaller than the number of examples.
+ seed: Fixes the random seed used for generating input keys by the SQSS.
+
+ Raises:
+ ValueError: `problem_type` is not one of
+ `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`.
+ ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but
+ `num_classes` is not specified.
+ """
+
+ name = 'MultiValueStateSavingRNN'
+ if problem_type == constants.ProblemType.LINEAR_REGRESSION:
+ name += 'Regressor'
+ target_column = layers.regression_target()
+ elif problem_type == constants.ProblemType.CLASSIFICATION:
+ if not num_classes:
+ raise ValueError('For CLASSIFICATION problem_type, num_classes must be '
+ 'specified.')
+ target_column = layers.multi_class_target(n_classes=num_classes)
+ name += 'Classifier'
+ else:
+ raise ValueError(
+ 'problem_type must be either ProblemType.LINEAR_REGRESSION '
+ 'or ProblemType.CLASSIFICATION; got {}'.format(
+ problem_type))
+
+ if optimizer_type == 'Momentum':
+ optimizer_type = momentum_opt.MomentumOptimizer(learning_rate, momentum)
+
+ cell = lstm_cell(num_units, num_rnn_layers)
+ rnn_model_fn = _get_rnn_model_fn(
+ cell=cell,
+ target_column=target_column,
+ problem_type=problem_type,
+ optimizer=optimizer_type,
+ num_unroll=num_unroll,
+ num_layers=num_rnn_layers,
+ num_threads=num_threads,
+ queue_capacity=queue_capacity,
+ batch_size=batch_size,
+ sequence_feature_columns=sequence_feature_columns,
+ context_feature_columns=context_feature_columns,
+ predict_probabilities=predict_probabilities,
+ learning_rate=learning_rate,
+ gradient_clipping_norm=gradient_clipping_norm,
+ input_keep_probability=input_keep_probability,
+ output_keep_probability=output_keep_probability,
+ name=name,
+ seed=seed)
+
+ super(StateSavingRnnEstimator, self).__init__(
+ model_fn=rnn_model_fn,
+ model_dir=model_dir,
+ config=config,
+ feature_engineering_fn=feature_engineering_fn)
+
+
+@deprecated('2017-04-01', 'multi_value_rnn_regressor is deprecated. '
+ 'Please construct a StateSavingRnnEstimator directly.')
def multi_value_rnn_regressor(num_units,
num_unroll,
batch_size,
- input_key_column_name,
sequence_feature_columns,
context_feature_columns=None,
num_rnn_layers=1,
@@ -725,7 +864,8 @@ def multi_value_rnn_regressor(num_units,
config=None,
feature_engineering_fn=None,
num_threads=3,
- queue_capacity=1000):
+ queue_capacity=1000,
+ seed=None):
"""Creates a RNN `Estimator` that predicts sequences of values.
Args:
@@ -734,9 +874,6 @@ def multi_value_rnn_regressor(num_units,
The input sequences of length `k` are then split into `k / num_unroll`
many segments.
batch_size: Python integer, the size of the minibatch.
- input_key_column_name: Python string, the name of the feature column
- containing a string scalar `Tensor` that serves as a unique key to
- identify input sequence across minibatches.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
of classes derived from `FeatureColumn`.
@@ -771,45 +908,40 @@ def multi_value_rnn_regressor(num_units,
Needs to be at least `batch_size`. Defaults to 1000. When iterating
over the same input example multiple times reusing their keys the
`queue_capacity` must be smaller than the number of examples.
+ seed: Fixes the random seed used for generating input keys by the SQSS.
Returns:
An initialized `Estimator`.
"""
- cell = lstm_cell(num_units, num_rnn_layers)
- target_column = layers.regression_target()
- if optimizer_type == 'Momentum':
- optimizer_type = momentum_opt.MomentumOptimizer(learning_rate, momentum)
- rnn_model_fn = _get_rnn_model_fn(
- cell=cell,
- target_column=target_column,
- problem_type=ProblemType.REGRESSION,
- optimizer=optimizer_type,
- num_unroll=num_unroll,
- num_layers=num_rnn_layers,
- num_threads=num_threads,
- queue_capacity=queue_capacity,
- batch_size=batch_size,
- input_key_column_name=input_key_column_name,
- sequence_feature_columns=sequence_feature_columns,
+ return StateSavingRnnEstimator(
+ constants.ProblemType.LINEAR_REGRESSION,
+ num_units,
+ num_unroll,
+ batch_size,
+ sequence_feature_columns,
context_feature_columns=context_feature_columns,
+ num_classes=None,
+ num_rnn_layers=num_rnn_layers,
+ optimizer_type=optimizer_type,
learning_rate=learning_rate,
+ predict_probabilities=False,
+ momentum=momentum,
gradient_clipping_norm=gradient_clipping_norm,
input_keep_probability=input_keep_probability,
output_keep_probability=output_keep_probability,
- name='MultiValueRnnRegressor')
-
- return estimator.Estimator(
- model_fn=rnn_model_fn,
model_dir=model_dir,
config=config,
- feature_engineering_fn=feature_engineering_fn)
+ feature_engineering_fn=feature_engineering_fn,
+ num_threads=num_threads,
+ queue_capacity=queue_capacity,
+ seed=seed)
-@experimental
+@deprecated('2017-04-01', 'multi_value_rnn_classifier is deprecated. '
+ 'Please construct a StateSavingRnnEstimator directly.')
def multi_value_rnn_classifier(num_classes,
num_units,
num_unroll,
batch_size,
- input_key_column_name,
sequence_feature_columns,
context_feature_columns=None,
num_rnn_layers=1,
@@ -824,7 +956,8 @@ def multi_value_rnn_classifier(num_classes,
config=None,
feature_engineering_fn=None,
num_threads=3,
- queue_capacity=1000):
+ queue_capacity=1000,
+ seed=None):
"""Creates a RNN `Estimator` that predicts sequences of labels.
Args:
@@ -834,9 +967,6 @@ def multi_value_rnn_classifier(num_classes,
The input sequences of length `k` are then split into `k / num_unroll`
many segments.
batch_size: Python integer, the size of the minibatch.
- input_key_column_name: Python string, the name of the feature column
- containing a string scalar `Tensor` that serves as a unique key to
- identify input sequence across minibatches.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
of classes derived from `FeatureColumn`.
@@ -872,35 +1002,29 @@ def multi_value_rnn_classifier(num_classes,
Needs to be at least `batch_size`. Defaults to 1000. When iterating
over the same input example multiple times reusing their keys the
`queue_capacity` must be smaller than the number of examples.
+ seed: Fixes the random seed used for generating input keys by the SQSS.
Returns:
An initialized `Estimator`.
"""
- cell = lstm_cell(num_units, num_rnn_layers)
- target_column = layers.multi_class_target(n_classes=num_classes)
- if optimizer_type == 'Momentum':
- optimizer_type = momentum_opt.MomentumOptimizer(learning_rate, momentum)
- rnn_model_fn = _get_rnn_model_fn(
- cell=cell,
- target_column=target_column,
- problem_type=ProblemType.CLASSIFICATION,
- optimizer=optimizer_type,
- num_unroll=num_unroll,
- num_layers=num_rnn_layers,
- num_threads=num_threads,
- queue_capacity=queue_capacity,
- batch_size=batch_size,
- input_key_column_name=input_key_column_name,
- sequence_feature_columns=sequence_feature_columns,
+ return StateSavingRnnEstimator(
+ constants.ProblemType.CLASSIFICATION,
+ num_units,
+ num_unroll,
+ batch_size,
+ sequence_feature_columns,
context_feature_columns=context_feature_columns,
- predict_probabilities=predict_probabilities,
+ num_classes=num_classes,
+ num_rnn_layers=num_rnn_layers,
+ optimizer_type=optimizer_type,
learning_rate=learning_rate,
+ predict_probabilities=predict_probabilities,
+ momentum=momentum,
gradient_clipping_norm=gradient_clipping_norm,
input_keep_probability=input_keep_probability,
output_keep_probability=output_keep_probability,
- name='MultiValueRnnClassifier')
-
- return estimator.Estimator(
- model_fn=rnn_model_fn,
model_dir=model_dir,
config=config,
- feature_engineering_fn=feature_engineering_fn)
+ feature_engineering_fn=feature_engineering_fn,
+ num_threads=num_threads,
+ queue_capacity=queue_capacity,
+ seed=seed)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
index 3e8bc302d6..4ad6c01fee 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
@@ -31,7 +31,9 @@ import numpy as np
from tensorflow.contrib import lookup
from tensorflow.contrib.layers.python.layers import feature_column
from tensorflow.contrib.layers.python.layers import target_column as target_column_lib
+from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
+from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.learn.python.learn.estimators import state_saving_rnn_estimator as ssre
from tensorflow.python.framework import constant_op
@@ -42,7 +44,6 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -249,13 +250,10 @@ class StateSavingRnnEstimatorTest(test.TestCase):
seq_feature_name = 'seq_feature'
sparse_seq_feature_name = 'wire_cast'
ctx_feature_name = 'ctx_feature'
- input_key_column_name = 'input_key_column'
sequence_length = 4
embedding_dimension = 8
features = {
- input_key_column_name:
- constant_op.constant('input0'),
sparse_seq_feature_name:
sparse_tensor.SparseTensor(
indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0], [1, 1, 1],
@@ -289,8 +287,6 @@ class StateSavingRnnEstimatorTest(test.TestCase):
ctx_feature_name, dimension=1)
]
- expected_input_key = b'input0'
-
expected_sequence = {
ssre.RNNKeys.LABELS_KEY:
np.array([5., 5., 5., 5.]),
@@ -309,8 +305,8 @@ class StateSavingRnnEstimatorTest(test.TestCase):
expected_context = {ctx_feature_name: 2.}
- input_key, sequence, context = ssre._prepare_features_for_sqss(
- features, labels, mode, input_key_column_name, sequence_feature_columns,
+ sequence, context = ssre._prepare_features_for_sqss(
+ features, labels, mode, sequence_feature_columns,
context_feature_columns)
def assert_equal(expected, got):
@@ -326,9 +322,8 @@ class StateSavingRnnEstimatorTest(test.TestCase):
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(data_flow_ops.initialize_all_tables())
- actual_input_key, actual_sequence, actual_context = sess.run(
- [input_key, sequence, context])
- self.assertEqual(expected_input_key, actual_input_key)
+ actual_sequence, actual_context = sess.run(
+ [sequence, context])
assert_equal(expected_sequence, actual_sequence)
assert_equal(expected_context, actual_context)
@@ -396,7 +391,6 @@ class StateSavingRnnEstimatorTest(test.TestCase):
]
features = {
'inputs': constant_op.constant([1., 2., 3.]),
- 'input_key_column': constant_op.constant('input0')
}
labels = constant_op.constant([1., 0., 1.])
model_fn = ssre._get_rnn_model_fn(
@@ -408,9 +402,8 @@ class StateSavingRnnEstimatorTest(test.TestCase):
num_threads=1,
queue_capacity=10,
batch_size=1,
- input_key_column_name='input_key_column',
# Only CLASSIFICATION yields eval metrics to test for.
- problem_type=ssre.ProblemType.CLASSIFICATION,
+ problem_type=constants.ProblemType.CLASSIFICATION,
sequence_feature_columns=seq_columns,
context_feature_columns=None,
learning_rate=0.1)
@@ -444,7 +437,6 @@ class StateSavingRnnEstimatorTest(test.TestCase):
self.assertFalse(model_fn_ops.eval_metric_ops)
def testExport(self):
- input_key_column_name = 'input0'
input_feature_key = 'magic_input_feature_key'
batch_size = 8
cell_size = 4
@@ -460,22 +452,13 @@ class StateSavingRnnEstimatorTest(test.TestCase):
def get_input_fn(mode, seed):
def input_fn():
- input_key = string_ops.string_join([
- 'key_', string_ops.as_string(
- random_ops.random_uniform(
- (),
- minval=0,
- maxval=10000000,
- dtype=dtypes.int32,
- seed=seed))
- ])
features = {}
random_sequence = random_ops.random_uniform(
[sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed)
labels = array_ops.slice(random_sequence, [0], [sequence_length])
inputs = math_ops.to_float(
array_ops.slice(random_sequence, [1], [sequence_length]))
- features = {'inputs': inputs, input_key_column_name: input_key}
+ features = {'inputs': inputs}
if mode == model_fn_lib.ModeKeys.INFER:
input_examples = array_ops.placeholder(dtypes.string)
@@ -488,16 +471,17 @@ class StateSavingRnnEstimatorTest(test.TestCase):
model_dir = tempfile.mkdtemp()
def estimator_fn():
- return ssre.multi_value_rnn_classifier(
- num_classes=num_classes,
+ return ssre.StateSavingRnnEstimator(
+ constants.ProblemType.CLASSIFICATION,
num_units=cell_size,
num_unroll=num_unroll,
batch_size=batch_size,
- input_key_column_name=input_key_column_name,
sequence_feature_columns=seq_columns,
+ num_classes=num_classes,
predict_probabilities=True,
model_dir=model_dir,
- queue_capacity=2 + batch_size)
+ queue_capacity=2 + batch_size,
+ seed=1234)
# Train a bit to create an exportable checkpoint.
estimator_fn().fit(input_fn=get_input_fn(
@@ -516,6 +500,72 @@ class StateSavingRnnEstimatorTest(test.TestCase):
input_feature_key=input_feature_key)
+# Smoke tests to ensure deprecated constructor functions still work.
+class LegacyConstructorTest(test.TestCase):
+
+ def _get_input_fn(self,
+ sequence_length,
+ seed=None):
+ def input_fn():
+ random_sequence = random_ops.random_uniform(
+ [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed)
+ labels = array_ops.slice(random_sequence, [0], [sequence_length])
+ inputs = math_ops.to_float(
+ array_ops.slice(random_sequence, [1], [sequence_length]))
+ return {'inputs': inputs}, labels
+ return input_fn
+
+ def testClassifierConstructor(self):
+ batch_size = 16
+ num_classes = 2
+ num_unroll = 32
+ sequence_length = 32
+ num_units = 4
+ learning_rate = 0.5
+ steps = 100
+ input_fn = self._get_input_fn(sequence_length,
+ seed=1234)
+ model_dir = tempfile.mkdtemp()
+ seq_columns = [
+ feature_column.real_valued_column(
+ 'inputs', dimension=num_units)
+ ]
+ estimator = ssre.multi_value_rnn_classifier(num_classes,
+ num_units,
+ num_unroll,
+ batch_size,
+ seq_columns,
+ learning_rate=learning_rate,
+ model_dir=model_dir,
+ queue_capacity=batch_size+2,
+ seed=1234)
+ estimator.fit(input_fn=input_fn, steps=steps)
+
+ def testRegressorConstructor(self):
+ batch_size = 16
+ num_unroll = 32
+ sequence_length = 32
+ num_units = 4
+ learning_rate = 0.5
+ steps = 100
+ input_fn = self._get_input_fn(sequence_length,
+ seed=4321)
+ model_dir = tempfile.mkdtemp()
+ seq_columns = [
+ feature_column.real_valued_column(
+ 'inputs', dimension=num_units)
+ ]
+ estimator = ssre.multi_value_rnn_regressor(num_units,
+ num_unroll,
+ batch_size,
+ seq_columns,
+ learning_rate=learning_rate,
+ model_dir=model_dir,
+ queue_capacity=batch_size+2,
+ seed=1234)
+ estimator.fit(input_fn=input_fn, steps=steps)
+
+
# TODO(jtbates): move all tests below to a benchmark test.
class StateSavingRNNEstimatorLearningTest(test.TestCase):
"""Learning tests for state saving RNN Estimators."""
@@ -530,7 +580,6 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
num_units = 4
learning_rate = 0.3
loss_threshold = 0.035
- input_key_column_name = 'input_key_column'
def get_sin_input_fn(sequence_length, increment, seed=None):
@@ -542,11 +591,7 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
sequence_length + 1))
inputs = array_ops.slice(sin_curves, [0], [sequence_length])
labels = array_ops.slice(sin_curves, [1], [sequence_length])
- input_key = string_ops.string_join([
- 'key_',
- string_ops.as_string(math_ops.cast(10000 * start, dtypes.int32))
- ])
- return {'inputs': inputs, input_key_column_name: input_key}, labels
+ return {'inputs': inputs}, labels
return input_fn
@@ -555,17 +600,18 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
'inputs', dimension=num_units)
]
config = run_config.RunConfig(tf_random_seed=1234)
- sequence_estimator = ssre.multi_value_rnn_regressor(
+ sequence_estimator = ssre.StateSavingRnnEstimator(
+ constants.ProblemType.LINEAR_REGRESSION,
num_units=num_units,
num_unroll=num_unroll,
batch_size=batch_size,
- input_key_column_name=input_key_column_name,
sequence_feature_columns=seq_columns,
learning_rate=learning_rate,
input_keep_probability=0.9,
output_keep_probability=0.9,
config=config,
- queue_capacity=2 * batch_size)
+ queue_capacity=2 * batch_size,
+ seed=1234)
train_input_fn = get_sin_input_fn(sequence_length, np.pi / 32, seed=1234)
eval_input_fn = get_sin_input_fn(sequence_length, np.pi / 32, seed=4321)
@@ -592,7 +638,6 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
num_units = 4
learning_rate = 0.5
accuracy_threshold = 0.9
- input_key_column_name = 'input_key_column'
def get_shift_input_fn(sequence_length, seed=None):
@@ -602,16 +647,7 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
labels = array_ops.slice(random_sequence, [0], [sequence_length])
inputs = math_ops.to_float(
array_ops.slice(random_sequence, [1], [sequence_length]))
- input_key = string_ops.string_join([
- 'key_', string_ops.as_string(
- random_ops.random_uniform(
- (),
- minval=0,
- maxval=10000000,
- dtype=dtypes.int32,
- seed=seed))
- ])
- return {'inputs': inputs, input_key_column_name: input_key}, labels
+ return {'inputs': inputs}, labels
return input_fn
@@ -620,17 +656,18 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
'inputs', dimension=num_units)
]
config = run_config.RunConfig(tf_random_seed=21212)
- sequence_estimator = ssre.multi_value_rnn_classifier(
- num_classes=num_classes,
+ sequence_estimator = ssre.StateSavingRnnEstimator(
+ constants.ProblemType.CLASSIFICATION,
num_units=num_units,
num_unroll=num_unroll,
batch_size=batch_size,
- input_key_column_name=input_key_column_name,
sequence_feature_columns=seq_columns,
+ num_classes=num_classes,
learning_rate=learning_rate,
config=config,
predict_probabilities=True,
- queue_capacity=2 + batch_size)
+ queue_capacity=2 + batch_size,
+ seed=1234)
train_input_fn = get_shift_input_fn(sequence_length, seed=12321)
eval_input_fn = get_shift_input_fn(sequence_length, seed=32123)
@@ -650,11 +687,11 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
self.assertListEqual(
sorted(list(prediction_dict.keys())),
sorted([
- ssre.RNNKeys.PREDICTIONS_KEY, ssre.RNNKeys.PROBABILITIES_KEY,
- ssre._get_state_name(0)
+ prediction_key.PredictionKey.CLASSES,
+ prediction_key.PredictionKey.PROBABILITIES, ssre._get_state_name(0)
]))
- predictions = prediction_dict[ssre.RNNKeys.PREDICTIONS_KEY]
- probabilities = prediction_dict[ssre.RNNKeys.PROBABILITIES_KEY]
+ predictions = prediction_dict[prediction_key.PredictionKey.CLASSES]
+ probabilities = prediction_dict[prediction_key.PredictionKey.PROBABILITIES]
self.assertListEqual(list(predictions.shape), [batch_size, sequence_length])
self.assertListEqual(
list(probabilities.shape), [batch_size, sequence_length, 2])
@@ -672,7 +709,6 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
num_units = 4
learning_rate = 0.4
accuracy_threshold = 0.70
- input_key_column_name = 'input_key_column'
def get_lyrics_input_fn(seed):
@@ -692,16 +728,7 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
mapping=list(vocab), default_value=-1, name='lookup')
labels = table.lookup(
array_ops.slice(lyrics_list_concat, [start + 1], [sequence_length]))
- input_key = string_ops.string_join([
- 'key_', string_ops.as_string(
- random_ops.random_uniform(
- (),
- minval=0,
- maxval=10000000,
- dtype=dtypes.int32,
- seed=seed))
- ])
- return {'lyrics': inputs, input_key_column_name: input_key}, labels
+ return {'lyrics': inputs}, labels
return input_fn
@@ -711,17 +738,18 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase):
dimension=8)
]
config = run_config.RunConfig(tf_random_seed=21212)
- sequence_estimator = ssre.multi_value_rnn_classifier(
- num_classes=num_classes,
+ sequence_estimator = ssre.StateSavingRnnEstimator(
+ constants.ProblemType.CLASSIFICATION,
num_units=num_units,
num_unroll=num_unroll,
batch_size=batch_size,
- input_key_column_name=input_key_column_name,
sequence_feature_columns=sequence_feature_columns,
+ num_classes=num_classes,
learning_rate=learning_rate,
config=config,
predict_probabilities=True,
- queue_capacity=2 + batch_size)
+ queue_capacity=2 + batch_size,
+ seed=1234)
train_input_fn = get_lyrics_input_fn(seed=12321)
eval_input_fn = get_lyrics_input_fn(seed=32123)