aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
blob: 12dd72a95b24db968695699b32028899ac9d6a94 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Tests for saving/loading function for keras Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import shutil

from absl.testing import parameterized
import numpy as np

from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model
from tensorflow.python import keras
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import training as training_module


class TestModelSavingandLoading(test.TestCase):

  def _save_model_dir(self, dirname='saved_model'):
    temp_dir = self.get_temp_dir()
    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
    return os.path.join(temp_dir, dirname)

  def test_saving_sequential_model(self):
    with self.cached_session():
      model = keras.models.Sequential()
      model.add(keras.layers.Dense(2, input_shape=(3,)))
      model.add(keras.layers.RepeatVector(3))
      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
      model.compile(
          loss=keras.losses.MSE,
          optimizer=keras.optimizers.RMSprop(lr=0.0001),
          metrics=[keras.metrics.categorical_accuracy],
          sample_weight_mode='temporal')
      x = np.random.random((1, 3))
      y = np.random.random((1, 3, 3))
      model.train_on_batch(x, y)

      ref_y = model.predict(x)

      temp_saved_model = self._save_model_dir()
      output_path = keras_saved_model.save_keras_model(model, temp_saved_model)

      loaded_model = keras_saved_model.load_keras_model(output_path)
      y = loaded_model.predict(x)
      self.assertAllClose(ref_y, y, atol=1e-05)

  @test_util.run_in_graph_and_eager_modes
  def test_saving_sequential_model_without_compile(self):
    with self.cached_session():
      model = keras.models.Sequential()
      model.add(keras.layers.Dense(2, input_shape=(3,)))
      model.add(keras.layers.RepeatVector(3))
      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))

      x = np.random.random((1, 3))
      ref_y = model.predict(x)

      temp_saved_model = self._save_model_dir()
      output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
      loaded_model = keras_saved_model.load_keras_model(output_path)

      y = loaded_model.predict(x)
      self.assertAllClose(ref_y, y, atol=1e-05)

  def test_saving_functional_model(self):
    with self.cached_session():
      inputs = keras.layers.Input(shape=(3,))
      x = keras.layers.Dense(2)(inputs)
      output = keras.layers.Dense(3)(x)

      model = keras.models.Model(inputs, output)
      model.compile(
          loss=keras.losses.MSE,
          optimizer=keras.optimizers.RMSprop(lr=0.0001),
          metrics=[keras.metrics.categorical_accuracy])
      x = np.random.random((1, 3))
      y = np.random.random((1, 3))
      model.train_on_batch(x, y)

      ref_y = model.predict(x)

      temp_saved_model = self._save_model_dir()
      output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
      loaded_model = keras_saved_model.load_keras_model(output_path)

      y = loaded_model.predict(x)
      self.assertAllClose(ref_y, y, atol=1e-05)

  @test_util.run_in_graph_and_eager_modes
  def test_saving_functional_model_without_compile(self):
    with self.cached_session():
      inputs = keras.layers.Input(shape=(3,))
      x = keras.layers.Dense(2)(inputs)
      output = keras.layers.Dense(3)(x)

      model = keras.models.Model(inputs, output)

      x = np.random.random((1, 3))
      y = np.random.random((1, 3))

      ref_y = model.predict(x)

      temp_saved_model = self._save_model_dir()
      output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
      loaded_model = keras_saved_model.load_keras_model(output_path)

      y = loaded_model.predict(x)
      self.assertAllClose(ref_y, y, atol=1e-05)

  @test_util.run_in_graph_and_eager_modes
  def test_saving_with_tf_optimizer(self):
    with self.cached_session():
      model = keras.models.Sequential()
      model.add(keras.layers.Dense(2, input_shape=(3,)))
      model.add(keras.layers.Dense(3))
      model.compile(
          loss='mse',
          optimizer=training_module.RMSPropOptimizer(0.1),
          metrics=['acc'])

      x = np.random.random((1, 3))
      y = np.random.random((1, 3))
      model.train_on_batch(x, y)
      model.train_on_batch(x, y)

      ref_y = model.predict(x)

      temp_saved_model = self._save_model_dir()
      output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
      loaded_model = keras_saved_model.load_keras_model(output_path)
      loaded_model.compile(
          loss='mse',
          optimizer=training_module.RMSPropOptimizer(0.1),
          metrics=['acc'])
      y = loaded_model.predict(x)
      self.assertAllClose(ref_y, y, atol=1e-05)

      # test that new updates are the same with both models
      x = np.random.random((1, 3))
      y = np.random.random((1, 3))

      ref_loss = model.train_on_batch(x, y)
      loss = loaded_model.train_on_batch(x, y)
      self.assertAllClose(ref_loss, loss, atol=1e-05)

      ref_y = model.predict(x)
      y = loaded_model.predict(x)
      self.assertAllClose(ref_y, y, atol=1e-05)

      # test saving/loading again
      temp_saved_model2 = self._save_model_dir('saved_model_2')
      output_path2 = keras_saved_model.save_keras_model(
          loaded_model, temp_saved_model2)
      loaded_model = keras_saved_model.load_keras_model(output_path2)
      y = loaded_model.predict(x)
      self.assertAllClose(ref_y, y, atol=1e-05)

  def test_saving_subclassed_model_raise_error(self):
    # For now, saving subclassed model should raise an error. It should be
    # avoided later with loading from SavedModel.pb.

    class SubclassedModel(training.Model):

      def __init__(self):
        super(SubclassedModel, self).__init__()
        self.layer1 = keras.layers.Dense(3)
        self.layer2 = keras.layers.Dense(1)

      def call(self, inp):
        return self.layer2(self.layer1(inp))

    model = SubclassedModel()

    temp_saved_model = self._save_model_dir()
    with self.assertRaises(NotImplementedError):
      keras_saved_model.save_keras_model(model, temp_saved_model)


