aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/training.py
blob: e38b765da52a7b6957a4fb8a02087c5d1fd5a781 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes and functions related to train_and_evaluate."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import json
import os
import time

import six

from tensorflow.core.protobuf import config_pb2
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import exporter as exporter_lib
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export

_MAX_DELAY_SECS = 60
_DELAY_SECS_PER_WORKER = 5
_TF_CONFIG_ENV = 'TF_CONFIG'
_ENVIRONMENT_KEY = 'environment'
_ENVIRONMENT_GOOGLE_VALUE = 'google'
_TRAINER_JOBS = (run_config_lib.TaskType.CHIEF, run_config_lib.TaskType.MASTER,
                 run_config_lib.TaskType.WORKER)


def _validate_input_fn(input_fn):
  """Validates the `input_fn`."""
  if not callable(input_fn):
    raise TypeError('`input_fn` must be callable, given: {}'.format(input_fn))


def _validate_hooks(hooks):
  """Validates the `hooks`."""
  hooks = tuple(hooks or [])
  for hook in hooks:
    if not isinstance(hook, session_run_hook.SessionRunHook):
      raise TypeError(
          'All hooks must be `SessionRunHook` instances, given: {}'.format(
              hook))
  return hooks


def _validate_exporters(exporters):
  """Validates `exporters` and returns them as a tuple."""
  if not exporters:
    return ()

  if isinstance(exporters, exporter_lib.Exporter):
    exporters = [exporters]

  unique_names = []  # `Exporter`s should have unique names.
  try:
    for exporter in exporters:
      if not isinstance(exporter, exporter_lib.Exporter):
        # Error message will be printed out by the outer try/except.
        raise TypeError

      if not exporter.name:
        full_list_of_names = [e.name for e in exporters]
        raise ValueError('An Exporter cannot have a name that is `None` or'
                         ' empty. All exporter names:'
                         ' {}'.format(full_list_of_names))

      if not isinstance(exporter.name, six.string_types):
        raise ValueError('An Exporter must have a string name. Given: '
                         '{}'.format(type(exporter.name)))

      if exporter.name in unique_names:
        full_list_of_names = [e.name for e in exporters]
        raise ValueError(
            '`exporters` must have unique names. Such a name cannot be `None`.'
            ' All exporter names: {}'.format(full_list_of_names))
      unique_names.append(exporter.name)
  except TypeError:
    # Two possibilities:
    # - `exporters` is neither `Exporter` nor iterable.  Python has
    #   raised a `TypeError` when iterating over `exporters`.
    # - an `exporter` was None or not of type `Exporter`, so we raised a
    #   `TypeError`.
    raise TypeError('`exporters` must be an Exporter,'
                    ' an iterable of Exporter, or `None`,'
                    ' found %s.' % exporters)

  return tuple(exporters)


def _is_google_env():
  """Detects whether current environment is google."""
  tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV) or '{}')
  if not tf_config:
    logging.warn('TF_CONFIG should not be empty in distributed environment.')
  return tf_config.get(_ENVIRONMENT_KEY) == _ENVIRONMENT_GOOGLE_VALUE


@tf_export('estimator.TrainSpec')
class TrainSpec(
    collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])):
  """Configuration for the "train" part for the `train_and_evaluate` call.

  `TrainSpec` determines the input data for the training, as well as the
  duration. Optional hooks run at various stages of training.
  """

  def __new__(cls, input_fn, max_steps=None, hooks=None):
    """Creates a validated `TrainSpec` instance.

    Args:
      input_fn: A function that provides input data for training as minibatches.
        See @{$get_started/premade_estimators#create_input_functions} for more
        information. The function should construct and return one of
        the following:
          * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
            tuple (features, labels) with same constraints as below.
          * A tuple (features, labels): Where features is a `Tensor` or a
            dictionary of string feature name to `Tensor` and labels is a
            `Tensor` or a dictionary of string label name to `Tensor`.
            
      max_steps: Int. Positive number of total steps for which to train model.
        If `None`, train forever. The training `input_fn` is not expected to
        generate `OutOfRangeError` or `StopIteration` exceptions. See the
        `train_and_evaluate` stop condition section for details.
      hooks: Iterable of `tf.train.SessionRunHook` objects to run
        on all workers (including chief) during training.

    Returns:
      A validated `TrainSpec` object.

    Raises:
      ValueError: If any of the input arguments is invalid.
      TypeError: If any of the arguments is not of the expected type.
    """
    # Validate input_fn.
    _validate_input_fn(input_fn)

    # Validate max_steps.
    if max_steps is not None and max_steps <= 0:
      raise ValueError(
          'Must specify max_steps > 0, given: {}'.format(max_steps))

    # Validate hooks.
    hooks = _validate_hooks(hooks)

    return super(TrainSpec, cls).__new__(
        cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks)


