diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-03-30 11:35:17 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-30 12:47:55 -0700 |
commit | 052e207e4f1076ac6b7464c01a67ecf4f986766d (patch) | |
tree | d84e0d28d9e1f8a424458ac41b284779b78627f9 | |
parent | d82f8f818cec7a1099a7e680a30ba8f5f8ed2589 (diff) |
Revert DynamicRNNEstimator's ability to accept an instance of RNNCell.
As we move to make RNNCells subclasses of Layers, several things become clear:
1. Layers are stateful - once called, they keep track of their variables.
2. Estimators create a new graph for predict/fit/etc.
3. Items #1 and #2 are not compatible, because a Layer modified in .fit()
cannot be used in .predict(): its variables and internal state tensors
are associated with the graph created in fit() -- and cannot be used
with the graph being created in predict(). Strange errors occur.
As a result, Estimators may never accept a Layer instance given this API
conflict.
This PR unblocks the move of RNNCell to be subclasses of tf.layers.Layer.
Change: 151735868
4 files changed, 29 insertions, 71 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index 63ef058caa..525f84d511 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import layers -from tensorflow.contrib import rnn as contrib_rnn from tensorflow.contrib.framework.python.framework import deprecated from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn.estimators import constants @@ -541,21 +540,17 @@ def _get_dynamic_rnn_model_fn( return _dynamic_rnn_model_fn -def _get_dropout_and_num_units(cell_type, - num_units, +def _get_dropout_and_num_units(num_units, num_rnn_layers, input_keep_probability, output_keep_probability): """Helper function for deprecated factory functions.""" dropout_keep_probabilities = None - if isinstance(cell_type, contrib_rnn.RNNCell): - num_units = None - else: - num_units = [num_units for _ in range(num_rnn_layers)] - if input_keep_probability or output_keep_probability: - dropout_keep_probabilities = ([input_keep_probability] - + [1.0] * (num_rnn_layers - 1) - + [output_keep_probability]) + num_units = [num_units for _ in range(num_rnn_layers)] + if input_keep_probability or output_keep_probability: + dropout_keep_probabilities = ([input_keep_probability] + + [1.0] * (num_rnn_layers - 1) + + [output_keep_probability]) return dropout_keep_probabilities, num_units @@ -629,10 +624,8 @@ class DynamicRnnEstimator(estimator.Estimator): num_classes: the number of classes for a classification problem. Only used when `problem_type=ProblemType.CLASSIFICATION`. num_units: A list of integers indicating the number of units in the - `RNNCell`s in each layer. Either `num_units` is specified or `cell_type` - is an instance of `RNNCell`. - cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of - 'basic_rnn,' 'lstm' or 'gru'. + `RNNCell`s in each layer. + cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'. optimizer: The type of optimizer to use. Either a subclass of `Optimizer`, an instance of an `Optimizer`, a callback that returns an optimizer, or a string. Strings must be one of 'Adagrad', 'Adam', @@ -658,8 +651,6 @@ class DynamicRnnEstimator(estimator.Estimator): config: A `RunConfig` instance. Raises: - ValueError: Both or neither of the following are true: (a) `num_units` is - specified and (b) `cell_type` is an instance of `RNNCell`. ValueError: `problem_type` is not one of `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`. ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but @@ -667,12 +658,6 @@ class DynamicRnnEstimator(estimator.Estimator): ValueError: `prediction_type` is not one of `PredictionType.MULTIPLE_VALUE` or `PredictionType.SINGLE_VALUE`. """ - if (num_units is not None) == isinstance(cell_type, contrib_rnn.RNNCell): - raise ValueError( - 'Either num_units is specified OR cell_type is an instance of ' - 'RNNCell. Got num_units = {} and cell_type = {}.'.format( - num_units, cell_type)) - if prediction_type == rnn_common.PredictionType.MULTIPLE_VALUE: name = 'MultiValueDynamicRNN' elif prediction_type == rnn_common.PredictionType.SINGLE_VALUE: @@ -744,8 +729,7 @@ def multi_value_rnn_regressor(num_units, recurrent network and outputs a sequence of continuous values. Args: - num_units: The size of the RNN cells. This argument has no effect - if `cell_type` is an instance of `RNNCell`. + num_units: The size of the RNN cells. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances of classes derived from `FeatureColumn`. @@ -753,8 +737,7 @@ def multi_value_rnn_regressor(num_units, describing context features, i.e., features that apply accross all time steps. All items in the set should be instances of classes derived from `FeatureColumn`. - cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of - 'basic_rnn,' 'lstm' or 'gru'. + cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'. num_rnn_layers: Number of RNN layers. Leave this at its default value 1 if passing a `cell_type` that is already a MultiRNNCell. optimizer_type: The type of optimizer to use. Either a subclass of @@ -782,7 +765,6 @@ def multi_value_rnn_regressor(num_units, An initialized `Estimator`. """ dropout_keep_probabilities, num_units = _get_dropout_and_num_units( - cell_type, num_units, num_rnn_layers, input_keep_probability, @@ -831,8 +813,7 @@ def multi_value_rnn_classifier(num_classes, Args: num_classes: The number of classes for categorization. - num_units: The size of the RNN cells. This argument has no effect - if `cell_type` is an instance of `RNNCell`. + num_units: The size of the RNN cells. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances of classes derived from `FeatureColumn`. @@ -840,8 +821,7 @@ def multi_value_rnn_classifier(num_classes, describing context features, i.e., features that apply accross all time steps. All items in the set should be instances of classes derived from `FeatureColumn`. - cell_type: A subclass of `RNNCell`, an instance of an `RNNCell or one of - 'basic_rnn,' 'lstm' or 'gru'. + cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'. num_rnn_layers: Number of RNN layers. Leave this at its default value 1 if passing a `cell_type` that is already a MultiRNNCell. optimizer_type: The type of optimizer to use. Either a subclass of @@ -871,7 +851,6 @@ def multi_value_rnn_classifier(num_classes, An initialized `Estimator`. """ dropout_keep_probabilities, num_units = _get_dropout_and_num_units( - cell_type, num_units, num_rnn_layers, input_keep_probability, @@ -918,8 +897,7 @@ def single_value_rnn_regressor(num_units, recurrent network and outputs a single continuous values. Args: - num_units: The size of the RNN cells. This argument has no effect - if `cell_type` is an instance of `RNNCell`. + num_units: The size of the RNN cells. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances of classes derived from `FeatureColumn`. @@ -927,8 +905,7 @@ def single_value_rnn_regressor(num_units, describing context features, i.e., features that apply accross all time steps. All items in the set should be instances of classes derived from `FeatureColumn`. - cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of - 'basic_rnn,' 'lstm' or 'gru'. + cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'. num_rnn_layers: Number of RNN layers. Leave this at its default value 1 if passing a `cell_type` that is already a MultiRNNCell. optimizer_type: The type of optimizer to use. Either a subclass of @@ -956,7 +933,6 @@ def single_value_rnn_regressor(num_units, An initialized `Estimator`. """ dropout_keep_probabilities, num_units = _get_dropout_and_num_units( - cell_type, num_units, num_rnn_layers, input_keep_probability, @@ -1005,8 +981,7 @@ def single_value_rnn_classifier(num_classes, Args: num_classes: The number of classes for categorization. - num_units: The size of the RNN cells. This argument has no effect - if `cell_type` is an instance of `RNNCell`. + num_units: The size of the RNN cells. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances of classes derived from `FeatureColumn`. @@ -1014,8 +989,7 @@ def single_value_rnn_classifier(num_classes, describing context features, i.e., features that apply accross all time steps. All items in the set should be instances of classes derived from `FeatureColumn`. - cell_type: A subclass of `RNNCell`, an instance of an `RNNCell or one of - 'basic_rnn,' 'lstm' or 'gru'. + cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'. num_rnn_layers: Number of RNN layers. Leave this at its default value 1 if passing a `cell_type` that is already a MultiRNNCell. optimizer_type: The type of optimizer to use. Either a subclass of @@ -1045,7 +1019,6 @@ def single_value_rnn_classifier(num_classes, An initialized `Estimator`. """ dropout_keep_probabilities, num_units = _get_dropout_and_num_units( - cell_type, num_units, num_rnn_layers, input_keep_probability, diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index ef8358b35e..43b3d2a78f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -387,14 +387,14 @@ class DynamicRnnEstimatorTest(test.TestCase): seq_columns = [feature_column.real_valued_column('inputs', dimension=1)] config = run_config.RunConfig(tf_random_seed=21212) - cell = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.BasicLSTMCell(size) for size in cell_sizes]) + cell_type = 'lstm' sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator( problem_type=constants.ProblemType.CLASSIFICATION, prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE, num_classes=2, + num_units=cell_sizes, sequence_feature_columns=seq_columns, - cell_type=cell, + cell_type=cell_type, learning_rate=learning_rate, config=config, predict_probabilities=True) diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py index ebede1e9f1..f20dc78834 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py +++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py @@ -57,8 +57,8 @@ def _get_single_cell(cell_type, num_units): """Constructs and return a single `RNNCell`. Args: - cell_type: Either a string identifying the `RNNCell` type, a subclass of - `RNNCell` or an instance of an `RNNCell`. + cell_type: Either a string identifying the `RNNCell` type or a subclass of + `RNNCell`. num_units: The number of units in the `RNNCell`. Returns: An initialized `RNNCell`. @@ -66,17 +66,10 @@ def _get_single_cell(cell_type, num_units): ValueError: `cell_type` is an invalid `RNNCell` name. TypeError: `cell_type` is not a string or a subclass of `RNNCell`. """ - if isinstance(cell_type, contrib_rnn.RNNCell): - return cell_type - if isinstance(cell_type, str): - cell_type = _CELL_TYPES.get(cell_type) - if cell_type is None: - raise ValueError('The supported cell types are {}; got {}'.format( - list(_CELL_TYPES.keys()), cell_type)) - if not issubclass(cell_type, contrib_rnn.RNNCell): - raise TypeError( - 'cell_type must be a subclass of RNNCell or one of {}.'.format( - list(_CELL_TYPES.keys()))) + cell_type = _CELL_TYPES.get(cell_type) + if cell_type is None and not issubclass(cell_type, contrib_rnn.RNNCell): + raise ValueError('The supported cell types are {}; got {}'.format( + list(_CELL_TYPES.keys()), cell_type)) return cell_type(num_units=num_units) @@ -90,8 +83,8 @@ def construct_rnn_cell(num_units, cell_type='basic_rnn', Args: num_units: A single `int` or a list/tuple of `int`s. The size of the `RNNCell`s. - cell_type: A string identifying the `RNNCell` type, a subclass of - `RNNCell` or an instance of an `RNNCell`. + cell_type: A string identifying the `RNNCell` type or a subclass of + `RNNCell`. dropout_keep_probabilities: a list of dropout probabilities or `None`. If a list is given, it must have length `len(cell_type) + 1`. diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py index 2707a347b7..e09278bc63 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py @@ -399,8 +399,7 @@ def _get_rnn_model_fn(cell_type, """Creates a state saving RNN model function for an `Estimator`. Args: - cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of - 'basic_rnn,' 'lstm' or 'gru'. + cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'. target_column: An initialized `TargetColumn`, used to calculate prediction and loss. problem_type: `ProblemType.CLASSIFICATION` or @@ -573,8 +572,7 @@ class StateSavingRnnEstimator(estimator.Estimator): num_units: A list of integers indicating the number of units in the `RNNCell`s in each layer. Either `num_units` is specified or `cell_type` is an instance of `RNNCell`. - cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of - 'basic_rnn,' 'lstm' or 'gru'. + cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'. optimizer_type: The type of optimizer to use. Either a subclass of `Optimizer`, an instance of an `Optimizer` or a string. Strings must be one of 'Adagrad', 'Adam', 'Ftrl', Momentum', 'RMSProp', or 'SGD'. @@ -611,12 +609,6 @@ class StateSavingRnnEstimator(estimator.Estimator): ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but `num_classes` is not specified. """ - if (num_units is not None) == isinstance(cell_type, rnn_cell.RNNCell): - raise ValueError( - 'Either num_units is specified OR cell_type is an instance of ' - 'RNNCell. Got num_units = {} and cell_type = {}.'.format( - num_units, cell_type)) - name = 'MultiValueStateSavingRNN' if problem_type == constants.ProblemType.LINEAR_REGRESSION: name += 'Regressor' |