class LayerWithLearningPhase(keras.engine.base_layer.Layer):

  def call(self, x):
    phase = keras.backend.learning_phase()
    output = tf_utils.smart_cond(
        phase, lambda: x * 0, lambda: array_ops.identity(x))
    if not context.executing_eagerly():
      output._uses_learning_phase = True  # pylint: disable=protected-access
    return output

  def compute_output_shape(self, input_shape):
    return input_shape


def functional_model(uses_learning_phase):
  inputs = keras.layers.Input(shape=(3,))
  x = keras.layers.Dense(2)(inputs)
  x = keras.layers.Dense(3)(x)
  if uses_learning_phase:
    x = LayerWithLearningPhase()(x)
  return keras.models.Model(inputs, x)


def sequential_model(uses_learning_phase):
  model = keras.models.Sequential()
  model.add(keras.layers.Dense(2, input_shape=(3,)))
  model.add(keras.layers.Dense(3))
  if uses_learning_phase:
    model.add(LayerWithLearningPhase())
  return model


def load_model(sess, path, mode):
  tags = model_fn_lib.EXPORT_TAG_MAP[mode]
  sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
                 if mode == model_fn_lib.ModeKeys.PREDICT else mode)
  meta_graph_def = loader_impl.load(sess, tags, path)
  inputs = {
      k: sess.graph.get_tensor_by_name(v.name)
      for k, v in meta_graph_def.signature_def[sig_def_key].inputs.items()}
  outputs = {
      k: sess.graph.get_tensor_by_name(v.name)
      for k, v in meta_graph_def.signature_def[sig_def_key].outputs.items()}
  return inputs, outputs