@tf_export('estimator.EvalSpec')
class EvalSpec(
    collections.namedtuple('EvalSpec', [
        'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs',
        'throttle_secs'
    ])):
  """Configuration for the "eval" part for the `train_and_evaluate` call.

  `EvalSpec` combines details of evaluation of the trained model as well as its
  export. Evaluation consists of computing metrics to judge the performance of
  the trained model.  Export writes out the trained model on to external
  storage.
  """

  def __new__(cls,
              input_fn,
              steps=100,
              name=None,
              hooks=None,
              exporters=None,
              start_delay_secs=120,
              throttle_secs=600):
    """Creates a validated `EvalSpec` instance.

    Args:
      input_fn: A function that constructs the input data for evaluation.
        See @{$get_started/premade_estimators#create_input_functions} for more
        information. The function should construct and return one of
        the following:
          * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
            tuple (features, labels) with same constraints as below.
          * A tuple (features, labels): Where features is a `Tensor` or a
            dictionary of string feature name to `Tensor` and labels is a
            `Tensor` or a dictionary of string label name to `Tensor`.
            
      steps: Int. Positive number of steps for which to evaluate model. If
        `None`, evaluates until `input_fn` raises an end-of-input exception.
        See `Estimator.evaluate` for details.
      name: String. Name of the evaluation if user needs to run multiple
        evaluations on different data sets. Metrics for different evaluations
        are saved in separate folders, and appear separately in tensorboard.
      hooks: Iterable of `tf.train.SessionRunHook` objects to run
        during evaluation.
      exporters: Iterable of `Exporter`s, or a single one, or `None`.
        `exporters` will be invoked after each evaluation.
      start_delay_secs: Int. Start evaluating after waiting for this many
        seconds.
      throttle_secs: Int. Do not re-evaluate unless the last evaluation was
        started at least this many seconds ago. Of course, evaluation does not
        occur if no new checkpoints are available, hence, this is the minimum.

    Returns:
      A validated `EvalSpec` object.

    Raises:
      ValueError: If any of the input arguments is invalid.
      TypeError: If any of the arguments is not of the expected type.
    """
    # Validate input_fn.
    _validate_input_fn(input_fn)

    # Validate steps.
    if steps is not None and steps <= 0:
      raise ValueError('Must specify steps > 0, given: {}'.format(steps))

    # Validate name.
    if name is not None and not isinstance(name, six.string_types):
      raise TypeError('`name` must be string, given: {}'.format(name))

    # Validate hooks.
    hooks = _validate_hooks(hooks)

    # Validate exporters.
    exporters = _validate_exporters(exporters)

    # Validate start_delay_secs.
    if start_delay_secs < 0:
      raise ValueError('Must specify start_delay_secs >= 0, given: {}'.format(
          start_delay_secs))

    # Validate throttle_secs.
    if throttle_secs < 0:
      raise ValueError(
          'Must specify throttle_secs >= 0, given: {}'.format(throttle_secs))

    return super(EvalSpec, cls).__new__(
        cls,
        input_fn=input_fn,
        steps=steps,
        name=name,
        hooks=hooks,
        exporters=exporters,
        start_delay_secs=start_delay_secs,
        throttle_secs=throttle_secs)


