aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-03-30 11:35:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-30 12:47:55 -0700
commit052e207e4f1076ac6b7464c01a67ecf4f986766d (patch)
treed84e0d28d9e1f8a424458ac41b284779b78627f9
parentd82f8f818cec7a1099a7e680a30ba8f5f8ed2589 (diff)
Revert DynamicRNNEstimator's ability to accept an instance of RNNCell.
As we move to make RNNCells subclasses of Layers, several things become clear: 1. Layers are stateful - once called, they keep track of their variables. 2. Estimators create a new graph for predict/fit/etc. 3. Items #1 and #2 are not compatible, because a Layer modified in .fit() cannot be used in .predict(): its variables and internal state tensors are associated with the graph created in fit() -- and cannot be used with the graph being created in predict(). Strange errors occur. As a result, Estimators may never accept a Layer instance given this API conflict. This PR unblocks the move of RNNCell to be subclasses of tf.layers.Layer. Change: 151735868
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py59
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/rnn_common.py23
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py12
4 files changed, 29 insertions, 71 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
index 63ef058caa..525f84d511 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import layers
-from tensorflow.contrib import rnn as contrib_rnn
from tensorflow.contrib.framework.python.framework import deprecated
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn.estimators import constants
@@ -541,21 +540,17 @@ def _get_dynamic_rnn_model_fn(
return _dynamic_rnn_model_fn
-def _get_dropout_and_num_units(cell_type,
- num_units,
+def _get_dropout_and_num_units(num_units,
num_rnn_layers,
input_keep_probability,
output_keep_probability):
"""Helper function for deprecated factory functions."""
dropout_keep_probabilities = None
- if isinstance(cell_type, contrib_rnn.RNNCell):
- num_units = None
- else:
- num_units = [num_units for _ in range(num_rnn_layers)]
- if input_keep_probability or output_keep_probability:
- dropout_keep_probabilities = ([input_keep_probability]
- + [1.0] * (num_rnn_layers - 1)
- + [output_keep_probability])
+ num_units = [num_units for _ in range(num_rnn_layers)]
+ if input_keep_probability or output_keep_probability:
+ dropout_keep_probabilities = ([input_keep_probability]
+ + [1.0] * (num_rnn_layers - 1)
+ + [output_keep_probability])
return dropout_keep_probabilities, num_units
@@ -629,10 +624,8 @@ class DynamicRnnEstimator(estimator.Estimator):
num_classes: the number of classes for a classification problem. Only
used when `problem_type=ProblemType.CLASSIFICATION`.
num_units: A list of integers indicating the number of units in the
- `RNNCell`s in each layer. Either `num_units` is specified or `cell_type`
- is an instance of `RNNCell`.
- cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of
- 'basic_rnn,' 'lstm' or 'gru'.
+ `RNNCell`s in each layer.
+ cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'.
optimizer: The type of optimizer to use. Either a subclass of
`Optimizer`, an instance of an `Optimizer`, a callback that returns an
optimizer, or a string. Strings must be one of 'Adagrad', 'Adam',
@@ -658,8 +651,6 @@ class DynamicRnnEstimator(estimator.Estimator):
config: A `RunConfig` instance.
Raises:
- ValueError: Both or neither of the following are true: (a) `num_units` is
- specified and (b) `cell_type` is an instance of `RNNCell`.
ValueError: `problem_type` is not one of
`ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`.
ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but
@@ -667,12 +658,6 @@ class DynamicRnnEstimator(estimator.Estimator):
ValueError: `prediction_type` is not one of
`PredictionType.MULTIPLE_VALUE` or `PredictionType.SINGLE_VALUE`.
"""
- if (num_units is not None) == isinstance(cell_type, contrib_rnn.RNNCell):
- raise ValueError(
- 'Either num_units is specified OR cell_type is an instance of '
- 'RNNCell. Got num_units = {} and cell_type = {}.'.format(
- num_units, cell_type))
-
if prediction_type == rnn_common.PredictionType.MULTIPLE_VALUE:
name = 'MultiValueDynamicRNN'
elif prediction_type == rnn_common.PredictionType.SINGLE_VALUE:
@@ -744,8 +729,7 @@ def multi_value_rnn_regressor(num_units,
recurrent network and outputs a sequence of continuous values.
Args:
- num_units: The size of the RNN cells. This argument has no effect
- if `cell_type` is an instance of `RNNCell`.
+ num_units: The size of the RNN cells.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
of classes derived from `FeatureColumn`.
@@ -753,8 +737,7 @@ def multi_value_rnn_regressor(num_units,
describing context features, i.e., features that apply accross all time
steps. All items in the set should be instances of classes derived from
`FeatureColumn`.
- cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of
- 'basic_rnn,' 'lstm' or 'gru'.
+ cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'.
num_rnn_layers: Number of RNN layers. Leave this at its default value 1
if passing a `cell_type` that is already a MultiRNNCell.
optimizer_type: The type of optimizer to use. Either a subclass of
@@ -782,7 +765,6 @@ def multi_value_rnn_regressor(num_units,
An initialized `Estimator`.
"""
dropout_keep_probabilities, num_units = _get_dropout_and_num_units(
- cell_type,
num_units,
num_rnn_layers,
input_keep_probability,
@@ -831,8 +813,7 @@ def multi_value_rnn_classifier(num_classes,
Args:
num_classes: The number of classes for categorization.
- num_units: The size of the RNN cells. This argument has no effect
- if `cell_type` is an instance of `RNNCell`.
+ num_units: The size of the RNN cells.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
of classes derived from `FeatureColumn`.
@@ -840,8 +821,7 @@ def multi_value_rnn_classifier(num_classes,
describing context features, i.e., features that apply accross all time
steps. All items in the set should be instances of classes derived from
`FeatureColumn`.
- cell_type: A subclass of `RNNCell`, an instance of an `RNNCell or one of
- 'basic_rnn,' 'lstm' or 'gru'.
+ cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'.
num_rnn_layers: Number of RNN layers. Leave this at its default value 1
if passing a `cell_type` that is already a MultiRNNCell.
optimizer_type: The type of optimizer to use. Either a subclass of
@@ -871,7 +851,6 @@ def multi_value_rnn_classifier(num_classes,
An initialized `Estimator`.
"""
dropout_keep_probabilities, num_units = _get_dropout_and_num_units(
- cell_type,
num_units,
num_rnn_layers,
input_keep_probability,
@@ -918,8 +897,7 @@ def single_value_rnn_regressor(num_units,
recurrent network and outputs a single continuous values.
Args:
- num_units: The size of the RNN cells. This argument has no effect
- if `cell_type` is an instance of `RNNCell`.
+ num_units: The size of the RNN cells.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
of classes derived from `FeatureColumn`.
@@ -927,8 +905,7 @@ def single_value_rnn_regressor(num_units,
describing context features, i.e., features that apply accross all time
steps. All items in the set should be instances of classes derived from
`FeatureColumn`.
- cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of
- 'basic_rnn,' 'lstm' or 'gru'.
+ cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'.
num_rnn_layers: Number of RNN layers. Leave this at its default value 1
if passing a `cell_type` that is already a MultiRNNCell.
optimizer_type: The type of optimizer to use. Either a subclass of
@@ -956,7 +933,6 @@ def single_value_rnn_regressor(num_units,
An initialized `Estimator`.
"""
dropout_keep_probabilities, num_units = _get_dropout_and_num_units(
- cell_type,
num_units,
num_rnn_layers,
input_keep_probability,
@@ -1005,8 +981,7 @@ def single_value_rnn_classifier(num_classes,
Args:
num_classes: The number of classes for categorization.
- num_units: The size of the RNN cells. This argument has no effect
- if `cell_type` is an instance of `RNNCell`.
+ num_units: The size of the RNN cells.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
of classes derived from `FeatureColumn`.
@@ -1014,8 +989,7 @@ def single_value_rnn_classifier(num_classes,
describing context features, i.e., features that apply accross all time
steps. All items in the set should be instances of classes derived from
`FeatureColumn`.
- cell_type: A subclass of `RNNCell`, an instance of an `RNNCell or one of
- 'basic_rnn,' 'lstm' or 'gru'.
+ cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'.
num_rnn_layers: Number of RNN layers. Leave this at its default value 1
if passing a `cell_type` that is already a MultiRNNCell.
optimizer_type: The type of optimizer to use. Either a subclass of
@@ -1045,7 +1019,6 @@ def single_value_rnn_classifier(num_classes,
An initialized `Estimator`.
"""
dropout_keep_probabilities, num_units = _get_dropout_and_num_units(
- cell_type,
num_units,
num_rnn_layers,
input_keep_probability,
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
index ef8358b35e..43b3d2a78f 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
@@ -387,14 +387,14 @@ class DynamicRnnEstimatorTest(test.TestCase):
seq_columns = [feature_column.real_valued_column('inputs', dimension=1)]
config = run_config.RunConfig(tf_random_seed=21212)
- cell = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.BasicLSTMCell(size) for size in cell_sizes])
+ cell_type = 'lstm'
sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
problem_type=constants.ProblemType.CLASSIFICATION,
prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
num_classes=2,
+ num_units=cell_sizes,
sequence_feature_columns=seq_columns,
- cell_type=cell,
+ cell_type=cell_type,
learning_rate=learning_rate,
config=config,
predict_probabilities=True)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py
index ebede1e9f1..f20dc78834 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py
@@ -57,8 +57,8 @@ def _get_single_cell(cell_type, num_units):
"""Constructs and return a single `RNNCell`.
Args:
- cell_type: Either a string identifying the `RNNCell` type, a subclass of
- `RNNCell` or an instance of an `RNNCell`.
+ cell_type: Either a string identifying the `RNNCell` type or a subclass of
+ `RNNCell`.
num_units: The number of units in the `RNNCell`.
Returns:
An initialized `RNNCell`.
@@ -66,17 +66,10 @@ def _get_single_cell(cell_type, num_units):
ValueError: `cell_type` is an invalid `RNNCell` name.
TypeError: `cell_type` is not a string or a subclass of `RNNCell`.
"""
- if isinstance(cell_type, contrib_rnn.RNNCell):
- return cell_type
- if isinstance(cell_type, str):
- cell_type = _CELL_TYPES.get(cell_type)
- if cell_type is None:
- raise ValueError('The supported cell types are {}; got {}'.format(
- list(_CELL_TYPES.keys()), cell_type))
- if not issubclass(cell_type, contrib_rnn.RNNCell):
- raise TypeError(
- 'cell_type must be a subclass of RNNCell or one of {}.'.format(
- list(_CELL_TYPES.keys())))
+ cell_type = _CELL_TYPES.get(cell_type)
+ if cell_type is None and not issubclass(cell_type, contrib_rnn.RNNCell):
+ raise ValueError('The supported cell types are {}; got {}'.format(
+ list(_CELL_TYPES.keys()), cell_type))
return cell_type(num_units=num_units)
@@ -90,8 +83,8 @@ def construct_rnn_cell(num_units, cell_type='basic_rnn',
Args:
num_units: A single `int` or a list/tuple of `int`s. The size of the
`RNNCell`s.
- cell_type: A string identifying the `RNNCell` type, a subclass of
- `RNNCell` or an instance of an `RNNCell`.
+ cell_type: A string identifying the `RNNCell` type or a subclass of
+ `RNNCell`.
dropout_keep_probabilities: a list of dropout probabilities or `None`. If a
list is given, it must have length `len(cell_type) + 1`.
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
index 2707a347b7..e09278bc63 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
@@ -399,8 +399,7 @@ def _get_rnn_model_fn(cell_type,
"""Creates a state saving RNN model function for an `Estimator`.
Args:
- cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of
- 'basic_rnn,' 'lstm' or 'gru'.
+ cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'.
target_column: An initialized `TargetColumn`, used to calculate prediction
and loss.
problem_type: `ProblemType.CLASSIFICATION` or
@@ -573,8 +572,7 @@ class StateSavingRnnEstimator(estimator.Estimator):
num_units: A list of integers indicating the number of units in the
`RNNCell`s in each layer. Either `num_units` is specified or `cell_type`
is an instance of `RNNCell`.
- cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of
- 'basic_rnn,' 'lstm' or 'gru'.
+ cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'.
optimizer_type: The type of optimizer to use. Either a subclass of
`Optimizer`, an instance of an `Optimizer` or a string. Strings must be
one of 'Adagrad', 'Adam', 'Ftrl', Momentum', 'RMSProp', or 'SGD'.
@@ -611,12 +609,6 @@ class StateSavingRnnEstimator(estimator.Estimator):
ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but
`num_classes` is not specified.
"""
- if (num_units is not None) == isinstance(cell_type, rnn_cell.RNNCell):
- raise ValueError(
- 'Either num_units is specified OR cell_type is an instance of '
- 'RNNCell. Got num_units = {} and cell_type = {}.'.format(
- num_units, cell_type))
-
name = 'MultiValueStateSavingRNN'
if problem_type == constants.ProblemType.LINEAR_REGRESSION:
name += 'Regressor'