aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/estimator.py
blob: e6d82f0db739f0d8cf02cebd97561cab5963d100 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Base Estimator class."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import copy
import os
import tempfile

import numpy as np
import six

from google.protobuf import message
from tensorflow.core.framework import summary_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export import export as export_helpers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.summary import summary
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import device_setter
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import evaluation
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
from tensorflow.python.training import training
from tensorflow.python.training import training_util
from tensorflow.python.training import warm_starting_util
from tensorflow.python.util import compat
from tensorflow.python.util import compat_internal
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import estimator_export


_VALID_MODEL_FN_ARGS = set(
    ['features', 'labels', 'mode', 'params', 'self', 'config'])


@estimator_export('estimator.Estimator')
class Estimator(object):
  """Estimator class to train and evaluate TensorFlow models.

  The `Estimator` object wraps a model which is specified by a `model_fn`,
  which, given inputs and a number of other parameters, returns the ops
  necessary to perform training, evaluation, or predictions.

  All outputs (checkpoints, event files, etc.) are written to `model_dir`, or a
  subdirectory thereof. If `model_dir` is not set, a temporary directory is
  used.

  The `config` argument can be passed `tf.estimator.RunConfig` object containing
  information about the execution environment. It is passed on to the
  `model_fn`, if the `model_fn` has a parameter named "config" (and input
  functions in the same manner). If the `config` parameter is not passed, it is
  instantiated by the `Estimator`. Not passing config means that defaults useful
  for local execution are used. `Estimator` makes config available to the model
  (for instance, to allow specialization based on the number of workers
  available), and also uses some of its fields to control internals, especially
  regarding checkpointing.

  The `params` argument contains hyperparameters. It is passed to the
  `model_fn`, if the `model_fn` has a parameter named "params", and to the input
  functions in the same manner. `Estimator` only passes params along, it does
  not inspect it. The structure of `params` is therefore entirely up to the
  developer.

  None of `Estimator`'s methods can be overridden in subclasses (its
  constructor enforces this). Subclasses should use `model_fn` to configure
  the base class, and may add methods implementing specialized functionality.

  @compatibility(eager)
  Calling methods of `Estimator` will work while eager execution is enabled.
  However, the `model_fn` and `input_fn` is not executed eagerly, `Estimator`
  will switch to graph model before calling all user-provided functions (incl.
  hooks), so their code has to be compatible with graph mode execution. Note
  that `input_fn` code using `tf.data` generally works in both graph and eager
  modes.
  @end_compatibility
  """

  def __init__(self, model_fn, model_dir=None, config=None, params=None,
               warm_start_from=None):
    """Constructs an `Estimator` instance.

    See [estimators](https://tensorflow.org/guide/estimators) for more
    information.

    To warm-start an `Estimator`:

    ```python
    estimator = tf.estimator.DNNClassifier(
        feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
        hidden_units=[1024, 512, 256],
        warm_start_from="/path/to/checkpoint/dir")
    ```

    For more details on warm-start configuration, see
    `tf.estimator.WarmStartSettings`.

    Args:
      model_fn: Model function. Follows the signature:

        * Args:

          * `features`: This is the first item returned from the `input_fn`
                 passed to `train`, `evaluate`, and `predict`. This should be a
                 single `tf.Tensor` or `dict` of same.
          * `labels`: This is the second item returned from the `input_fn`
                 passed to `train`, `evaluate`, and `predict`. This should be a
                 single `tf.Tensor` or `dict` of same (for multi-head models).
                 If mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will
                 be passed. If the `model_fn`'s signature does not accept
                 `mode`, the `model_fn` must still be able to handle
                 `labels=None`.
          * `mode`: Optional. Specifies if this training, evaluation or
                 prediction. See `tf.estimator.ModeKeys`.
          * `params`: Optional `dict` of hyperparameters.  Will receive what
                 is passed to Estimator in `params` parameter. This allows
                 to configure Estimators from hyper parameter tuning.
          * `config`: Optional `estimator.RunConfig` object. Will receive what
                 is passed to Estimator as its `config` parameter, or a default
                 value. Allows setting up things in your `model_fn` based on
                 configuration such as `num_ps_replicas`, or `model_dir`.

        * Returns:
          `tf.estimator.EstimatorSpec`

      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into an estimator to
        continue training a previously saved model. If `PathLike` object, the
        path will be resolved. If `None`, the model_dir in `config` will be used
        if set. If both are set, they must be same. If both are `None`, a
        temporary directory will be used.
      config: `estimator.RunConfig` configuration object.
      params: `dict` of hyper parameters that will be passed into `model_fn`.
              Keys are names of parameters, values are basic python types.
      warm_start_from: Optional string filepath to a checkpoint or SavedModel to
                       warm-start from, or a `tf.estimator.WarmStartSettings`
                       object to fully configure warm-starting.  If the string
                       filepath is provided instead of a
                       `tf.estimator.WarmStartSettings`, then all variables are
                       warm-started, and it is assumed that vocabularies
                       and `tf.Tensor` names are unchanged.

    Raises:
      ValueError: parameters of `model_fn` don't match `params`.
      ValueError: if this is called via a subclass and if that class overrides
        a member of `Estimator`.
    """
    Estimator._assert_members_are_not_overridden(self)

    self._config = maybe_overwrite_model_dir_and_session_config(config,
                                                                model_dir)

    # The distribute field contains an instance of DistributionStrategy.
    self._train_distribution = self._config.train_distribute
    self._eval_distribution = self._config.eval_distribute
    # Model directory.
    self._model_dir = self._config.model_dir
    self._session_config = self._config.session_config
    logging.info('Using config: %s', str(vars(self._config)))

    self._device_fn = (
        self._config.device_fn or _get_replica_device_setter(self._config))

    if model_fn is None:
      raise ValueError('model_fn must be provided to Estimator.')
    _verify_model_fn_args(model_fn, params)
    self._model_fn = model_fn
    self._params = copy.deepcopy(params or {})

    # pylint: disable=protected-access
    self._warm_start_settings = _get_default_warm_start_settings(
        warm_start_from)
    # pylint: enable=protected-access

  @property
  def model_dir(self):
    return self._model_dir

  @property
  def config(self):
    return copy.deepcopy(self._config)

  @property
  def params(self):
    return copy.deepcopy(self._params)

  @property
  def model_fn(self):
    """Returns the `model_fn` which is bound to `self.params`.

    Returns:
      The `model_fn` with following signature:
        `def model_fn(features, labels, mode, config)`
    """

    def public_model_fn(features, labels, mode, config):
      return self._call_model_fn(features, labels, mode, config)

    return public_model_fn

  # TODO(ispir): support a list of names
  def get_variable_value(self, name):
    """Returns value of the variable given by name.

    Args:
      name: string or a list of string, name of the tensor.

    Returns:
      Numpy array - value of the tensor.

    Raises:
      ValueError: If the `Estimator` has not produced a checkpoint yet.
    """
    _check_checkpoint_available(self.model_dir)
    with context.graph_mode():
      return training.load_variable(self.model_dir, name)

  def get_variable_names(self):
    """Returns list of all variable names in this model.

    Returns:
      List of names.

    Raises:
      ValueError: If the `Estimator` has not produced a checkpoint yet.
    """
    _check_checkpoint_available(self.model_dir)
    with context.graph_mode():
      return [name for name, _ in training.list_variables(self.model_dir)]

  def latest_checkpoint(self):
    """Finds the filename of the latest saved checkpoint file in `model_dir`.

    Returns:
      The full path to the latest checkpoint or `None` if no checkpoint was
      found.
    """
    with context.graph_mode():
      return checkpoint_management.latest_checkpoint(self.model_dir)

  def train(self,
            input_fn,
            hooks=None,
            steps=None,
            max_steps=None,
            saving_listeners=None):
    """Trains a model given training data `input_fn`.

    Args:
      input_fn: A function that provides input data for training as minibatches.
        See [Premade Estimators](
        https://tensorflow.org/guide/premade_estimators#create_input_functions)
        for more information. The function should construct and return one of
        the following:  * A
        `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
        `(features, labels)` with same constraints as below. * A tuple
        `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary
        of string feature name to `Tensor` and `labels` is a `Tensor` or a
        dictionary of string label name to `Tensor`. Both `features` and
        `labels` are consumed by `model_fn`. They should satisfy the expectation
        of `model_fn` from inputs.
      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
        callbacks inside the training loop.
      steps: Number of steps for which to train the model. If `None`, train
        forever or train until `input_fn` generates the `tf.errors.OutOfRange`
        error or `StopIteration` exception. `steps` works incrementally. If you
        call two times `train(steps=10)` then training occurs in total 20 steps.
        If `OutOfRange` or `StopIteration` occurs in the middle, training stops
        before 20 steps. If you don't want to have incremental behavior please
        set `max_steps` instead. If set, `max_steps` must be `None`.
      max_steps: Number of total steps for which to train model. If `None`,
        train forever or train until `input_fn` generates the
        `tf.errors.OutOfRange` error or `StopIteration` exception. If set,
        `steps` must be `None`. If `OutOfRange` or `StopIteration` occurs in the
        middle, training stops before `max_steps` steps. Two calls to
        `train(steps=100)` means 200 training iterations. On the other hand, two
        calls to `train(max_steps=100)` means that the second call will not do
        any iteration since first call did all 100 steps.
      saving_listeners: list of `CheckpointSaverListener` objects. Used for
        callbacks that run immediately before or after checkpoint savings.

    Returns:
      `self`, for chaining.

    Raises:
      ValueError: If both `steps` and `max_steps` are not `None`.
      ValueError: If either `steps` or `max_steps <= 0`.
    """
    if self.config.task_type in (run_config.TaskType.EVALUATOR,
                                 run_config.TaskType.PS):
      raise ValueError(
          'Train has been called wrong configuration. Please use '
          'tf.estimator.train_and_evaluate which calls proper API according '
          'to given configuration. Current configuration: {}.'.format(
              self.config))

    with context.graph_mode():
      if (steps is not None) and (max_steps is not None):
        raise ValueError('Can not provide both steps and max_steps.')
      if steps is not None and steps <= 0:
        raise ValueError('Must specify steps > 0, given: {}'.format(steps))
      if max_steps is not None and max_steps <= 0:
        raise ValueError(
            'Must specify max_steps > 0, given: {}'.format(max_steps))

      if max_steps is not None:
        start_step = _load_global_step_from_checkpoint_dir(self._model_dir)
        if max_steps <= start_step:
          logging.info('Skipping training since max_steps has already saved.')
          return self

      hooks = _check_hooks_type(hooks)
      hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))

      saving_listeners = _check_listeners_type(saving_listeners)
      loss = self._train_model(input_fn, hooks, saving_listeners)
      logging.info('Loss for final step: %s.', loss)
      return self

  def _convert_train_steps_to_hooks(self, steps, max_steps):
    """Create hooks to run correct number of steps in training.

    Args:
      steps: number of steps to run during training.
      max_steps: maximum number of steps to be run during training. It'll be
        the maximum number of steps the model will train to after restoring
        from checkpoint even across multiple estimator.train calls.

    Returns:
      List of hooks to be passed to the estimator.
    """
    if steps is not None or max_steps is not None:
      if self._train_distribution:
        steps_per_run = getattr(self._train_distribution, 'steps_per_run', 1)
        if steps_per_run > 1:
          return [basic_session_run_hooks._MultiStepStopAtStepHook(  # pylint: disable=protected-access
              steps, max_steps, steps_per_run)]
      return [training.StopAtStepHook(steps, max_steps)]
    else:
      return []

  def eval_dir(self, name=None):
    """Shows the directory name where evaluation metrics are dumped.

    Args:
      name: Name of the evaluation if user needs to run multiple evaluations on
        different data sets, such as on training data vs test data. Metrics for
        different evaluations are saved in separate folders, and appear
        separately in tensorboard.

    Returns:
      A string which is the path of directory contains evaluation metrics.
    """
    return os.path.join(self._model_dir, 'eval' if not name else
                        'eval_' + name)

  def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
               name=None):
    """Evaluates the model given evaluation data `input_fn`.

    For each step, calls `input_fn`, which returns one batch of data.
    Evaluates until:
    - `steps` batches are processed, or
    - `input_fn` raises an end-of-input exception (`tf.errors.OutOfRangeError`
    or
    `StopIteration`).

    Args:
      input_fn: A function that constructs the input data for evaluation. See
        [Premade Estimators](
        https://tensorflow.org/guide/premade#create_input_functions)
        for more information. The
        function should construct and return one of the following:  * A
        `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
        `(features, labels)` with same constraints as below. * A tuple
        `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary
        of string feature name to `Tensor` and `labels` is a `Tensor` or a
        dictionary of string label name to `Tensor`. Both `features` and
        `labels` are consumed by `model_fn`. They should satisfy the expectation
        of `model_fn` from inputs.
      steps: Number of steps for which to evaluate model. If `None`, evaluates
        until `input_fn` raises an end-of-input exception.
      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
        callbacks inside the evaluation call.
      checkpoint_path: Path of a specific checkpoint to evaluate. If `None`, the
        latest checkpoint in `model_dir` is used.  If there are no checkpoints
        in `model_dir`, evaluation is run with newly initialized `Variables`
        instead of ones restored from checkpoint.
      name: Name of the evaluation if user needs to run multiple evaluations on
        different data sets, such as on training data vs test data. Metrics for
        different evaluations are saved in separate folders, and appear
        separately in tensorboard.

    Returns:
      A dict containing the evaluation metrics specified in `model_fn` keyed by
      name, as well as an entry `global_step` which contains the value of the
      global step for which this evaluation was performed. For canned
      estimators, the dict contains the `loss` (mean loss per mini-batch) and
      the `average_loss` (mean loss per sample). Canned classifiers also return
      the `accuracy`. Canned regressors also return the `label/mean` and the
      `prediction/mean`.

    Raises:
      ValueError: If `steps <= 0`.
      ValueError: If no model has been trained, namely `model_dir`, or the
        given `checkpoint_path` is empty.
    """
    with context.graph_mode():
      hooks = _check_hooks_type(hooks)
      hooks.extend(self._convert_eval_steps_to_hooks(steps))

      # Check that model has been trained (if nothing has been set explicitly).
      if not checkpoint_path:
        latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
        if not latest_path:
          logging.info('Could not find trained model in model_dir: {}, running '
                       'initialization to evaluate.'.format(self._model_dir))
        checkpoint_path = latest_path

      def _evaluate():
        (scaffold, update_op, eval_dict, all_hooks) = (
            self._evaluate_build_graph(input_fn, hooks, checkpoint_path))
        return self._evaluate_run(
            checkpoint_path=checkpoint_path,
            scaffold=scaffold,
            update_op=update_op,
            eval_dict=eval_dict,
            all_hooks=all_hooks,
            output_dir=self.eval_dir(name))

      with ops.Graph().as_default():
        if self._eval_distribution:
          # We want to create the iterations variable outside the distribution
          # scope as that is just stored on the host and mainly used to drive
          # the loop and doesn't need to be a Mirrored/Device variable.
          training.get_or_create_steps_per_run_variable()
          with self._eval_distribution.scope():
            return _evaluate()
        else:
          return _evaluate()

  def _convert_eval_steps_to_hooks(self, steps):
    """Create hooks to run correct number of steps in evaluation.

    Args:
      steps: number of steps to run during evaluation.

    Raises:
      ValueError: if steps is less than or equal to zero.

    Returns:
      List of hooks to be passed to the estimator.
    """
    if steps is None:
      return []

    if steps <= 0:
      raise ValueError('Must specify steps > 0, given: {}'.format(steps))

    # The hooks are declared as private in evaluation.py discourage the use
    # by other libraries or open source users. This should be the only usage
    # of the estimator evaluation hooks.
    if self._eval_distribution:
      steps_per_run = getattr(self._eval_distribution, 'steps_per_run', 1)
      if steps_per_run > 1:
        return [evaluation._MultiStepStopAfterNEvalsHook(  # pylint: disable=protected-access
            num_evals=steps, steps_per_run=steps_per_run)]
    return [evaluation._StopAfterNEvalsHook(num_evals=steps)]  # pylint: disable=protected-access

  def predict(self,
              input_fn,
              predict_keys=None,
              hooks=None,
              checkpoint_path=None,
              yield_single_examples=True):
    """Yields predictions for given features.

    Please note that interleaving two predict outputs does not work. See:
    [issue/20506](
    https://github.com/tensorflow/tensorflow/issues/20506#issuecomment-422208517)

    Args:
      input_fn: A function that constructs the features. Prediction continues
        until `input_fn` raises an end-of-input exception
        (`tf.errors.OutOfRangeError` or `StopIteration`).
        See [Premade Estimators](
        https://tensorflow.org/guide/premade_estimators#create_input_functions)
        for more information. The function should construct and return one of
        the following:

          * A `tf.data.Dataset` object: Outputs of `Dataset` object must have
            same constraints as below.
          * features: A `tf.Tensor` or a dictionary of string feature name to
            `Tensor`. features are consumed by `model_fn`. They should satisfy
            the expectation of `model_fn` from inputs.
          * A tuple, in which case the first item is extracted as features.

      predict_keys: list of `str`, name of the keys to predict. It is used if
        the `tf.estimator.EstimatorSpec.predictions` is a `dict`. If
        `predict_keys` is used then rest of the predictions will be filtered
        from the dictionary. If `None`, returns all.
      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
        callbacks inside the prediction call.
      checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
        latest checkpoint in `model_dir` is used.  If there are no checkpoints
        in `model_dir`, prediction is run with newly initialized `Variables`
        instead of ones restored from checkpoint.
      yield_single_examples: If `False`, yields the whole batch as returned by
        the `model_fn` instead of decomposing the batch into individual
        elements. This is useful if `model_fn` returns some tensors whose first
        dimension is not equal to the batch size.

    Yields:
      Evaluated values of `predictions` tensors.

    Raises:
      ValueError: Could not find a trained model in `model_dir`.
      ValueError: If batch length of predictions is not the same and
        `yield_single_examples` is `True`.
      ValueError: If there is a conflict between `predict_keys` and
        `predictions`. For example if `predict_keys` is not `None` but
        `tf.estimator.EstimatorSpec.predictions` is not a `dict`.
    """
    with context.graph_mode():
      hooks = _check_hooks_type(hooks)
      # Check that model has been trained.
      if not checkpoint_path:
        checkpoint_path = checkpoint_management.latest_checkpoint(
            self._model_dir)
      if not checkpoint_path:
        logging.info('Could not find trained model in model_dir: {}, running '
                     'initialization to predict.'.format(self._model_dir))
      with ops.Graph().as_default() as g:
        random_seed.set_random_seed(self._config.tf_random_seed)
        self._create_and_assert_global_step(g)
        features, input_hooks = self._get_features_from_input_fn(
            input_fn, model_fn_lib.ModeKeys.PREDICT)
        estimator_spec = self._call_model_fn(
            features, None, model_fn_lib.ModeKeys.PREDICT, self.config)

        # Call to warm_start has to be after model_fn is called.
        self._maybe_warm_start(checkpoint_path)

        predictions = self._extract_keys(
            estimator_spec.predictions, predict_keys)
        all_hooks = list(input_hooks)
        all_hooks.extend(hooks)
        all_hooks.extend(list(estimator_spec.prediction_hooks or []))
        with training.MonitoredSession(
            session_creator=training.ChiefSessionCreator(
                checkpoint_filename_with_path=checkpoint_path,
                master=self._config.master,
                scaffold=estimator_spec.scaffold,
                config=self._session_config),
            hooks=all_hooks) as mon_sess:
          while not mon_sess.should_stop():
            preds_evaluated = mon_sess.run(predictions)
            if not yield_single_examples:
              yield preds_evaluated
            elif not isinstance(predictions, dict):
              for pred in preds_evaluated:
                yield pred
            else:
              for i in range(self._extract_batch_length(preds_evaluated)):
                yield {
                    key: value[i]
                    for key, value in six.iteritems(preds_evaluated)
                }

  def _assert_members_are_not_overridden(self):
    """Asserts members of `Estimator` are not overridden."""
    # TPUEstimator is special cased (owned by TF).
    if self.__class__.__name__ == 'TPUEstimator':
      return

    allowed_overrides = set([
        '_create_and_assert_global_step',
        '_tf_api_names', '_tf_api_names_v1', '_estimator_api_names',
        '_estimator_api_names_v1', '_estimator_api_constants',
        '_estimator_api_constants_v1',
    ])
    estimator_members = set([m for m in Estimator.__dict__.keys()
                             if not m.startswith('__')])
    subclass_members = set(self.__class__.__dict__.keys())
    common_members = estimator_members & subclass_members - allowed_overrides
    overridden_members = [
        m for m in common_members
        if Estimator.__dict__[m] != self.__class__.__dict__[m]]
    if overridden_members:
      raise ValueError(
          'Subclasses of Estimator cannot override members of Estimator. '
          '{} does override {}'.format(self.__class__, overridden_members))

  def export_savedmodel(
      self, export_dir_base, serving_input_receiver_fn,
      assets_extra=None,
      as_text=False,
      checkpoint_path=None,
      strip_default_attrs=False):
    # pylint: disable=line-too-long,g-doc-args,g-doc-return-or-yield
    """Exports inference graph as a `SavedModel` into the given dir.

    Note that `export_to_savedmodel` will be renamed to `export_saved_model`
    in TensorFlow 2.0. At that time, `export_to_savedmodel` without the
    additional underscore will be available only through tf.compat.v1.

    Please see `tf.estimator.Estimator.export_saved_model` for more information.

    There is one additional arg versus the new method:
      strip_default_attrs: This parameter is going away in TF 2.0, and
        the new behavior will automatically strip all default attributes.
        Boolean. If `True`, default-valued attributes will be
        removed from the `NodeDef`s. For a detailed guide, see [Stripping
        Default-Valued Attributes](
        https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
    """
    # pylint: enable=line-too-long,g-doc-args,g-doc-return-or-yield
    return self._export_saved_model_for_mode(
        export_dir_base,
        serving_input_receiver_fn,
        assets_extra=assets_extra,
        as_text=as_text,
        checkpoint_path=checkpoint_path,
        strip_default_attrs=strip_default_attrs,
        mode=model_fn_lib.ModeKeys.PREDICT)

  def export_saved_model(
      self, export_dir_base, serving_input_receiver_fn,
      assets_extra=None,
      as_text=False,
      checkpoint_path=None):
    # pylint: disable=line-too-long
    """Exports inference graph as a `SavedModel` into the given dir.

    For a detailed guide, see
    [Using SavedModel with Estimators](https://tensorflow.org/guide/saved_model#using_savedmodel_with_estimators).

    This method builds a new graph by first calling the
    `serving_input_receiver_fn` to obtain feature `Tensor`s, and then calling
    this `Estimator`'s `model_fn` to generate the model graph based on those
    features. It restores the given checkpoint (or, lacking that, the most
    recent checkpoint) into this graph in a fresh session.  Finally it creates
    a timestamped export directory below the given `export_dir_base`, and writes
    a `SavedModel` into it containing a single `tf.MetaGraphDef` saved from this
    session.

    The exported `MetaGraphDef` will provide one `SignatureDef` for each
    element of the `export_outputs` dict returned from the `model_fn`, named
    using
    the same keys.  One of these keys is always
    `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`,
    indicating which
    signature will be served when a serving request does not specify one.
    For each signature, the outputs are provided by the corresponding
    `tf.estimator.export.ExportOutput`s, and the inputs are always the input
    receivers provided by
    the `serving_input_receiver_fn`.

    Extra assets may be written into the `SavedModel` via the `assets_extra`
    argument.  This should be a dict, where each key gives a destination path
    (including the filename) relative to the assets.extra directory.  The
    corresponding value gives the full path of the source file to be copied.
    For example, the simple case of copying a single file without renaming it
    is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.

    Args:
      export_dir_base: A string containing a directory in which to create
        timestamped subdirectories containing exported `SavedModel`s.
      serving_input_receiver_fn: A function that takes no argument and returns a
        `tf.estimator.export.ServingInputReceiver` or
        `tf.estimator.export.TensorServingInputReceiver`.
      assets_extra: A dict specifying how to populate the assets.extra directory
        within the exported `SavedModel`, or `None` if no extra assets are
        needed.
      as_text: whether to write the `SavedModel` proto in text format.
      checkpoint_path: The checkpoint path to export.  If `None` (the default),
        the most recent checkpoint found within the model directory is chosen.

    Returns:
      The string path to the exported directory.

    Raises:
      ValueError: if no `serving_input_receiver_fn` is provided, no
      `export_outputs` are provided, or no checkpoint can be found.
    """
    # pylint: enable=line-too-long
    # TODO(b/111442174): `export_to_savedmodel` will be renamed to
    # `export_saved_model` in TensorFlow 2.0. This function is a wrapper
    # while staging the new version; do not add any logic here.
    return self.export_savedmodel(
        export_dir_base,
        serving_input_receiver_fn,
        assets_extra=assets_extra,
        as_text=as_text,
        checkpoint_path=checkpoint_path,
        strip_default_attrs=True)

  def _export_saved_model_for_mode(
      self, export_dir_base, input_receiver_fn,
      assets_extra=None,
      as_text=False,
      checkpoint_path=None,
      strip_default_attrs=False,
      mode=model_fn_lib.ModeKeys.PREDICT):
    # pylint: disable=line-too-long
    """Exports a single train/eval/predict graph as a `SavedModel`.

    This method is a wrapper for `_export_all_saved_models`, and wraps a raw
    `input_receiver_fn` in a dictionary to pass in to that function.
    See `_export_all_saved_models` for full docs.

    See `tf.contrib.estimator.export_saved_model_for_mode` for the currently
    exposed version of this function.

    Args:
      export_dir_base: A string containing a directory in which to create
        timestamped subdirectories containing exported `SavedModel`s.
      input_receiver_fn: a function that takes no argument and returns the
        appropriate subclass of `InputReceiver`.
      assets_extra: A dict specifying how to populate the assets.extra directory
        within the exported `SavedModel`, or `None` if no extra assets are
        needed.
      as_text: whether to write the `SavedModel` proto in text format.
      checkpoint_path: The checkpoint path to export.  If `None` (the default),
        the most recent checkpoint found within the model directory is chosen.
      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
        removed from the `NodeDef`s. For a detailed guide, see [Stripping
        Default-Valued
        Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
      mode: `tf.estimator.ModeKeys` value indicating with mode will be exported.

    Returns:
      The string path to the exported directory.

    Raises:
      ValueError: if `input_receiver_fn` is `None`, no `export_outputs`
        are provided, or no checkpoint can be found.
    """
    # pylint: enable=line-too-long
    if not input_receiver_fn:
      raise ValueError('An input_receiver_fn must be defined.')

    input_receiver_fn_map = {mode: input_receiver_fn}

    return self._export_all_saved_models(
        export_dir_base,
        input_receiver_fn_map,
        assets_extra=assets_extra,
        as_text=as_text,
        checkpoint_path=checkpoint_path,
        strip_default_attrs=strip_default_attrs)

  def _export_all_saved_models(
      self, export_dir_base, input_receiver_fn_map,
      assets_extra=None,
      as_text=False,
      checkpoint_path=None,
      strip_default_attrs=False):
    # pylint: disable=line-too-long
    """Exports a `SavedModel` containing `tf.MetaGraphDefs` for each requested mode.

    See `tf.contrib.estimator.export_all_saved_models` for the currently
    exposed version of this function.

    For each mode passed in via the `input_receiver_fn_map`,
    this method builds a new graph by calling the `input_receiver_fn` to obtain
    feature and label `Tensor`s. Next, this method calls the `Estimator`'s
    `model_fn` in the passed mode to generate the model graph based on
    those features and labels, and restores the given checkpoint
    (or, lacking that, the most recent checkpoint) into the graph.
    Only one of the modes is used for saving variables to the `SavedModel`
    (order of preference: `tf.estimator.ModeKeys.TRAIN`,
    `tf.estimator.ModeKeys.EVAL`, then
    `tf.estimator.ModeKeys.PREDICT`), such that up to three
    `tf.MetaGraphDefs` are saved with a single set of variables in a single
    `SavedModel` directory.

    For the variables and `tf.MetaGraphDefs`, a timestamped export directory
    below
    `export_dir_base`, and writes a `SavedModel` into it containing
    the `tf.MetaGraphDef` for the given mode and its associated signatures.

    For prediction, the exported `MetaGraphDef` will provide one `SignatureDef`
    for each element of the `export_outputs` dict returned from the `model_fn`,
    named using the same keys.  One of these keys is always
    `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`,
    indicating which
    signature will be served when a serving request does not specify one.
    For each signature, the outputs are provided by the corresponding
    `tf.estimator.export.ExportOutput`s, and the inputs are always the input
    receivers provided by
    the `serving_input_receiver_fn`.

    For training and evaluation, the `train_op` is stored in an extra
    collection,
    and loss, metrics, and predictions are included in a `SignatureDef` for the
    mode in question.

    Extra assets may be written into the `SavedModel` via the `assets_extra`
    argument.  This should be a dict, where each key gives a destination path
    (including the filename) relative to the assets.extra directory.  The
    corresponding value gives the full path of the source file to be copied.
    For example, the simple case of copying a single file without renaming it
    is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.

    Args:
      export_dir_base: A string containing a directory in which to create
        timestamped subdirectories containing exported `SavedModel`s.
      input_receiver_fn_map: dict of `tf.estimator.ModeKeys` to
        `input_receiver_fn` mappings, where the `input_receiver_fn` is a
        function that takes no arguments and returns the appropriate subclass of
        `InputReceiver`.
      assets_extra: A dict specifying how to populate the assets.extra directory
        within the exported `SavedModel`, or `None` if no extra assets are
        needed.
      as_text: whether to write the `SavedModel` proto in text format.
      checkpoint_path: The checkpoint path to export.  If `None` (the default),
        the most recent checkpoint found within the model directory is chosen.
      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
        removed from the `NodeDef`s. For a detailed guide, see [Stripping
        Default-Valued
        Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).

    Returns:
      A dict of `tf.estimator.ModeKeys` value to string path for each exported
      directory.

    Raises:
      ValueError: if any `input_receiver_fn` is `None`, no `export_outputs`
        are provided, or no checkpoint can be found.
    """
    # pylint: enable=line-too-long
    # TODO(b/65561022): Consider allowing multiple input_receiver_fns per mode.
    with context.graph_mode():
      if not checkpoint_path:
        # Locate the latest checkpoint
        checkpoint_path = checkpoint_management.latest_checkpoint(
            self._model_dir)
      if not checkpoint_path:
        raise ValueError("Couldn't find trained model at %s." % self._model_dir)

      export_dir = export_helpers.get_timestamped_export_dir(export_dir_base)
      temp_export_dir = export_helpers.get_temp_export_dir(export_dir)

      builder = saved_model_builder.SavedModelBuilder(temp_export_dir)

      save_variables = True
      # Note that the order in which we run here matters, as the first
      # mode we pass through will be used to save the variables. We run TRAIN
      # first, as that is also the mode used for checkpoints, and therefore
      # we are not likely to have vars in PREDICT that are not in the checkpoint
      # created by TRAIN.
      if input_receiver_fn_map.get(model_fn_lib.ModeKeys.TRAIN):
        self._add_meta_graph_for_mode(
            builder, input_receiver_fn_map, checkpoint_path,
            strip_default_attrs, save_variables,
            mode=model_fn_lib.ModeKeys.TRAIN)
        save_variables = False
      if input_receiver_fn_map.get(model_fn_lib.ModeKeys.EVAL):
        self._add_meta_graph_for_mode(
            builder, input_receiver_fn_map, checkpoint_path,
            strip_default_attrs, save_variables,
            mode=model_fn_lib.ModeKeys.EVAL)
        save_variables = False
      if input_receiver_fn_map.get(model_fn_lib.ModeKeys.PREDICT):
        self._add_meta_graph_for_mode(
            builder, input_receiver_fn_map, checkpoint_path,
            strip_default_attrs, save_variables,
            mode=model_fn_lib.ModeKeys.PREDICT)
        save_variables = False

      if save_variables:
        raise ValueError('No valid modes for exporting found. Got {}.'.format(
            input_receiver_fn_map.keys()))

      builder.save(as_text)

      # Add the extra assets
      if assets_extra:
        assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir),
                                         compat.as_bytes('assets.extra'))
        for dest_relative, source in assets_extra.items():
          dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
                                       compat.as_bytes(dest_relative))
          dest_path = os.path.dirname(dest_absolute)
          gfile.MakeDirs(dest_path)
          gfile.Copy(source, dest_absolute)

      gfile.Rename(temp_export_dir, export_dir)
      return export_dir

  def _add_meta_graph_for_mode(self,
                               builder,
                               input_receiver_fn_map,
                               checkpoint_path,
                               strip_default_attrs,
                               save_variables=True,
                               mode=model_fn_lib.ModeKeys.PREDICT,
                               export_tags=None,
                               check_variables=True):
    # pylint: disable=line-too-long
    """Loads variables and adds them along with a `tf.MetaGraphDef` for saving.

    Args:
      builder: instance of `tf.saved_modle.builder.SavedModelBuilder` that will
        be used for saving.
      input_receiver_fn_map: dict of `tf.estimator.ModeKeys` to
        `input_receiver_fn` mappings, where the `input_receiver_fn` is a
        function that takes no argument and returns the appropriate subclass of
        `InputReceiver`.
      checkpoint_path: The checkpoint path to export.  If `None` (the default),
        the most recent checkpoint found within the model directory is chosen.
      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
        removed from the `NodeDef`s. For a detailed guide, see [Stripping
        Default-Valued
        Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
      save_variables: bool, whether variables should be saved. If `False`, just
        the `tf.MetaGraphDef` will be saved. Note that `save_variables` should
        only be `True` for the first call to this function, and the
        `SavedModelBuilder` will raise an error if that is not the case.
      mode: `tf.estimator.ModeKeys` value indicating which mode will be
        exported.
      export_tags: The set of tags with which to save `tf.MetaGraphDef`. If
        `None`, a default set will be selected to matched the passed mode.
      check_variables: bool, whether to check the checkpoint has all variables.

    Raises:
      ValueError: if `save_variables` is `True` and `check_variable` is `False`.
    """
    # pylint: enable=line-too-long
    if export_tags is None:
      export_tags = model_fn_lib.EXPORT_TAG_MAP[mode]
    input_receiver_fn = input_receiver_fn_map[mode]

    with ops.Graph().as_default() as g:
      self._create_and_assert_global_step(g)
      random_seed.set_random_seed(self._config.tf_random_seed)

      input_receiver = input_receiver_fn()

      # Call the model_fn and collect the export_outputs.
      estimator_spec = self._call_model_fn(
          features=input_receiver.features,
          labels=getattr(input_receiver, 'labels', None),
          mode=mode,
          config=self.config)

      export_outputs = model_fn_lib.export_outputs_for_mode(
          mode=estimator_spec.mode,
          serving_export_outputs=estimator_spec.export_outputs,
          predictions=estimator_spec.predictions,
          loss=estimator_spec.loss,
          metrics=estimator_spec.eval_metric_ops)

      # Build the SignatureDefs from receivers and all outputs
      signature_def_map = export_helpers.build_all_signature_defs(
          input_receiver.receiver_tensors,
          export_outputs,
          getattr(input_receiver, 'receiver_tensors_alternatives', None),
          serving_only=(mode == model_fn_lib.ModeKeys.PREDICT))

      with tf_session.Session(config=self._session_config) as session:

        if estimator_spec.scaffold.local_init_op is not None:
          local_init_op = estimator_spec.scaffold.local_init_op
        else:
          local_init_op = monitored_session.Scaffold.default_local_init_op()

        # This saver will be used both for restoring variables now,
        # and in saving out the metagraph below. This ensures that any
        # Custom Savers stored with the Scaffold are passed through to the
        # SavedModel for restore later.
        graph_saver = estimator_spec.scaffold.saver or saver.Saver(sharded=True)

        if save_variables and not check_variables:
          raise ValueError('If `save_variables` is `True, `check_variables`'
                           'must not be `False`.')
        if check_variables:
          try:
            graph_saver.restore(session, checkpoint_path)
          except errors.NotFoundError as e:
            msg = ('Could not load all requested variables from checkpoint. '
                   'Please make sure your model_fn does not expect variables '
                   'that were not saved in the checkpoint.\n\n'
                   'Encountered error with mode `{}` while restoring '
                   'checkpoint from: `{}`. Full Traceback:\n\n{}').format(
                       mode, checkpoint_path, e)
            raise ValueError(msg)

        # We add the train op explicitly for now, so that we don't have to
        # change the Builder public interface. Note that this is a no-op
        # for prediction, where train_op is None.
        builder._add_train_op(estimator_spec.train_op)  # pylint: disable=protected-access

        meta_graph_kwargs = dict(
            tags=export_tags,
            signature_def_map=signature_def_map,
            assets_collection=ops.get_collection(
                ops.GraphKeys.ASSET_FILEPATHS),
            strip_default_attrs=strip_default_attrs,
            legacy_init_op=local_init_op,
            saver=graph_saver)

        if save_variables:
          builder.add_meta_graph_and_variables(
              session, **meta_graph_kwargs)
        else:
          builder.add_meta_graph(**meta_graph_kwargs)

  def _get_features_from_input_fn(self, input_fn, mode):
    """Extracts the `features` from return values of `input_fn`."""
    result = self._call_input_fn(input_fn, mode)
    result, _, hooks = estimator_util.parse_input_fn_result(result)
    self._validate_features_in_predict_input(result)
    return result, hooks

  def _validate_features_in_predict_input(self, result):
    if not _has_dataset_or_queue_runner(result):
      logging.warning('Input graph does not use tf.data.Dataset or contain a '
                      'QueueRunner. That means predict yields forever. '
                      'This is probably a mistake.')

  def _get_iterator_from_input_fn(self, input_fn, mode, distribution=None):
    if distribution is not None:
      result = distribution.distribute_dataset(
          lambda: self._call_input_fn(input_fn, mode))
    else:
      result = self._call_input_fn(input_fn, mode)

    iterator = result.make_initializable_iterator()
    input_hooks = [estimator_util._DatasetInitializerHook(iterator)]  # pylint: disable=protected-access
    return iterator, input_hooks

  def _get_features_and_labels_from_input_fn(self, input_fn, mode):
    """Extracts the `features` and labels from return values of `input_fn`."""
    return estimator_util.parse_input_fn_result(
        self._call_input_fn(input_fn, mode))

  def _extract_batch_length(self, preds_evaluated):
    """Extracts batch length of predictions."""
    batch_length = None
    for key, value in six.iteritems(preds_evaluated):
      batch_length = batch_length or value.shape[0]
      if value.shape[0] != batch_length:
        raise ValueError('Batch length of predictions should be same. %s has '
                         'different batch length than others.' % key)
    return batch_length

  def _extract_keys(self, predictions, predict_keys):
    """Extracts `predict_keys` from `predictions`."""
    if not predict_keys:
      return predictions
    if not isinstance(predictions, dict):
      raise ValueError(
          'predict_keys argument is not valid in case of non-dict predictions.')
    existing_keys = predictions.keys()
    predictions = {
        key: value
        for key, value in six.iteritems(predictions) if key in predict_keys
    }
    if not predictions:
      raise ValueError('Expected to run at least one output from %s, '
                       'provided %s.' % (existing_keys, predict_keys))
    return predictions

  def _create_global_step(self, graph):
    """Creates the global step tensor in graph.

    The global step tensor must be an integer type with name 'global_step' and
    be added to the collection `tf.GraphKeys.GLOBAL_STEP`.

    Args:
      graph: The graph in which to create the global step tensor.

    Returns:
      The global step `tf.Tensor`.
    """
    return training.create_global_step(graph)

  def _create_and_assert_global_step(self, graph):
    """Creates and asserts properties of the global step.

    Args:
      graph: The graph in which to create the global step tensor.

    Returns:
      The global step `tf.Tensor`.
    """
    step = self._create_global_step(graph)
    assert step == training.get_global_step()
    assert step.dtype.is_integer
    return step

  def _call_input_fn(self, input_fn, mode):
    """Calls the input function.

    Args:
      input_fn: The input function.
      mode: `tf.estimator.ModeKeys`

    Returns:
      The return value of the passed `input_fn`, which should be one of:

        * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
            tuple `(features, labels)` with same constraints as below.
        * A tuple `(features, labels)`: Where `features` is a `Tensor` or a
          dictionary of string feature name to `Tensor` and `labels` is a
          `Tensor` or a dictionary of string label name to `Tensor`. Both
          `features` and `labels` are consumed by `model_fn`. They should
          satisfy the expectation of `model_fn` from inputs.

    Raises:
      ValueError: if `input_fn` takes invalid arguments.
    """
    input_fn_args = function_utils.fn_args(input_fn)
    kwargs = {}
    if 'mode' in input_fn_args:
      kwargs['mode'] = mode
    if 'params' in input_fn_args:
      kwargs['params'] = self.params
    if 'config' in input_fn_args:
      kwargs['config'] = self.config
    with ops.device('/cpu:0'):
      return input_fn(**kwargs)

  def _call_model_fn(self, features, labels, mode, config):
    """Calls model function.

    Args:
      features: features dict.
      labels: labels dict.
      mode: `tf.estimator.ModeKeys`
      config: `tf.estimator.RunConfig`

    Returns:
      An `tf.estimator.EstimatorSpec` object.

    Raises:
      ValueError: if `model_fn` returns invalid objects.
    """
    model_fn_args = function_utils.fn_args(self._model_fn)
    kwargs = {}
    if 'labels' in model_fn_args:
      kwargs['labels'] = labels
    else:
      if labels is not None:
        raise ValueError(
            'model_fn does not take labels, but input_fn returns labels.')
    if 'mode' in model_fn_args:
      kwargs['mode'] = mode
    if 'params' in model_fn_args:
      kwargs['params'] = self.params
    if 'config' in model_fn_args:
      kwargs['config'] = config

    logging.info('Calling model_fn.')
    model_fn_results = self._model_fn(features=features, **kwargs)
    logging.info('Done calling model_fn.')

    if not isinstance(model_fn_results, model_fn_lib.EstimatorSpec):
      raise ValueError('model_fn should return an EstimatorSpec.')

    return model_fn_results

  def _train_model(self, input_fn, hooks, saving_listeners):
    if self._train_distribution:
      return self._train_model_distributed(input_fn, hooks, saving_listeners)
    else:
      return self._train_model_default(input_fn, hooks, saving_listeners)

  def _train_model_default(self, input_fn, hooks, saving_listeners):
    """Initiate training with `input_fn`, without `DistributionStrategies`.

    Args:
      input_fn: A function that provides input data for training as minibatches.
      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
        callbacks inside the training loop.
      saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used
        for callbacks that run immediately before or after checkpoint savings.

    Returns:
      Loss from training
    """
    worker_hooks = []
    with ops.Graph().as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = self._create_and_assert_global_step(g)

      # Skip creating a read variable if _create_and_assert_global_step
      # returns None (e.g. tf.contrib.estimator.SavedModelEstimator).
      if global_step_tensor is not None:
        training_util._get_or_create_global_step_read(g)  # pylint: disable=protected-access

      features, labels, input_hooks = (
          self._get_features_and_labels_from_input_fn(
              input_fn, model_fn_lib.ModeKeys.TRAIN))
      worker_hooks.extend(input_hooks)
      estimator_spec = self._call_model_fn(
          features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
      global_step_tensor = training_util.get_global_step(g)
      return self._train_with_estimator_spec(estimator_spec, worker_hooks,
                                             hooks, global_step_tensor,
                                             saving_listeners)

  def _train_model_distributed(self, input_fn, hooks, saving_listeners):
    """Initiate training with `input_fn`, using `DistributionStrategies`.

    Args:
      input_fn: A function that provides input data for training as minibatches.
      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
        callbacks inside the training loop.
      saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used
        for callbacks that run immediately before or after checkpoint savings.

    Returns:
      Loss from training
    """
    self._train_distribution.configure(self._session_config)

    # TODO(sourabhbajaj): Remove this hack once we migrate the other strategies
    # to use the new API
    is_tpu_strategy = (
        self._train_distribution.__class__.__name__ == 'TPUStrategy')

    worker_hooks = []
    with ops.Graph().as_default() as g:
      # We want to create the iterations variable outside the distribution scope
      # as that is just stored on the host and mainly used to drive the loop
      # and doesn't need to be a Mirrored/Device variable.
      if is_tpu_strategy:
        steps_per_run_variable = training.get_or_create_steps_per_run_variable()
      with self._train_distribution.scope():
        random_seed.set_random_seed(self._config.tf_random_seed)
        iterator, input_hooks = self._get_iterator_from_input_fn(
            input_fn, model_fn_lib.ModeKeys.TRAIN, self._train_distribution)
        worker_hooks.extend(input_hooks)
        global_step_tensor = self._create_and_assert_global_step(g)
        # we want to add to the global collection in the main thread not the
        # tower threads.
        ops.add_to_collection(
            training_util.GLOBAL_STEP_READ_KEY,
            self._train_distribution.read_var(global_step_tensor))

        if is_tpu_strategy:
          # Create a step_fn from the train_op of grouped_estimator_spec
          def step_fn(ctx, features, labels=None):
            """A single step that is passed to run_on_dataset."""
            estimator_spec = self._train_distribution.call_for_each_tower(
                self._call_model_fn,
                features,
                labels,
                model_fn_lib.ModeKeys.TRAIN,
                self.config)
            ctx.set_last_step_output(
                name='loss',
                output=estimator_spec.loss,
                aggregation=distribute_lib.get_loss_reduction())
            ctx.set_non_tensor_output(
                name='estimator_spec', output=estimator_spec)
            return estimator_spec.train_op

          # Create new train_op post graph rewrites
          initial_training_loss = constant_op.constant(1e7)
          ctx = self._train_distribution.run_steps_on_dataset(
              step_fn, iterator, iterations=steps_per_run_variable,
              initial_loop_values={'loss': initial_training_loss})
          distributed_train_op = ctx.run_op
          loss = ctx.last_step_outputs['loss']
          grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
        else:
          features, labels = estimator_util.parse_iterator_result(
              iterator.get_next())
          grouped_estimator_spec = self._train_distribution.call_for_each_tower(
              self._call_model_fn,
              features,
              labels,  # although this will be None it seems
              model_fn_lib.ModeKeys.TRAIN,
              self.config)
          loss = self._train_distribution.unwrap(
              self._train_distribution.reduce(
                  distribute_lib.get_loss_reduction(),
                  grouped_estimator_spec.loss,
                  destinations='/device:CPU:0'))[0]
          distributed_train_op = grouped_estimator_spec.train_op

        scaffold = _combine_distributed_scaffold(
            grouped_estimator_spec.scaffold, self._train_distribution)

        # TODO(yuefengz): add a test for unwrapping per_device_hooks.
        def get_hooks_from_the_first_device(per_device_hooks):
          return [
              self._distribution.unwrap(per_device_hook)[0]
              for per_device_hook in per_device_hooks
          ]

        training_hooks = get_hooks_from_the_first_device(
            grouped_estimator_spec.training_hooks)
        training_chief_hooks = get_hooks_from_the_first_device(
            grouped_estimator_spec.training_chief_hooks)
        worker_hooks.append(
            estimator_util.StrategyInitFinalizeHook(
                self._train_distribution.initialize,
                self._train_distribution.finalize))

        estimator_spec = model_fn_lib.EstimatorSpec(
            mode=grouped_estimator_spec.mode,
            loss=loss,
            train_op=self._train_distribution.group(distributed_train_op),
            training_hooks=training_hooks,
            training_chief_hooks=training_chief_hooks,
            scaffold=scaffold)
        return self._train_with_estimator_spec(estimator_spec, worker_hooks,
                                               hooks, global_step_tensor,
                                               saving_listeners)

  def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
                                 global_step_tensor, saving_listeners):
    """Train a model with the given Estimator Spec."""
    if self._warm_start_settings:
      logging.info('Warm-starting with WarmStartSettings: %s' %
                   (self._warm_start_settings,))
      warm_starting_util.warm_start(*self._warm_start_settings)
    # Check if the user created a loss summary, and add one if they didn't.
    # We assume here that the summary is called 'loss'. If it is not, we will
    # make another one with the name 'loss' to ensure it shows up in the right
    # graph in TensorBoard.
    if not any([x.op.name == 'loss'
                for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
      summary.scalar('loss', estimator_spec.loss)
    ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
    worker_hooks.extend(hooks)
    worker_hooks.append(
        training.NanTensorHook(estimator_spec.loss)
    )
    if self._config.log_step_count_steps is not None:
      worker_hooks.append(
          training.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=self._config.log_step_count_steps)
      )
    worker_hooks.extend(estimator_spec.training_hooks)

    if not (estimator_spec.scaffold.saver or
            ops.get_collection(ops.GraphKeys.SAVERS)):
      ops.add_to_collection(
          ops.GraphKeys.SAVERS,
          training.Saver(
              sharded=True,
              max_to_keep=self._config.keep_checkpoint_max,
              keep_checkpoint_every_n_hours=(
                  self._config.keep_checkpoint_every_n_hours),
              defer_build=True,
              save_relative_paths=True))

    chief_hooks = []
    all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
    saver_hooks = [
        h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
    if (self._config.save_checkpoints_secs or
        self._config.save_checkpoints_steps):
      if not saver_hooks:
        chief_hooks = [
            training.CheckpointSaverHook(
                self._model_dir,
                save_secs=self._config.save_checkpoints_secs,
                save_steps=self._config.save_checkpoints_steps,
                scaffold=estimator_spec.scaffold)
        ]
        saver_hooks = [chief_hooks[0]]
    if saving_listeners:
      if not saver_hooks:
        raise ValueError(
            'There should be a CheckpointSaverHook to use saving_listeners. '
            'Please set one of the RunConfig.save_checkpoints_steps or '
            'RunConfig.save_checkpoints_secs.')
      else:
        # It is expected to have one CheckpointSaverHook. If multiple, we pick
        # up the first one to add listener.
        saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access

    # Add summary hooks to worker 0 if we are running with a master, to ensure
    # that summaries are written at correct intervals even with long-running
    # evaluations.
    save_summary_steps = self._config.save_summary_steps
    log_step_count_steps = self._config.log_step_count_steps
    if (self._config.cluster_spec and self._config.cluster_spec.jobs and
        (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
      # Update config values to prevent the default hooks from being created on
      # the master or other workers.
      save_summary_steps = 0
      log_step_count_steps = None

      if (self._config.task_type == run_config.TaskType.WORKER and
          self._config.task_id == 0):
        if (self._config.save_summary_steps and
            self._config.save_summary_steps > 0):
          worker_hooks.append(
              training.SummarySaverHook(
                  save_steps=self._config.save_summary_steps,
                  output_dir=self._config.model_dir,
                  scaffold=estimator_spec.scaffold))

        if (self._config.log_step_count_steps and
            self._config.log_step_count_steps > 0):
          worker_hooks.append(
              training.StepCounterHook(
                  every_n_steps=self._config.log_step_count_steps,
                  output_dir=self._config.model_dir))

    with training.MonitoredTrainingSession(
        master=self._config.master,
        is_chief=self._config.is_chief,
        checkpoint_dir=self._model_dir,
        scaffold=estimator_spec.scaffold,
        hooks=worker_hooks,
        chief_only_hooks=(
            tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
        save_checkpoint_secs=0,  # Saving is handled by a hook.
        save_summaries_steps=save_summary_steps,
        config=self._session_config,
        log_step_count_steps=log_step_count_steps) as mon_sess:
      loss = None
      while not mon_sess.should_stop():
        _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
    return loss

  def _evaluate_build_graph(self, input_fn, hooks=None, checkpoint_path=None):
    """Builds the graph and related hooks to run evaluation."""
    random_seed.set_random_seed(self._config.tf_random_seed)
    self._create_and_assert_global_step(ops.get_default_graph())

    if self._eval_distribution:
      (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict) = (
          self._call_model_fn_eval_distributed(input_fn, self.config))
    else:
      (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict) = (
          self._call_model_fn_eval(input_fn, self.config))

    global_step_tensor = training_util.get_global_step(ops.get_default_graph())
    # Call to warm_start has to be after model_fn is called.
    self._maybe_warm_start(checkpoint_path)

    if ops.GraphKeys.GLOBAL_STEP in eval_dict:
      raise ValueError(
          'Metric with name `global_step` is not allowed, because Estimator '
          'already defines a default metric with the same name.')
    eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step_tensor

    all_hooks = list(input_hooks)
    all_hooks.extend(hooks)
    all_hooks.extend(list(evaluation_hooks or []))
    # New local variables have been added, so update the estimator spec's
    # local init op if it was defined.
    if scaffold and scaffold.local_init_op:
      # Ensure that eval step has been created before updating local init op.
      evaluation._get_or_create_eval_step()  # pylint: disable=protected-access

      scaffold = monitored_session.Scaffold(
          local_init_op=control_flow_ops.group(
              scaffold.local_init_op,
              monitored_session.Scaffold.default_local_init_op()),
          copy_from_scaffold=scaffold
      )

    return scaffold, update_op, eval_dict, all_hooks

  def _call_model_fn_eval(self, input_fn, config):
    """Call model_fn for evaluation and handle return values."""
    features, labels, input_hooks = self._get_features_and_labels_from_input_fn(
        input_fn, model_fn_lib.ModeKeys.EVAL)

    estimator_spec = self._call_model_fn(
        features, labels, model_fn_lib.ModeKeys.EVAL, config)
    eval_metric_ops = _verify_and_create_loss_metric(
        estimator_spec.eval_metric_ops, estimator_spec.loss)
    update_op, eval_dict = _extract_metric_update_ops(eval_metric_ops)
    return (estimator_spec.scaffold, estimator_spec.evaluation_hooks,
            input_hooks, update_op, eval_dict)

  def _call_model_fn_eval_distributed(self, input_fn, config):
    """Call model_fn in distribution mode and handle return values."""

    iterator, input_hooks = self._get_iterator_from_input_fn(
        input_fn, model_fn_lib.ModeKeys.EVAL, self._eval_distribution)

    is_tpu_strategy = (
        self._eval_distribution.__class__.__name__ == 'TPUStrategy')

    if is_tpu_strategy:
      steps_per_run_variable = training.get_or_create_steps_per_run_variable()
      def step_fn(ctx, features, labels=None):
        """Runs one step of the eval computation and captures outputs."""
        estimator_spec = self._eval_distribution.call_for_each_tower(
            self._call_model_fn, features, labels, model_fn_lib.ModeKeys.EVAL,
            config)
        eval_metric_ops = _verify_and_create_loss_metric(
            estimator_spec.eval_metric_ops, estimator_spec.loss,
            self._eval_distribution)
        update_op, eval_dict = _extract_metric_update_ops(
            eval_metric_ops, self._eval_distribution)
        ctx.set_non_tensor_output(name='estimator_spec', output=estimator_spec)
        ctx.set_non_tensor_output(name='eval_dict', output=eval_dict)
        return update_op

      # TODO(priyag): Fix eval step hook to account for steps_per_run.
      ctx = self._eval_distribution.run_steps_on_dataset(
          step_fn, iterator, iterations=steps_per_run_variable)
      update_op = ctx.run_op
      eval_dict = ctx.non_tensor_outputs['eval_dict']
      grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
    else:
      features, labels = estimator_util.parse_iterator_result(
          iterator.get_next())
      grouped_estimator_spec = self._eval_distribution.call_for_each_tower(
          self._call_model_fn, features, labels,
          model_fn_lib.ModeKeys.EVAL, config)
      eval_metric_ops = _verify_and_create_loss_metric(
          grouped_estimator_spec.eval_metric_ops, grouped_estimator_spec.loss,
          self._eval_distribution)
      update_op, eval_dict = _extract_metric_update_ops(
          eval_metric_ops, self._eval_distribution)

    scaffold = _combine_distributed_scaffold(
        grouped_estimator_spec.scaffold, self._eval_distribution)
    evaluation_hooks = self._eval_distribution.unwrap(
        grouped_estimator_spec.evaluation_hooks)[0]
    evaluation_hooks = evaluation_hooks + (
        estimator_util.StrategyInitFinalizeHook(
            self._eval_distribution.initialize,
            self._eval_distribution.finalize),)

    return (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict)

  def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict,
                    all_hooks, output_dir):
    """Run evaluation."""
    eval_results = evaluation._evaluate_once(  # pylint: disable=protected-access
        checkpoint_path=checkpoint_path,
        master=self._config.evaluation_master,
        scaffold=scaffold,
        eval_ops=update_op,
        final_ops=eval_dict,
        hooks=all_hooks,
        config=self._session_config)

    current_global_step = eval_results[ops.GraphKeys.GLOBAL_STEP]

    _write_dict_to_summary(
        output_dir=output_dir,
        dictionary=eval_results,
        current_global_step=current_global_step)

    if checkpoint_path:
      _write_checkpoint_path_to_summary(
          output_dir=output_dir,
          checkpoint_path=checkpoint_path,
          current_global_step=current_global_step)

    return eval_results

  def _maybe_warm_start(self, checkpoint_path):
    if not checkpoint_path and self._warm_start_settings:
      logging.info('Warm-starting with WarmStartSettings: %s' %
                   (self._warm_start_settings,))
      warm_starting_util.warm_start(*self._warm_start_settings)


def _verify_and_create_loss_metric(eval_metric_ops, loss, distribution=None):
  """Creates a metric for loss and throws an error if one already exists."""
  if model_fn_lib.LOSS_METRIC_KEY in eval_metric_ops:
    raise ValueError(
        'Metric with name "%s" is not allowed, because Estimator ' %
        (model_fn_lib.LOSS_METRIC_KEY) +
        'already defines a default metric with the same name.')

  if distribution is None:
    loss_metric = metrics_lib.mean(loss)
  else:
    loss_metric = distribution.call_for_each_tower(
        metrics_lib.mean, loss)
  eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric
  return eval_metric_ops


def maybe_overwrite_model_dir_and_session_config(config, model_dir):
  """Overwrite estimator config by `model_dir` and `session_config` if needed.

  Args:
    config: Original estimator config.
    model_dir: Estimator model checkpoint directory.

  Returns:
    Overwritten estimator config.

  Raises:
    ValueError: Model directory inconsistent between `model_dir` and `config`.
  """

  if config is None:
    config = run_config.RunConfig()
    logging.info('Using default config.')
  if not isinstance(config, run_config.RunConfig):
    raise ValueError(
        'config must be an instance of `RunConfig`, but provided %s.' % config)

  if config.session_config is None:
    session_config = run_config.get_default_session_config()
    config = run_config.RunConfig.replace(config, session_config=session_config)

  model_dir = compat_internal.path_to_str(model_dir)
  if model_dir is not None:
    if (getattr(config, 'model_dir', None) is not None and
        config.model_dir != model_dir):
      raise ValueError(
          "`model_dir` are set both in constructor and `RunConfig`, but with "
          "different values. In constructor: '{}', in `RunConfig`: "
          "'{}' ".format(model_dir, config.model_dir))
  if model_dir:
    config = run_config.RunConfig.replace(config, model_dir=model_dir)
  elif getattr(config, 'model_dir', None) is None:
    model_dir = tempfile.mkdtemp()
    logging.warning('Using temporary folder as model directory: %s', model_dir)
    config = run_config.RunConfig.replace(config, model_dir=model_dir)

  return config


def create_per_tower_ready_for_local_init_op(scaffold):
  """Create a `tf.train.Scaffold.ready_for_local_init_op` inside a tower."""
  if scaffold.ready_for_local_init_op:
    return scaffold.ready_for_local_init_op

  def default_ready_for_local_init_op():
    return variables.report_uninitialized_variables(
        variables.global_variables())

  return monitored_session.Scaffold.get_or_default(
      'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
      default_ready_for_local_init_op)


def _combine_distributed_scaffold(grouped_scaffold, distribution):
  """Combines scaffold(s) returned from `distribution.call_for_each_tower`."""

  # TODO(anjalisridhar): Figure out how to resolve the following scaffold
  # parameters: init_feed_dict, init_fn.
  scaffold_list = distribution.unwrap(grouped_scaffold)
  init_feed_dict = [
      s.init_feed_dict
      for s in scaffold_list
      if s.init_feed_dict is not None
  ]
  if init_feed_dict:
    init_feed_dict = distribution.group(init_feed_dict)
  else:
    init_feed_dict = None

  init_fn = [s.init_fn for s in scaffold_list if s.init_fn is not None]
  if init_fn:
    init_fn = distribution.group(init_fn)
  else:
    init_fn = None

  init_op = [s.init_op for s in scaffold_list if s.init_op is not None]
  if init_op:
    init_op = distribution.group(init_op)
  else:
    init_op = None

  def _unwrap_and_concat(value):
    value = nest.flatten(distribution.unwrap(value))
    if len(value) != 1:
      return array_ops.concat(value, 0)
    return value[0]

  ready_op = distribution.call_for_each_tower(
      lambda scaffold: scaffold.ready_op, grouped_scaffold)
  if ready_op is not None:
    ready_op = _unwrap_and_concat(ready_op)

  ready_for_local_init_op = distribution.call_for_each_tower(
      create_per_tower_ready_for_local_init_op, grouped_scaffold)
  if ready_for_local_init_op is not None:
    ready_for_local_init_op = _unwrap_and_concat(ready_for_local_init_op)
  else:
    ready_for_local_init_op = None

  local_init_op = [
      s.local_init_op
      for s in scaffold_list
      if s.local_init_op is not None
  ]
  if local_init_op:
    local_init_op = distribution.group(local_init_op)
  else:
    local_init_op = None

  summary_op = [
      s.summary_op for s in scaffold_list if s.summary_op is not None
  ]
  if summary_op:
    summary_op = distribution.group(summary_op)
  else:
    summary_op = None

  scaffold = monitored_session.Scaffold(
      init_op=init_op,
      ready_op=ready_op,
      ready_for_local_init_op=ready_for_local_init_op,
      local_init_op=local_init_op,
      summary_op=summary_op,
      init_feed_dict=init_feed_dict,
      init_fn=init_fn)
  return scaffold


def _check_checkpoint_available(model_dir):
  latest_path = checkpoint_management.latest_checkpoint(model_dir)
  if not latest_path:
    raise ValueError(
        'Could not find trained model in model_dir: {}.'.format(model_dir))


def _check_hooks_type(hooks):
  """Returns hooks if all are `SessionRunHook`, raises TypeError otherwise."""
  hooks = list(hooks or [])
  for h in hooks:
    if not isinstance(h, training.SessionRunHook):
      raise TypeError('Hooks must be a SessionRunHook, given: {}'.format(h))
  return hooks


def _check_listeners_type(saving_listeners):
  """Check listeners type."""
  listeners = list(saving_listeners or [])
  for l in listeners:
    if not isinstance(l, training.CheckpointSaverListener):
      raise TypeError(
          'saving_listeners must be a list of CheckpointSaverListener, '
          'given: {}'.format(l))
  return listeners


def _get_replica_device_setter(config):
  """Creates a replica device setter if required as a default `device_fn`.

  `Estimator` uses `tf.train.ReplicaDeviceSetter` as a default device placer. It
  sets the
  distributed related arguments such as number of `ps_replicas` based on given
  `config`.

  Args:
    config: A `tf.estimator.RunConfig` instance.

  Returns:
    A replica device setter, or `None`.
  """
  if config.task_type:
    worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)
  else:
    worker_device = '/job:worker'

  if config.num_ps_replicas > 0:
    return training.replica_device_setter(
        ps_tasks=config.num_ps_replicas,
        worker_device=worker_device,
        merge_devices=True,
        ps_ops=list(device_setter.STANDARD_PS_OPS),
        cluster=config.cluster_spec)
  else:
    return None


def _verify_model_fn_args(model_fn, params):
  """Verifies `model_fn` arguments."""
  args = set(function_utils.fn_args(model_fn))
  if 'features' not in args:
    raise ValueError('model_fn (%s) must include features argument.' % model_fn)
  if params is not None and 'params' not in args:
    raise ValueError('model_fn (%s) does not include params argument, '
                     'but params (%s) is passed to Estimator.' % (model_fn,
                                                                  params))
  if params is None and 'params' in args:
    logging.warning('Estimator\'s model_fn (%s) includes params '
                    'argument, but params are not passed to Estimator.',
                    model_fn)
  non_valid_args = list(args - _VALID_MODEL_FN_ARGS)
  if non_valid_args:
    raise ValueError('model_fn (%s) has following not expected args: %s' %
                     (model_fn, non_valid_args))


def _load_global_step_from_checkpoint_dir(checkpoint_dir):
  try:
    checkpoint_reader = training.NewCheckpointReader(
        training.latest_checkpoint(checkpoint_dir))
    return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)
  except:  # pylint: disable=bare-except
    return 0


def _extract_metric_update_ops(eval_dict, distribution=None):
  """Separate update operations from metric value operations."""
  update_ops = []
  value_ops = {}
  # Sort metrics lexicographically so graph is identical every time.
  for name, value in sorted(six.iteritems(eval_dict)):
    value_ops[name] = value[0]
    update_ops.append(
        distribution.group(value[1]) if distribution else value[1])

  update_op = control_flow_ops.group(*update_ops) if update_ops else None
  return update_op, value_ops


def _dict_to_str(dictionary):
  """Get a `str` representation of a `dict`.

  Args:
    dictionary: The `dict` to be represented as `str`.

  Returns:
    A `str` representing the `dictionary`.
  """
  return ', '.join('%s = %s' % (k, v)
                   for k, v in sorted(six.iteritems(dictionary))
                   if not isinstance(v, six.binary_type))


def _write_dict_to_summary(output_dir,
                           dictionary,
                           current_global_step):
  """Writes a `dict` into summary file in given output directory.

  Args:
    output_dir: `str`, directory to write the summary file in.
    dictionary: the `dict` to be written to summary file.
    current_global_step: `int`, the current global step.
  """
  logging.info('Saving dict for global step %d: %s', current_global_step,
               _dict_to_str(dictionary))
  summary_writer = writer_cache.FileWriterCache.get(output_dir)
  summary_proto = summary_pb2.Summary()
  for key in dictionary:
    if dictionary[key] is None:
      continue
    if key == 'global_step':
      continue
    if (isinstance(dictionary[key], np.float32) or
        isinstance(dictionary[key], float)):
      summary_proto.value.add(tag=key, simple_value=float(dictionary[key]))
    elif (isinstance(dictionary[key], np.int64) or
          isinstance(dictionary[key], np.int32) or
          isinstance(dictionary[key], int)):
      summary_proto.value.add(tag=key, simple_value=int(dictionary[key]))
    elif isinstance(dictionary[key], six.binary_type):
      try:
        summ = summary_pb2.Summary.FromString(dictionary[key])
        for i, _ in enumerate(summ.value):
          summ.value[i].tag = '%s/%d' % (key, i)
        summary_proto.value.extend(summ.value)
      except message.DecodeError:
        logging.warn('Skipping summary for %s, cannot parse string to Summary.',
                     key)
        continue
    elif isinstance(dictionary[key], np.ndarray):
      value = summary_proto.value.add()
      value.tag = key
      value.node_name = key
      tensor_proto = tensor_util.make_tensor_proto(dictionary[key])
      value.tensor.CopyFrom(tensor_proto)
      # pylint: disable=line-too-long
      logging.info(
          'Summary for np.ndarray is not visible in Tensorboard by default. '
          'Consider using a Tensorboard plugin for visualization (see '
          'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md'
          ' for more information).')
      # pylint: enable=line-too-long
    else:
      logging.warn(
          'Skipping summary for %s, must be a float, np.float32, np.int64, '
          'np.int32 or int or np.ndarray or a serialized string of Summary.',
          key)
  summary_writer.add_summary(summary_proto, current_global_step)
  summary_writer.flush()


def _write_checkpoint_path_to_summary(output_dir, checkpoint_path,
                                      current_global_step):
  """Writes `checkpoint_path` into summary file in the given output directory.

  Args:
    output_dir: `str`, directory to write the summary file in.
    checkpoint_path: `str`, checkpoint file path to be written to summary file.
    current_global_step: `int`, the current global step.
  """

  checkpoint_path_tag = 'checkpoint_path'

  logging.info('Saving \'%s\' summary for global step %d: %s',
               checkpoint_path_tag, current_global_step, checkpoint_path)
  summary_proto = summary_pb2.Summary()
  summary_proto.value.add(
      tag=checkpoint_path_tag,
      tensor=tensor_util.make_tensor_proto(
          checkpoint_path, dtype=dtypes.string))
  summary_writer = writer_cache.FileWriterCache.get(output_dir)
  summary_writer.add_summary(summary_proto, current_global_step)
  summary_writer.flush()


def _has_dataset_or_queue_runner(maybe_tensor):
  """Returns `True` if `Dataset` or `QueueRunner` has been used."""
  # Check TF dataset first. Here, we use a simple algorithm to check the top
  # level Tensors only, which should be sufficient for most users.
  tensors = [x for x in nest.flatten(maybe_tensor) if isinstance(x, ops.Tensor)]
  if any([t.op.type == 'IteratorGetNext' for t in tensors]):
    return True

  # Now, check queue.
  return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS)


VocabInfo = warm_starting_util.VocabInfo  # pylint: disable=invalid-name
estimator_export('estimator.VocabInfo')(VocabInfo)


@estimator_export('estimator.WarmStartSettings')
class WarmStartSettings(
    collections.namedtuple('WarmStartSettings', [
        'ckpt_to_initialize_from',
        'vars_to_warm_start',
        'var_name_to_vocab_info',
        'var_name_to_prev_var_name',
    ])):
  """Settings for warm-starting in `tf.estimator.Estimators`.

  Example Use with canned `tf.estimator.DNNEstimator`:

  ```
  emb_vocab_file = tf.feature_column.embedding_column(
      tf.feature_column.categorical_column_with_vocabulary_file(
          "sc_vocab_file", "new_vocab.txt", vocab_size=100),
      dimension=8)
  emb_vocab_list = tf.feature_column.embedding_column(
      tf.feature_column.categorical_column_with_vocabulary_list(
          "sc_vocab_list", vocabulary_list=["a", "b"]),
      dimension=8)
  estimator = tf.estimator.DNNClassifier(
    hidden_units=[128, 64], feature_columns=[emb_vocab_file, emb_vocab_list],
    warm_start_from=ws)
  ```

  where `ws` could be defined as:

  Warm-start all weights in the model (input layer and hidden weights).
  Either the directory or a specific checkpoint can be provided (in the case
  of the former, the latest checkpoint will be used):

  ```
  ws = WarmStartSettings(ckpt_to_initialize_from="/tmp")
  ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")
  ```

  Warm-start only the embeddings (input layer):

  ```
  ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
                         vars_to_warm_start=".*input_layer.*")
  ```

  Warm-start all weights but the embedding parameters corresponding to
  `sc_vocab_file` have a different vocab from the one used in the current
  model:

  ```
  vocab_info = tf.estimator.VocabInfo(
      new_vocab=sc_vocab_file.vocabulary_file,
      new_vocab_size=sc_vocab_file.vocabulary_size,
      num_oov_buckets=sc_vocab_file.num_oov_buckets,
      old_vocab="old_vocab.txt"
  )
  ws = WarmStartSettings(
      ckpt_to_initialize_from="/tmp",
      var_name_to_vocab_info={
          "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
      })
  ```

  Warm-start only `sc_vocab_file` embeddings (and no other variables), which
  have a different vocab from the one used in the current model:

  ```
  vocab_info = tf.estimator.VocabInfo(
      new_vocab=sc_vocab_file.vocabulary_file,
      new_vocab_size=sc_vocab_file.vocabulary_size,
      num_oov_buckets=sc_vocab_file.num_oov_buckets,
      old_vocab="old_vocab.txt"
  )
  ws = WarmStartSettings(
      ckpt_to_initialize_from="/tmp",
      vars_to_warm_start=None,
      var_name_to_vocab_info={
          "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
      })
  ```

  Warm-start all weights but the parameters corresponding to `sc_vocab_file`
  have a different vocab from the one used in current checkpoint, and only
  100 of those entries were used:

  ```
  vocab_info = tf.estimator.VocabInfo(
      new_vocab=sc_vocab_file.vocabulary_file,
      new_vocab_size=sc_vocab_file.vocabulary_size,
      num_oov_buckets=sc_vocab_file.num_oov_buckets,
      old_vocab="old_vocab.txt",
      old_vocab_size=100
  )
  ws = WarmStartSettings(
      ckpt_to_initialize_from="/tmp",
      var_name_to_vocab_info={
          "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
      })
  ```

  Warm-start all weights but the parameters corresponding to `sc_vocab_file`
  have a different vocab from the one used in current checkpoint and the
  parameters corresponding to `sc_vocab_list` have a different name from the
  current checkpoint:

  ```
  vocab_info = tf.estimator.VocabInfo(
      new_vocab=sc_vocab_file.vocabulary_file,
      new_vocab_size=sc_vocab_file.vocabulary_size,
      num_oov_buckets=sc_vocab_file.num_oov_buckets,
      old_vocab="old_vocab.txt",
      old_vocab_size=100
  )
  ws = WarmStartSettings(
      ckpt_to_initialize_from="/tmp",
      var_name_to_vocab_info={
          "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
      },
      var_name_to_prev_var_name={
          "input_layer/sc_vocab_list_embedding/embedding_weights":
              "old_tensor_name"
      })
  ```

  Attributes:
    ckpt_to_initialize_from: [Required] A string specifying the directory with
      checkpoint file(s) or path to checkpoint from which to warm-start the
      model parameters.
    vars_to_warm_start: [Optional] One of the following:  - A regular expression
      (string) that captures which variables to warm-start (see
      `tf.get_collection`).  This expression will only consider variables in the
      `TRAINABLE_VARIABLES` collection. - A list of Variables to warm-start. - A
      list of strings, each representing a full variable name to warm-start. -
      `None`, in which case only variables specified in `var_name_to_vocab_info`
      will be warm-started.  Defaults to `'.*'`, which warm-starts all variables
      in the `TRAINABLE_VARIABLES` collection.  Note that this excludes
      variables such as accumulators and moving statistics from batch norm.
    var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
      `tf.estimator.VocabInfo`. The variable names should be "full" variables,
      not the names of the partitions.  If not explicitly provided, the variable
      is assumed to have no (changes to) vocabulary.
    var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
      name of the previously-trained variable in `ckpt_to_initialize_from`. If
      not explicitly provided, the name of the variable is assumed to be same
      between previous checkpoint and current model.
  """

  def __new__(cls,
              ckpt_to_initialize_from,
              vars_to_warm_start='.*',
              var_name_to_vocab_info=None,
              var_name_to_prev_var_name=None):
    if not ckpt_to_initialize_from:
      raise ValueError(
          '`ckpt_to_initialize_from` MUST be set in WarmStartSettings')
    return super(WarmStartSettings, cls).__new__(
        cls,
        ckpt_to_initialize_from,
        vars_to_warm_start,
        var_name_to_vocab_info or {},
        var_name_to_prev_var_name or {},
    )


def _get_saved_model_ckpt(saved_model_dir):
  """Return path to variables checkpoint in a `SavedModel` directory."""
  if not gfile.Exists(
      os.path.join(saved_model_utils.get_variables_dir(saved_model_dir),
                   compat.as_text('variables.index'))):
    raise ValueError('Directory provided has an invalid SavedModel format: %s'
                     % saved_model_dir)
  return saved_model_utils.get_variables_path(saved_model_dir)


def _get_default_warm_start_settings(warm_start_from):
  """Returns default `tf.estimator.WarmStartSettings`.

  Args:
    warm_start_from: Either a string representing the filepath of a checkpoint
      or `SavedModel` to initialize from, or an instance of
      `tf.estimator.WarmStartSettings`.

  Returns:
    Either None or an instance of `WarmStartSettings`.

  Raises:
    ValueError: If `warm_start_from` is not `None` but is neither a string nor
    an
      instance of `WarmStartSettings`.
  """
  if warm_start_from is None:
    return None
  if isinstance(warm_start_from, (six.string_types, six.binary_type)):
    # Infer that this is a SavedModel if export_path +
    # 'variables/variables.index' exists, and if so, construct the
    # WarmStartSettings pointing to the variables path
    # (export_path + 'variables/variables').
    if gfile.Exists(os.path.join(
        saved_model_utils.get_variables_dir(warm_start_from),
        compat.as_text('variables.index'))):
      logging.info('Warm-starting from a SavedModel')
      return WarmStartSettings(
          ckpt_to_initialize_from=saved_model_utils.get_variables_path(
              warm_start_from))
    return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)
  elif isinstance(warm_start_from, WarmStartSettings):
    return warm_start_from
  else:
    raise ValueError('warm_start_from must be a string or a WarmStartSettings, '
                     'instead got {}'.format(type(warm_start_from)))