@tf_export('estimator.train_and_evaluate')
def train_and_evaluate(estimator, train_spec, eval_spec):
  """Train and evaluate the `estimator`.

  This utility function trains, evaluates, and (optionally) exports the model by
  using the given `estimator`. All training related specification is held in
  `train_spec`, including training `input_fn` and training max steps, etc. All
  evaluation and export related specification is held in `eval_spec`, including
  evaluation `input_fn`, steps, etc.

  This utility function provides consistent behavior for both local
  (non-distributed) and distributed configurations. Currently, the only
  supported distributed training configuration is between-graph replication.

  Overfitting: In order to avoid overfitting, it is recommended to set up the
  training `input_fn` to shuffle the training data properly. It is also
  recommended to train the model a little longer, say multiple epochs, before
  performing evaluation, as the input pipeline starts from scratch for each
  training. It is particularly important for local training and evaluation.

  Stop condition: In order to support both distributed and non-distributed
  configuration reliably, the only supported stop condition for model
  training is `train_spec.max_steps`. If `train_spec.max_steps` is `None`, the
  model is trained forever. *Use with care* if model stop condition is
  different. For example, assume that the model is expected to be trained with
  one epoch of training data, and the training `input_fn` is configured to throw
  `OutOfRangeError` after going through one epoch, which stops the
  `Estimator.train`. For a three-training-worker distributed configuration, each
  training worker is likely to go through the whole epoch independently. So, the
  model will be trained with three epochs of training data instead of one epoch.

  Example of local (non-distributed) training:
  ```python
  # Set up feature columns.
  categorial_feature_a = categorial_column_with_hash_bucket(...)
  categorial_feature_a_emb = embedding_column(
      categorical_column=categorial_feature_a, ...)
  ...  # other feature columns

  estimator = DNNClassifier(
      feature_columns=[categorial_feature_a_emb, ...],
      hidden_units=[1024, 512, 256])

  # Or set up the model directory
  #   estimator = DNNClassifier(
  #       config=tf.estimator.RunConfig(
  #           model_dir='/my_model', save_summary_steps=100),
  #       feature_columns=[categorial_feature_a_emb, ...],
  #       hidden_units=[1024, 512, 256])

  # Input pipeline for train and evaluate.
  def train_input_fn: # returns x, y
    # please shuffle the data.
    pass
  def eval_input_fn_eval: # returns x, y
    pass

  train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
  eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)

  tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
  ```

  Example of distributed training:

  Regarding the example of distributed training, the code above can be used
  without a change (Please do make sure that the `RunConfig.model_dir` for all
  workers is set to the same directory, i.e., a shared file system all workers
  can read and write). The only extra work to do is setting the environment
  variable `TF_CONFIG` properly for each worker correspondingly.

  Also see: https://www.tensorflow.org/deploy/distributed

  Setting environment variable depends on the platform. For example, on Linux,
  it can be done as follows (`$` is the shell prompt):
  ```
  $ TF_CONFIG='<replace_with_real_content>' python train_model.py
  ```

  For the content in `TF_CONFIG`, assume that the training cluster spec looks
  like:
  ```
  cluster = {"chief": ["host0:2222"],
             "worker": ["host1:2222", "host2:2222", "host3:2222"],
             "ps": ["host4:2222", "host5:2222"]}
  ```

  Example of `TF_CONFIG` for chief training worker (must have one and only one):
  ```
  # This should be a JSON string, which is set as environment variable. Usually
  # the cluster manager handles that.
  TF_CONFIG='{
      "cluster": {
          "chief": ["host0:2222"],
          "worker": ["host1:2222", "host2:2222", "host3:2222"],
          "ps": ["host4:2222", "host5:2222"]
      },
      "task": {"type": "chief", "index": 0}
  }'
  ```
  Note that the chief worker also does the model training job, similar to other
  non-chief training workers (see next paragraph). In addition to the model
  training, it manages some extra work, e.g., checkpoint saving and restoring,
  writing summaries, etc.

  Example of `TF_CONFIG` for non-chief training worker (optional, could be
  multiple):
  ```
  # This should be a JSON string, which is set as environment variable. Usually
  # the cluster manager handles that.
  TF_CONFIG='{
      "cluster": {
          "chief": ["host0:2222"],
          "worker": ["host1:2222", "host2:2222", "host3:2222"],
          "ps": ["host4:2222", "host5:2222"]
      },
      "task": {"type": "worker", "index": 0}
  }'
  ```
  where the `task.index` should be set as 0, 1, 2, in this example, respectively
  for non-chief training workers.

  Example of `TF_CONFIG` for parameter server, aka ps (could be multiple):
  ```
  # This should be a JSON string, which is set as environment variable. Usually
  # the cluster manager handles that.
  TF_CONFIG='{
      "cluster": {
          "chief": ["host0:2222"],
          "worker": ["host1:2222", "host2:2222", "host3:2222"],
          "ps": ["host4:2222", "host5:2222"]
      },
      "task": {"type": "ps", "index": 0}
  }'
  ```
  where the `task.index` should be set as 0 and 1, in this example, respectively
  for parameter servers.

  Example of `TF_CONFIG` for evaluator task. Evaluator is a special task that is
  not part of the training cluster. There could be only one. It is used for
  model evaluation.
  ```
  # This should be a JSON string, which is set as environment variable. Usually
  # the cluster manager handles that.
  TF_CONFIG='{
      "cluster": {
          "chief": ["host0:2222"],
          "worker": ["host1:2222", "host2:2222", "host3:2222"],
          "ps": ["host4:2222", "host5:2222"]
      },
      "task": {"type": "evaluator", "index": 0}
  }'
  ```

  Args:
    estimator: An `Estimator` instance to train and evaluate.
    train_spec: A `TrainSpec` instance to specify the training specification.
    eval_spec: A `EvalSpec` instance to specify the evaluation and export
      specification.

  Raises:
    ValueError: if environment variable `TF_CONFIG` is incorrectly set.
  """
  executor = _TrainingExecutor(
      estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)

  config = estimator.config
  if (config.task_type == run_config_lib.TaskType.EVALUATOR and
      config.task_id > 0):
    raise ValueError(
        'For distributed training, there can only be one `evaluator` task '
        '(with task id 0).  Given task id {}'.format(config.task_id))

  executor.run()


