aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/rnn_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/rnn_test.py')
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index e32d7c4e67..c72ada11da 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -301,14 +301,12 @@ class RNNTest(test.TestCase):
self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias))
def testRNNCellSerialization(self):
- for cell in [
+ for cell in [
rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True),
rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32),
- # TODO(scottzhu): GRU and BasicRNN cell are not compatible with Keras.
- # rnn_cell_impl.BasicRNNCell(
- # 32, activation="relu", dtype=dtypes.float32),
- # rnn_cell_impl.GRUCell(
- # 32, kernel_initializer="ones", dtype=dtypes.float32)
+ rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32),
+ rnn_cell_impl.GRUCell(
+ 32, kernel_initializer="ones", dtype=dtypes.float32)
]:
with self.test_session():
x = keras.Input((None, 5))
@@ -326,11 +324,13 @@ class RNNTest(test.TestCase):
# not visible as a Keras layer, and also has a name conflict with
# keras.LSTMCell and GRUCell.
layer = keras.layers.RNN.from_config(
- config, custom_objects={
- # "BasicRNNCell": rnn_cell_impl.BasicRNNCell,
- # "GRUCell": rnn_cell_impl.GRUCell,
+ config,
+ custom_objects={
+ "BasicRNNCell": rnn_cell_impl.BasicRNNCell,
+ "GRUCell": rnn_cell_impl.GRUCell,
"LSTMCell": rnn_cell_impl.LSTMCell,
- "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell})
+ "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell
+ })
y = layer(x)
model = keras.models.Model(x, y)
model.set_weights(weights)