aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-01 04:35:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-01 04:53:03 -0800
commit1617ffbc3df4066ff76325e889ce79d12e4b1c0f (patch)
tree522d3f520551e1268bcde521338232ba57ac380d
parentec86b037893fb00be8e9c366a5a6196d89a6dd72 (diff)
Move dynamic_rnn_estimator._construct_rnn_cell() into rnn_common and make it
public. Change: 148875698
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py62
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/rnn_common.py61
2 files changed, 65 insertions, 58 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
index 2c0799a9b7..35f1c3cc63 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
@@ -50,11 +50,6 @@ class RNNKeys(object):
STATE_PREFIX = 'rnn_cell_state'
-_CELL_TYPES = {'basic_rnn': contrib_rnn.BasicRNNCell,
- 'lstm': contrib_rnn.LSTMCell,
- 'gru': contrib_rnn.GRUCell,}
-
-
def _get_state_name(i):
"""Constructs the name string for state component `i`."""
return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i)
@@ -532,7 +527,8 @@ def _get_dynamic_rnn_model_fn(
dropout = (dropout_keep_probabilities
if mode == model_fn.ModeKeys.TRAIN
else None)
- cell = _construct_rnn_cell(cell_type, num_units, dropout)
+ # This class promises to use the cell type selected by that function.
+ cell = rnn_common.construct_rnn_cell(num_units, cell_type, dropout)
initial_state = dict_to_state_tuple(features, cell)
rnn_activations, final_state = construct_rnn(
initial_state,
@@ -589,58 +585,6 @@ def _get_dynamic_rnn_model_fn(
return _dynamic_rnn_model_fn
-def _get_single_cell(cell_type, num_units):
- """Constructs and return an single `RNNCell`.
-
- Args:
- cell_type: Either a string identifying the `RNNCell` type, a subclass of
- `RNNCell` or an instance of an `RNNCell`.
- num_units: The number of units in the `RNNCell`.
- Returns:
- An initialized `RNNCell`.
- Raises:
- ValueError: `cell_type` is an invalid `RNNCell` name.
- TypeError: `cell_type` is not a string or a subclass of `RNNCell`.
- """
- if isinstance(cell_type, contrib_rnn.RNNCell):
- return cell_type
- if isinstance(cell_type, str):
- cell_type = _CELL_TYPES.get(cell_type)
- if cell_type is None:
- raise ValueError('The supported cell types are {}; got {}'.format(
- list(_CELL_TYPES.keys()), cell_type))
- if not issubclass(cell_type, contrib_rnn.RNNCell):
- raise TypeError(
- 'cell_type must be a subclass of RNNCell or one of {}.'.format(
- list(_CELL_TYPES.keys())))
- return cell_type(num_units=num_units)
-
-
-def _construct_rnn_cell(cell_type, num_units, dropout_keep_probabilities):
- """Constructs cells, applies dropout and assembles a `MultiRNNCell`.
-
- Args:
- cell_type: A string identifying the `RNNCell` type, a subclass of
- `RNNCell` or an instance of an `RNNCell`.
- num_units: A single `int` or a list/tuple of `int`s. The size of the
- `RNNCell`s.
- dropout_keep_probabilities: a list of dropout probabilities or `None`. If a
- list is given, it must have length `len(cell_type) + 1`.
-
- Returns:
- An initialized `RNNCell`.
- """
- if not isinstance(num_units, (list, tuple)):
- num_units = (num_units,)
-
- cells = [_get_single_cell(cell_type, n) for n in num_units]
- if dropout_keep_probabilities:
- cells = rnn_common.apply_dropout(cells, dropout_keep_probabilities)
- if len(cells) == 1:
- return cells[0]
- return contrib_rnn.MultiRNNCell(cells)
-
-
def _get_dropout_and_num_units(cell_type,
num_units,
num_rnn_layers,
@@ -695,6 +639,8 @@ class DynamicRnnEstimator(estimator.Estimator):
all state components or none of them. If none are included, then the default
(zero) state is used as an initial state. See the documentation for
`dict_to_state_tuple` and `state_tuple_to_dict` for further details.
+ The input function can call rnn_common.construct_rnn_cell() to obtain the
+ same cell type that this class will select from arguments to __init__.
The `predict()` method of the `Estimator` returns a dictionary with keys
`STATE_PREFIX_i` for `0 <= i < n` where `n` is the number of nested elements
diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py
index b3aea01ffa..51ecfe13dd 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py
@@ -40,6 +40,67 @@ class RNNKeys(object):
STATE_PREFIX = 'rnn_cell_state'
+_CELL_TYPES = {'basic_rnn': contrib_rnn.BasicRNNCell,
+ 'lstm': contrib_rnn.LSTMCell,
+ 'gru': contrib_rnn.GRUCell,}
+
+
+def _get_single_cell(cell_type, num_units):
+ """Constructs and return a single `RNNCell`.
+
+ Args:
+ cell_type: Either a string identifying the `RNNCell` type, a subclass of
+ `RNNCell` or an instance of an `RNNCell`.
+ num_units: The number of units in the `RNNCell`.
+ Returns:
+ An initialized `RNNCell`.
+ Raises:
+ ValueError: `cell_type` is an invalid `RNNCell` name.
+ TypeError: `cell_type` is not a string or a subclass of `RNNCell`.
+ """
+ if isinstance(cell_type, contrib_rnn.RNNCell):
+ return cell_type
+ if isinstance(cell_type, str):
+ cell_type = _CELL_TYPES.get(cell_type)
+ if cell_type is None:
+ raise ValueError('The supported cell types are {}; got {}'.format(
+ list(_CELL_TYPES.keys()), cell_type))
+ if not issubclass(cell_type, contrib_rnn.RNNCell):
+ raise TypeError(
+ 'cell_type must be a subclass of RNNCell or one of {}.'.format(
+ list(_CELL_TYPES.keys())))
+ return cell_type(num_units=num_units)
+
+
+def construct_rnn_cell(num_units, cell_type='basic_rnn',
+ dropout_keep_probabilities=None):
+ """Constructs cells, applies dropout and assembles a `MultiRNNCell`.
+
+ The cell type chosen by DynamicRNNEstimator.__init__() is the same as
+ returned by this function when called with the same arguments.
+
+ Args:
+ num_units: A single `int` or a list/tuple of `int`s. The size of the
+ `RNNCell`s.
+ cell_type: A string identifying the `RNNCell` type, a subclass of
+ `RNNCell` or an instance of an `RNNCell`.
+ dropout_keep_probabilities: a list of dropout probabilities or `None`. If a
+ list is given, it must have length `len(cell_type) + 1`.
+
+ Returns:
+ An initialized `RNNCell`.
+ """
+ if not isinstance(num_units, (list, tuple)):
+ num_units = (num_units,)
+
+ cells = [_get_single_cell(cell_type, n) for n in num_units]
+ if dropout_keep_probabilities:
+ cells = apply_dropout(cells, dropout_keep_probabilities)
+ if len(cells) == 1:
+ return cells[0]
+ return contrib_rnn.MultiRNNCell(cells)
+
+
def apply_dropout(cells, dropout_keep_probabilities, random_seed=None):
"""Applies dropout to the outputs and inputs of `cell`.