aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py
blob: 718da1367ce69285f37269c5631fa0be2b050c97 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for SavedModelEstimator."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import shutil
import tempfile

from tensorflow.contrib.estimator.python.estimator import export as contrib_export
from tensorflow.contrib.estimator.python.estimator import saved_model_estimator
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import monitored_session
from tensorflow.python.training import training


def dummy_input_fn():
  return dataset_ops.Dataset.from_tensors((
      {'x': constant_op.constant([[1], [-2]], dtype=dtypes.int64)},
      constant_op.constant([[4], [-3]], dtype=dtypes.float32))).repeat()


def dummy_input_fn_features_only():
  return dataset_ops.Dataset.from_tensors(
      {'x': constant_op.constant([[5], [6]], dtype=dtypes.int64)}).repeat()


def dummy_supervised_receiver_fn():
  feature_spec = {
      'x': array_ops.placeholder(
          dtype=dtypes.int64, shape=(2, 1), name='feature_x'),
      }
  label_spec = array_ops.placeholder(
      dtype=dtypes.float32, shape=[2, 1], name='truth')
  return export.build_raw_supervised_input_receiver_fn(
      feature_spec, label_spec)


def dummy_serving_receiver_fn():
  feature_spec = {'x': array_ops.placeholder(
      dtype=dtypes.int64, shape=(2, 1), name='feature_x'),}
  return export.build_raw_serving_input_receiver_fn(feature_spec)


def model_fn_diff_modes(features, labels, mode):
  _, _ = features, labels
  v = variables.Variable(21, name='some_var')
  train_op = None
  loss = constant_op.constant(104)
  if mode == model_fn_lib.ModeKeys.TRAIN:
    loss = constant_op.constant(105)
    predictions = constant_op.constant([501])
    train_op = control_flow_ops.group(
        state_ops.assign_add(training.get_global_step(), 1),
        state_ops.assign_add(v, 3))
  elif mode == model_fn_lib.ModeKeys.EVAL:
    loss = constant_op.constant(106)
    predictions = constant_op.constant([502])
  else:
    loss = constant_op.constant(107)
    predictions = constant_op.constant([503])
  return model_fn_lib.EstimatorSpec(
      mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops={
          'abs_err': metrics_lib.mean_absolute_error(
              constant_op.constant(0), predictions)},
      predictions=predictions)


