aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar Scott Zhu <scottzhu@google.com>2018-10-01 11:29:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 11:39:18 -0700
commitf0f301f05fb1f1965c966ef57cc390e48d966f12 (patch)
tree29be99009e738828ee13cbd6d72ebee0b8e772c9 /tensorflow/python/kernel_tests
parenta9b01e8a31a02188bc81349c103f136095f322ac (diff)
Add deprecation notice for BasicRNNCell, which will be replaced by keras.SimpleRNNCell.
PiperOrigin-RevId: 215249611
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 05ad9f6336..2f6963f6b8 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -535,6 +535,45 @@ class RNNTest(test.TestCase):
self.assertAllClose(tf_out, k_out)
self.assertAllClose(tf_state, k_state)
+ def testSimpleRNNCellAndBasicRNNCellComparison(self):
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 20
+ (x_train, _), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ fix_weights_generator = keras.layers.SimpleRNNCell(output_shape)
+ fix_weights_generator.build((None, input_shape))
+ # The SimpleRNNCell contains 3 weights: kernel, recurrent_kernel, and bias
+ # The BasicRNNCell contains 2 weight: kernel and bias, where kernel is
+ # zipped [kernel, recurrent_kernel] in SimpleRNNCell.
+ keras_weights = fix_weights_generator.get_weights()
+ kernel, recurrent_kernel, bias = keras_weights
+ tf_weights = [np.concatenate((kernel, recurrent_kernel)), bias]
+
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ cell = keras.layers.SimpleRNNCell(output_shape)
+ k_out, k_state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ cell.set_weights(keras_weights)
+ [k_out, k_state] = sess.run([k_out, k_state], {inputs: x_train})
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ cell = rnn_cell_impl.BasicRNNCell(output_shape)
+ tf_out, tf_state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ cell.set_weights(tf_weights)
+ [tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train})
+
+ self.assertAllClose(tf_out, k_out)
+ self.assertAllClose(tf_state, k_state)
+
def testBasicLSTMCellInterchangeWithLSTMCell(self):
with self.session(graph=ops_lib.Graph()) as sess:
basic_cell = rnn_cell_impl.BasicLSTMCell(1)