diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-01 04:35:36 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-01 04:53:03 -0800 |
commit | 1617ffbc3df4066ff76325e889ce79d12e4b1c0f (patch) | |
tree | 522d3f520551e1268bcde521338232ba57ac380d | |
parent | ec86b037893fb00be8e9c366a5a6196d89a6dd72 (diff) |
Move dynamic_rnn_estimator._construct_rnn_cell() into rnn_common and make it
public.
Change: 148875698
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py | 62 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/rnn_common.py | 61 |
2 files changed, 65 insertions, 58 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index 2c0799a9b7..35f1c3cc63 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -50,11 +50,6 @@ class RNNKeys(object): STATE_PREFIX = 'rnn_cell_state' -_CELL_TYPES = {'basic_rnn': contrib_rnn.BasicRNNCell, - 'lstm': contrib_rnn.LSTMCell, - 'gru': contrib_rnn.GRUCell,} - - def _get_state_name(i): """Constructs the name string for state component `i`.""" return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i) @@ -532,7 +527,8 @@ def _get_dynamic_rnn_model_fn( dropout = (dropout_keep_probabilities if mode == model_fn.ModeKeys.TRAIN else None) - cell = _construct_rnn_cell(cell_type, num_units, dropout) + # This class promises to use the cell type selected by that function. + cell = rnn_common.construct_rnn_cell(num_units, cell_type, dropout) initial_state = dict_to_state_tuple(features, cell) rnn_activations, final_state = construct_rnn( initial_state, @@ -589,58 +585,6 @@ def _get_dynamic_rnn_model_fn( return _dynamic_rnn_model_fn -def _get_single_cell(cell_type, num_units): - """Constructs and return an single `RNNCell`. - - Args: - cell_type: Either a string identifying the `RNNCell` type, a subclass of - `RNNCell` or an instance of an `RNNCell`. - num_units: The number of units in the `RNNCell`. - Returns: - An initialized `RNNCell`. - Raises: - ValueError: `cell_type` is an invalid `RNNCell` name. - TypeError: `cell_type` is not a string or a subclass of `RNNCell`. - """ - if isinstance(cell_type, contrib_rnn.RNNCell): - return cell_type - if isinstance(cell_type, str): - cell_type = _CELL_TYPES.get(cell_type) - if cell_type is None: - raise ValueError('The supported cell types are {}; got {}'.format( - list(_CELL_TYPES.keys()), cell_type)) - if not issubclass(cell_type, contrib_rnn.RNNCell): - raise TypeError( - 'cell_type must be a subclass of RNNCell or one of {}.'.format( - list(_CELL_TYPES.keys()))) - return cell_type(num_units=num_units) - - -def _construct_rnn_cell(cell_type, num_units, dropout_keep_probabilities): - """Constructs cells, applies dropout and assembles a `MultiRNNCell`. - - Args: - cell_type: A string identifying the `RNNCell` type, a subclass of - `RNNCell` or an instance of an `RNNCell`. - num_units: A single `int` or a list/tuple of `int`s. The size of the - `RNNCell`s. - dropout_keep_probabilities: a list of dropout probabilities or `None`. If a - list is given, it must have length `len(cell_type) + 1`. - - Returns: - An initialized `RNNCell`. - """ - if not isinstance(num_units, (list, tuple)): - num_units = (num_units,) - - cells = [_get_single_cell(cell_type, n) for n in num_units] - if dropout_keep_probabilities: - cells = rnn_common.apply_dropout(cells, dropout_keep_probabilities) - if len(cells) == 1: - return cells[0] - return contrib_rnn.MultiRNNCell(cells) - - def _get_dropout_and_num_units(cell_type, num_units, num_rnn_layers, @@ -695,6 +639,8 @@ class DynamicRnnEstimator(estimator.Estimator): all state components or none of them. If none are included, then the default (zero) state is used as an initial state. See the documentation for `dict_to_state_tuple` and `state_tuple_to_dict` for further details. + The input function can call rnn_common.construct_rnn_cell() to obtain the + same cell type that this class will select from arguments to __init__. The `predict()` method of the `Estimator` returns a dictionary with keys `STATE_PREFIX_i` for `0 <= i < n` where `n` is the number of nested elements diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py index b3aea01ffa..51ecfe13dd 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py +++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py @@ -40,6 +40,67 @@ class RNNKeys(object): STATE_PREFIX = 'rnn_cell_state' +_CELL_TYPES = {'basic_rnn': contrib_rnn.BasicRNNCell, + 'lstm': contrib_rnn.LSTMCell, + 'gru': contrib_rnn.GRUCell,} + + +def _get_single_cell(cell_type, num_units): + """Constructs and return a single `RNNCell`. + + Args: + cell_type: Either a string identifying the `RNNCell` type, a subclass of + `RNNCell` or an instance of an `RNNCell`. + num_units: The number of units in the `RNNCell`. + Returns: + An initialized `RNNCell`. + Raises: + ValueError: `cell_type` is an invalid `RNNCell` name. + TypeError: `cell_type` is not a string or a subclass of `RNNCell`. + """ + if isinstance(cell_type, contrib_rnn.RNNCell): + return cell_type + if isinstance(cell_type, str): + cell_type = _CELL_TYPES.get(cell_type) + if cell_type is None: + raise ValueError('The supported cell types are {}; got {}'.format( + list(_CELL_TYPES.keys()), cell_type)) + if not issubclass(cell_type, contrib_rnn.RNNCell): + raise TypeError( + 'cell_type must be a subclass of RNNCell or one of {}.'.format( + list(_CELL_TYPES.keys()))) + return cell_type(num_units=num_units) + + +def construct_rnn_cell(num_units, cell_type='basic_rnn', + dropout_keep_probabilities=None): + """Constructs cells, applies dropout and assembles a `MultiRNNCell`. + + The cell type chosen by DynamicRNNEstimator.__init__() is the same as + returned by this function when called with the same arguments. + + Args: + num_units: A single `int` or a list/tuple of `int`s. The size of the + `RNNCell`s. + cell_type: A string identifying the `RNNCell` type, a subclass of + `RNNCell` or an instance of an `RNNCell`. + dropout_keep_probabilities: a list of dropout probabilities or `None`. If a + list is given, it must have length `len(cell_type) + 1`. + + Returns: + An initialized `RNNCell`. + """ + if not isinstance(num_units, (list, tuple)): + num_units = (num_units,) + + cells = [_get_single_cell(cell_type, n) for n in num_units] + if dropout_keep_probabilities: + cells = apply_dropout(cells, dropout_keep_probabilities) + if len(cells) == 1: + return cells[0] + return contrib_rnn.MultiRNNCell(cells) + + def apply_dropout(cells, dropout_keep_probabilities, random_seed=None): """Applies dropout to the outputs and inputs of `cell`. |