aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
blob: 7c7c97638e4fbeb777059ca5eac6cd093f785d78 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""TPUEstimator class."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import copy
import os
import signal
import sys
import threading
import time

import numpy as np
import six
from six.moves import queue as Queue  # pylint: disable=redefined-builtin
from six.moves import xrange  # pylint: disable=redefined-builtin

from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import error_handling
from tensorflow.contrib.tpu.python.tpu import session_support
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_context
from tensorflow.contrib.tpu.python.tpu import tpu_feed
from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.contrib.tpu.python.tpu import util as util_lib
from tensorflow.contrib.training.python.training import hparam
from tensorflow.core.framework import variable_pb2
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import summary_ops_v2 as contrib_summary
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import evaluation
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.training import training_util
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect


_INITIAL_LOSS = 1e7
_ZERO_LOSS = 0.
_TPU_ESTIMATOR = 'tpu_estimator'
_ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop'
_BATCH_SIZE_KEY = 'batch_size'
_CTX_KEY = 'context'
_USE_TPU_KEY = 'use_tpu'
_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
_ONE_GIGABYTE = 1024 * 1024 * 1024
_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'
_TPU_TRAIN_OP = '_tpu_train_op'
_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference'

# Ideally _USE_TPU_KEY should be reserved as well. However there are already
# models that make use of this key, thus it can not be reserved now to prevent
# breakage. In the long run, we would like to mitigate this by migrating models
# off of using _USE_TPU_KEY.
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY]


# TODO(b/65703635): Flip the value and remove all dead code. Currently, this is
# only used for per-core based deployments. For per-host based pipelines, if a
# user returns a Dataset instance it will be automatically wrapped in a
# tf.while_loop (This can be disabled by returning features and labels
# explicitly).
_WRAP_INPUT_FN_INTO_WHILE_LOOP = False


ops.register_proto_function(
    '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR),
    proto_type=variable_pb2.VariableDef,
    to_proto=resource_variable_ops._to_proto_fn,  # pylint: disable=protected-access
    from_proto=resource_variable_ops._from_proto_fn)  # pylint: disable=protected-access


def _create_global_step(graph):
  graph = graph or ops.get_default_graph()
  if training.get_global_step(graph) is not None:
    raise ValueError('"global_step" already exists.')
  # Create in proper graph and base name_scope.
  with graph.as_default() as g, g.name_scope(None):
    return variable_scope.get_variable(
        ops.GraphKeys.GLOBAL_STEP,
        shape=[],
        dtype=dtypes.int64,
        initializer=init_ops.zeros_initializer(),
        trainable=False,
        use_resource=True,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])


def _create_or_get_iterations_per_loop():
  """Creates or gets the iterations_per_loop variable.

  In TPUEstimator, the user provided computation, the model_fn, is wrapped
  inside a tf.while_loop for peak performance. The iterations of the loop are
  specified by this variable, which adjusts its value on the CPU after each TPU
  program execution and before the next TPU execution.

  The purpose of using a variable, rather then a constant, is to allow
  TPUEstimator adapt the TPU training iterations according to the final steps
  specified by users. For example, if the user sets the iterations_per_loop as 4
  in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop
  variable will have the following value before each TPU training.

      - 1-th TPU execution: iterations_per_loop = 4
      - 2-th TPU execution: iterations_per_loop = 4
      - 3-th TPU execution: iterations_per_loop = 2

  As model_fn increases the global step once per train_op invocation, the global
  step is 10 after all TPU executions, matching the steps=10 inputs passed in by
  users.

  Returns:
    A TF non-trainable resource variable.

  Raises:
    RuntimeError: If multi iterations_per_loop variables were found.
  """
  graph = ops.get_default_graph()
  collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR)
  iter_vars = graph.get_collection(collection_name)
  if len(iter_vars) == 1:
    return iter_vars[0]
  elif len(iter_vars) > 1:
    raise RuntimeError('Multiple iterations_per_loop_var in collection.')

  with ops.colocate_with(training_util.get_global_step()):
    with variable_scope.variable_scope(
        _TPU_ESTIMATOR, reuse=variable_scope.AUTO_REUSE):
      return variable_scope.get_variable(
          _ITERATIONS_PER_LOOP_VAR,
          initializer=init_ops.zeros_initializer(),
          shape=[],
          dtype=dtypes.int32,
          trainable=False,
          collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES],
          use_resource=True)


def _sync_variables_ops():
  # Gets the variables back from TPU nodes. This means the variables updated
  # by TPU will now be *synced* to host memory.
  return [
      array_ops.check_numerics(v.read_value(),
                               'Gradient for %s is NaN' % v.name).op
      for v in variables.trainable_variables()
  ]


def _increase_eval_step_op(iterations_per_loop):
  """Returns an op to increase the eval step for TPU evaluation.

  Args:
    iterations_per_loop: Tensor. The number of eval steps running in TPU
        system before returning to CPU host for each `Session.run`.

  Returns:
    An operation
  """
  eval_step = evaluation._get_or_create_eval_step()  # pylint: disable=protected-access
  # Estimator evaluate increases 1 by default. So, we increase the difference.
  return state_ops.assign_add(
      eval_step,
      math_ops.cast(iterations_per_loop - 1, dtype=eval_step.dtype),
      use_locking=True)


class _SIGNAL(object):
  """Signal used to control the thread of infeed/outfeed.

  All preserved signals must be negative numbers. Positive numbers are used to
  indicate the number of iterations for next training/evaluation loop.
  """
  NEXT_BATCH = -1
  STOP = -2