class _StopAtSecsHook(session_run_hook.SessionRunHook):
  """Stops given secs after begin is called."""

  def __init__(self, stop_after_secs):
    self._stop_after_secs = stop_after_secs
    self._start_time = None

  def begin(self):
    self._start_time = time.time()

  def after_run(self, run_context, run_values):
    del run_values
    if time.time() - self._start_time >= self._stop_after_secs:
      run_context.request_stop()


class _TrainingExecutor(object):
  """The executor to run `Estimator` training and evaluation.

  This implementation supports both distributed and non-distributed (aka local)
  training and evaluation based on the setting in `tf.estimator.RunConfig`.
  """

  def __init__(self,
               estimator,
               train_spec,
               eval_spec,
               train_hooks=None,
               continuous_eval_listener=None):
    if not isinstance(estimator, estimator_lib.Estimator):
      raise TypeError(
          '`estimator` must have type `tf.estimator.Estimator`. '
          'Got: {}'.format(type(estimator)))
    self._estimator = estimator

    if not isinstance(train_spec, TrainSpec):
      raise TypeError(
          '`train_spec` must have type `tf.estimator.TrainSpec`. '
          'Got: {}'.format(type(train_spec)))
    self._train_spec = train_spec

    if not isinstance(eval_spec, EvalSpec):
      raise TypeError(
          '`eval_spec` must have type `tf.estimator.EvalSpec`. '
          'Got: {}'.format(type(eval_spec)))
    self._eval_spec = eval_spec

    self._train_hooks = _validate_hooks(train_hooks)

    if (continuous_eval_listener and
        not isinstance(continuous_eval_listener, _ContinuousEvalListener)):
      raise TypeError('`continuous_eval_listener` must have type '
                      '`_ContinuousEvalListener`.')
    self._continuous_eval_listener = (
        continuous_eval_listener or _ContinuousEvalListener())

  @property
  def estimator(self):
    return self._estimator

  def run(self):
    """Executes the run_foo for task type `foo`.

    `_TrainingExecutor` predefines the procedure for task type 'chief',
    'worker', 'ps', and 'evaluator'. For task type `foo`, the corresponding
    procedure is `run_foo'. This `run` method invoke the procedure base on the
    `RunConfig.task_type`.

    Raises:
      ValueError: if the estimator.config is mis-configured.
    """
    config = self._estimator.config

    if (not config.cluster_spec and
        config.task_type != run_config_lib.TaskType.EVALUATOR):
      logging.info('Running training and evaluation locally (non-distributed).')
      self.run_local()
      return

    # Distributed case.
    if not config.task_type:
      # TODO(xiejw): Improve the error message about how to set the TF_CONFIG
      # correctly.
      raise ValueError(
          '`estimator.config` must have task_type set. This usually means '
          'TF_CONFIG environment is not set correctly.')

    if config.task_type == 'local':
      raise ValueError(
          '`task.type` in TF_CONFIG cannot be `local`. Leaving `cluster` and '
          '`task` properties in TF_CONFIG absent triggers train and evaluate '
          '`Estimator` locally (non-distributed).')

    # For task type foo, call executor.run_foo.
    available_tasks = [
        x for x in dir(self)
        if x.startswith('run_') and x != 'run_local' and
        callable(getattr(self, x))
    ]
    task_to_run = 'run_' + config.task_type
    if task_to_run not in available_tasks:
      raise ValueError(
          'Task type {} is not supported. Supported task types are {}'.format(
              config.task_type, [x[len('run_'):] for x in available_tasks]))
    getattr(self, task_to_run)()

  def run_chief(self):
    """Runs task chief."""
    # TODO(xiejw): To allow execution framework to add train hooks.
    return self._start_distributed_training()

  def run_worker(self):
    """Runs task (training) worker."""
    # TODO(xiejw): To allow execution framework to add train hooks.
    return self._start_distributed_training()

  def run_master(self):
    """Runs task master."""

    class NewCheckpointListener(
        basic_session_run_hooks.CheckpointSaverListener):

      def __init__(self, evaluator, eval_throttle_secs):
        self._evaluator = evaluator
        self._eval_throttle_secs = eval_throttle_secs

      def begin(self):
        self._timer = basic_session_run_hooks.SecondOrStepTimer(
            every_secs=self._eval_throttle_secs)

      def after_save(self, session, global_step_value):
        del session  # unused; required by signature.

        if self._timer.should_trigger_for_step(global_step_value):
          self._timer.update_last_triggered_step(global_step_value)
          self._evaluator.evaluate_and_export()
        else:
          logging.info('Skip the current checkpoint eval due to throttle secs '
                       '({} secs).'.format(self._eval_throttle_secs))

    # Final export signal: For any eval result with global_step >= train
    # max_steps, the evaluator will send the final export signal. There is a
    # small chance that the Estimator.train stopping logic sees a different
    # global_step value (due to global step race condition and the fact the
    # saver sees a larger value for checkpoing saving), which does not end
    # the training. When the training ends, a new checkpoint is generated, which
    # triggers the listener again. So, it could be the case the final export is
    # triggered twice.
    #
    # But here, throttle_secs will skip the next intermediate checkpoint and,
    # so, the double final export chance is very small.
    evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
                                             self._train_spec.max_steps)

    # When the underlying `Estimator` object saves a new checkpoint, we would
    # like this callback to be called so that evaluation and export can trigger.
    saving_listeners = [
        NewCheckpointListener(evaluator, self._eval_spec.throttle_secs)
    ]
    self._start_distributed_training(saving_listeners=saving_listeners)

    if not evaluator.is_final_export_triggered:
      logging.info('Training has already ended. But the last eval is skipped '
                   'due to eval throttle_secs. Now evaluating the final '
                   'checkpoint.')
      evaluator.evaluate_and_export()

  def run_evaluator(self):
    """Runs task evaluator."""
    # TODO(xiejw): To allow execution framework to add continuous eval listener.
    return self._start_continuous_evaluation()

  def run_ps(self):
    """Runs task parameter server (in training cluster spec)."""
    config = self._estimator.config
    server = self._start_std_server(config)
    server.join()

  def run_local(self):
    """Runs training and evaluation locally (non-distributed)."""

    def _should_stop_local_train(global_step):
      if self._train_spec.max_steps is None:
        return False
      if global_step >= self._train_spec.max_steps:
        return True
      return False

    if self._eval_spec.throttle_secs <= 0:
      raise ValueError('eval_spec.throttle_secs should be positive, given: {}.'
                       'It is used do determine how long each training '
                       'iteration should go when train and evaluate '
                       'locally.'.format(self._eval_spec.throttle_secs))

    stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs)
    train_hooks = (
        list(self._train_spec.hooks) + [stop_hook] + list(self._train_hooks))
    logging.info('Start train and evaluate loop. The evaluate will happen '
                 'after {} secs (eval_spec.throttle_secs) or training is '
                 'finished.'.format(self._eval_spec.throttle_secs))

    evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
                                             self._train_spec.max_steps)

    while True:
      self._estimator.train(
          input_fn=self._train_spec.input_fn,
          max_steps=self._train_spec.max_steps,
          hooks=train_hooks)

      # Final export signal: For any eval result with global_step >= train
      # max_steps, the evaluator will send the final export signal. The
      # _should_stop_local_train will then end the while True as the stopping
      # condition is satisfied (both checks use the same global_step value,
      # i.e., no race condition)
      eval_result = evaluator.evaluate_and_export()

      if eval_result.status != _EvalStatus.EVALUATED:
        #  This is unexpected; should never happen.
        #  Training should always end with a new checkpoint.
        raise RuntimeError('There was no new checkpoint after the training. '
                           'Eval status: {}'.format(eval_result.status))

      if _should_stop_local_train(
          eval_result.metrics[ops.GraphKeys.GLOBAL_STEP]):
        break

  def _start_std_server(self, config):
    """Creates, starts, and returns a server_lib.Server."""
    if (not config.cluster_spec or not config.task_type or
        config.task_id is None):
      raise RuntimeError('Could not start server; be sure to specify '
                         'cluster_spec, task_type, and task in '
                         'RunConfig or set the TF_CONFIG environment variable.')

    if not config.master:
      jobs = config.cluster_spec.jobs
      if (len(jobs) == 1 and
          len(config.cluster_spec.job_tasks(jobs[0])) == 1 and
          config.task_type in _TRAINER_JOBS):
        # For distributed training, config.master is empty if and only if it has
        # a single node in the cluster spec. In this case, we should not start
        # the server.
        logging.info('Skip starting Tensorflow server as there is only one '
                     'node in the cluster.')
        return
      else:
        raise RuntimeError(
            'Could not start server; be sure to specify master in '
            'RunConfig or set the TF_CONFIG environment variable.')

    logging.info('Start Tensorflow server.')

    if config.session_config is None:
      session_config = config_pb2.ConfigProto(log_device_placement=False)
    else:
      session_config = config_pb2.ConfigProto(
          log_device_placement=False,
          gpu_options=config.session_config.gpu_options)

    server = server_lib.Server(
        config.cluster_spec,
        job_name=config.task_type,
        task_index=config.task_id,
        config=session_config,
        start=False)
    server.start()
    return server

  def _start_distributed_training(self, saving_listeners=None):
    """Calls `Estimator` train in a distributed setting."""
    config = self._estimator.config

    # Start in-process TensorFlow server if needed. It's important to start the
    # server before we (optionally) sleep. Otherwise, the servers will wait to
    # connect to each other before starting to train.
    if not _is_google_env():
      self._start_std_server(config)

    # Delay worker to start. For asynchronous training, this usually helps model
    # to converge faster.  Chief starts the training immediately, so, worker
    # with task id x (0-based) should wait (x+1) * _DELAY_SECS_PER_WORKER.
    start_delay_secs = 0
    if config.task_type == run_config_lib.TaskType.WORKER:
      # TODO(xiejw): Replace the hard code logic (task_id + 1) with unique id in
      # training cluster.
      start_delay_secs = min(_MAX_DELAY_SECS,
                             (config.task_id + 1) * _DELAY_SECS_PER_WORKER)
    if start_delay_secs > 0:
      logging.info('Waiting %d secs before starting training.',
                   start_delay_secs)
      time.sleep(start_delay_secs)

    self._estimator.train(
        input_fn=self._train_spec.input_fn,
        max_steps=self._train_spec.max_steps,
        hooks=list(self._train_spec.hooks) + list(self._train_hooks),
        saving_listeners=saving_listeners)

  def _start_continuous_evaluation(self):
    """Repeatedly calls `Estimator` evaluate and export until training ends."""
    start_delay_secs = self._eval_spec.start_delay_secs
    if start_delay_secs:
      logging.info('Waiting %f secs before starting eval.', start_delay_secs)
      time.sleep(start_delay_secs)

    latest_eval_result = None
    evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
                                             self._train_spec.max_steps)

    should_early_stop = False
    while not should_early_stop:
      if (latest_eval_result and
          latest_eval_result.status == _EvalStatus.EVALUATED):
        global_step = latest_eval_result.metrics.get(ops.GraphKeys.GLOBAL_STEP)
        if (global_step and self._train_spec.max_steps and
            global_step >= self._train_spec.max_steps):
          logging.info(
              'Exiting evaluation, global_step=%s >= train max_steps=%s',
              global_step, self._train_spec.max_steps)
          return

      latest_eval_result, should_early_stop = self._execute_evaluator_once(
          evaluator, self._continuous_eval_listener,
          self._eval_spec.throttle_secs)

  def _execute_evaluator_once(self, evaluator, continuous_eval_listener,
                              throttle_secs):
    """Executes the `evaluator`."""
    start = time.time()

    eval_result = None
    should_early_stop = False

    if not continuous_eval_listener.before_eval():
      logging.info('Exiting evaluation, as requested by '
                   '_ContinuousEvalListener.before_eval.')
      should_early_stop = True
      return (eval_result, should_early_stop)

    # Final export signal: For any eval result with global_step >= train
    # max_steps, the evaluator will send the final export signal. The next
    # iteration of while loop will end the continuous eval as the stopping
    # condition is satisfied (both checks use the same global_step value,
    # i.e., no race condition)
    eval_result = evaluator.evaluate_and_export()

    if not self._continuous_eval_listener.after_eval(eval_result):
      logging.info('Exiting evaluation, as requested by '
                   '_ContinuousEvalListener.after_eval.')
      should_early_stop = True
      return (eval_result, should_early_stop)

    # Throttle if necessary.
    elapsed_time = time.time() - start
    difference = throttle_secs - elapsed_time
    if difference > 0:
      logging.info('Waiting %f secs before starting next eval run.', difference)
      time.sleep(difference)

    return (eval_result, should_early_stop)

  class _Evaluator(object):
    """A helper class to call `Estimator.evaluate` and export model."""

    def __init__(self, estimator, eval_spec, max_training_steps):
      self._estimator = estimator
      self._eval_spec = eval_spec
      self._is_final_export_triggered = False
      self._previous_ckpt_path = None
      self._last_warning_time = 0
      self._max_training_steps = max_training_steps

    @property
    def is_final_export_triggered(self):
      return self._is_final_export_triggered

    def evaluate_and_export(self):
      """Evaluate and (maybe) export the current model.

      Returns:
        An `EvalResult` instance.

      Raises:
        RuntimeError: for any unexpected internal error.
        TypeError: if evaluation result has wrong type.
      """
      latest_ckpt_path = self._estimator.latest_checkpoint()
      if not latest_ckpt_path:
        self._log_err_msg('Estimator is not trained yet. Will start an '
                          'evaluation when a checkpoint is ready.')
        return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT)

      if latest_ckpt_path == self._previous_ckpt_path:
        self._log_err_msg(
            'No new checkpoint ready for evaluation. Skip the current '
            'evaluation pass as evaluation results are expected to be same '
            'for the same checkpoint.')
        return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT)

      metrics = self._estimator.evaluate(
          input_fn=self._eval_spec.input_fn,
          steps=self._eval_spec.steps,
          name=self._eval_spec.name,
          checkpoint_path=latest_ckpt_path,
          hooks=self._eval_spec.hooks)

      # _EvalResult validates the metrics.
      eval_result = _EvalResult(
          status=_EvalStatus.EVALUATED,
          metrics=metrics,
          checkpoint_path=latest_ckpt_path)

      is_the_final_export = (
          eval_result.metrics[ops.GraphKeys.GLOBAL_STEP] >=
          self._max_training_steps if self._max_training_steps else False)
      self._export_eval_result(eval_result, is_the_final_export)

      if is_the_final_export:
        logging.debug('Calling exporter with the `is_the_final_export=True`.')
        self._is_final_export_triggered = True

      self._last_warning_time = 0
      self._previous_ckpt_path = latest_ckpt_path
      return eval_result

    def _log_err_msg(self, message):
      """Prints warning `message` every 10 mins."""
      current_time = time.time()
      if current_time - self._last_warning_time > 600:
        logging.warning(message)
        self._last_warning_time = current_time

    def _export_eval_result(self, eval_result, is_the_final_export):
      """Export `eval_result` according to exporters in `EvalSpec`."""
      export_dir_base = os.path.join(
          compat.as_str_any(self._estimator.model_dir),
          compat.as_str_any('export'))

      for exporter in self._eval_spec.exporters:
        exporter.export(
            estimator=self._estimator,
            export_path=os.path.join(
                compat.as_str_any(export_dir_base),
                compat.as_str_any(exporter.name)),
            checkpoint_path=eval_result.checkpoint_path,
            eval_result=eval_result.metrics,
            is_the_final_export=is_the_final_export)


