aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/rnn_cell_test.py')
-rw-r--r--tensorflow/python/kernel_tests/rnn_cell_test.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py
index e4e239169a..cc60e796ba 100644
--- a/tensorflow/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/python/kernel_tests/rnn_cell_test.py
@@ -23,9 +23,10 @@ import functools
import numpy as np
import tensorflow as tf
+from tensorflow.python.ops import rnn_cell_impl
# TODO(ebrevdo): Remove once _linear is fully deprecated.
# pylint: disable=protected-access
-from tensorflow.python.ops.rnn_cell import _linear as linear
+from tensorflow.python.ops.rnn_cell_impl import _linear as linear
# pylint: enable=protected-access
@@ -367,7 +368,7 @@ class SlimRNNCellTest(tf.test.TestCase):
m = tf.zeros([1, 2])
my_cell = functools.partial(basic_rnn_cell, num_units=2)
# pylint: disable=protected-access
- g, _ = tf.nn.rnn_cell._SlimRNNCell(my_cell)(x, m)
+ g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
# pylint: enable=protected-access
sess.run([tf.global_variables_initializer()])
res = sess.run([g], {x.name: np.array([[1., 1.]]),
@@ -384,7 +385,7 @@ class SlimRNNCellTest(tf.test.TestCase):
_, initial_state = basic_rnn_cell(inputs, None, num_units)
my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
# pylint: disable=protected-access
- slim_cell = tf.nn.rnn_cell._SlimRNNCell(my_cell)
+ slim_cell = rnn_cell_impl._SlimRNNCell(my_cell)
# pylint: enable=protected-access
slim_outputs, slim_state = slim_cell(inputs, initial_state)
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units)