class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
  """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.

  See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and
  `export_outputs`.

  For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where
  `metric_fn` runs on CPU to generate metrics and `tensors` represents the
  `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
  To be precise, TPU evaluation expects a slightly different signature from the
  @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
  dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
  The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
  `tensors` usually specify the model logits, which are transferred back from
  TPU system to CPU host. All tensors must have be batch-major, i.e., the batch
  size is the first dimension. Once all tensors are available at CPU host from
  all shards, they are concatenated (on CPU) and passed as positional arguments
  to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is
  a dict. `metric_fn` takes the `tensors` and returns a dict from metric string
  name to the result of calling a metric function, namely a `(metric_tensor,
  update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the
  `eval_metrics`.

  `scaffold_fn` is a function running on CPU to generate the `Scaffold`. This
  function should not capture any Tensors in `model_fn`.

  `host_call` is a tuple of a `function` and a list or dictionary of `tensors`
  to pass to that function and returns a list of Tensors. `host_call` currently
  works for train() and evaluate(). The Tensors returned by the function is
  executed on the CPU on every step, so there is communication overhead when
  sending tensors from TPU to CPU. To reduce the overhead, try reducing the
  size of the tensors. The `tensors` are concatenated along their major (batch)
  dimension, and so must be >= rank 1. The `host_call` is useful for writing
  summaries with @{tf.contrib.summary.create_file_writer}.
  """

  def __new__(cls,
              mode,
              predictions=None,
              loss=None,
              train_op=None,
              eval_metrics=None,
              export_outputs=None,
              scaffold_fn=None,
              host_call=None):
    """Creates a validated `TPUEstimatorSpec` instance."""
    host_calls = {}
    if eval_metrics is not None:
      host_calls['eval_metrics'] = eval_metrics
    if host_call is not None:
      host_calls['host_call'] = host_call
    _OutfeedHostCall.validate(host_calls)
    return super(TPUEstimatorSpec, cls).__new__(
        cls,
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metrics=eval_metrics,
        export_outputs=export_outputs,
        scaffold_fn=scaffold_fn,
        host_call=host_call)

  def as_estimator_spec(self):
    """Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
    host_calls = {}
    if self.eval_metrics is not None:
      host_calls['eval_metrics'] = self.eval_metrics
    if self.host_call is not None:
      host_calls['host_call'] = self.host_call
    host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls)
    eval_metric_ops = None
    if self.eval_metrics is not None:
      eval_metric_ops = host_call_ret['eval_metrics']
    hooks = None
    if self.host_call is not None:
      hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]
    scaffold = self.scaffold_fn() if self.scaffold_fn else None
    return model_fn_lib.EstimatorSpec(
        mode=self.mode,
        predictions=self.predictions,
        loss=self.loss,
        train_op=self.train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs=self.export_outputs,
        scaffold=scaffold,
        training_hooks=hooks,
        evaluation_hooks=hooks,
        prediction_hooks=hooks)


class _OpQueueContext(object):
  """Manages work queue and thread for a infeed/outfeed thread."""

  def __init__(self, name, target, args):
    self._name = name
    self._queue = Queue.Queue()
    args = (self,) + args
    self._thread = threading.Thread(name=name, target=target, args=args)
    self._thread.daemon = True
    self._thread.start()

  def stop(self):
    self._queue.put(_SIGNAL.STOP)

  def send_next_batch_signal(self, iterations):
    self._queue.put(iterations)

  def read_iteration_counts(self):
    while True:
      iterations = self._queue.get(block=True)
      logging.debug('%s read iterations %s', self._name, iterations)
      if iterations == _SIGNAL.STOP:
        logging.info('%s received shutdown signal, stopping.', self._name)
        return
      yield iterations

  def join(self):
    logging.info('Shutting down %s thread.' % self._name)
    self.stop()
    self._thread.join()


class _OpSignalOnceQueueContext(_OpQueueContext):
  """Manages work queue and thread for a infeed/outfeed thread.

  This subclass only signals once.
  """

  def __init__(self, name, target, args):
    super(_OpSignalOnceQueueContext, self).__init__(name, target, args)
    self._has_signaled = False

  def send_next_batch_signal(self, iterations):
    if not self._has_signaled:
      self._queue.put(iterations)
      self._has_signaled = True


class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
  """A Session hook setting up the TPU initialization, infeed, and outfeed.

  This hook does two major things:
  1. initialize and shutdown TPU system.
  2. launch and join the threads for infeed enqueue and (optional) outfeed
     dequeue.
  """

  def __init__(self,
               ctx,
               enqueue_ops,
               dequeue_ops,
               run_infeed_loop_on_coordinator=True,
               rendezvous=None):
    self._master_job = ctx.master_job
    self._enqueue_ops = enqueue_ops
    self._dequeue_ops = dequeue_ops
    self._rendezvous = rendezvous

    self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator
    self._initial_infeed_sleep_secs = (
        ctx.config.tpu_config.initial_infeed_sleep_secs)

    self._feed_error = None
    self._finished = False

  def begin(self):
    logging.info('TPU job name %s', self._master_job)
    self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
    self._init_ops = [tpu.initialize_system(job=self._master_job)]
    self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]

    summary_writer_init_ops = contrib_summary.summary_writer_initializer_op()
    self._init_ops.extend(summary_writer_init_ops)
    # Get all the writer resources from the initializer, so we know what to
    # flush.
    for op in summary_writer_init_ops:
      self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0]))

  def _run_infeed(self, queue_ctx, session):
    logging.info('Starting infeed thread controller.')
    if self._initial_infeed_sleep_secs:
      logging.info('%s thread sleeping for %d seconds.', self._name,
                   self._initial_infeed_sleep_secs)
      time.sleep(self._initial_infeed_sleep_secs)
      logging.info('%s thread starting after sleep', self._name)

    with self._rendezvous.catch_errors(source='infeed', session=session):
      if self._run_infeed_loop_on_coordinator:
        for count, steps in enumerate(queue_ctx.read_iteration_counts()):
          for i in xrange(steps):
            logging.debug('Infeed enqueue for iteration (%d, %d)', count, i)
            session.run(self._enqueue_ops)
      else:
        for _ in queue_ctx.read_iteration_counts():
          session.run(self._enqueue_ops)
      logging.info('Infeed thread finished, shutting down.')

  def _run_outfeed(self, queue_ctx, session):
    logging.info('Starting outfeed thread controller.')
    with self._rendezvous.catch_errors(source='outfeed', session=session):
      for count, steps in enumerate(queue_ctx.read_iteration_counts()):
        for i in xrange(steps):
          logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
          session.run(self._dequeue_ops)
      logging.info('Outfeed thread finished, shutting down.')

  def _create_infeed_controller(self, name, target, args):
    return _OpQueueContext(name=name, target=target, args=args)

  def after_create_session(self, session, coord):
    logging.info('Init TPU system')
    session.run(self._init_ops,
                options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))

    self._infeed_controller = self._create_infeed_controller(
        name='InfeedController', target=self._run_infeed, args=(session,))

    self._outfeed_controller = _OpQueueContext(
        name='OutfeedController', target=self._run_outfeed, args=(session,))

  def before_run(self, run_context):
    self._feed_error = None

    iterations = run_context.session.run(self._iterations_per_loop_var)

    logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
    self._infeed_controller.send_next_batch_signal(iterations)

    logging.info('Dequeue next (%d) batch(es) of data from outfeed.',
                 iterations)
    self._outfeed_controller.send_next_batch_signal(iterations)

  def end(self, session):
    self._finished = True
    logging.info('Stop infeed thread controller')
    self._infeed_controller.join()
    self._rendezvous.record_done('infeed')

    logging.info('Stop output thread controller')
    self._outfeed_controller.join()
    self._rendezvous.record_done('outfeed')

    logging.info('Shutdown TPU system.')
    session.run(self._finalize_ops)


class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook):

  def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None):
    super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__(
        ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False,
        rendezvous=rendezvous)

  def _create_infeed_controller(self, name, target, args):
    return _OpSignalOnceQueueContext(name=name, target=target, args=args)


class _TPUStopAtStepHook(session_run_hook.SessionRunHook):
  """Hook that requests stop at a specified step.

  This hook is similar to the `session_run_hook._StopAfterNEvalsHook` with
  following differences for TPU training:

  1. This hook sets the variable for iterations_per_loop, which is used by
     `TPUInfeedOutfeedSessionHook` to control the iterations for infeed/outfeed.
     As the hook execution order is not guaranteed, the variable update is
     handled in `after_create_session` and `after_run` as
     `TPUInfeedOutfeedSessionHook` reads the variable value in `before_run`.

  2. For each training loop (session.run), the global step could be increased
     multiple times on TPU. The global step tensor value will be explicitly read
     again in `after_run` to ensure the latest value is retrieved to avoid race
     condition.
  """

  def __init__(self, iterations, num_steps=None, last_step=None):
    """Initializes a `StopAtStepHook`.

    Args:
      iterations: The number of iterations to run optimizer per training loop.
      num_steps: Number of steps to execute.
      last_step: Step after which to stop.

    Raises:
      ValueError: If one of the arguments is invalid.
    """
    if num_steps is None and last_step is None:
      raise ValueError('One of num_steps or last_step must be specified.')
    if num_steps is not None and last_step is not None:
      raise ValueError('Only one of num_steps or last_step can be specified.')
    self._num_steps = num_steps
    self._last_step = last_step
    self._iterations = iterations

  def _next_iterations(self, global_step, last_step):
    gap = last_step - global_step
    return min(gap, self._iterations)

  def begin(self):
    self._global_step_tensor = training_util.get_global_step()
    if self._global_step_tensor is None:
      raise RuntimeError('Global step should be created.')

    self._iterations_per_loop_var = _create_or_get_iterations_per_loop()

  def after_create_session(self, session, coord):
    global_step = session.run(self._global_step_tensor)
    if self._last_step is None:
      self._last_step = global_step + self._num_steps

    iterations = self._next_iterations(global_step, self._last_step)

    self._iterations_per_loop_var.load(iterations, session=session)

  def after_run(self, run_context, run_values):
    # Global step cannot be retrieved via SessionRunArgs and before_run due to
    # race condition.
    global_step = run_context.session.run(self._global_step_tensor)
    if global_step >= self._last_step:
      run_context.request_stop()
    else:
      iterations = self._next_iterations(global_step, self._last_step)
      self._iterations_per_loop_var.load(
          iterations, session=run_context.session)


class _SetEvalIterationsHook(session_run_hook.SessionRunHook):
  """Hook that requests stop at a specified step."""

  def __init__(self, num_steps):
    """Initializes a `_SetEvalIterationsHook`.

    Args:
      num_steps: Number of steps to execute.
    """
    self._num_steps = num_steps

  def begin(self):
    self._iterations_per_loop_var = _create_or_get_iterations_per_loop()

  def after_create_session(self, session, coord):
    self._iterations_per_loop_var.load(self._num_steps, session=session)


class _StoppingPredictHook(session_run_hook.SessionRunHook):
  """Hook that requests stop according to the stopping signal in prediction."""

  def __init__(self, scalar_stopping_signal):
    self._scalar_stopping_signal = scalar_stopping_signal

  def begin(self):
    self._iterations_per_loop_var = _create_or_get_iterations_per_loop()

  def after_create_session(self, session, coord):
    # This is not necessary as we do not run infeed enqueue and outfeed dequeue
    # in side threads for prediction model. But it makes the
    # TPUInfeedOutfeedSessionHook prints nice message.
    self._iterations_per_loop_var.load(1, session=session)

  def before_run(self, run_context):
    return session_run_hook.SessionRunArgs(self._scalar_stopping_signal)

  def after_run(self, run_context, run_values):
    _ = run_context
    scalar_stopping_signal = run_values.results
    if _StopSignals.should_stop(scalar_stopping_signal):
      # NOTE(xiejw): In prediction, stopping signals are inserted for each
      # batch. And we append one more batch to signal the system it should stop.
      # The data flow might look like
      #
      #  batch   0: images, labels, stop = 0  (user provided)
      #  batch   1: images, labels, stop = 0  (user provided)
      #  ...
      #  batch  99: images, labels, stop = 0  (user provided)
      #  batch 100: images, labels, stop = 1  (TPUEstimator appended)
      #
      # where the final batch (id = 100) is appended by TPUEstimator, so we
      # should drop it before returning the predictions to user.
      # To achieve that, we throw the OutOfRangeError in after_run. Once
      # Monitored Session sees this error in SessionRunHook.after_run, the
      # "current" prediction, i.e., batch with id=100, will be discarded
      # immediately
      raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.')


def generate_per_core_enqueue_ops_fn_for_host(
    ctx, input_fn, inputs_structure_recorder, host_device, host_id):
  """Generates infeed enqueue ops for per-core input_fn on a single host."""
  captured_infeed_queue = _CapturedObject()
  tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)

  def enqueue_ops_fn():
    """A fn returns enqueue_ops."""
    num_cores_per_host = ctx.num_of_cores_per_host
    per_host_sharded_inputs = []
    for core_ordinal in range(num_cores_per_host):
      with ops.name_scope('ordinal_%d' % (core_ordinal)):
        user_context = tpu_context.TPUContext(
            internal_ctx=ctx,
            input_device=host_device,
            invocation_index=host_id * ctx.num_of_cores_per_host + core_ordinal
        )
        inputs = _Inputs.from_input_fn(input_fn(user_context))
        if inputs.is_dataset:
          raise TypeError(
              '`input_fn` returning `Dataset`  is not yet supported in '
              'per-Core input pipeline deployment yet. Please set '
              'TPUConfig.per_host_input_for_training to True or return '
              '`features` and `labels` from `input_fn`')
        features, labels = inputs.features_and_labels()

        inputs_structure_recorder.validate_and_record_structure(
            features, labels)
        flattened_inputs = (
            inputs_structure_recorder.flatten_features_and_labels(
                features, labels))
        per_host_sharded_inputs.append(flattened_inputs)

    infeed_queue = tpu_feed.InfeedQueue(
        number_of_tuple_elements=len(per_host_sharded_inputs[0]))
    captured_infeed_queue.capture(infeed_queue)

    per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
        per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl)
    return per_host_enqueue_ops

  return enqueue_ops_fn, captured_infeed_queue


def generate_per_host_enqueue_ops_fn_for_host(
    ctx, input_fn, inputs_structure_recorder, batch_axis, device, host_id):
  """Generates infeed enqueue ops for per-host input_fn on a single host."""
  captured_infeed_queue = _CapturedObject()

  hooks = []

  with ops.device(device):
    user_context = tpu_context.TPUContext(
        internal_ctx=ctx,
        input_device=device,
        invocation_index=host_id)
    inputs = _Inputs.from_input_fn(input_fn(user_context))

    is_dataset = inputs.is_dataset
    if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
      if not is_dataset:
        raise TypeError(
            'For mode PREDICT, `input_fn` must return `Dataset` instead of '
            '`features` and `labels`.')
      if batch_axis is not None:
        raise TypeError('For mode PREDICT, batch_axis is not supported yet.')
      inputs = _InputsWithStoppingSignals(
          dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn,
          add_padding=True)

    if is_dataset:
      hooks.append(inputs.dataset_initializer_hook())

    tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)

  def enqueue_ops_fn():
    """A Fn returning the TPU infeed enqueue ops.

    By providing as a Fn, it can be invoked inside the tf.while_loop such that
    the input pipeline for multiple iterations can be executed by one
    Session.run call.

    Returns:
      list of dict of ops.
    """
    with ops.device(device):
      num_of_replicas_per_host = ctx.num_of_replicas_per_host
      # Convert user input to features and labels.  If the user returns a
      # dataset, it is initialized and the features and labels extracted via
      # `dataset.iterator.get_next()`
      features, labels = inputs.features_and_labels()
      signals = inputs.signals()

      inputs_structure_recorder.validate_and_record_structure(
          features, labels, signals)
      unsharded_tensor_list = (
          inputs_structure_recorder.flatten_features_and_labels(
              features, labels, signals))

      infeed_queue = tpu_feed.InfeedQueue(
          tuple_types=[t.dtype for t in unsharded_tensor_list],
          tuple_shapes=[t.shape for t in unsharded_tensor_list],
          shard_dimensions=batch_axis)
      captured_infeed_queue.capture(infeed_queue)
      infeed_queue.set_number_of_shards(num_of_replicas_per_host)
      per_host_enqueue_ops = (
          infeed_queue.split_inputs_and_generate_enqueue_ops(
              unsharded_tensor_list,
              placement_function=lambda x: device,
              tpu_ordinal_function=tpu_ordinal_function_impl))
      if signals is None:
        return per_host_enqueue_ops
      else:
        return {
            'ops': per_host_enqueue_ops,
            'signals': signals,
        }

  return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset


def generate_per_host_v2_enqueue_ops_fn_for_host(
    ctx, input_fn, inputs_structure_recorder, device, host_id):
  """Generates infeed enqueue ops for per-host input_fn on a single host."""
  captured_infeed_queue = _CapturedObject()
  hooks = []

  with ops.device(device):
    user_context = tpu_context.TPUContext(
        internal_ctx=ctx,
        input_device=device,
        invocation_index=host_id)
    inputs = _Inputs.from_input_fn(input_fn(user_context))

    is_dataset = inputs.is_dataset
    if not is_dataset:
      raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 '
                      'input pipeline configuration.')
    if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
      # TODO(b/XXX): Add predict support for PER_HOST_V2
      raise TypeError('Most PREDICT not yet supported in PER_HOST_V2 mode.')

    hooks.append(inputs.dataset_initializer_hook())
    tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)

  def enqueue_ops_fn():
    """Generates the per_host enqueue ops."""
    control_deps = []
    per_host_sharded_inputs = []
    num_replicas_per_host = ctx.num_of_replicas_per_host
    with ops.device(device):
      if not inputs.is_dataset:
        raise TypeError('`input_fn` must return a `Dataset` for this mode.')
      for _ in range(num_replicas_per_host):
        # Use control dependencies to ensure a deterministic ordering.
        with ops.control_dependencies(control_deps):
          features, labels = inputs.features_and_labels()  # Calls get_next()

        inputs_structure_recorder.validate_and_record_structure(
            features, labels)
        flattened_inputs = (
            inputs_structure_recorder.flatten_features_and_labels(
                features, labels))

        control_deps.extend(flattened_inputs)
        per_host_sharded_inputs.append(flattened_inputs)

    infeed_queue = tpu_feed.InfeedQueue(
        number_of_tuple_elements=len(per_host_sharded_inputs[0]))
    captured_infeed_queue.capture(infeed_queue)

    per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
        per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl)
    return per_host_enqueue_ops

  return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset


def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
                                      num_hosts):
  """Generates infeed enqueue ops for one input_fn on all the hosts."""
  captured_infeed_queue = _CapturedObject()
  hooks = []
  device_0 = ctx.tpu_host_placement_function(host_id=0)
  with ops.device(device_0):
    user_context = tpu_context.TPUContext(
        internal_ctx=ctx, input_device=device_0, invocation_index=0)
    inputs = _Inputs.from_input_fn(input_fn(user_context))

    is_dataset = inputs.is_dataset
    if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
      raise TypeError('Mode PREDICT not yet supported in BROADCAST mode.')

    if is_dataset:
      hooks.append(inputs.dataset_initializer_hook())
    num_replicas_per_host = ctx.num_of_replicas_per_host

  def tpu_ordinal_function_impl(replica_id):
    if ctx.device_assignment:
      return ctx.device_assignment.tpu_ordinal(replica=replica_id)
    else:
      return replica_id % num_replicas_per_host

  def device_function_impl(replica_id):
    return ctx.tpu_host_placement_function(replica_id=replica_id)

  def enqueue_ops_fn():
    """Generates enqueue ops for all the hosts."""
    broadcasted_inputs = []
    flattened_inputs = None  # Cache result from input_fn.
    for host_id in xrange(num_hosts):
      with ops.device(ctx.tpu_host_placement_function(host_id=host_id)):
        for _ in xrange(ctx.num_of_replicas_per_host):
          # Note: input_fn is only called once at host 0 for the first replica.
          # The features and labels returned from that invocation are
          # broadcasted to other replicas(including the replicas on other
          # hosts).
          if flattened_inputs is None:
            features, labels = inputs.features_and_labels()  # Calls get_next()
            inputs_structure_recorder.validate_and_record_structure(
                features, labels)
            flattened_inputs = (
                inputs_structure_recorder.flatten_features_and_labels(
                    features, labels))
          broadcasted_inputs.append(flattened_inputs)

    infeed_queue = tpu_feed.InfeedQueue(
        number_of_tuple_elements=len(broadcasted_inputs[0]))
    captured_infeed_queue.capture(infeed_queue)
    enqueue_ops = infeed_queue.generate_enqueue_ops(
        broadcasted_inputs,
        tpu_ordinal_function=tpu_ordinal_function_impl,
        placement_function=device_function_impl)
    return enqueue_ops

  return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset


class _InputPipeline(object):
  """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.

  `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from
  call site.  To be precise, based on the configuration in
  `_InternalTPUContext`,  it invokes `input_fn` for all cores (usually
  multi-host TPU training) or for one host (usually for single-host TPU
  evaluation), and sends all `features` and `labels` returned by `input_fn` to
  TPU infeed. For per-core invocation, `features` and `labels` are piped to
  infeed directly, one tuple for each core. For per-host invocation,  `features`
  and `labels` are split at host (with respect to `batch_axis`) and piped to all
  cores accordingly.

  In addition, flatten/unflatten are handled by `_InputPipeline` also.  Model
  inputs returned by the `input_fn` can have one of the following forms:
  1. features
  2. (features, labels)

  Internally, form 1 is reformed to `(features, None)` as features and labels
  are passed separately to underlying methods. For TPU training, TPUEstimator
  may expect multiple `features` and `labels` tuples one for each core.

  TPUEstimator allows various different structures for inputs (namely `features`
  and `labels`).  `features` can be `Tensor` or dict of string name to `Tensor`,
  and `labels` could be `None`, `Tensor`, or dict of string name to `Tensor`.
  TPU infeed/outfeed library expects flattened tensor list. So, `features` and
  `labels` need to be flattened, before infeed enqueue, and the structure of
  them needs to be recorded, in order to restore them after infeed dequeue.
  """

  class InputsStructureRecorder(object):
    """The recorder to record inputs structure."""

    def __init__(self):
      # Holds the structure of inputs
      self._feature_names = []
      self._label_names = []
      self._has_labels = False
      self._signals_helper = None

      # Internal state.
      self._initialized = False

    def has_labels(self):
      return self._has_labels

    def validate_and_record_structure(self, features, labels, signals=None):
      """Validates and records the structure of features` and `labels`."""

      def _extract_key_names(tensor_or_dict):
        if tensor_or_dict is None:
          return []
        return sorted(tensor_or_dict.keys()) if isinstance(
            tensor_or_dict, dict) else []

      # Extract structure.
      has_labels = labels is not None
      feature_names = _extract_key_names(features)
      label_names = _extract_key_names(labels)

      if signals is not None and self._signals_helper is None:
        # Record signals helper.
        self._signals_helper = _SignalsHelper(signals)

      if self._initialized:
        # Verify the structure is same. The following should never happen.
        assert feature_names == self._feature_names, 'feature keys mismatched'
        assert label_names == self._label_names, 'label keys mismatched'
        assert has_labels == self._has_labels, 'label presence mismatched'
      else:
        # Record structure.
        self._initialized = True
        self._feature_names = feature_names
        self._label_names = label_names
        self._has_labels = has_labels

    def flatten_features_and_labels(self, features, labels, signals=None):
      """Flattens the `features` and `labels` to a single tensor list."""
      flattened_inputs = []
      if self._feature_names:
        # We need a fixed ordering for enqueueing and dequeueing.
        flattened_inputs.extend(
            [features[name] for name in self._feature_names])
      else:
        flattened_inputs.append(features)

      if labels is not None:
        if self._label_names:
          # We need a fixed ordering for enqueueing and dequeueing.
          flattened_inputs.extend([labels[name] for name in self._label_names])
        else:
          flattened_inputs.append(labels)

      if signals is not None:
        flattened_inputs.extend(_SignalsHelper.as_tensor_list(signals))
      return flattened_inputs

    def unflatten_features_and_labels(self, flattened_inputs):
      """Restores the flattened inputs to original features and labels form.

      Args:
        flattened_inputs: Flattened inputs for each shard.

      Returns:
        A tuple of (`features`, `labels`), where `labels` could be None.
        Each one, if present, should have identical structure (single tensor vs
        dict) as the one returned by input_fn.

      Raises:
        ValueError: If the number of expected tensors from `flattened_inputs`
          mismatches the recorded structure.
      """
      expected_num_features = (
          len(self._feature_names) if self._feature_names else 1)
      if self._has_labels:
        expected_num_labels = (
            len(self._label_names) if self._label_names else 1)
      else:
        expected_num_labels = 0

      expected_num_signals = (
          self._signals_helper.num_signals if self._signals_helper else 0)

      expected_num_tensors = (
          expected_num_features + expected_num_labels + expected_num_signals)

      if expected_num_tensors != len(flattened_inputs):
        raise ValueError(
            'The number of flattened tensors mismatches expected num. '
            'Expected {}, got {}'.format(expected_num_tensors,
                                         len(flattened_inputs)))
      if self._feature_names:
        unflattened_features = dict(
            zip(self._feature_names, flattened_inputs[:expected_num_features]))
      else:
        # Single tensor case
        unflattened_features = flattened_inputs[0]

      if expected_num_labels == 0:
        unflattened_label = None
      elif self._label_names:
        label_list = flattened_inputs[
            expected_num_features:expected_num_features + expected_num_labels]
        unflattened_label = dict(zip(self._label_names, label_list))
      else:
        # Single tensor case.
        unflattened_label = flattened_inputs[expected_num_features]

      signals = None
      if expected_num_signals != 0:
        tensor_list_for_signals = flattened_inputs[
            expected_num_features + expected_num_labels:]
        signals = self._signals_helper.unflatten(tensor_list_for_signals)

      return _Inputs(unflattened_features, unflattened_label, signals=signals)

  def __init__(self, input_fn, batch_axis, ctx):
    """Constructor.

    Args:
      input_fn: input fn for train or eval.
      batch_axis: A python tuple of int values describing how each tensor
        produced by the Estimator `input_fn` should be split across the TPU
        compute shards.
      ctx: A `_InternalTPUContext` instance with mode.

    Raises:
      ValueError: If both `sharded_features` and `num_cores` are `None`.
    """
    self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder()

    self._sharded_per_core = ctx.is_input_sharded_per_core()
    self._input_fn = input_fn
    self._infeed_queue = None
    self._ctx = ctx
    self._batch_axis = batch_axis

  def generate_infeed_enqueue_ops_and_dequeue_fn(self):
    """Generates infeed enqueue ops and dequeue_fn."""
    # While tf.while_loop is called, the body function, which invokes
    # `enqueue_fn` passed in, is called to construct the graph. So, input_fn
    # structure is recorded.
    enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = (
        self._invoke_input_fn_and_record_structure())

    self._validate_input_pipeline()

    def dequeue_fn():
      """dequeue_fn is used by TPU to retrieve the tensors."""
      # In the model-parallel case, both the host-side and device-side
      # computations must agree on the core on which infeed takes place. We
      # choose to perform infeed on logical core 0 of each replica.
      values = self._infeed_queue.generate_dequeue_op(tpu_device=0)
      # The unflatten process uses the structure information recorded above.
      return self._inputs_structure_recorder.unflatten_features_and_labels(
          values)

    return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator)

  def _invoke_input_fn_and_record_structure(self):
    """Deploys the input pipeline and record input structure."""
    enqueue_ops = []
    infeed_queues = []
    all_hooks = []
    num_hosts = self._ctx.num_hosts
    tpu_host_placement_fn = self._ctx.tpu_host_placement_function

    run_infeed_loop_on_coordinator = True

    if self._sharded_per_core:
      # Per-Core input pipeline deployment.
      # Invoke input pipeline for each core and placed on the corresponding
      # host.
      for host_id in range(num_hosts):
        host_device = tpu_host_placement_fn(host_id=host_id)
        with ops.device(host_device):
          with ops.name_scope('input_pipeline_task%d' % (host_id)):
            enqueue_ops_fn, captured_infeed_queue = (
                generate_per_core_enqueue_ops_fn_for_host(
                    self._ctx, self._input_fn, self._inputs_structure_recorder,
                    host_device, host_id))

            if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
              run_infeed_loop_on_coordinator = False
              enqueue_ops.append(
                  _wrap_computation_in_while_loop(
                      device=host_device, op_fn=enqueue_ops_fn))
            else:
              enqueue_ops.append(enqueue_ops_fn())
            # Infeed_queue_getter must be called after enqueue_ops_fn is called.
            infeed_queues.append(captured_infeed_queue.get())

    elif self._ctx.is_input_broadcast_with_iterators():
      # Only calls input_fn in host 0.
      host_device = tpu_host_placement_fn(host_id=0)
      enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
          generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn,
                                            self._inputs_structure_recorder,
                                            num_hosts))
      all_hooks.extend(hooks)
      if is_dataset:
        run_infeed_loop_on_coordinator = False
        enqueue_ops.append(
            _wrap_computation_in_while_loop(
                device=host_device, op_fn=enqueue_ops_fn))
      else:
        enqueue_ops.append(enqueue_ops_fn())
      infeed_queues.append(captured_infeed_queue.get())
    else:
      for host_id in range(num_hosts):
        host_device = tpu_host_placement_fn(host_id=host_id)
        with ops.device(host_device):
          with ops.name_scope('input_pipeline_task%d' % (host_id)):
            if self._ctx.is_input_per_host_with_iterators():
              enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
                  generate_per_host_v2_enqueue_ops_fn_for_host(
                      self._ctx, self._input_fn,
                      self._inputs_structure_recorder, host_device, host_id))
            else:
              enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
                  generate_per_host_enqueue_ops_fn_for_host(
                      self._ctx, self._input_fn,
                      self._inputs_structure_recorder, self._batch_axis,
                      host_device, host_id))
            all_hooks.extend(hooks)

            # NOTE(xiejw): We dispatch here based on the return type of the
            # users `input_fn`.
            #
            # 1. If input_fn returns a Dataset instance, we initialize the
            # iterator outside of tf.while_loop, and call the iterator.get_next
            # inside tf.while_loop.  This should be always safe.
            #
            # 2. If input_fn returns (features, labels), it is too late to wrap
            # them inside tf.while_loop, as resource initialization cannot be
            # handled in TF control flow properly. In this case, we will use
            # python loop to enqueue the data into TPU system.  This may be
            # slow compared to the previous case.
            if is_dataset:
              run_infeed_loop_on_coordinator = False
              wrap_fn = (
                  _wrap_computation_in_while_loop
                  if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else
                  _wrap_computation_in_while_loop_with_stopping_signals)
              enqueue_ops.append(
                  wrap_fn(device=host_device, op_fn=enqueue_ops_fn))
            else:
              enqueue_ops.append(enqueue_ops_fn())
            infeed_queues.append(captured_infeed_queue.get())
    # infeed_queue is used to generate dequeue ops. The only thing it uses for
    # dequeue is dtypes and types. So, any one can be used. Here, grab the
    # first one.
    self._infeed_queue = infeed_queues[0]
    return enqueue_ops, all_hooks, run_infeed_loop_on_coordinator

  def _validate_input_pipeline(self):
    """Validates the input pipeline.

    Perform some sanity checks to log user friendly information. We should
    error out to give users better error message. But, if
    _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break
    user code, so, log a warning.

    Raises:
      RuntimeError: If the validation failed.
    """
    if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
      err_msg = ('Input pipeline contains one or more QueueRunners. '
                 'It could be slow and not scalable. Please consider '
                 'converting your input pipeline to use `tf.data` instead (see '
                 'https://www.tensorflow.org/guide/datasets for '
                 'instructions.')
      if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
        raise RuntimeError(err_msg)
      else:
        logging.warn(err_msg)


class _ModelFnWrapper(object):
  """A `model_fn` wrapper.

  This makes calling model_fn on CPU and TPU easier and more consistent and
  performs necessary check and mutation required by TPU training and evaluation.

  In addition, this wrapper manages converting the `model_fn` to a single TPU
  train and eval step.
  """

  def __init__(self, model_fn, config, params, ctx):
    self._model_fn = model_fn
    self._config = config
    self._params = params
    self._ctx = ctx

  def call_without_tpu(self, features, labels, is_export_mode):
    return self._call_model_fn(features, labels, is_export_mode=is_export_mode)

  def convert_to_single_tpu_train_step(self, dequeue_fn):
    """Converts user provided model_fn` as a single train step on TPU.

    The user provided `model_fn` takes input tuple
    (features, labels) and produces the EstimatorSpec with train_op and loss for
    train `mode`. This usually represents a single train computation on CPU.

    For TPU training, a train (computation) step is first wrapped in a
    tf.while_loop control flow to repeat for many times and then replicated to
    all TPU shards. Besides the input should be taken from TPU infeed rather
    than input pipeline (input_fn) directly. To fit TPU loop and replicate
    pattern, the original train computation should be reformed, which is the
    returned `train_step`.

    Args:
      dequeue_fn: The function to retrieve inputs, features and labels, from TPU
        infeed dequeue channel.

    Returns:
      A tuple of train_fn, host_calls, and captured scaffold_fn. The train_fn
      representing the train step for TPU.
    """

    host_call = _OutfeedHostCall(self._ctx)
    captured_scaffold_fn = _CapturedObject()

    def train_step(loss):
      """Training step function for use inside a while loop."""
      del loss  # unused; required in function signature.
      inputs = dequeue_fn()
      features, labels = inputs.features_and_labels()

      estimator_spec = self._verify_estimator_spec(
          self._call_model_fn(features, labels))
      loss, train_op = estimator_spec.loss, estimator_spec.train_op

      if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
        captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
      else:
        captured_scaffold_fn.capture(None)

      # We must run train_op to update the variables prior to running the
      # outfeed.
      with ops.control_dependencies([train_op]):
        host_call_outfeed_ops = []
        if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)  # pylint: disable=protected-access
            and estimator_spec.host_call is not None):
          host_call.record({'host_call': estimator_spec.host_call})
          host_call_outfeed_ops = host_call.create_enqueue_op()
        with ops.control_dependencies(host_call_outfeed_ops):
          return array_ops.identity(loss)

    return train_step, host_call, captured_scaffold_fn

  def convert_to_single_tpu_eval_step(self, dequeue_fn):
    """Converts user provided model_fn` as a single eval step on TPU.

    Similar to training, the user provided `model_fn` takes input tuple
    (features, labels) and produces the TPUEstimatorSpec with eval_metrics for
    eval `mode`. This usually represents a single evaluation computation on CPU.

    For TPU evaluation, a eval (computation) step is first wrapped in a
    tf.while_loop control flow to repeat for many times and then replicated to
    all TPU shards. Besides the input and output are slightly different. Input,
    features and labels, should be taken from TPU infeed rather than input
    pipeline (input_fn) directly. Output is managed in two stages.  First, the
    model outputs as the result of evaluation computation, usually model logits,
    should be transferred from TPU system to CPU. Then, all model outputs are
    concatenated first on CPU and sent to the metric_fn for metrics computation.
    To fit TPU evaluation pattern, the original eval computation should be
    reformed, which is the returned `eval_step`.

    Args:
      dequeue_fn: The function to retrieve inputs, features and labels, from TPU
        infeed dequeue channel.

    Returns:
      A tuple of eval_fn, host_calls, and captured scaffold_fn. The eval_fn
      representing the eval step for TPU.
    """
    host_calls = _OutfeedHostCall(self._ctx)
    captured_scaffold_fn = _CapturedObject()

    def eval_step(total_loss):
      """Evaluation step function for use inside a while loop."""
      inputs = dequeue_fn()
      features, labels = inputs.features_and_labels()

      tpu_estimator_spec = self._call_model_fn(features, labels)
      if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
        raise RuntimeError(
            'estimator_spec used by TPU evaluation must have type'
            '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))

      loss = tpu_estimator_spec.loss
      captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
      to_record = {}
      if tpu_estimator_spec.eval_metrics:
        to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics
      if tpu_estimator_spec.host_call is not None:
        # We assume that evaluate won't update global step, so we don't wrap
        # this host_call.
        to_record['host_call'] = tpu_estimator_spec.host_call
      host_calls.record(to_record)

      with ops.control_dependencies(host_calls.create_enqueue_op()):
        return math_ops.add(total_loss, loss)

    return eval_step, host_calls, captured_scaffold_fn

  def convert_to_single_tpu_predict_step(self, dequeue_fn):
    """Converts user provided model_fn` as a single predict step on TPU.

    Args:
      dequeue_fn: The function to retrieve inputs, features and labels, from TPU
        infeed dequeue channel.

    Returns:
      A tuple of predict_fn, host_calls, and captured scaffold_fn. The
      predict_fn representing the predict step for TPU.
    """
    host_calls = _OutfeedHostCall(self._ctx)
    captured_scaffold_fn = _CapturedObject()

    def predict_step(unused_scalar_stopping_signal):
      """Evaluation step function for use inside a while loop."""
      inputs = dequeue_fn()
      features, labels = inputs.features_and_labels()
      stopping_signals = inputs.signals()

      assert stopping_signals is not None, (
          'Internal Error: `signals` is missing.')

      tpu_estimator_spec = self._call_model_fn(
          features, labels, is_export_mode=False)
      if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
        raise RuntimeError(
            'estimator_spec used by TPU prediction must have type'
            '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))

      self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions)

      captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
      to_record = {}
      identity_fn = lambda **kwargs: kwargs
      to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions]
      to_record['signals'] = [identity_fn, stopping_signals]
      if tpu_estimator_spec.host_call is not None:
        to_record['host_call'] = tpu_estimator_spec.host_call
      host_calls.record(to_record)

      with ops.control_dependencies(host_calls.create_enqueue_op()):
        return _StopSignals.as_scalar_stopping_signal(stopping_signals)

    return predict_step, host_calls, captured_scaffold_fn

  def _verify_tpu_spec_predictions(self, predictions):
    """Validates TPUEstimatorSpec.predictions dict."""
    # TODO(xiejw): Adds validation for prediction dictionrary.
    # TODO(xiejw): Adds support for single tensor as predictions.
    if not isinstance(predictions, dict):
      raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.')

    for (key, tensor) in predictions.items():
      if tensor.shape[0].value is None:
        raise ValueError(
            'The tensor with key ({}) in TPUEstimatorSpec.predictions has '
            'dynamic shape (should be static). Tensor: {}'.format(
                key, tensor))
    return predictions

  def _validate_model_features_and_labels(self,
                                          features,
                                          labels,
                                          is_export_mode):
    """Validates that the features and labels for the model function are valid.

    A valid features/labels object is the one with:
    - Type: Tensor or a dictionary of Tensors
    - Static shape if is_export_mode is False.

    Args:
      features: the features that would be input to the model function.
      labels: the labels that would be input to the model function.
      is_export_mode: boolean value specifying if in export mode.

    Raises:
      TypeError: If features/labels are not of the correct type.
      ValueError: If features/labels have dynamic shape.
    """

    def validate(obj, obj_name):
      """Helper validate function."""
      if not isinstance(obj, ops.Tensor) and not isinstance(obj, dict):
        raise TypeError(
            'The {} to the model returned by input_fn must be either a Tensor '
            'or a dictionary of Tensors. {}: {}'.format(obj_name, obj_name,
                                                        obj))
      if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode):
        return
      if isinstance(obj, ops.Tensor):
        if not obj.get_shape().is_fully_defined():
          raise ValueError(
              'The {} to the model returned by input_fn must have static shape.'
              ' Tensor: {}'.format(obj_name, obj))
      else:
        for (key, tensor) in obj.items():
          if not tensor.get_shape().is_fully_defined():
            raise ValueError(
                'The {} to the model returned by input_fn must have static '
                'shape. Key: \'{}\', Tensor: {}'.format(
                    obj_name, key, tensor))

    validate(features, 'features')
    if labels is not None:
      validate(labels, 'labels')

  def _call_model_fn(self, features, labels, is_export_mode=False):
    """Calls the model_fn with required parameters."""
    self._validate_model_features_and_labels(features, labels, is_export_mode)
    model_fn_args = function_utils.fn_args(self._model_fn)
    kwargs = {}

    # Makes deep copy with `config` and params` in case user mutates them.
    config = copy.deepcopy(self._config)
    params = copy.deepcopy(self._params)

    if 'labels' in model_fn_args:
      kwargs['labels'] = labels
    elif labels is not None:
      raise ValueError(
          'model_fn does not take labels, but input_fn returns labels.')
    if 'mode' in model_fn_args:
      kwargs['mode'] = self._ctx.mode
    if 'config' in model_fn_args:
      kwargs['config'] = config
    if 'params' in model_fn_args:
      kwargs['params'] = params

    if 'params' not in model_fn_args:
      raise ValueError('model_fn ({}) does not include params argument, '
                       'required by TPUEstimator to pass batch size as '
                       'params[\'batch_size\']'.format(self._model_fn))

    if is_export_mode:
      batch_size_for_model_fn = None
    else:
      batch_size_for_model_fn = self._ctx.batch_size_for_model_fn

    if batch_size_for_model_fn is not None:
      _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn)

    running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode)
    _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu)

    if not running_on_cpu:
      user_context = tpu_context.TPUContext(
          internal_ctx=self._ctx, call_from_input_fn=False)
      _add_item_to_params(params, _CTX_KEY, user_context)

    estimator_spec = self._model_fn(features=features, **kwargs)
    if (running_on_cpu and
        isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)):  # pylint: disable=protected-access
      # The estimator_spec will be passed to `Estimator` directly, which expects
      # type `EstimatorSpec`.
      return estimator_spec.as_estimator_spec()
    else:
      return estimator_spec

  def _verify_estimator_spec(self, estimator_spec):
    """Validates the estimator_spec."""
    if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
      return estimator_spec

    err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'
    if estimator_spec.training_chief_hooks:
      raise ValueError(err_msg.format('training_chief_hooks'))
    if estimator_spec.training_hooks:
      raise ValueError(err_msg.format('training_hooks'))
    if estimator_spec.evaluation_hooks:
      raise ValueError(err_msg.format('evaluation_hooks'))

    if estimator_spec.scaffold:
      logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. '
                      'Please use TPUEstimatorSpec.')
    return estimator_spec


