aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/_impl/keras/layers/lstm_test.py
blob: f43d90fec8fb4325d808e992060a48562db224a7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for LSTM layer."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.keras._impl import keras
from tensorflow.python.keras._impl.keras import testing_utils
from tensorflow.python.platform import test


class LSTMLayerTest(test.TestCase):

  def test_return_sequences_LSTM(self):
    num_samples = 2
    timesteps = 3
    embedding_dim = 4
    units = 2
    with self.test_session():
      testing_utils.layer_test(
          keras.layers.LSTM,
          kwargs={'units': units,
                  'return_sequences': True},
          input_shape=(num_samples, timesteps, embedding_dim))

  def test_dynamic_behavior_LSTM(self):
    num_samples = 2
    timesteps = 3
    embedding_dim = 4
    units = 2
    with self.test_session():
      layer = keras.layers.LSTM(units, input_shape=(None, embedding_dim))
      model = keras.models.Sequential()
      model.add(layer)
      model.compile('sgd', 'mse')
      x = np.random.random((num_samples, timesteps, embedding_dim))
      y = np.random.random((num_samples, units))
      model.train_on_batch(x, y)

  def test_dropout_LSTM(self):
    num_samples = 2
    timesteps = 3
    embedding_dim = 4
    units = 2
    with self.test_session():
      testing_utils.layer_test(
          keras.layers.LSTM,
          kwargs={'units': units,
                  'dropout': 0.1,
                  'recurrent_dropout': 0.1},
          input_shape=(num_samples, timesteps, embedding_dim))

  def test_implementation_mode_LSTM(self):
    num_samples = 2
    timesteps = 3
    embedding_dim = 4
    units = 2
    with self.test_session():
      for mode in [0, 1, 2]:
        testing_utils.layer_test(
            keras.layers.LSTM,
            kwargs={'units': units,
                    'implementation': mode},
            input_shape=(num_samples, timesteps, embedding_dim))

  def test_statefulness_LSTM(self):
    num_samples = 2
    timesteps = 3
    embedding_dim = 4
    units = 2
    layer_class = keras.layers.LSTM
    with self.test_session():
      model = keras.models.Sequential()
      model.add(
          keras.layers.Embedding(
              4,
              embedding_dim,
              mask_zero=True,
              input_length=timesteps,
              batch_input_shape=(num_samples, timesteps)))
      layer = layer_class(
          units, return_sequences=False, stateful=True, weights=None)
      model.add(layer)
      model.compile(optimizer='sgd', loss='mse')
      out1 = model.predict(np.ones((num_samples, timesteps)))
      self.assertEqual(out1.shape, (num_samples, units))

      # train once so that the states change
      model.train_on_batch(
          np.ones((num_samples, timesteps)), np.ones((num_samples, units)))
      out2 = model.predict(np.ones((num_samples, timesteps)))

      # if the state is not reset, output should be different
      self.assertNotEqual(out1.max(), out2.max())

      # check that output changes after states are reset
      # (even though the model itself didn't change)
      layer.reset_states()
      out3 = model.predict(np.ones((num_samples, timesteps)))
      self.assertNotEqual(out2.max(), out3.max())

      # check that container-level reset_states() works
      model.reset_states()
      out4 = model.predict(np.ones((num_samples, timesteps)))
      self.assertAllClose(out3, out4, atol=1e-5)

      # check that the call to `predict` updated the states
      out5 = model.predict(np.ones((num_samples, timesteps)))
      self.assertNotEqual(out4.max(), out5.max())

      # Check masking
      layer.reset_states()

      left_padded_input = np.ones((num_samples, timesteps))
      left_padded_input[0, :1] = 0
      left_padded_input[1, :2] = 0
      out6 = model.predict(left_padded_input)

      layer.reset_states()

      right_padded_input = np.ones((num_samples, timesteps))
      right_padded_input[0, -1:] = 0
      right_padded_input[1, -2:] = 0
      out7 = model.predict(right_padded_input)

      self.assertAllClose(out7, out6, atol=1e-5)

  def test_regularizers_LSTM(self):
    embedding_dim = 4
    layer_class = keras.layers.LSTM
    with self.test_session():
      layer = layer_class(
          5,
          return_sequences=False,
          weights=None,
          input_shape=(None, embedding_dim),
          kernel_regularizer=keras.regularizers.l1(0.01),
          recurrent_regularizer=keras.regularizers.l1(0.01),
          bias_regularizer='l2',
          activity_regularizer='l1')
      layer.build((None, None, 2))
      self.assertEqual(len(layer.losses), 3)
      layer(keras.backend.variable(np.ones((2, 3, 2))))
      self.assertEqual(len(layer.losses), 4)

  def test_constraints_LSTM(self):
    embedding_dim = 4
    layer_class = keras.layers.LSTM
    with self.test_session():
      k_constraint = keras.constraints.max_norm(0.01)
      r_constraint = keras.constraints.max_norm(0.01)
      b_constraint = keras.constraints.max_norm(0.01)
      layer = layer_class(
          5,
          return_sequences=False,
          weights=None,
          input_shape=(None, embedding_dim),
          kernel_constraint=k_constraint,
          recurrent_constraint=r_constraint,
          bias_constraint=b_constraint)
      layer.build((None, None, embedding_dim))
      self.assertEqual(layer.kernel.constraint, k_constraint)
      self.assertEqual(layer.recurrent_kernel.constraint, r_constraint)
      self.assertEqual(layer.bias.constraint, b_constraint)

  def test_with_masking_layer_LSTM(self):
    layer_class = keras.layers.LSTM
    with self.test_session():
      inputs = np.random.random((2, 3, 4))
      targets = np.abs(np.random.random((2, 3, 5)))
      targets /= targets.sum(axis=-1, keepdims=True)
      model = keras.models.Sequential()
      model.add(keras.layers.Masking(input_shape=(3, 4)))
      model.add(layer_class(units=5, return_sequences=True, unroll=False))
      model.compile(loss='categorical_crossentropy', optimizer='adam')
      model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1)

  def test_from_config_LSTM(self):
    layer_class = keras.layers.LSTM
    for stateful in (False, True):
      l1 = layer_class(units=1, stateful=stateful)
      l2 = layer_class.from_config(l1.get_config())
      assert l1.get_config() == l2.get_config()

  def test_specify_initial_state_keras_tensor(self):
    num_states = 2
    timesteps = 3
    embedding_dim = 4
    units = 3
    num_samples = 2

    with self.test_session():
      # Test with Keras tensor
      inputs = keras.Input((timesteps, embedding_dim))
      initial_state = [keras.Input((units,)) for _ in range(num_states)]
      layer = keras.layers.LSTM(units)
      if len(initial_state) == 1:
        output = layer(inputs, initial_state=initial_state[0])
      else:
        output = layer(inputs, initial_state=initial_state)
      assert initial_state[0] in layer._inbound_nodes[0].input_tensors

      model = keras.models.Model([inputs] + initial_state, output)
      model.compile(loss='categorical_crossentropy', optimizer='adam')

      inputs = np.random.random((num_samples, timesteps, embedding_dim))
      initial_state = [np.random.random((num_samples, units))
                       for _ in range(num_states)]
      targets = np.random.random((num_samples, units))
      model.train_on_batch([inputs] + initial_state, targets)

  def test_specify_initial_state_non_keras_tensor(self):
    num_states = 2
    timesteps = 3
    embedding_dim = 4
    units = 3
    num_samples = 2

    with self.test_session():
      # Test with non-Keras tensor
      inputs = keras.Input((timesteps, embedding_dim))
      initial_state = [keras.backend.random_normal_variable(
          (num_samples, units), 0, 1)
                       for _ in range(num_states)]
      layer = keras.layers.LSTM(units)
      output = layer(inputs, initial_state=initial_state)

      model = keras.models.Model(inputs, output)
      model.compile(loss='categorical_crossentropy', optimizer='adam')

      inputs = np.random.random((num_samples, timesteps, embedding_dim))
      targets = np.random.random((num_samples, units))
      model.train_on_batch(inputs, targets)

  def test_reset_states_with_values(self):
    num_states = 2
    timesteps = 3
    embedding_dim = 4
    units = 3
    num_samples = 2

    with self.test_session():
      layer = keras.layers.LSTM(units, stateful=True)
      layer.build((num_samples, timesteps, embedding_dim))
      layer.reset_states()
      assert len(layer.states) == num_states
      assert layer.states[0] is not None
      self.assertAllClose(
          keras.backend.eval(layer.states[0]),
          np.zeros(keras.backend.int_shape(layer.states[0])),
          atol=1e-4)
      state_shapes = [keras.backend.int_shape(state) for state in layer.states]
      values = [np.ones(shape) for shape in state_shapes]
      if len(values) == 1:
        values = values[0]
      layer.reset_states(values)
      self.assertAllClose(
          keras.backend.eval(layer.states[0]),
          np.ones(keras.backend.int_shape(layer.states[0])),
          atol=1e-4)

      # Test with invalid data
      with self.assertRaises(ValueError):
        layer.reset_states([1] * (len(layer.states) + 1))

  def test_specify_state_with_masking(self):
    num_states = 2
    timesteps = 3
    embedding_dim = 4
    units = 3
    num_samples = 2

    with self.test_session():
      inputs = keras.Input((timesteps, embedding_dim))
      _ = keras.layers.Masking()(inputs)
      initial_state = [keras.Input((units,)) for _ in range(num_states)]
      output = keras.layers.LSTM(units)(inputs, initial_state=initial_state)

      model = keras.models.Model([inputs] + initial_state, output)
      model.compile(loss='categorical_crossentropy', optimizer='adam')

      inputs = np.random.random((num_samples, timesteps, embedding_dim))
      initial_state = [np.random.random((num_samples, units))
                       for _ in range(num_states)]
      targets = np.random.random((num_samples, units))
      model.train_on_batch([inputs] + initial_state, targets)

  def test_return_state(self):
    num_states = 2
    timesteps = 3
    embedding_dim = 4
    units = 3
    num_samples = 2

    with self.test_session():
      inputs = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
      layer = keras.layers.LSTM(units, return_state=True, stateful=True)
      outputs = layer(inputs)
      state = outputs[1:]
      assert len(state) == num_states
      model = keras.models.Model(inputs, state[0])

      inputs = np.random.random((num_samples, timesteps, embedding_dim))
      state = model.predict(inputs)
      self.assertAllClose(keras.backend.eval(layer.states[0]), state, atol=1e-4)

  def test_state_reuse(self):
    timesteps = 3
    embedding_dim = 4
    units = 3
    num_samples = 2

    with self.test_session():
      inputs = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
      layer = keras.layers.LSTM(units, return_state=True, return_sequences=True)
      outputs = layer(inputs)
      output, state = outputs[0], outputs[1:]
      output = keras.layers.LSTM(units)(output, initial_state=state)
      model = keras.models.Model(inputs, output)

      inputs = np.random.random((num_samples, timesteps, embedding_dim))
      outputs = model.predict(inputs)

  def test_initial_states_as_other_inputs(self):
    timesteps = 3
    embedding_dim = 4
    units = 3
    num_samples = 2
    num_states = 2
    layer_class = keras.layers.LSTM

    with self.test_session():
      # Test with Keras tensor
      main_inputs = keras.Input((timesteps, embedding_dim))
      initial_state = [keras.Input((units,)) for _ in range(num_states)]
      inputs = [main_inputs] + initial_state

      layer = layer_class(units)
      output = layer(inputs)
      assert initial_state[0] in layer._inbound_nodes[0].input_tensors

      model = keras.models.Model(inputs, output)
      model.compile(loss='categorical_crossentropy', optimizer='adam')

      main_inputs = np.random.random((num_samples, timesteps, embedding_dim))
      initial_state = [np.random.random((num_samples, units))
                       for _ in range(num_states)]
      targets = np.random.random((num_samples, units))
      model.train_on_batch([main_inputs] + initial_state, targets)


if __name__ == '__main__':
  test.main()