class SavedModelEstimatorTest(test.TestCase):

  def setUp(self):
    self.tmpdirs = []

  def tearDown(self):
    for tmpdir in self.tmpdirs:
      # gfile.DeleteRecursively fails in the windows cmake test, so use shutil.
      shutil.rmtree(tmpdir, ignore_errors=True)
    self.tmpdirs = []

  def _get_tmp_dir(self):
    tmpdir = tempfile.mkdtemp()
    self.tmpdirs.append(tmpdir)
    return tmpdir

  def _export_estimator(self, train=True, evaluate=True, predict=True,
                        model_fn=model_fn_diff_modes):
    est = estimator.Estimator(model_fn, self._get_tmp_dir())
    est.train(input_fn=dummy_input_fn, steps=10)

    input_receiver_fn_map = {}
    if train:
      input_receiver_fn_map[model_fn_lib.ModeKeys.TRAIN] = (
          dummy_supervised_receiver_fn())
    if evaluate:
      input_receiver_fn_map[model_fn_lib.ModeKeys.EVAL] = (
          dummy_supervised_receiver_fn())
    if predict:
      input_receiver_fn_map[model_fn_lib.ModeKeys.PREDICT] = (
          dummy_serving_receiver_fn())

    export_base_path = self._get_tmp_dir()
    export_dir = contrib_export.export_all_saved_models(
        est, export_base_path, input_receiver_fn_map)
    return export_dir

  def test_load_all_modes(self):
    sme = saved_model_estimator.SavedModelEstimator(
        self._export_estimator(), self._get_tmp_dir())
    sme.train(input_fn=dummy_input_fn, steps=1)
    sme.train(input_fn=dummy_input_fn, steps=2)
    self.assertEqual(13, sme.get_variable_value('global_step'))
    self.assertEqual(60, sme.get_variable_value('some_var'))

    eval_results = sme.evaluate(dummy_input_fn, steps=5)

    self.assertEqual(13, eval_results['global_step'])
    self.assertEqual(106, eval_results['loss'])
    self.assertEqual(502, eval_results['metrics/abs_err'])

    predictions = next(sme.predict(dummy_input_fn_features_only))
    self.assertDictEqual({'output': 503}, predictions)

  def test_load_all_modes_no_train(self):
    """Ensure that all functions can be used without requiring a ckpt."""
    sme = saved_model_estimator.SavedModelEstimator(
        self._export_estimator(), self._get_tmp_dir())
    eval_results = sme.evaluate(dummy_input_fn, steps=5)
    self.assertEqual(10, eval_results['global_step'])
    self.assertEqual(106, eval_results['loss'])
    self.assertEqual(502, eval_results['metrics/abs_err'])

    predictions = next(sme.predict(dummy_input_fn_features_only))
    self.assertDictEqual({'output': 503}, predictions)

  def test_partial_exported_estimator(self):
    sme1 = saved_model_estimator.SavedModelEstimator(
        self._export_estimator(train=False, predict=False), self._get_tmp_dir())
    sme1.evaluate(dummy_input_fn, steps=5)
    with self.assertRaisesRegexp(RuntimeError, 'train mode is not available'):
      sme1.train(input_fn=dummy_input_fn, steps=1)
    with self.assertRaisesRegexp(RuntimeError, 'infer mode is not available'):
      next(sme1.predict(dummy_input_fn_features_only))

    sme2 = saved_model_estimator.SavedModelEstimator(
        self._export_estimator(evaluate=False), self._get_tmp_dir())
    sme2.train(input_fn=dummy_input_fn, steps=1)
    next(sme2.predict(dummy_input_fn_features_only))
    with self.assertRaisesRegexp(RuntimeError, 'eval mode is not available'):
      sme2.evaluate(dummy_input_fn, steps=5)

  def test_with_incorrect_input(self):
    sme = saved_model_estimator.SavedModelEstimator(
        self._export_estimator(), self._get_tmp_dir())

    def bad_shape_input_fn():
      return dataset_ops.Dataset.from_tensors((
          {'x': constant_op.constant([1, 2], dtype=dtypes.int64)},
          constant_op.constant([1, 2], dtype=dtypes.float32)))

    with self.assertRaisesRegexp(ValueError, 'Expected shape'):
      sme.train(bad_shape_input_fn, steps=1)

    def bad_dtype_input_fn():
      return dataset_ops.Dataset.from_tensors((
          {'x': constant_op.constant([[1], [1]], dtype=dtypes.int32)},
          constant_op.constant([[1], [1]], dtype=dtypes.int64)))

    with self.assertRaisesRegexp(ValueError, 'Expected dtype'):
      sme.train(bad_dtype_input_fn, steps=1)

  def test_input_fn_with_global_step(self):
    sme = saved_model_estimator.SavedModelEstimator(
        self._export_estimator(), self._get_tmp_dir())

    def bad_input_fn():
      training.get_or_create_global_step()
      return dataset_ops.Dataset.from_tensors((
          {'x': constant_op.constant([[1], [1]], dtype=dtypes.int64)},
          constant_op.constant([[1], [1]], dtype=dtypes.float32)))

    with self.assertRaisesRegexp(RuntimeError,
                                 'Graph must not contain a global step tensor'):
      sme.train(bad_input_fn, steps=1)

  def test_re_export_saved_model_serving_only(self):
    sme = saved_model_estimator.SavedModelEstimator(
        self._export_estimator(), self._get_tmp_dir())
    sme.train(dummy_input_fn, steps=3)
    self.assertEqual(13, sme.get_variable_value('global_step'))
    self.assertEqual(60, sme.get_variable_value('some_var'))

    predictions = next(sme.predict(dummy_input_fn_features_only))
    self.assertDictEqual({'output': 503}, predictions)

    # Export SavedModel, and test that the variable and prediction values are
    # the same.
    sme_export_dir = sme.export_savedmodel(
        self._get_tmp_dir(), dummy_serving_receiver_fn())

    sme2 = saved_model_estimator.SavedModelEstimator(
        sme_export_dir, self._get_tmp_dir())
    self.assertEqual(60, sme.get_variable_value('some_var'))
    self.assertEqual(13, sme.get_variable_value('global_step'))

    predictions = next(sme2.predict(dummy_input_fn_features_only))
    self.assertDictEqual({'output': 503}, predictions)

  def test_re_export_saved_model(self):
    sme = saved_model_estimator.SavedModelEstimator(
        self._export_estimator(), self._get_tmp_dir())
    self.assertDictEqual(
        {'loss': 106, 'metrics/abs_err': 502, 'global_step': 10},
        sme.evaluate(dummy_input_fn, steps=1))

    sme.train(dummy_input_fn, steps=3)
    self.assertDictEqual(
        {'loss': 106, 'metrics/abs_err': 502, 'global_step': 13},
        sme.evaluate(dummy_input_fn, steps=1))
    self.assertEqual(60, sme.get_variable_value('some_var'))

    predictions = next(sme.predict(dummy_input_fn_features_only))
    self.assertDictEqual({'output': 503}, predictions)

    # Export SavedModel for all modes
    input_receiver_fn_map = {
        model_fn_lib.ModeKeys.TRAIN: dummy_supervised_receiver_fn(),
        model_fn_lib.ModeKeys.EVAL: dummy_supervised_receiver_fn(),
        model_fn_lib.ModeKeys.PREDICT: dummy_serving_receiver_fn()}
    sme_export_dir = contrib_export.export_all_saved_models(
        sme, self._get_tmp_dir(), input_receiver_fn_map)

    sme2 = saved_model_estimator.SavedModelEstimator(
        sme_export_dir, self._get_tmp_dir())
    self.assertDictEqual(
        {'loss': 106, 'metrics/abs_err': 502, 'global_step': 13},
        sme.evaluate(dummy_input_fn, steps=1))
    self.assertEqual(60, sme.get_variable_value('some_var'))

    sme.train(dummy_input_fn, steps=7)
    self.assertEqual(20, sme.get_variable_value('global_step'))

    predictions = next(sme2.predict(dummy_input_fn_features_only))
    self.assertDictEqual({'output': 503}, predictions)

  def test_load_saved_model_from_serving_only(self):
    def model_fn(features, labels, mode):
      _, _ = features, labels
      return model_fn_lib.EstimatorSpec(
          mode,
          loss=constant_op.constant([103]),
          train_op=state_ops.assign_add(training.get_global_step(), 1),
          predictions=constant_op.constant([502]),
          export_outputs={'test': export_output.ClassificationOutput(
              constant_op.constant([[32.]]))})

    est = estimator.Estimator(model_fn, self._get_tmp_dir())
    est.train(input_fn=dummy_input_fn, steps=10)

    def serving_input_receiver_fn():
      return export.ServingInputReceiver(
          {'test-features': constant_op.constant([[1], [1]])},
          array_ops.placeholder(dtype=dtypes.string))

    export_dir = est.export_savedmodel(
        self._get_tmp_dir(), serving_input_receiver_fn)

    sme = saved_model_estimator.SavedModelEstimator(
        export_dir, self._get_tmp_dir())

    def input_fn():
      return {'inputs': constant_op.constant('someinputstr')}

    prediction = next(sme.predict(input_fn))
    self.assertDictEqual({'scores': 32}, prediction)

  def test_with_local_init_op(self):
    def model_fn(features, labels, mode):
      _, _ = features, labels
      v = variables.Variable(21, name='some_var')
      scaffold = monitored_session.Scaffold(
          local_init_op=state_ops.assign_add(v, -3).op
      )
      return model_fn_lib.EstimatorSpec(
          mode,
          scaffold=scaffold,
          train_op=state_ops.assign_add(training.get_global_step(), 1),
          loss=array_ops.identity(v))
    export_dir = self._export_estimator(predict=False, model_fn=model_fn)
    sme = saved_model_estimator.SavedModelEstimator(
        export_dir, self._get_tmp_dir())

    eval_results1 = sme.evaluate(dummy_input_fn, steps=2)
    self.assertEqual(15, eval_results1['loss'])

    sme.train(dummy_input_fn, steps=1)
    self.assertEqual(15, sme.get_variable_value('some_var'))

    eval_results2 = sme.evaluate(dummy_input_fn, steps=5)
    self.assertEqual(12, eval_results2['loss'])

  def test_with_working_input_fn(self):
    def model_fn(features, labels, mode):
      loss = None
      if labels is not None:
        loss = labels[0][0] + labels[1][0]
      return model_fn_lib.EstimatorSpec(
          mode,
          loss=loss,
          train_op=state_ops.assign_add(training.get_global_step(), 1),
          predictions={'features_0': array_ops.identity([features['x'][0][0]]),
                       'features_1': array_ops.identity([features['x'][1][0]])})

    sme = saved_model_estimator.SavedModelEstimator(
        self._export_estimator(model_fn=model_fn), self._get_tmp_dir())
    eval_results = sme.evaluate(dummy_input_fn, steps=1)
    self.assertEqual(1, eval_results['loss'])

    predictions = next(sme.predict(dummy_input_fn_features_only))
    self.assertDictEqual({'features_0': 5, 'features_1': 6}, predictions)

  def test_control_dependency(self):
    # Control dependencies are saved with "^" appended to the start of the input
    # name. The input map must include control dependencies as well.
    def model_fn(features, labels, mode):
      _ = labels
      with ops.control_dependencies([features['x']]):
        loss = features['x'][1][0]
      return model_fn_lib.EstimatorSpec(
          mode,
          loss=loss,
          train_op=state_ops.assign_add(training.get_global_step(), 1))
    sme = saved_model_estimator.SavedModelEstimator(
        self._export_estimator(train=False, predict=False, model_fn=model_fn),
        self._get_tmp_dir())
    sme.evaluate(dummy_input_fn, steps=1)  # Should run without error


if __name__ == '__main__':
  test.main()