@test_util.run_all_in_graph_and_eager_modes
class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):

  def _save_model_dir(self, dirname='saved_model'):
    temp_dir = self.get_temp_dir()
    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
    return os.path.join(temp_dir, dirname)

  @parameterized.parameters(
      (functional_model, True, training_module.AdadeltaOptimizer(), True),
      (functional_model, True, training_module.AdadeltaOptimizer(), False),
      (functional_model, False, None, False),
      (sequential_model, True, training_module.AdadeltaOptimizer(), True),
      (sequential_model, True, training_module.AdadeltaOptimizer(), False),
      (sequential_model, False, None, False))
  def testSaveAndLoadSavedModelExport(
      self, model_builder, uses_learning_phase, optimizer, train_before_export):
    saved_model_path = self._save_model_dir()
    with self.test_session(graph=ops.Graph()):
      input_arr = np.random.random((1, 3))
      target_arr = np.random.random((1, 3))

      model = model_builder(uses_learning_phase)
      if optimizer is not None:
        model.compile(
            loss='mse',
            optimizer=optimizer,
            metrics=['mae'])
        if train_before_export:
          model.train_on_batch(input_arr, target_arr)

        ref_loss, ref_mae = model.evaluate(input_arr, target_arr)

      ref_predict = model.predict(input_arr)

      # Export SavedModel
      output_path = keras_saved_model.save_keras_model(model, saved_model_path)

    input_name = model.input_names[0]
    output_name = model.output_names[0]
    target_name = output_name + '_target'

    # Load predict graph, and test predictions
    with session.Session(graph=ops.Graph()) as sess:
      inputs, outputs = load_model(sess, output_path,
                                   model_fn_lib.ModeKeys.PREDICT)

      predictions = sess.run(outputs[output_name],
                             {inputs[input_name]: input_arr})
      self.assertAllClose(ref_predict, predictions, atol=1e-05)

    if optimizer:
      # Load eval graph, and test predictions, loss and metric values
      with session.Session(graph=ops.Graph()) as sess:
        inputs, outputs = load_model(sess, output_path,
                                     model_fn_lib.ModeKeys.EVAL)

        eval_results = sess.run(outputs, {inputs[input_name]: input_arr,
                                          inputs[target_name]: target_arr})

        self.assertEqual(int(train_before_export),
                         sess.run(training_module.get_global_step()))
        self.assertAllClose(ref_loss, eval_results['loss'], atol=1e-05)
        self.assertAllClose(
            ref_mae, eval_results['metrics/mae/update_op'], atol=1e-05)
        self.assertAllClose(
            ref_predict, eval_results['predictions/' + output_name], atol=1e-05)

      # Load train graph, and check for the train op, and prediction values
      with session.Session(graph=ops.Graph()) as sess:
        inputs, outputs = load_model(sess, output_path,
                                     model_fn_lib.ModeKeys.TRAIN)
        self.assertEqual(int(train_before_export),
                         sess.run(training_module.get_global_step()))
        self.assertIn('loss', outputs)
        self.assertIn('metrics/mae/update_op', outputs)
        self.assertIn('metrics/mae/value', outputs)
        self.assertIn('predictions/' + output_name, outputs)

        # Train for a step
        train_op = ops.get_collection(constants.TRAIN_OP_KEY)
        train_outputs, _ = sess.run(
            [outputs, train_op], {inputs[input_name]: input_arr,
                                  inputs[target_name]: target_arr})
        self.assertEqual(int(train_before_export) + 1,
                         sess.run(training_module.get_global_step()))

        if uses_learning_phase:
          self.assertAllClose(
              [[0, 0, 0]], train_outputs['predictions/' + output_name],
              atol=1e-05)
        else:
          self.assertNotAllClose(
              [[0, 0, 0]], train_outputs['predictions/' + output_name],
              atol=1e-05)

  def testSaveAndLoadSavedModelWithCustomObject(self):
    saved_model_path = self._save_model_dir()
    with session.Session(graph=ops.Graph()) as sess:
      def relu6(x):
        return keras.backend.relu(x, max_value=6)
      inputs = keras.layers.Input(shape=(1,))
      outputs = keras.layers.Activation(relu6)(inputs)
      model = keras.models.Model(inputs, outputs)
      output_path = keras_saved_model.save_keras_model(
          model, saved_model_path, custom_objects={'relu6': relu6})
    with session.Session(graph=ops.Graph()) as sess:
      inputs, outputs = load_model(sess, output_path,
                                   model_fn_lib.ModeKeys.PREDICT)
      input_name = model.input_names[0]
      output_name = model.output_names[0]
      predictions = sess.run(
          outputs[output_name], {inputs[input_name]: [[7], [-3], [4]]})
      self.assertAllEqual([[6], [0], [4]], predictions)

  def testAssertModelCloneSameObjectsIgnoreOptimizer(self):
    input_arr = np.random.random((1, 3))
    target_arr = np.random.random((1, 3))

    model_graph = ops.Graph()
    clone_graph = ops.Graph()

    # Create two models with the same layers but different optimizers.
    with session.Session(graph=model_graph):
      inputs = keras.layers.Input(shape=(3,))
      x = keras.layers.Dense(2)(inputs)
      x = keras.layers.Dense(3)(x)
      model = keras.models.Model(inputs, x)

      model.compile(loss='mse', optimizer=training_module.AdadeltaOptimizer())
      model.train_on_batch(input_arr, target_arr)

    with session.Session(graph=clone_graph):
      inputs = keras.layers.Input(shape=(3,))
      x = keras.layers.Dense(2)(inputs)
      x = keras.layers.Dense(3)(x)
      clone = keras.models.Model(inputs, x)
      clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001))
      clone.train_on_batch(input_arr, target_arr)

    keras_saved_model._assert_same_non_optimizer_objects(
        model, model_graph, clone, clone_graph)

  def testAssertModelCloneSameObjectsThrowError(self):
    input_arr = np.random.random((1, 3))
    target_arr = np.random.random((1, 3))

    model_graph = ops.Graph()
    clone_graph = ops.Graph()

    # Create two models with the same layers but different optimizers.
    with session.Session(graph=model_graph):
      inputs = keras.layers.Input(shape=(3,))
      x = keras.layers.Dense(2)(inputs)
      x = keras.layers.Dense(3)(x)
      model = keras.models.Model(inputs, x)

      model.compile(loss='mse', optimizer=training_module.AdadeltaOptimizer())
      model.train_on_batch(input_arr, target_arr)

    with session.Session(graph=clone_graph):
      inputs = keras.layers.Input(shape=(3,))
      x = keras.layers.Dense(2)(inputs)
      x = keras.layers.Dense(4)(x)
      x = keras.layers.Dense(3)(x)
      clone = keras.models.Model(inputs, x)
      clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001))
      clone.train_on_batch(input_arr, target_arr)

    with self.assertRaisesRegexp(
        errors.InternalError, 'Model and clone must use the same variables.'):
      keras_saved_model._assert_same_non_optimizer_objects(
          model, model_graph, clone, clone_graph)


if __name__ == '__main__':
  test.main()