class _OutfeedHostCall(object):
  """Support for `eval_metrics` and `host_call` in TPUEstimatorSpec."""

  def __init__(self, ctx):
    self._ctx = ctx
    self._names = []
    # All of these are dictionaries of lists keyed on the name.
    self._host_fns = {}
    self._tensor_keys = collections.defaultdict(list)
    self._tensors = collections.defaultdict(list)
    self._tensor_dtypes = collections.defaultdict(list)
    self._tensor_shapes = collections.defaultdict(list)

  @staticmethod
  def validate(host_calls):
    """Validates the `eval_metrics` and `host_call` in `TPUEstimatorSpec`."""

    for name, host_call in host_calls.items():
      if not isinstance(host_call, (tuple, list)):
        raise ValueError('{} should be tuple or list'.format(name))
      if len(host_call) != 2:
        raise ValueError('{} should have two elements.'.format(name))
      if not callable(host_call[0]):
        raise TypeError('{}[0] should be callable.'.format(name))
      if not isinstance(host_call[1], (tuple, list, dict)):
        raise ValueError('{}[1] should be tuple or list, or dict.'.format(name))

      if isinstance(host_call[1], (tuple, list)):
        fullargspec = tf_inspect.getfullargspec(host_call[0])
        fn_args = function_utils.fn_args(host_call[0])
        # wrapped_hostcall_with_global_step uses varargs, so we allow that.
        if fullargspec.varargs is None and len(host_call[1]) != len(fn_args):
          raise RuntimeError(
              'In TPUEstimatorSpec.{}, length of tensors {} does not match '
              'method args of the function, which takes {}.'.format(
                  name, len(host_call[1]), len(fn_args)))

  @staticmethod
  def create_cpu_hostcall(host_calls):
    """Runs on the host_call on CPU instead of TPU when use_tpu=False."""

    _OutfeedHostCall.validate(host_calls)
    ret = {}
    for name, host_call in host_calls.items():
      host_fn, tensors = host_call
      if isinstance(tensors, (tuple, list)):
        ret[name] = host_fn(*tensors)
      else:
        # Must be dict.
        try:
          ret[name] = host_fn(**tensors)
        except TypeError as e:
          logging.warning(
              'Exception while calling %s: %s. It is likely the tensors '
              '(%s[1]) do not match the '
              'function\'s arguments', name, e, name)
          raise e
    return ret

  def record(self, host_calls):
    """Records the host_call structure."""

    for name, host_call in host_calls.items():
      host_fn, tensor_list_or_dict = host_call
      self._names.append(name)
      self._host_fns[name] = host_fn

      if isinstance(tensor_list_or_dict, dict):
        for (key, tensor) in six.iteritems(tensor_list_or_dict):
          self._tensor_keys[name].append(key)
          self._tensors[name].append(tensor)
          self._tensor_dtypes[name].append(tensor.dtype)
          self._tensor_shapes[name].append(tensor.shape)
      else:
        # List or tuple.
        self._tensor_keys[name] = None
        for tensor in tensor_list_or_dict:
          self._tensors[name].append(tensor)
          self._tensor_dtypes[name].append(tensor.dtype)
          self._tensor_shapes[name].append(tensor.shape)

  def create_enqueue_op(self):
    """Create the op to enqueue the recorded host_calls.

    Returns:
      A list of enqueue ops, which is empty if there are no host calls.
    """
    if not self._names:
      return []

    tensors = []
    # TODO(jhseu): Consider deduping tensors.
    for name in self._names:
      tensors.extend(self._tensors[name])

    with ops.device(tpu.core(0)):
      return [tpu_ops.outfeed_enqueue_tuple(tensors)]

  def create_tpu_hostcall(self):
    """Sends the tensors through outfeed and runs the host_fn on CPU.

    The tensors are concatenated along dimension 0 to form a global tensor
    across all shards. The concatenated function is passed to the host_fn and
    executed on the first host.

    Returns:
      A dictionary mapping name to the return type of the host_call by that
      name.

    Raises:
      RuntimeError: If outfeed tensor is scalar.
    """
    if not self._names:
      return {}

    ret = {}
    # For each i, dequeue_ops[i] is a list containing the tensors from all
    # shards. This list is concatenated later.
    dequeue_ops = []
    tensor_dtypes = []
    tensor_shapes = []
    for name in self._names:
      for _ in self._tensors[name]:
        dequeue_ops.append([])
      for dtype in self._tensor_dtypes[name]:
        tensor_dtypes.append(dtype)
      for shape in self._tensor_shapes[name]:
        tensor_shapes.append(shape)

    # Outfeed ops execute on each replica's first logical core. Note: we must
    # constraint it such that we have at most one outfeed dequeue and enqueue
    # per replica.
    for i in xrange(self._ctx.num_replicas):
      host_device, ordinal_id = self._ctx.device_for_replica(i)
      with ops.device(host_device):
        outfeed_tensors = tpu_ops.outfeed_dequeue_tuple(
            dtypes=tensor_dtypes,
            shapes=tensor_shapes,
            device_ordinal=ordinal_id)
        for j, item in enumerate(outfeed_tensors):
          dequeue_ops[j].append(item)

    # Deconstruct dequeue ops.
    dequeue_ops_by_name = {}
    pos = 0
    for name in self._names:
      dequeue_ops_by_name[name] = dequeue_ops[pos:pos+len(self._tensors[name])]
      pos += len(self._tensors[name])

    # It is assumed evaluation always happens on single host TPU system. So,
    # place all ops on tpu host if possible.
    #
    # TODO(jhseu): Evaluate whether this is right for summaries.
    with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)):
      for name in self._names:
        dequeue_ops = dequeue_ops_by_name[name]
        for i, item in enumerate(dequeue_ops):
          if dequeue_ops[i][0].shape.ndims == 0:
            raise RuntimeError(
                'All tensors outfed from TPU should preserve batch size '
                'dimension, but got scalar {}'.format(dequeue_ops[i][0]))
          # TODO(xiejw): Allow users to specify the axis for batch size
          # dimension.
          dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0)

        if self._tensor_keys[name] is not None:
          # The user-provided eval_metrics[1] is a dict.
          dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops))
          try:
            ret[name] = self._host_fns[name](**dequeue_ops)
          except TypeError as e:
            logging.warning(
                'Exception while calling %s: %s. It is likely the tensors '
                '(%s[1]) do not match the '
                'function\'s arguments', name, e, name)
            raise e
        else:
          ret[name] = self._host_fns[name](*dequeue_ops)

    return ret