class _EvalStatus(object):
  """The status of an evaluation event.

  For local training and evaluation, the status can only be `EVALUATED` as
  `Estimator.train` always generates a new checkpoint.

  For distributed training and evaluation, a separated evaluator keeps looking
  for new checkpoint. So, multiple situations might occur:

  - EVALUATED: A new checkpoint is found since last evaluation.
      `Estimator.evaluate` will be invoked.
  - MISSING_CHECKPOINT: No checkpoint can be found. Typically, this means
      the trainer has not yet produced any checkpoint.
  - NO_NEW_CHECKPOINT: No new checkpoint can be found since last evaluation.
      Typically, this means the trainer has not yet produced any new checkpoint.
  """

  EVALUATED = 'evaluated'
  MISSING_CHECKPOINT = 'missing checkpoint'
  NO_NEW_CHECKPOINT = 'no new checkpoint'


class _EvalResult(
    collections.namedtuple('EvalResult',
                           ['status', 'metrics', 'checkpoint_path'])):
  """_EvalResult holds the result of an evaluation event."""

  def __new__(cls, status, metrics=None, checkpoint_path=None):
    """Creates a validated `_EvalResult`.

    Args:
      status: See `_EvalStatus`.
      metrics: The evaluation results returned by `Estimator.evaluate`. Only set
          if status is `EVALUATED`.
      checkpoint_path: The corresponding checkpoint path for the `metrics`. Only
          set if status is `EVALUATED`.
    Returns:
      A validated `_EvalResult` object.

    Raises:
      ValueError: If validation fails.
      TypeError: If any of the arguments is not the expected type.
    """

    if status != _EvalStatus.EVALUATED:
      if metrics:
        raise ValueError(
            'metrics must be `None` if status is not {}; got status {},'
            ' metrics {}'.format(_EvalStatus.EVALUATED, status, metrics))
      if checkpoint_path:
        raise ValueError(
            'checkpoint must be `None` if status is not {}; got status {}, '
            'checkpoint_path {}'.format(_EvalStatus.EVALUATED, status,
                                        checkpoint_path))
      return super(_EvalResult, cls).__new__(cls, status, metrics,
                                             checkpoint_path)

    # Now, evaluated case.
    assert status == _EvalStatus.EVALUATED

    # Validates metrics.
    if not metrics:
      raise ValueError(
          'Internal error: `Estimator.evaluate` should never return empty '
          'metrics.')
    if not isinstance(metrics, dict):
      raise TypeError(
          '`Estimator.evaluate` should return dict. Given {}.'.format(
              type(metrics)))
    if ops.GraphKeys.GLOBAL_STEP not in metrics:
      raise ValueError(
          'Internal error: `Estimator.evaluate` result should have '
          '`global_step` in result. Given {}'.format(metrics))

    # Validates checkpoint_path.
    if not checkpoint_path:
      raise ValueError(
          'Internal error: `checkpoint_path` should never be empty.')

    return super(_EvalResult, cls).__new__(cls, status, metrics,
                                           checkpoint_path)


class _ContinuousEvalListener(object):
  """Interface for listeners that take action before or after evaluation."""

  def before_eval(self):
    """Called before evaluation.

    Returns:
      `False` if you want to skip the current evaluation and early stop the
      continuous evaluation; `True` otherwise.
    """
    return True

  def after_eval(self, eval_result):
    """Called after the evaluation is executed.

    Args:
      eval_result: An `_EvalResult` instance.

    Returns:
      False if you want to early stop continuous evaluation; `True` otherwise.
    """
    del eval_result
    return True