aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2016-12-08 08:07:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-08 08:23:39 -0800
commitf03bd900ad8316ce0e12eabada3735778d4145e9 (patch)
tree33229aa83c4151c124b474cddd989a2c920ee949
parent1e8902e555b027f79ebaff8e8c31681d541afa8a (diff)
Import tf.contrib.rnn for cells instead of tf.nn.rnn_cell.
Change: 141440475
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py16
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py3
2 files changed, 9 insertions, 10 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
index dc708ec0c5..71df2129ef 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
@@ -22,6 +22,7 @@ import functools
from tensorflow.contrib import layers
from tensorflow.contrib import metrics
+from tensorflow.contrib import rnn as contrib_rnn
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn import metric_spec
@@ -33,7 +34,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
-from tensorflow.python.ops import rnn_cell
from tensorflow.python.training import momentum as momentum_opt
@@ -54,9 +54,9 @@ class RNNKeys(object):
PROBABILITIES_KEY = 'probabilities'
FINAL_STATE_KEY = 'final_state'
-_CELL_TYPES = {'basic_rnn': rnn_cell.BasicRNNCell,
- 'lstm': rnn_cell.LSTMCell,
- 'gru': rnn_cell.GRUCell,}
+_CELL_TYPES = {'basic_rnn': contrib_rnn.BasicRNNCell,
+ 'lstm': contrib_rnn.LSTMCell,
+ 'gru': contrib_rnn.GRUCell,}
def mask_activations_and_labels(activations, labels, sequence_lengths):
@@ -475,7 +475,7 @@ def apply_dropout(
input_keep_probability = 1.0
if output_prob_none:
output_keep_probability = 1.0
- return rnn_cell.DropoutWrapper(
+ return contrib_rnn.DropoutWrapper(
cell, input_keep_probability, output_keep_probability, random_seed)
@@ -643,20 +643,20 @@ def _to_rnn_cell(cell_or_type, num_units, num_layers):
ValueError: `cell_or_type` is an invalid `RNNCell` name.
TypeError: `cell_or_type` is not a string or a subclass of `RNNCell`.
"""
- if isinstance(cell_or_type, rnn_cell.RNNCell):
+ if isinstance(cell_or_type, contrib_rnn.RNNCell):
return cell_or_type
if isinstance(cell_or_type, str):
cell_or_type = _CELL_TYPES.get(cell_or_type)
if cell_or_type is None:
raise ValueError('The supported cell types are {}; got {}'.format(
list(_CELL_TYPES.keys()), cell_or_type))
- if not issubclass(cell_or_type, rnn_cell.RNNCell):
+ if not issubclass(cell_or_type, contrib_rnn.RNNCell):
raise TypeError(
'cell_or_type must be a subclass of RNNCell or one of {}.'.format(
list(_CELL_TYPES.keys())))
cell = cell_or_type(num_units=num_units)
if num_layers > 1:
- cell = rnn_cell.MultiRNNCell(
+ cell = contrib_rnn.MultiRNNCell(
[cell] * num_layers, state_is_tuple=True)
return cell
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
index ab869958d0..67f540d410 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
@@ -24,7 +24,6 @@ import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import dynamic_rnn_estimator
-from tensorflow.python.ops import rnn_cell
class IdentityRNNCell(tf.contrib.rnn.RNNCell):
@@ -87,7 +86,7 @@ class DynamicRnnEstimatorTest(tf.test.TestCase):
def setUp(self):
super(DynamicRnnEstimatorTest, self).setUp()
- self.rnn_cell = rnn_cell.BasicRNNCell(self.NUM_RNN_CELL_UNITS)
+ self.rnn_cell = tf.contrib.rnn.BasicRNNCell(self.NUM_RNN_CELL_UNITS)
self.mock_target_column = MockTargetColumn(
num_label_columns=self.NUM_LABEL_COLUMNS)