class _OutfeedHostCallHook(session_run_hook.SessionRunHook):
  """Hook to run host calls when use_tpu=False."""

  def __init__(self, tensors):
    self._tensors = tensors

  def begin(self):
    # We duplicate this code from the TPUInfeedOutfeedSessionHook rather than
    # create a separate hook to guarantee execution order, because summaries
    # need to be initialized before the outfeed thread starts.
    # TODO(jhseu): Make a wrapper hook instead?
    self._init_ops = contrib_summary.summary_writer_initializer_op()
    # Get all the writer resources from the initializer, so we know what to
    # flush.
    self._finalize_ops = []
    for op in self._init_ops:
      self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0]))

  def after_create_session(self, session, coord):
    session.run(self._init_ops)

  def before_run(self, run_context):
    return basic_session_run_hooks.SessionRunArgs(self._tensors)

  def end(self, session):
    session.run(self._finalize_ops)


class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
  """Calculate and report global_step/sec and examples/sec during runtime."""

  def __init__(self,
               batch_size,
               every_n_steps=100,
               every_n_secs=None,
               output_dir=None,
               summary_writer=None):
    self._batch_size = batch_size
    super(ExamplesPerSecondHook, self).__init__(
        every_n_steps=every_n_steps,
        every_n_secs=every_n_secs,
        output_dir=output_dir,
        summary_writer=summary_writer)

  def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
    global_step_per_sec = elapsed_steps / elapsed_time
    examples_per_sec = self._batch_size * global_step_per_sec
    if self._summary_writer is not None:
      global_step_summary = Summary(value=[
          Summary.Value(tag='global_step/sec', simple_value=global_step_per_sec)
      ])
      example_summary = Summary(value=[
          Summary.Value(tag='examples/sec', simple_value=examples_per_sec)
      ])
      self._summary_writer.add_summary(global_step_summary, global_step)
      self._summary_writer.add_summary(example_summary, global_step)
    logging.info('global_step/sec: %g', global_step_per_sec)
    logging.info('examples/sec: %g', examples_per_sec)


