diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/rnn_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/rnn_test.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index e32d7c4e67..c72ada11da 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -301,14 +301,12 @@ class RNNTest(test.TestCase): self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias)) def testRNNCellSerialization(self): - for cell in [ + for cell in [ rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True), rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32), - # TODO(scottzhu): GRU and BasicRNN cell are not compatible with Keras. - # rnn_cell_impl.BasicRNNCell( - # 32, activation="relu", dtype=dtypes.float32), - # rnn_cell_impl.GRUCell( - # 32, kernel_initializer="ones", dtype=dtypes.float32) + rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32), + rnn_cell_impl.GRUCell( + 32, kernel_initializer="ones", dtype=dtypes.float32) ]: with self.test_session(): x = keras.Input((None, 5)) @@ -326,11 +324,13 @@ class RNNTest(test.TestCase): # not visible as a Keras layer, and also has a name conflict with # keras.LSTMCell and GRUCell. layer = keras.layers.RNN.from_config( - config, custom_objects={ - # "BasicRNNCell": rnn_cell_impl.BasicRNNCell, - # "GRUCell": rnn_cell_impl.GRUCell, + config, + custom_objects={ + "BasicRNNCell": rnn_cell_impl.BasicRNNCell, + "GRUCell": rnn_cell_impl.GRUCell, "LSTMCell": rnn_cell_impl.LSTMCell, - "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell}) + "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell + }) y = layer(x) model = keras.models.Model(x, y) model.set_weights(weights) |