diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/rnn_cell_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/rnn_cell_test.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py index e4e239169a..cc60e796ba 100644 --- a/tensorflow/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/python/kernel_tests/rnn_cell_test.py @@ -23,9 +23,10 @@ import functools import numpy as np import tensorflow as tf +from tensorflow.python.ops import rnn_cell_impl # TODO(ebrevdo): Remove once _linear is fully deprecated. # pylint: disable=protected-access -from tensorflow.python.ops.rnn_cell import _linear as linear +from tensorflow.python.ops.rnn_cell_impl import _linear as linear # pylint: enable=protected-access @@ -367,7 +368,7 @@ class SlimRNNCellTest(tf.test.TestCase): m = tf.zeros([1, 2]) my_cell = functools.partial(basic_rnn_cell, num_units=2) # pylint: disable=protected-access - g, _ = tf.nn.rnn_cell._SlimRNNCell(my_cell)(x, m) + g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m) # pylint: enable=protected-access sess.run([tf.global_variables_initializer()]) res = sess.run([g], {x.name: np.array([[1., 1.]]), @@ -384,7 +385,7 @@ class SlimRNNCellTest(tf.test.TestCase): _, initial_state = basic_rnn_cell(inputs, None, num_units) my_cell = functools.partial(basic_rnn_cell, num_units=num_units) # pylint: disable=protected-access - slim_cell = tf.nn.rnn_cell._SlimRNNCell(my_cell) + slim_cell = rnn_cell_impl._SlimRNNCell(my_cell) # pylint: enable=protected-access slim_outputs, slim_state = slim_cell(inputs, initial_state) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units) |