class InstallSignalHandlerHook(session_run_hook.SessionRunHook):
  """Change SIGINT (CTRL^C) handler to force quit the process.

  The default behavior often results in hanging processes.
  The original handler is restored after training/evaluation.
  """

  def __init__(self):
    self._signal_fn = signal.getsignal(signal.SIGINT)

  def before_run(self, run_context):
    signal.signal(signal.SIGINT, signal.SIG_DFL)

  def end(self, session):
    signal.signal(signal.SIGINT, self._signal_fn)


class TPUEstimator(estimator_lib.Estimator):
  """Estimator with TPU support.

  TPUEstimator also supports training on CPU and GPU. You don't need to define
  a separate `tf.estimator.Estimator`.

  TPUEstimator handles many of the details of running on TPU devices, such as
  replicating inputs and models for each core, and returning to host
  periodically to run hooks.

  TPUEstimator transforms a global batch size in params to a per-shard batch
  size when calling the `input_fn` and `model_fn`. Users should specify
  global batch size in constructor, and then get the batch size for each shard
  in `input_fn` and `model_fn` by `params['batch_size']`.

  - For training, `model_fn` gets per-core batch size; `input_fn` may get
    per-core or per-host batch size depending on `per_host_input_for_training`
    in `TPUConfig` (See docstring for TPUConfig for details).

  - For evaluation and prediction, `model_fn` gets per-core batch size and
    `input_fn` get per-host batch size.

  Evaluation
  ==========

  `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics`
  for TPU evaluation. However, if eval_on_tpu is False, `model_fn` must return
  `EstimatorSpec` and the evaluation will execute on CPU or GPU; in this case
  the following discussion on TPU evaluation does not apply.

  `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where
  `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See
  `TPUEstimatorSpec` for details).  `metric_fn` takes the `tensors` and returns
  a dict from metric string name to the result of calling a metric function,
  namely a `(metric_tensor, update_op)` tuple.

  One can set `use_tpu` to `False` for testing. All training, evaluation, and
  predict will be executed on CPU. `input_fn` and `model_fn` will receive
  `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`.

  Current limitations:
  --------------------

  1. TPU evaluation only works on a single host (one TPU worker) except
     BROADCAST mode.

  2. `input_fn` for evaluation should **NOT** raise an end-of-input exception
     (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all
     batches should have the same size.

  Example (MNIST):
  ----------------

  ```
  # The metric Fn which runs on CPU.
  def metric_fn(labels, logits):
    predictions = tf.argmax(logits, 1)
    return {
      'accuracy': tf.metrics.precision(
          labels=labels, predictions=predictions),
    }

  # Your model Fn which runs on TPU (eval_metrics is list in this example)
  def model_fn(features, labels, mode, config, params):
    ...
    logits = ...

    if mode = tf.estimator.ModeKeys.EVAL:
      return tpu_estimator.TPUEstimatorSpec(
          mode=mode,
          loss=loss,
          eval_metrics=(metric_fn, [labels, logits]))

  # or specify the eval_metrics tensors as dict.
  def model_fn(features, labels, mode, config, params):
    ...
    final_layer_output = ...

    if mode = tf.estimator.ModeKeys.EVAL:
      return tpu_estimator.TPUEstimatorSpec(
          mode=mode,
          loss=loss,
          eval_metrics=(metric_fn, {
              'labels': labels,
              'logits': final_layer_output,
          }))
  ```

  Prediction
  ==========

  Prediction on TPU is an experimental feature to support large batch inference.
  It is not designed for latency-critical system. In addition, due to some
  usability issues, for prediction with small dataset, CPU `.predict`, i.e.,
  creating a new `TPUEstimator` instance with `use_tpu=False`, might be more
  convenient.

  Note: In contrast to TPU training/evaluation, the `input_fn` for prediction
  *should* raise an end-of-input exception (`OutOfRangeError` or
  `StopIteration`), which serves as the stopping signal to `TPUEstimator`. To be
  precise, the ops created by `input_fn` produce one batch of the data.
  The `predict()` API processes one batch at a time. When reaching the end of
  the data source, an end-of-input exception should be raised by one of these
  operations. The user usually does not need to do this manually. As long as the
  dataset is not repeated forever, the `tf.data` API will raise an end-of-input
  exception automatically after the last batch has been produced.

  Note: Estimator.predict returns a Python generator. Please consume all the
  data from the generator so that TPUEstimator can shutdown the TPU system
  properly for user.

  Current limitations:
  --------------------
  1. TPU prediction only works on a single host (one TPU worker).

  2. `input_fn` must return a `Dataset` instance rather than `features`. In
  fact, .train() and .evaluate() also support Dataset as return value.

  Example (MNIST):
  ----------------
  ```
  height = 32
  width = 32
  total_examples = 100

  def predict_input_fn(params):
    batch_size = params['batch_size']

    images = tf.random_uniform(
        [total_examples, height, width, 3], minval=-1, maxval=1)

    dataset = tf.data.Dataset.from_tensor_slices(images)
    dataset = dataset.map(lambda images: {'image': images})

    dataset = dataset.batch(batch_size)
    return dataset

  def model_fn(features, labels, params, mode):
     # Generate predictions, called 'output', from features['image']

    if mode == tf.estimator.ModeKeys.PREDICT:
      return tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          predictions={
              'predictions': output,
              'is_padding': features['is_padding']
          })

  tpu_est = TPUEstimator(
      model_fn=model_fn,
      ...,
      predict_batch_size=16)

  # Fully consume the generator so that TPUEstimator can shutdown the TPU
  # system.
  for item in tpu_est.predict(input_fn=input_fn):
    # Filter out item if the `is_padding` is 1.
    # Process the 'predictions'
  ```

  Exporting
  =========

  `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`,
  and another with `tag_constants.SERVING` and `tag_constants.TPU`.
  At serving time, these tags are used to select metagraph to load.

  Before running the graph on TPU, TPU system needs to be initialized. If
  TensorFlow Serving model-server is used, this is done automatically. If
  not, please call `session.run(tpu.initialize_system())`.

  `tpu.outside_compilation` can be used to wrap TPU incompatible ops in
  `model_fn`.

  Example:
  ----------------

  ```
  def model_fn(features, labels, mode, config, params):
    ...
    logits = ...
    export_outputs = {
      'logits': export_output_lib.PredictOutput(
        {'logits': logits})
    }

    def host_call(logits):
      class_ids = math_ops.argmax(logits)
      classes = string_ops.as_string(class_ids)
      export_outputs['classes'] =
        export_output_lib.ClassificationOutput(classes=classes)

    tpu.outside_compilation(host_call, logits)

    ...
  ```

  """

  def __init__(self,
               model_fn=None,
               model_dir=None,
               config=None,
               params=None,
               use_tpu=True,
               train_batch_size=None,
               eval_batch_size=None,
               predict_batch_size=None,
               batch_axis=None,
               eval_on_tpu=True,
               export_to_tpu=True,
               warm_start_from=None):
    """Constructs an `TPUEstimator` instance.

    Args:
      model_fn: Model function as required by `Estimator`. For training, the
        returned `EstimatorSpec` cannot have hooks as it is not supported in
        `TPUEstimator`. Instead, the user can pass the training hooks as
        an argument to `TPUEstimator.train()`.
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model. If `None`, the model_dir in
        `config` will be used if set. If both are set, they must be same. If
        both are `None`, a temporary directory will be used.
      config: An `tpu_config.RunConfig` configuration object. Cannot be `None`.
      params: An optional `dict` of hyper parameters that will be passed into
        `input_fn` and `model_fn`.  Keys are names of parameters, values are
        basic python types. There are reserved keys for `TPUEstimator`,
        including 'batch_size'.
      use_tpu: A bool indicating whether TPU support is enabled. Currently,
        - TPU training and evaluation respect this bit, but eval_on_tpu can
          override execution of eval. See below.
        - Predict still happens on CPU.
      train_batch_size: An int representing the global training batch size.
        TPUEstimator transforms this global batch size to a per-shard batch
        size, as params['batch_size'], when calling `input_fn` and `model_fn`.
        Cannot be `None` if `use_tpu` is `True`.
        Must be divisible by total number of replicas.
      eval_batch_size: An int representing evaluation batch size.
        Must be divisible by total number of replicas.
      predict_batch_size: An int representing the prediction batch size.
        Must be divisible by total number of replicas.
      batch_axis: A python tuple of int values describing how each tensor
        produced by the Estimator `input_fn` should be split across the TPU
        compute shards. For example, if your input_fn produced (images, labels)
        where the images tensor is in `HWCN` format, your shard dimensions would
        be [3, 0], where 3 corresponds to the `N` dimension of your images
        Tensor, and 0 corresponds to the dimension along which to split the
        labels to match up with the corresponding images. If None is supplied,
        and per_host_input_for_training is True, batches will be sharded based
        on the major dimension. If tpu_config.per_host_input_for_training is
        False or `PER_HOST_V2`, batch_axis is ignored.
      eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the
        model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`.
      export_to_tpu: If True, `export_savedmodel()` exports a metagraph for
        serving on TPU besides the one on CPU.
      warm_start_from: Optional string filepath to a checkpoint or SavedModel to
                       warm-start from, or a `tf.estimator.WarmStartSettings`
                       object to fully configure warm-starting.  If the string
                       filepath is provided instead of a `WarmStartSettings`,
                       then all variables are warm-started, and it is assumed
                       that vocabularies and Tensor names are unchanged.

    Raises:
      ValueError: `params` has reserved keys already.
    """
    if config is None or not isinstance(config, tpu_config.RunConfig):
      raise ValueError(
          '`config` must be provided with type `tpu_config.RunConfig`')

    if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS):
      raise ValueError('{} are reserved keys but existed in params {}.'.format(
          _RESERVED_PARAMS_KEYS, params))

    if use_tpu:
      # Perform some very basic validations. More validations will be found in
      # _InternalTPUContext.
      if train_batch_size is None:
        raise ValueError('`train_batch_size` cannot be `None`')
      util_lib.check_positive_integer(train_batch_size, 'train_batch_size')

      if (config.tpu_config.per_host_input_for_training is
          tpu_config.InputPipelineConfig.PER_SHARD_V1 and
          config.tpu_config.num_cores_per_replica):
        raise ValueError(
            'Model parallelism only supports per host input for training. '
            'Please adjust TPURunconfig.per_host_input_for_training.')

      if eval_batch_size is not None:
        util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size')

      if predict_batch_size is not None:
        util_lib.check_positive_integer(predict_batch_size,
                                        'predict_batch_size')

    # Verifies the model_fn signature according to Estimator framework.
    estimator_lib._verify_model_fn_args(model_fn, params)  # pylint: disable=protected-access
    # We cannot store config and params in this constructor as parent
    # constructor might change them, such as assigning a temp dir for
    # config.model_dir.
    model_function = self._augment_model_fn(model_fn, batch_axis)

    # Overwrite log_step_count_steps to disable TensorLoggingHook and
    # StepCounterHook from being created in Estimator. TPUEstimator already
    # added equivalent hooks in _augment_model_fn above.
    self._log_every_n_steps = config.log_step_count_steps
    config = config.replace(log_step_count_steps=None)

    # Passing non-None params as wrapped model_fn has it.
    params = params or {}
    super(TPUEstimator, self).__init__(
        model_fn=model_function,
        model_dir=model_dir,
        config=config,
        params=params,
        warm_start_from=warm_start_from)
    self._iterations_per_training_loop = (
        self._config.tpu_config.iterations_per_loop)

    # All properties passed to _InternalTPUContext are immutable.
    # pylint: disable=protected-access
    self._ctx = tpu_context._get_tpu_context(
        self._config, train_batch_size,
        eval_batch_size, predict_batch_size,
        use_tpu,
        eval_on_tpu)

    self._export_to_tpu = export_to_tpu

    self._is_input_fn_invoked = None
    self._rendezvous = {}

  def _add_meta_graph_for_mode(self,
                               builder,
                               input_receiver_fn_map,
                               checkpoint_path,
                               strip_default_attrs,
                               save_variables=True,
                               mode=model_fn_lib.ModeKeys.PREDICT,
                               export_tags=None,
                               check_variables=True):
    if mode != model_fn_lib.ModeKeys.PREDICT:
      raise NotImplementedError(
          'TPUEstimator only handles mode PREDICT for export_savedmodel(); '
          'got {}.'.format(mode))

    (super(TPUEstimator, self).
     _add_meta_graph_for_mode(builder,
                              input_receiver_fn_map,
                              checkpoint_path,
                              strip_default_attrs,
                              save_variables,
                              mode=mode,
                              export_tags=export_tags,
                              check_variables=check_variables))

    if self._export_to_tpu:
      input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE:
                               input_receiver_fn_map[mode]}
      export_tags = [tag_constants.SERVING, tag_constants.TPU]
      mode = _REWRITE_FOR_INFERENCE_MODE
      # See b/110052256 for why `check_variables` is `False`.
      (super(TPUEstimator, self).
       _add_meta_graph_for_mode(builder,
                                input_receiver_fn_map,
                                checkpoint_path,
                                strip_default_attrs,
                                save_variables=False,
                                mode=mode,
                                export_tags=export_tags,
                                check_variables=False))

  def _call_model_fn(self, features, labels, mode, config):
    if mode == _REWRITE_FOR_INFERENCE_MODE:
      return self._call_model_fn_for_inference(features, labels, mode, config)
    else:
      return super(TPUEstimator, self)._call_model_fn(
          features, labels, mode, config)

  def _call_model_fn_for_inference(self, features, labels, mode, config):
    """Wraps `_call_model_fn` for `export_savedmodel`."""
    if mode != _REWRITE_FOR_INFERENCE_MODE:
      raise ValueError('mode must be {}; '
                       'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode))

    capture = _CapturedObject()

    def computation():
      """Compute tpu tensors used in export_outputs.

      Passed to rewrite_for_inference so that model_fn will be called under
      the rewriting contexts. Only tpu tensors are returned, but export_outputs
      and scaffold are captured.

      Returns:
         A list of Tensors used in export_outputs and not marked for
         outside_compilation.
      """
      # We should only call model fn once and it should be inside `computation`
      # so that building the graph will happen under `rewrite_for_inference`.
      mode = model_fn_lib.ModeKeys.PREDICT
      estimator_spec = self._call_model_fn(features, labels, mode, config)

      # We pick the TPU tensors out from `export_output` and later return them
      # from `computation` for rewriting.
      tensors_dict = collections.OrderedDict(
          (k, _export_output_to_tensors(v))
          for k, v in six.iteritems(estimator_spec.export_outputs)
      )
      tensors = nest.flatten(tensors_dict)
      tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)]

      # We cannot return anything other than `tpu_tensors` here so we capture
      # the rest for later use.
      capture.capture((estimator_spec, tensors_dict, tensors))
      return tpu_tensors

    tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation)
    estimator_spec, tensors_dict, tensors = capture.get()

    # Reconstruct `tensors`, but with `tpu_tensors` replaced with
    # `tpu_tensors_on_cpu`.
    new_tensors = []
    for t in tensors:
      if _is_tpu_tensor(t):
        new_tensors.append(tpu_tensors_on_cpu.pop(0))
      elif t is None:
        new_tensors.append(None)
      else:
        # Only fetching `tpu_tensors_on_cpu` does not trigger
        # TPU computation and blocks, so we add the control dependency here.
        control_inputs = (tpu_tensors_on_cpu
                          if isinstance(tpu_tensors_on_cpu, (list, tuple))
                          else (tpu_tensors_on_cpu,))
        with ops.control_dependencies(control_inputs):
          new_tensors.append(array_ops.identity(t))

    # Reconstruct `tensors_dict`.
    new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors)
    # Reconstruct `export_outputs`.
    export_outputs = estimator_spec.export_outputs
    new_export_outputs = collections.OrderedDict(
        (k, _clone_export_output_with_tensors(export_outputs[k], v))
        for k, v in six.iteritems(new_tensors_dict)
    )

    return estimator_spec._replace(export_outputs=new_export_outputs)

  def _create_global_step(self, graph):
    """Creates a global step suitable for TPUs.

    Args:
      graph: The graph in which to create the global step.

    Returns:
      A global step `Tensor`.

    Raises:
      ValueError: if the global step tensor is already defined.
    """
    return _create_global_step(graph)

  def _convert_train_steps_to_hooks(self, steps, max_steps):
    with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx:
      if ctx.is_running_on_cpu():
        return super(TPUEstimator, self)._convert_train_steps_to_hooks(
            steps, max_steps)

    # On TPU.
    if steps is None and max_steps is None:
      raise ValueError(
          'For TPU training, one of `steps` or `max_steps` must be set. '
          'Cannot be both `None`.')

    # Estimator.train has explicit positiveness check.
    if steps is not None:
      util_lib.check_positive_integer(steps, 'Train steps')
    if max_steps is not None:
      util_lib.check_positive_integer(max_steps, 'Train max_steps')

    return [
        _TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps)
    ]

  def _convert_eval_steps_to_hooks(self, steps):
    with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx:
      if ctx.is_running_on_cpu():
        return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps)

    if steps is None:
      raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.')

    util_lib.check_positive_integer(steps, 'Eval steps')

    return [
        evaluation._StopAfterNEvalsHook(  # pylint: disable=protected-access
            num_evals=steps),
        _SetEvalIterationsHook(steps)
    ]

  def _call_input_fn(self, input_fn, mode):
    """Calls the input function.

    Args:
      input_fn: The input function.
      mode: ModeKeys

    Returns:
      Either features or (features, labels) where features and labels are:
        features - `Tensor` or dictionary of string feature name to `Tensor`.
        labels - `Tensor` or dictionary of `Tensor` with labels.

    Raises:
      ValueError: if input_fn takes invalid arguments or does not have `params`.
    """
    input_fn_args = function_utils.fn_args(input_fn)
    config = self.config  # a deep copy.
    kwargs = {}
    if 'params' in input_fn_args:
      kwargs['params'] = self.params  # a deep copy.
    else:
      raise ValueError('input_fn ({}) does not include params argument, '
                       'required by TPUEstimator to pass batch size as '
                       'params["batch_size"]'.format(input_fn))
    if 'config' in input_fn_args:
      kwargs['config'] = config

    if 'mode' in input_fn_args:
      kwargs['mode'] = mode

    # Records the fact input_fn has been invoked.
    self._is_input_fn_invoked = True

    with self._ctx.with_mode(mode) as ctx:
      # Setting the batch size in params first. This helps user to have same
      # input_fn for use_tpu=True/False.
      batch_size_for_input_fn = ctx.batch_size_for_input_fn
      if batch_size_for_input_fn is not None:
        _add_item_to_params(kwargs['params'],
                            _BATCH_SIZE_KEY, batch_size_for_input_fn)

      # For export_savedmodel, input_fn is never passed to Estimator. So,
      # `is_export_mode` must be False.
      if ctx.is_running_on_cpu(is_export_mode=False):
        with ops.device('/device:CPU:0'):
          return input_fn(**kwargs)

      # For TPU computation, input_fn should be invoked in a tf.while_loop for
      # performance. While constructing the tf.while_loop, the structure of
      # inputs returned by the `input_fn` needs to be recorded. The structure
      # includes whether features or labels is dict or single Tensor, dict keys,
      # tensor shapes, and dtypes. The recorded structure is used to create the
      # infeed dequeue ops, which must be wrapped and passed as a Fn, called
      # inside the TPU computation, as the TPU computation is wrapped inside a
      # tf.while_loop also. So, we either pass input_fn to model_fn or pass
      # dequeue_fn to model_fn. Here, `input_fn` is passed directly as
      # `features` in `model_fn` signature.
      def _input_fn(ctx):
        _add_item_to_params(kwargs['params'], _CTX_KEY, ctx)
        return input_fn(**kwargs)

      return _input_fn

  def _validate_features_in_predict_input(self, result):
    """Skip the validation.

    For TPUEstimator, we do not need to check the result type. `_InputPipeline`
    has stronger check. Parent class's check generates confusing warning msg.

    Args:
      result: `features` returned by input_fn.
    """
    pass

  def train(self,
            input_fn,
            hooks=None,
            steps=None,
            max_steps=None,
            saving_listeners=None):
    rendezvous = error_handling.ErrorRendezvous(num_sources=3)
    self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous
    try:
      return super(TPUEstimator, self).train(
          input_fn=input_fn, hooks=hooks, steps=steps, max_steps=max_steps,
          saving_listeners=saving_listeners
      )
    except Exception:  # pylint: disable=broad-except
      rendezvous.record_error('training_loop', sys.exc_info())
    finally:
      rendezvous.record_done('training_loop')
      rendezvous.raise_errors()

  def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
               name=None):
    rendezvous = error_handling.ErrorRendezvous(num_sources=3)
    self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous
    try:
      return super(TPUEstimator, self).evaluate(
          input_fn, steps=steps, hooks=hooks, checkpoint_path=checkpoint_path,
          name=name
      )
    except Exception:  # pylint: disable=broad-except
      rendezvous.record_error('evaluation_loop', sys.exc_info())
    finally:
      rendezvous.record_done('evaluation_loop')
      rendezvous.raise_errors()

  def predict(self,
              input_fn,
              predict_keys=None,
              hooks=None,
              checkpoint_path=None,
              yield_single_examples=True):
    rendezvous = error_handling.ErrorRendezvous(num_sources=3)
    self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous
    try:
      for result in super(TPUEstimator, self).predict(
          input_fn=input_fn,
          predict_keys=predict_keys,
          hooks=hooks,
          checkpoint_path=checkpoint_path,
          yield_single_examples=yield_single_examples):
        yield result
    except Exception:  # pylint: disable=broad-except
      rendezvous.record_error('prediction_loop', sys.exc_info())
    finally:
      rendezvous.record_done('prediction_loop')
      rendezvous.raise_errors()

    rendezvous.record_done('prediction_loop')
    rendezvous.raise_errors()

  def _augment_model_fn(self, model_fn, batch_axis):
    """Returns a new model_fn, which wraps the TPU support."""

    def _model_fn(features, labels, mode, config, params):
      """A Estimator `model_fn` for TPUEstimator."""
      with self._ctx.with_mode(mode) as ctx:
        model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)

        if mode != model_fn_lib.ModeKeys.PREDICT:
          is_export_mode = False
        else:
          # For export_savedmodel, input_fn is never passed to Estimator. So, by
          # checking the self._is_input_fn_invoked bit, we can know, given the
          # mode == PREDICT, it is the .predict API, not export_savedmodel API.
          if self._is_input_fn_invoked:
            is_export_mode = False
          else:
            is_export_mode = True

        # Clear the bit.
        self._is_input_fn_invoked = None

        # examples_hook is added to training_hooks for both CPU and TPU
        # execution.
        examples_hook = ExamplesPerSecondHook(
            ctx.global_batch_size,
            output_dir=self.model_dir,
            every_n_steps=self._log_every_n_steps)

        if ctx.is_running_on_cpu(is_export_mode=is_export_mode):
          logging.info('Running %s on CPU', mode)
          estimator_spec = model_fn_wrapper.call_without_tpu(
              features, labels, is_export_mode=is_export_mode)
          estimator_spec = estimator_spec._replace(
              training_hooks=estimator_spec.training_hooks + (examples_hook,))
          return estimator_spec

        assert labels is None, '`labels` passed to `model_fn` must be `None`.'
        # TPUEstimator._call_input_fn passes `input_fn` as features to here.
        assert callable(features), '`input_fn` is not callable.'
        input_fn = features

        input_holders = _InputPipeline(input_fn, batch_axis, ctx)
        enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = (
            input_holders.generate_infeed_enqueue_ops_and_dequeue_fn())

        graph = ops.get_default_graph()
        for enqueue_op in enqueue_ops:
          if isinstance(enqueue_op, list):
            graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op)
          else:
            graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op)

        if mode == model_fn_lib.ModeKeys.TRAIN:
          loss, host_call, scaffold = (
              _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
          host_ops = host_call.create_tpu_hostcall()
          if host_ops is None:
            host_ops = []

          shutdown_hooks = []
          shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE',
                                         'shutdown_worker')
          if shutdown_mode:
            if shutdown_mode == 'shutdown_worker':
              finalizer_hooks = [
                  session_support.ShutdownLameWorkers(timeout_ms=60*1000),
              ]
            elif shutdown_mode == 'shutdown_computation':
              finalizer_hooks = [
                  session_support.RestartComputation(timeout_ms=60*1000),
              ]
            else:
              raise ValueError('Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' %
                               shutdown_mode)

            shutdown_hooks.append(session_support.GracefulShutdownHook(
                checkpoint_prefix=self.model_dir + '/model.ckpt',
                on_shutdown_hooks=finalizer_hooks
            ))

          with ops.control_dependencies([loss]):
            global_step = array_ops.identity(training.get_global_step())
          hooks = input_hooks + shutdown_hooks
          logging_hook_frequency = (    # Divide and round up
              (self._log_every_n_steps +
               self._config.tpu_config.iterations_per_loop - 1) //
              self._config.tpu_config.iterations_per_loop)
          hooks.extend([
              TPUInfeedOutfeedSessionHook(
                  ctx,
                  enqueue_ops,
                  host_ops,
                  run_infeed_loop_on_coordinator=(
                      run_infeed_loop_on_coordinator),
                  rendezvous=self._rendezvous[mode],
              ),
              InstallSignalHandlerHook(),
              training.LoggingTensorHook(
                  {
                      'loss': array_ops.identity(loss),
                      'step': global_step,
                  },
                  every_n_iter=logging_hook_frequency)
          ])
          examples_hook._set_steps_per_run(   # pylint: disable=protected-access
              self._config.tpu_config.iterations_per_loop)
          hooks.append(examples_hook)

          chief_hooks = []
          if (self._config.save_checkpoints_secs or
              self._config.save_checkpoints_steps):
            checkpoint_hook = training.CheckpointSaverHook(
                self.model_dir,
                save_secs=self._config.save_checkpoints_secs,
                save_steps=self._config.save_checkpoints_steps,
                scaffold=scaffold)
            checkpoint_hook._set_steps_per_run(   # pylint: disable=protected-access
                self._config.tpu_config.iterations_per_loop)
            chief_hooks.append(checkpoint_hook)
          summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss)
          with ops.control_dependencies([loss]):
            update_ops = _sync_variables_ops()

          # Validate the TPU training graph to catch basic errors
          _validate_tpu_training_graph()

          train_op = control_flow_ops.group(*update_ops)
          graph.add_to_collection(_TPU_TRAIN_OP, train_op)

          return model_fn_lib.EstimatorSpec(
              mode,
              loss=loss,
              training_chief_hooks=chief_hooks,
              training_hooks=hooks,
              train_op=train_op,
              scaffold=scaffold)

        if mode == model_fn_lib.ModeKeys.EVAL:
          total_loss, host_calls, scaffold = _eval_on_tpu_system(
              ctx, model_fn_wrapper, dequeue_fn)
          iterations_per_loop_var = _create_or_get_iterations_per_loop()
          mean_loss = math_ops.div(total_loss,
                                   math_ops.cast(
                                       iterations_per_loop_var,
                                       dtype=total_loss.dtype))

          # Creates a dummy metric update_op for all metrics. Estimator expects
          # all metrics in eval_metric_ops have update_op and calls them one by
          # one. The real metric update_ops are invoked in a separated thread.
          # So, here give Estimator the dummy op for all metrics.
          with ops.control_dependencies([mean_loss]):
            # After TPU evaluation computation is done (the mean_loss tensor),
            # reads all variables back from TPU and updates the eval step
            # counter properly
            internal_ops_to_run = _sync_variables_ops()
            internal_ops_to_run.append(
                _increase_eval_step_op(iterations_per_loop_var))
            with ops.control_dependencies(internal_ops_to_run):
              dummy_update_op = control_flow_ops.no_op()

          host_call_ret = host_calls.create_tpu_hostcall()
          eval_metric_ops = {}
          eval_update_ops = []

          for k, v in host_call_ret.get('eval_metrics', {}).items():
            eval_metric_ops[k] = (v[0], dummy_update_op)
            eval_update_ops.append(v[1])

          if 'host_call' not in host_call_ret:
            host_ops = []
          else:
            host_ops = host_call_ret['host_call']
          hooks = [
              TPUInfeedOutfeedSessionHook(
                  ctx,
                  enqueue_ops,
                  eval_update_ops + host_ops,
                  run_infeed_loop_on_coordinator=(
                      run_infeed_loop_on_coordinator),
                  rendezvous=self._rendezvous[mode]),
          ] + input_hooks

          return model_fn_lib.EstimatorSpec(
              mode,
              loss=mean_loss,
              evaluation_hooks=hooks,
              eval_metric_ops=eval_metric_ops,
              scaffold=scaffold)

        # Predict
        assert mode == model_fn_lib.ModeKeys.PREDICT

        dummy_predict_op, host_calls, scaffold = _predict_on_tpu_system(
            ctx, model_fn_wrapper, dequeue_fn)
        with ops.control_dependencies([dummy_predict_op]):
          internal_ops_to_run = _sync_variables_ops()
          with ops.control_dependencies(internal_ops_to_run):
            dummy_predict_op = control_flow_ops.no_op()

        # In train and evaluation, the main TPU program is passed to monitored
        # training session to run. Infeed enqueue and outfeed dequeue are
        # executed in side threads. This is not the configuration for
        # prediction mode.
        #
        # For prediction, the Estimator executes the EstimatorSpec.predictions
        # directly and yield the element (via generator) to call site. So, the
        # outfeed based prediction must be passed to MonitoredSession directly.
        # Other parts of the TPU execution are organized as follows.
        #
        # 1. All outfeed based Tensors must be grouped with predictions Tensors
        #    to form a single invocation. This avoid the issue we might trigger
        #    multiple outfeeds incorrectly. To achieve this, `host_call` is
        #    placed in control_dependencies of `stopping_signals`, and
        #    `stopping_signals` is passed into _StoppingPredictHook, which sets
        #    the `stopping_signals` as SessionRunArgs. MonitoredSession merges
        #    all SessionRunArgs with the fetch in session.run together.
        #
        # 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue)
        #    are grouped together. They will be launched once and only once in
        #    side threads and they quit naturally according to the SAME stopping
        #    condition.
        enqueue_ops.append(dummy_predict_op)

        host_call_ret = host_calls.create_tpu_hostcall()
        if 'host_call' not in host_call_ret:
          host_ops = []
        else:
          host_ops = host_call_ret['host_call']

        predictions = host_call_ret['predictions']
        _verify_cross_hosts_transfer_size(
            predictions, message=(
                'The estimated size for TPUEstimatorSpec.predictions is too '
                'large.'))
        signals = host_call_ret['signals']

        with ops.control_dependencies(host_ops):
          host_ops = []  # Empty, we do do not need it anymore.
          scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal(
              signals)
          predictions = _PaddingSignals.slice_tensor_or_dict(
              predictions, signals)

        hooks = [
            _StoppingPredictHook(scalar_stopping_signal),
            TPUInfeedOutfeedSessionHookForPrediction(
                ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]),
        ] + input_hooks

        return model_fn_lib.EstimatorSpec(
            mode,
            prediction_hooks=hooks,
            predictions=predictions,
            scaffold=scaffold)

    return _model_fn


