diff options
author | Scott Zhu <scottzhu@google.com> | 2018-10-01 11:29:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 11:39:18 -0700 |
commit | f0f301f05fb1f1965c966ef57cc390e48d966f12 (patch) | |
tree | 29be99009e738828ee13cbd6d72ebee0b8e772c9 /tensorflow/python/kernel_tests | |
parent | a9b01e8a31a02188bc81349c103f136095f322ac (diff) |
Add deprecation notice for BasicRNNCell, which will be replaced by keras.SimpleRNNCell.
PiperOrigin-RevId: 215249611
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r-- | tensorflow/python/kernel_tests/rnn_test.py | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index 05ad9f6336..2f6963f6b8 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -535,6 +535,45 @@ class RNNTest(test.TestCase): self.assertAllClose(tf_out, k_out) self.assertAllClose(tf_state, k_state) + def testSimpleRNNCellAndBasicRNNCellComparison(self): + input_shape = 10 + output_shape = 5 + timestep = 4 + batch = 20 + (x_train, _), _ = testing_utils.get_test_data( + train_samples=batch, + test_samples=0, + input_shape=(timestep, input_shape), + num_classes=output_shape) + fix_weights_generator = keras.layers.SimpleRNNCell(output_shape) + fix_weights_generator.build((None, input_shape)) + # The SimpleRNNCell contains 3 weights: kernel, recurrent_kernel, and bias + # The BasicRNNCell contains 2 weight: kernel and bias, where kernel is + # zipped [kernel, recurrent_kernel] in SimpleRNNCell. + keras_weights = fix_weights_generator.get_weights() + kernel, recurrent_kernel, bias = keras_weights + tf_weights = [np.concatenate((kernel, recurrent_kernel)), bias] + + with self.test_session(graph=ops_lib.Graph()) as sess: + inputs = array_ops.placeholder( + dtypes.float32, shape=(None, timestep, input_shape)) + cell = keras.layers.SimpleRNNCell(output_shape) + k_out, k_state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32) + cell.set_weights(keras_weights) + [k_out, k_state] = sess.run([k_out, k_state], {inputs: x_train}) + with self.test_session(graph=ops_lib.Graph()) as sess: + inputs = array_ops.placeholder( + dtypes.float32, shape=(None, timestep, input_shape)) + cell = rnn_cell_impl.BasicRNNCell(output_shape) + tf_out, tf_state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32) + cell.set_weights(tf_weights) + [tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train}) + + self.assertAllClose(tf_out, k_out) + self.assertAllClose(tf_state, k_state) + def testBasicLSTMCellInterchangeWithLSTMCell(self): with self.session(graph=ops_lib.Graph()) as sess: basic_cell = rnn_cell_impl.BasicLSTMCell(1) |