diff options
author | Jianwei Xie <xiejw@google.com> | 2016-12-08 08:07:49 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-08 08:23:39 -0800 |
commit | f03bd900ad8316ce0e12eabada3735778d4145e9 (patch) | |
tree | 33229aa83c4151c124b474cddd989a2c920ee949 | |
parent | 1e8902e555b027f79ebaff8e8c31681d541afa8a (diff) |
Import tf.contrib.rnn for cells instead of tf.nn.rnn_cell.
Change: 141440475
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py | 16 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py | 3 |
2 files changed, 9 insertions, 10 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index dc708ec0c5..71df2129ef 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -22,6 +22,7 @@ import functools from tensorflow.contrib import layers from tensorflow.contrib import metrics +from tensorflow.contrib import rnn as contrib_rnn from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec @@ -33,7 +34,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn -from tensorflow.python.ops import rnn_cell from tensorflow.python.training import momentum as momentum_opt @@ -54,9 +54,9 @@ class RNNKeys(object): PROBABILITIES_KEY = 'probabilities' FINAL_STATE_KEY = 'final_state' -_CELL_TYPES = {'basic_rnn': rnn_cell.BasicRNNCell, - 'lstm': rnn_cell.LSTMCell, - 'gru': rnn_cell.GRUCell,} +_CELL_TYPES = {'basic_rnn': contrib_rnn.BasicRNNCell, + 'lstm': contrib_rnn.LSTMCell, + 'gru': contrib_rnn.GRUCell,} def mask_activations_and_labels(activations, labels, sequence_lengths): @@ -475,7 +475,7 @@ def apply_dropout( input_keep_probability = 1.0 if output_prob_none: output_keep_probability = 1.0 - return rnn_cell.DropoutWrapper( + return contrib_rnn.DropoutWrapper( cell, input_keep_probability, output_keep_probability, random_seed) @@ -643,20 +643,20 @@ def _to_rnn_cell(cell_or_type, num_units, num_layers): ValueError: `cell_or_type` is an invalid `RNNCell` name. TypeError: `cell_or_type` is not a string or a subclass of `RNNCell`. """ - if isinstance(cell_or_type, rnn_cell.RNNCell): + if isinstance(cell_or_type, contrib_rnn.RNNCell): return cell_or_type if isinstance(cell_or_type, str): cell_or_type = _CELL_TYPES.get(cell_or_type) if cell_or_type is None: raise ValueError('The supported cell types are {}; got {}'.format( list(_CELL_TYPES.keys()), cell_or_type)) - if not issubclass(cell_or_type, rnn_cell.RNNCell): + if not issubclass(cell_or_type, contrib_rnn.RNNCell): raise TypeError( 'cell_or_type must be a subclass of RNNCell or one of {}.'.format( list(_CELL_TYPES.keys()))) cell = cell_or_type(num_units=num_units) if num_layers > 1: - cell = rnn_cell.MultiRNNCell( + cell = contrib_rnn.MultiRNNCell( [cell] * num_layers, state_is_tuple=True) return cell diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index ab869958d0..67f540d410 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -24,7 +24,6 @@ import numpy as np import tensorflow as tf from tensorflow.contrib.learn.python.learn.estimators import dynamic_rnn_estimator -from tensorflow.python.ops import rnn_cell class IdentityRNNCell(tf.contrib.rnn.RNNCell): @@ -87,7 +86,7 @@ class DynamicRnnEstimatorTest(tf.test.TestCase): def setUp(self): super(DynamicRnnEstimatorTest, self).setUp() - self.rnn_cell = rnn_cell.BasicRNNCell(self.NUM_RNN_CELL_UNITS) + self.rnn_cell = tf.contrib.rnn.BasicRNNCell(self.NUM_RNN_CELL_UNITS) self.mock_target_column = MockTargetColumn( num_label_columns=self.NUM_LABEL_COLUMNS) |