def _is_tpu_tensor(tensor):
  if not isinstance(tensor, ops.Tensor):
    return False
  try:
    tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR)  # pylint: disable=protected-access
  except ValueError:
    return True
  else:
    return False


def _export_output_to_tensors(export_output):
  """Get a list of `Tensors` used in `export_output`.

  Args:
    export_output: an `ExportOutput` object such as `ClassificationOutput`,
            `RegressionOutput`, or `PredictOutput`.
  Returns:
    a list of tensors used in export_output.

  Raises:
    ValueError: if `export_output` is not one of `ClassificationOutput`,
        `RegressionOutput`, or `PredictOutput`.
  """
  if isinstance(export_output, export_output_lib.ClassificationOutput):
    return [export_output.scores, export_output.classes]
  elif isinstance(export_output, export_output_lib.RegressionOutput):
    return [export_output.value]
  elif isinstance(export_output, export_output_lib.PredictOutput):
    return export_output.outputs.values()
  else:
    raise ValueError(
        '`export_output` must be have type `ClassificationOutput`, '
        '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))


def _clone_export_output_with_tensors(export_output, tensors):
  """Clones `export_output` but with new `tensors`.

  Args:
    export_output: an `ExportOutput` object such as `ClassificationOutput`,
            `RegressionOutput`, or `PredictOutput`.
    tensors: a list of `Tensors` used to construct a new `export_output`.

  Returns:
    A dict similar to `export_output` but with `tensors`.

  Raises:
    ValueError: if `export_output` is not one of `ClassificationOutput`,
        `RegressionOutput`, or `PredictOutput`.
  """
  if isinstance(export_output, export_output_lib.ClassificationOutput):
    if len(tensors) != 2:
      raise ValueError('tensors must be of length 2; '
                       'got {}.'.format(len(tensors)))
    return export_output_lib.ClassificationOutput(*tensors)
  elif isinstance(export_output, export_output_lib.RegressionOutput):
    if len(tensors) != 1:
      raise ValueError('tensors must be of length 1; '
                       'got {}'.format(len(tensors)))
    return export_output_lib.RegressionOutput(*tensors)
  elif isinstance(export_output, export_output_lib.PredictOutput):
    return export_output_lib.PredictOutput(
        dict(zip(export_output.outputs.keys(), tensors)))
  else:
    raise ValueError(
        '`export_output` must be have type `ClassificationOutput`, '
        '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))


def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
  iterations_per_loop_var = _create_or_get_iterations_per_loop()

  single_tpu_eval_step, host_calls, captured_scaffold_fn = (
      model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn))

  def multi_tpu_eval_steps_on_single_shard():
    return training_loop.repeat(
        iterations_per_loop_var,
        single_tpu_eval_step, [_ZERO_LOSS])

  (loss,) = tpu.shard(
      multi_tpu_eval_steps_on_single_shard,
      inputs=[],
      num_shards=ctx.num_replicas,
      outputs_from_all_shards=False,
      device_assignment=ctx.device_assignment)

  scaffold = _get_scaffold(captured_scaffold_fn)
  return loss, host_calls, scaffold


def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
  iterations_per_loop_var = _create_or_get_iterations_per_loop()

  single_tpu_train_step, host_call, captured_scaffold_fn = (
      model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn))

  def multi_tpu_train_steps_on_single_shard():
    return training_loop.repeat(
        iterations_per_loop_var,
        single_tpu_train_step, [_INITIAL_LOSS])

  (loss,) = tpu.shard(
      multi_tpu_train_steps_on_single_shard,
      inputs=[],
      num_shards=ctx.num_replicas,
      outputs_from_all_shards=False,
      device_assignment=ctx.device_assignment)

  scaffold = _get_scaffold(captured_scaffold_fn)
  return loss, host_call, scaffold


def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
  num_cores = ctx.num_cores

  single_tpu_predict_step, host_calls, captured_scaffold_fn = (
      model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn))

  def multi_tpu_predict_steps_on_single_shard():

    def cond(scalar_stopping_signal):
      return math_ops.logical_not(
          _StopSignals.should_stop(scalar_stopping_signal))

    inputs = [_StopSignals.NON_STOPPING_SIGNAL]
    outputs = training_loop.while_loop(
        cond, single_tpu_predict_step, inputs=inputs, name=b'loop')
    return outputs

  (dummy_predict_op,) = tpu.shard(
      multi_tpu_predict_steps_on_single_shard,
      inputs=[],
      num_shards=num_cores,
      outputs_from_all_shards=False)

  scaffold = _get_scaffold(captured_scaffold_fn)
  return dummy_predict_op, host_calls, scaffold


def _wrap_computation_in_while_loop(device, op_fn):
  """Wraps the ops generated by `op_fn` in tf.while_loop."""

  def computation(i):
    with ops.control_dependencies(op_fn()):
      return i + 1

  iterations_per_loop_var = _create_or_get_iterations_per_loop()
  # By setting parallel_iterations=1, the parallel execution in while_loop is
  # basically turned off.
  with ops.device(device):
    iterations = array_ops.identity(iterations_per_loop_var)
    return control_flow_ops.while_loop(
        lambda i: i < iterations,
        computation, [constant_op.constant(0)],
        parallel_iterations=1)


def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):
  """Wraps the ops generated by `op_fn` in tf.while_loop."""

  def cond(scalar_stopping_signal):
    return math_ops.logical_not(
        _StopSignals.should_stop(scalar_stopping_signal))

  def computation(unused_scalar_stopping_signal):
    return_value = op_fn()
    execute_ops = return_value['ops']
    signals = return_value['signals']
    with ops.control_dependencies(execute_ops):
      return _StopSignals.as_scalar_stopping_signal(signals)

  # By setting parallel_iterations=1, the parallel execution in while_loop is
  # basically turned off.
  with ops.device(device):
    return control_flow_ops.while_loop(
        cond,
        computation, [_StopSignals.NON_STOPPING_SIGNAL],
        parallel_iterations=1)


def _validate_tpu_training_graph():
  """Validate graph before running distributed training.

  Raises:
    ValueError: If the graph seems invalid for running on device
  """
  operations = ops.get_default_graph().get_operations()

  # Check if there is atleast one CrossReplicaSum operation in the graph
  # This should be introduced by using the CrossShardOptimizer wrapper
  cross_replica_sum_ops = [
      o for o in operations if o.type == _CROSS_REPLICA_SUM_OP
  ]
  if not cross_replica_sum_ops:
    raise ValueError(
        'CrossShardOptimizer must be used for model training on TPUs.')


class _CapturedObject(object):
  """A placeholder to capture an object.

  This is useful when we need to capture a Python object in the Tensorflow
  control flow body function and use it outside the control flow.
  """

  def __init__(self):
    self._object = None
    self._captured = False

  def capture(self, o):
    if self._captured:
      raise RuntimeError(
          'InternalError: Object can capture only once. Please file bug.')

    self._captured = True
    self._object = o

  def get(self):
    if not self._captured:
      raise RuntimeError(
          'InternalError: Object is not captured properly before `get`. '
          'Please file bug.')
    return self._object


def _get_scaffold(captured_scaffold_fn):
  """Retrieves the Scaffold from `captured_scaffold_fn`."""
  with _CapturingContext(message='Inside scaffold_fn'):
    scaffold_fn = captured_scaffold_fn.get()
    if scaffold_fn:
      scaffold = scaffold_fn()
      if scaffold is None:
        raise ValueError(
            'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
    else:
      scaffold = None

  if scaffold:
    wrapped_finalize = scaffold.finalize

    def _finalize():
      with _CapturingContext('Inside Scaffold.finalize'):
        wrapped_finalize()

    scaffold.finalize = _finalize
  return scaffold


class _CapturingContext(control_flow_ops.ControlFlowContext):
  """Tracks references to Tensors defined in TPU replication."""

  def __init__(self, message):
    control_flow_ops.ControlFlowContext.__init__(self)
    self._message = message

  def AddOp(self, op):  # pylint: disable=invalid-name
    for c in op.inputs:
      if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr:  # pylint: disable=protected-access
        raise ValueError('{}: Op {} depends on TPU computation {}, '
                         'which is not allowed.'.format(self._message, op, c))

  def __enter__(self):
    # pylint: disable=protected-access
    self._g = ops.get_default_graph()
    self._old = self._g._get_control_flow_context()
    self._g._set_control_flow_context(self)
    # pylint: enable=protected-access

  def __exit__(self, _, __, ___):  # pylint: disable=invalid-name
    self._g._set_control_flow_context(self._old)  # pylint: disable=protected-access


class _Inputs(object):
  """A data structure representing the input_fn returned values.

  This also supports the returned value from input_fn as `Dataset`.
  """

  def __init__(self, features=None, labels=None, dataset=None, signals=None):
    if dataset is not None and (features is not None or labels is not None or
                                signals is not None):
      raise RuntimeError('Internal Error: Either (features and labels) or '
                         'dataset should be provided, not both. Please file '
                         'bug')

    self._features = features
    self._labels = labels
    self._signals = signals

    self._dataset = dataset
    self._iterator = None

  @staticmethod
  def from_input_fn(return_values):
    """Returns an `_Inputs` instance according to `input_fn` return value."""
    if isinstance(return_values, dataset_ops.Dataset):
      dataset = return_values
      return _Inputs(dataset=dataset)

    features, labels = _Inputs._parse_inputs(return_values)
    return _Inputs(features, labels)

  @staticmethod
  def _parse_inputs(return_values):
    if isinstance(return_values, tuple):
      features, labels = return_values
    else:
      features, labels = return_values, None
    return features, labels

  @property
  def is_dataset(self):
    """Returns True if the return value from input_fn is Dataset."""
    return self._dataset is not None

  def dataset_initializer_hook(self):
    """Returns a `SessionRunHook` to initialize this dataset.

    This must be called before `features_and_labels`.
    """
    iterator = self._dataset.make_initializable_iterator()
    # pylint: disable=protected-access
    hook = estimator_util._DatasetInitializerHook(iterator)
    # pylint: enable=protected-access
    self._iterator = iterator
    return hook

  def features_and_labels(self):
    """Gets `features` and `labels`."""
    if self.is_dataset:
      if self._iterator is None:
        raise RuntimeError('Internal error: Must call dataset_initializer_hook '
                           'before calling features_and_labels(). Please file '
                           'a bug!')
      return _Inputs._parse_inputs(self._iterator.get_next())

    return (self._features, self._labels)

  def signals(self):
    return self._signals

  @property
  def dataset(self):
    return self._dataset


class _InputsWithStoppingSignals(_Inputs):
  """Inputs with `_StopSignals` inserted into the dataset."""

  def __init__(self, dataset, batch_size, add_padding=False):

    assert dataset is not None

    user_provided_dataset = dataset.map(
        _InputsWithStoppingSignals.insert_stopping_signal(
            stop=False, batch_size=batch_size, add_padding=add_padding))
    final_batch_dataset = dataset.take(1).map(
        _InputsWithStoppingSignals.insert_stopping_signal(
            stop=True, batch_size=batch_size, add_padding=add_padding))
    dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2)

    super(_InputsWithStoppingSignals, self).__init__(dataset=dataset)
    self._current_inputs = None

  def features_and_labels(self):
    if self._current_inputs is not None:
      raise RuntimeError(
          'Internal Error: The previous inputs have not been properly '
          'consumed. First call features_and_labels, then call signals.')

    inputs_with_signals = self._iterator.get_next()
    features = inputs_with_signals['features']
    labels = inputs_with_signals.get('labels')

    self._current_inputs = inputs_with_signals
    return features, labels

  def signals(self):
    """Returns the `Signals` from `_Inputs`."""
    if self._current_inputs is None:
      raise RuntimeError(
          'Internal Error: The current inputs have not been properly '
          'generated. First call features_and_labels, then call signals.')
    signals = self._current_inputs['signals']
    self._current_inputs = None
    return signals

  @staticmethod
  def insert_stopping_signal(stop, batch_size, add_padding=False):
    """Inserts stopping_signal into dataset via _map_fn.

    Here we change the data structure in the dataset, such that the return value
    is a dictionary now and `features`, `labels`, and `signals` are three
    distinguished keys in that dict. This provides a better structure, which
    eases the process to decompose the inputs (see `features_and_labels`).

    Args:
      stop: bool, state of current stopping signals.
      batch_size: int, batch size.
      add_padding: bool, whether to pad the tensor to full batch size.

    Returns:
      A map_fn passed to dataset.map API.
    """

    def _map_fn(*args):
      """The map fn to insert signals."""
      if len(args) == 1:
        # Unpack the single Tensor/dict argument as features. This is required
        # for the input_fn returns no labels.
        args = args[0]
      features, labels = _Inputs._parse_inputs(args)
      new_input_dict = {}

      if add_padding:
        padding_mask, features, labels = (
            _PaddingSignals.pad_features_and_labels(
                features, labels, batch_size))

        new_input_dict['features'] = features
        if labels is not None:
          new_input_dict['labels'] = labels

      else:
        new_input_dict['features'] = features
        if labels is not None:
          new_input_dict['labels'] = labels
        padding_mask = None

      new_input_dict['signals'] = _StopSignals(
          stop=stop, batch_size=batch_size, padding_mask=padding_mask).as_dict()

      return new_input_dict

    return _map_fn


class _StopSignals(object):
  """Signals class holding all logic to handle TPU stopping condition."""

  NON_STOPPING_SIGNAL = False
  STOPPING_SIGNAL = True

  def __init__(self, stop, batch_size, padding_mask=None):
    self._stop = stop
    self._batch_size = batch_size
    self._padding_mask = padding_mask

  def as_dict(self):
    """Returns the signals as Python dict."""
    shape = [self._batch_size, 1]
    dtype = dtypes.bool

    if self._stop:
      stopping = array_ops.ones(shape=shape, dtype=dtype)
    else:
      stopping = array_ops.zeros(shape=shape, dtype=dtype)

    signals = {'stopping': stopping}
    if self._padding_mask is not None:
      signals['padding_mask'] = self._padding_mask
    return signals

  @staticmethod
  def as_scalar_stopping_signal(signals):
    return array_ops.identity(signals['stopping'][0][0])

  @staticmethod
  def should_stop(scalar_stopping_signal):
    """Detects whether scalar_stopping_signal indicates stopping."""
    if isinstance(scalar_stopping_signal, ops.Tensor):
      # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF
      # way to express the bool check whether scalar_stopping_signal is True.
      return math_ops.logical_and(
          scalar_stopping_signal, _StopSignals.STOPPING_SIGNAL)
    else:
      # For non Tensor case, it is used in SessionRunHook. So, we cannot modify
      # the graph anymore. Here, we use pure Python.
      return bool(scalar_stopping_signal)


class _PaddingSignals(object):
  """Signals class holding all logic to handle padding."""

  @staticmethod
  def pad_features_and_labels(features, labels, batch_size):
    """Pads out the batch dimension of features and labels."""
    real_batch_size = array_ops.shape(
        _PaddingSignals._find_any_tensor(features))[0]

    batch_size_tensor = constant_op.constant(batch_size, dtypes.int32)

    check_greater = check_ops.assert_greater_equal(
        batch_size_tensor, real_batch_size,
        data=(batch_size_tensor, real_batch_size),
        message='The real batch size should not be greater than batch_size.')

    with ops.control_dependencies([check_greater]):
      missing_count = batch_size_tensor - real_batch_size

    def pad_single_tensor(tensor):
      """Pads out the batch dimension of a tensor to the complete batch_size."""
      rank = len(tensor.shape)
      assert rank > 0
      padding = array_ops.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))
      padded_shape = (batch_size,) + tuple(tensor.shape[1:])
      padded_tensor = array_ops.pad(tensor, padding)
      padded_tensor.set_shape(padded_shape)
      return padded_tensor

    def nest_pad(tensor_or_dict):
      return nest.map_structure(pad_single_tensor, tensor_or_dict)

    features = nest_pad(features)
    if labels is not None:
      labels = nest_pad(labels)

    padding_mask = _PaddingSignals._padding_mask(
        real_batch_size, missing_count, batch_size)

    return padding_mask, features, labels

  @staticmethod
  def slice_tensor_or_dict(tensor_or_dict, signals):
    """Slice the real Tensors according to padding mask in signals."""

    padding_mask = signals['padding_mask']
    batch_size = array_ops.shape(padding_mask)[0]

    def verify_batch_size(tensor):
      check_batch_size = math_ops.equal(batch_size, tensor.shape[0])
      with ops.control_dependencies([check_batch_size]):
        return array_ops.identity(tensor)

    def slice_single_tensor(tensor):
      rank = len(tensor.shape)
      assert rank > 0
      real_batch_size = batch_size - math_ops.reduce_sum(padding_mask)
      return verify_batch_size(tensor)[0:real_batch_size]

    # As we split the Tensors to all TPU cores and concat them back, it is
    # important to ensure the real data is placed before padded ones, i.e.,
    # order is preserved. By that, the sliced padding mask should have all 0's.
    # If this assertion failed, # the slice logic here would not hold.
    sliced_padding_mask = slice_single_tensor(padding_mask)
    assert_padding_mask = math_ops.equal(
        math_ops.reduce_sum(sliced_padding_mask), 0)

    with ops.control_dependencies([assert_padding_mask]):
      should_stop = _StopSignals.should_stop(
          _StopSignals.as_scalar_stopping_signal(signals))

    is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0)

    def slice_fn(tensor):
      # If the current batch is full batch or part of stopping signals, we do
      # not need to slice to save performance.
      return control_flow_ops.cond(
          math_ops.logical_or(should_stop, is_full_batch),
          (lambda: verify_batch_size(tensor)),
          (lambda: slice_single_tensor(tensor)))

    return nest.map_structure(slice_fn, tensor_or_dict)

  @staticmethod
  def _find_any_tensor(batch_features):
    tensors = [x for x in nest.flatten(batch_features)
               if isinstance(x, ops.Tensor)]
    if not tensors:
      raise ValueError('Cannot find any Tensor in features dict.')
    return tensors[0]

  @staticmethod
  def _padding_mask(real_batch_size, missing_count, batch_size):
    padding_mask = array_ops.concat(
        [
            array_ops.zeros((real_batch_size,), dtype=dtypes.int32),
            array_ops.ones((missing_count,), dtype=dtypes.int32)
        ],
        axis=0)
    padding_mask.set_shape((batch_size,))
    return padding_mask


class _SignalsHelper(object):
  """A general helper class to handle common signals manipulation."""

  def __init__(self, signals):
    self._signal_keys = []
    for key in sorted(iter(signals.keys())):
      self._signal_keys.append(key)

  @property
  def num_signals(self):
    return len(self._signal_keys)

  def unflatten(self, tensor_list):
    return dict(zip(self._signal_keys, tensor_list))

  @staticmethod
  def as_tensor_list(signals):
    return [signals[key] for key in sorted(iter(signals.keys()))]


def _verify_cross_hosts_transfer_size(tensor_dict, message):
  total_size = 0
  tensor_structure = {}
  for key, tensor in tensor_dict.items():
    shape = tensor.shape
    size = np.product(shape) * tensor.dtype.size
    tensor_structure[key] = shape
    total_size += size
  if total_size >= _ONE_GIGABYTE:
    raise ValueError(
        '{} The transfer size is larger than the protobuf limit. Please '
        'consider to use Tensors with smaller shapes or reduce batch '
        'size. Given:\n'
        '{}'.format(message, '\n'.join([
            ' -- Key: {}, Shape: {}'.format(k, v)
            for k, v in tensor_structure.items()])))


def _add_item_to_params(params, key, value):
  """Adds a new item into `params`."""
  if isinstance(params, hparam.HParams):
    # For HParams, we need to use special API.
    if key in params:
      params.set_hparam(key, value)
    else:
      params.add_hparam(key, value)
  else:
    # Now params is Python dict.
    params[key] = value


def export_estimator_savedmodel(estimator,
                                export_dir_base,
                                serving_input_receiver_fn,
                                assets_extra=None,
                                as_text=False,
                                checkpoint_path=None,
                                strip_default_attrs=False):
  """Export `Estimator` trained model for TPU inference.

  Args:
    estimator: `Estimator` with which model has been trained.
    export_dir_base: A string containing a directory in which to create
      timestamped subdirectories containing exported SavedModels.
    serving_input_receiver_fn: A function that takes no argument and
      returns a `ServingInputReceiver` or `TensorServingInputReceiver`.
    assets_extra: A dict specifying how to populate the assets.extra directory
      within the exported SavedModel, or `None` if no extra assets are needed.
    as_text: whether to write the SavedModel proto in text format.
    checkpoint_path: The checkpoint path to export.  If `None` (the default),
      the most recent checkpoint found within the model directory is chosen.
    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
      removed from the NodeDefs.

  Returns:
    The string path to the exported directory.
  """
  # `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use
  # `estimator.config`.
  config = tpu_config.RunConfig(model_dir=estimator.model_dir)
  est = TPUEstimator(
      estimator._model_fn,  # pylint: disable=protected-access
      config=config,
      params=estimator.params,
      use_tpu=True,
      train_batch_size=2048,  # Does not matter.
      eval_batch_size=2048,  # Does not matter.
  )
  return est.export_savedmodel(export_dir_base, serving_input_receiver_fn,
                               assets_extra,
                               as_text,
                               checkpoint_path,
                               strip_default_attrs)