aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/ops/dataset_ops.py
blob: cdb883cac941acb3b43ca5168bd4621884f91299 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python wrappers for Datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import threading
import warnings

import numpy as np
import six

from tensorflow.python.compat import compat
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import random_seed
from tensorflow.python.data.util import sparse
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed as core_random_seed
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export


@tf_export("data.Dataset")
class Dataset(object):
  """Represents a potentially large set of elements.

  A `Dataset` can be used to represent an input pipeline as a
  collection of elements (nested structures of tensors) and a "logical
  plan" of transformations that act on those elements.
  """
  __metaclass__ = abc.ABCMeta

  def __init__(self):
    pass

  def _as_serialized_graph(self):
    """Produces serialized graph representation of the dataset.

    Returns:
      A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
      serialized graph.
    """
    return gen_dataset_ops.dataset_to_graph(self._as_variant_tensor())

  @abc.abstractmethod
  def _as_variant_tensor(self):
    """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.

    Returns:
      A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset.
    """
    raise NotImplementedError("Dataset._as_variant_tensor")

  @abc.abstractmethod
  def _inputs(self):
    """Returns a list of the input datasets of the dataset."""

    raise NotImplementedError("Dataset._inputs")

  def options(self):
    """Returns the options for this dataset.

    Returns:
      A `tf.data.Options` object representing the dataset options.
    """
    for input_dataset in self._inputs():
      options = input_dataset.options()
      if options is not None:
        return options
    return Options()

  def make_initializable_iterator(self, shared_name=None):
    """Creates an `Iterator` for enumerating the elements of this dataset.

    Note: The returned iterator will be in an uninitialized state,
    and you must run the `iterator.initializer` operation before using it:

    ```python
    dataset = ...
    iterator = dataset.make_initializable_iterator()
    # ...
    sess.run(iterator.initializer)
    ```

    Args:
      shared_name: (Optional.) If non-empty, the returned iterator will be
        shared under the given name across multiple sessions that share the
        same devices (e.g. when using a remote server).

    Returns:
      An `Iterator` over the elements of this dataset.

    Raises:
      RuntimeError: If eager execution is enabled.
    """
    if context.executing_eagerly():
      raise RuntimeError(
          "dataset.make_initializable_iterator is not supported when eager "
          "execution is enabled.")
    dataset = self
    options = self.options()
    static_optimizations = options._static_optimizations()  # pylint: disable=protected-access
    if static_optimizations:
      dataset = _OptimizeDataset(dataset, static_optimizations)
    if options.experimental_autotune:
      dataset = _ModelDataset(dataset)
    if shared_name is None:
      shared_name = ""
    if compat.forward_compatible(2018, 8, 3):
      iterator_resource = gen_dataset_ops.iterator_v2(
          container="", shared_name=shared_name, **flat_structure(self))
    else:
      iterator_resource = gen_dataset_ops.iterator(
          container="", shared_name=shared_name, **flat_structure(self))
    with ops.colocate_with(iterator_resource):
      initializer = gen_dataset_ops.make_iterator(
          dataset._as_variant_tensor(),  # pylint: disable=protected-access
          iterator_resource)
    return iterator_ops.Iterator(iterator_resource, initializer,
                                 dataset.output_types, dataset.output_shapes,
                                 dataset.output_classes)

  def __iter__(self):
    """Creates an `Iterator` for enumerating the elements of this dataset.

    The returned iterator implements the Python iterator protocol and therefore
    can only be used in eager mode.

    Returns:
      An `Iterator` over the elements of this dataset.

    Raises:
      RuntimeError: If eager execution is not enabled.
    """
    if context.executing_eagerly():
      return iterator_ops.EagerIterator(self)
    else:
      raise RuntimeError("dataset.__iter__() is only supported when eager "
                         "execution is enabled.")

  def make_one_shot_iterator(self):
    """Creates an `Iterator` for enumerating the elements of this dataset.

    Note: The returned iterator will be initialized automatically.
    A "one-shot" iterator does not currently support re-initialization.

    Returns:
      An `Iterator` over the elements of this dataset.
    """
    if context.executing_eagerly():
      return iterator_ops.EagerIterator(self)

    graph_level_seed, op_level_seed = core_random_seed.get_seed(None)

    # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
    # a 0-argument function.
    @function.Defun(capture_by_value=True)
    def _make_dataset():
      # NOTE(mrry): `Defun` does not capture the graph-level seed from the
      # enclosing graph, so if a graph-level seed is present we set the local
      # graph seed based on a combination of the graph- and op-level seeds.
      if graph_level_seed is not None:
        assert op_level_seed is not None
        core_random_seed.set_random_seed(
            (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1))

      dataset = self
      options = self.options()
      static_optimizations = options._static_optimizations()  # pylint: disable=protected-access
      if static_optimizations:
        dataset = _OptimizeDataset(dataset, static_optimizations)
      if options.experimental_autotune:
        dataset = _ModelDataset(dataset)
      return dataset._as_variant_tensor()  # pylint: disable=protected-access

    try:
      _make_dataset.add_to_graph(ops.get_default_graph())
    except ValueError as err:
      if "Cannot capture a stateful node" in str(err):
        raise ValueError(
            "Failed to create a one-shot iterator for a dataset. "
            "`Dataset.make_one_shot_iterator()` does not support datasets that "
            "capture stateful objects, such as a `Variable` or `LookupTable`. "
            "In these cases, use `Dataset.make_initializable_iterator()`. "
            "(Original error: %s)" % err)
      else:
        six.reraise(ValueError, err)

    return iterator_ops.Iterator(
        gen_dataset_ops.one_shot_iterator(
            dataset_factory=_make_dataset, **flat_structure(self)),
        None, self.output_types, self.output_shapes, self.output_classes)

  @abc.abstractproperty
  def output_classes(self):
    """Returns the class of each component of an element of this dataset.

    The expected values are `tf.Tensor` and `tf.SparseTensor`.

    Returns:
      A nested structure of Python `type` objects corresponding to each
      component of an element of this dataset.
    """
    raise NotImplementedError("Dataset.output_classes")

  @abc.abstractproperty
  def output_shapes(self):
    """Returns the shape of each component of an element of this dataset.

    Returns:
      A nested structure of `tf.TensorShape` objects corresponding to each
      component of an element of this dataset.
    """
    raise NotImplementedError("Dataset.output_shapes")

  @abc.abstractproperty
  def output_types(self):
    """Returns the type of each component of an element of this dataset.

    Returns:
      A nested structure of `tf.DType` objects corresponding to each component
      of an element of this dataset.
    """
    raise NotImplementedError("Dataset.output_types")

  def __repr__(self):
    output_shapes = nest.map_structure(str, self.output_shapes)
    output_shapes = str(output_shapes).replace("'", "")
    output_types = nest.map_structure(repr, self.output_types)
    output_types = str(output_types).replace("'", "")
    return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes,
                                            output_types))

  @staticmethod
  def from_tensors(tensors):
    """Creates a `Dataset` with a single element, comprising the given tensors.

    Note that if `tensors` contains a NumPy array, and eager execution is not
    enabled, the values will be embedded in the graph as one or more
    `tf.constant` operations. For large datasets (> 1 GB), this can waste
    memory and run into byte limits of graph serialization.  If tensors contains
    one or more large NumPy arrays, consider the alternative described in
    [this guide](https://tensorflow.org/guide/datasets#consuming_numpy_arrays).

    Args:
      tensors: A nested structure of tensors.

    Returns:
      Dataset: A `Dataset`.
    """
    return TensorDataset(tensors)

  @staticmethod
  def from_tensor_slices(tensors):
    """Creates a `Dataset` whose elements are slices of the given tensors.

    Note that if `tensors` contains a NumPy array, and eager execution is not
    enabled, the values will be embedded in the graph as one or more
    `tf.constant` operations. For large datasets (> 1 GB), this can waste
    memory and run into byte limits of graph serialization.  If tensors contains
    one or more large NumPy arrays, consider the alternative described in
    [this guide](https://tensorflow.org/guide/datasets#consuming_numpy_arrays).

    Args:
      tensors: A nested structure of tensors, each having the same size in the
        0th dimension.

    Returns:
      Dataset: A `Dataset`.
    """
    return TensorSliceDataset(tensors)

  @staticmethod
  @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.")
  def from_sparse_tensor_slices(sparse_tensor):
    """Splits each rank-N `tf.SparseTensor` in this dataset row-wise.

    Args:
      sparse_tensor: A `tf.SparseTensor`.

    Returns:
      Dataset: A `Dataset` of rank-(N-1) sparse tensors.
    """
    return SparseTensorSliceDataset(sparse_tensor)

  class _GeneratorState(object):
    """Stores outstanding iterators created from a Python generator.

    This class keeps track of potentially multiple iterators that may have
    been created from a generator, e.g. in the case that the dataset is
    repeated, or nested within a parallel computation.
    """

    def __init__(self, generator):
      self._generator = generator
      self._lock = threading.Lock()
      self._next_id = 0  # GUARDED_BY(self._lock)
      self._args = {}
      self._iterators = {}

    def get_next_id(self, *args):
      with self._lock:
        ret = self._next_id
        self._next_id += 1
      self._args[ret] = args
      # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
      # casting in `py_func()` will create an array of `np.int32` on Windows,
      # leading to a runtime error.
      return np.array(ret, dtype=np.int64)

    def get_iterator(self, iterator_id):
      try:
        return self._iterators[iterator_id]
      except KeyError:
        iterator = iter(self._generator(*self._args.pop(iterator_id)))
        self._iterators[iterator_id] = iterator
        return iterator

    def iterator_completed(self, iterator_id):
      del self._iterators[iterator_id]

  @staticmethod
  def from_generator(generator, output_types, output_shapes=None, args=None):
    """Creates a `Dataset` whose elements are generated by `generator`.

    The `generator` argument must be a callable object that returns
    an object that support the `iter()` protocol (e.g. a generator function).
    The elements generated by `generator` must be compatible with the given
    `output_types` and (optional) `output_shapes` arguments.

    For example:

    ```python
    import itertools

    def gen():
      for i in itertools.count(1):
        yield (i, [1] * i)

    ds = Dataset.from_generator(
        gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))
    value = ds.make_one_shot_iterator().get_next()

    sess.run(value)  # (1, array([1]))
    sess.run(value)  # (2, array([1, 1]))
    ```

    NOTE: The current implementation of `Dataset.from_generator()` uses
    `tf.py_func` and inherits the same constraints. In particular, it
    requires the `Dataset`- and `Iterator`-related operations to be placed
    on a device in the same process as the Python program that called
    `Dataset.from_generator()`. The body of `generator` will not be
    serialized in a `GraphDef`, and you should not use this method if you
    need to serialize your model and restore it in a different environment.

    NOTE: If `generator` depends on mutable global variables or other external
    state, be aware that the runtime may invoke `generator` multiple times
    (in order to support repeating the `Dataset`) and at any time
    between the call to `Dataset.from_generator()` and the production of the
    first element from the generator. Mutating global variables or external
    state can cause undefined behavior, and we recommend that you explicitly
    cache any external state in `generator` before calling
    `Dataset.from_generator()`.

    Args:
      generator: A callable object that returns an object that supports the
        `iter()` protocol. If `args` is not specified, `generator` must take
        no arguments; otherwise it must take as many arguments as there are
        values in `args`.
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element yielded by `generator`.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape`
        objects corresponding to each component of an element yielded by
        `generator`.
      args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated
        and passed to `generator` as NumPy-array arguments.

    Returns:
      Dataset: A `Dataset`.
    """
    if not callable(generator):
      raise TypeError("`generator` must be callable.")
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    if args is None:
      args = ()
    else:
      args = tuple(ops.convert_n_to_tensor(args, name="args"))

    flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)]
    flattened_shapes = nest.flatten(output_shapes)

    generator_state = Dataset._GeneratorState(generator)

    def get_iterator_id_fn(unused_dummy):
      """Creates a unique `iterator_id` for each pass over the dataset.

      The returned `iterator_id` disambiguates between multiple concurrently
      existing iterators.

      Args:
        unused_dummy: Ignored value.

      Returns:
        A `tf.int64` tensor whose value uniquely identifies an iterator in
        `generator_state`.
      """
      return script_ops.py_func(
          generator_state.get_next_id, args, dtypes.int64, stateful=True)

    def generator_next_fn(iterator_id_t):
      """Generates the next element from iterator with ID `iterator_id_t`.

      We map this function across an infinite repetition of the
      `iterator_id_t`, and raise `StopIteration` to terminate the iteration.

      Args:
        iterator_id_t: A `tf.int64` tensor whose value uniquely identifies
          the iterator in `generator_state` from which to generate an element.

      Returns:
        A nested structure of tensors representing an element from the iterator.
      """

      def generator_py_func(iterator_id):
        """A `py_func` that will be called to invoke the iterator."""
        # `next()` raises `StopIteration` when there are no more
        # elements remaining to be generated.
        values = next(generator_state.get_iterator(iterator_id))

        # Use the same _convert function from the py_func() implementation to
        # convert the returned values to arrays early, so that we can inspect
        # their values.
        try:
          flattened_values = nest.flatten_up_to(output_types, values)
        except (TypeError, ValueError):
          raise TypeError(
              "`generator` yielded an element that did not match the expected "
              "structure. The expected structure was %s, but the yielded "
              "element was %s." % (output_types, values))
        ret_arrays = []
        for ret, dtype in zip(flattened_values, flattened_types):
          try:
            ret_arrays.append(script_ops.FuncRegistry._convert(  # pylint: disable=protected-access
                ret, dtype=dtype.as_numpy_dtype))
          except (TypeError, ValueError):
            raise TypeError(
                "`generator` yielded an element that could not be converted to "
                "the expected type. The expected type was %s, but the yielded "
                "element was %s." % (dtype.name, ret))

        # Additional type and shape checking to ensure that the components
        # of the generated element match the `output_types` and `output_shapes`
        # arguments.
        for (ret_array, expected_dtype, expected_shape) in zip(
            ret_arrays, flattened_types, flattened_shapes):
          if ret_array.dtype != expected_dtype.as_numpy_dtype:
            raise TypeError(
                "`generator` yielded an element of type %s where an element "
                "of type %s was expected." % (ret_array.dtype,
                                              expected_dtype.as_numpy_dtype))
          if not expected_shape.is_compatible_with(ret_array.shape):
            raise ValueError(
                "`generator` yielded an element of shape %s where an element "
                "of shape %s was expected." % (ret_array.shape, expected_shape))

        return ret_arrays

      flat_values = script_ops.py_func(
          generator_py_func, [iterator_id_t], flattened_types, stateful=True)

      # The `py_func()` op drops the inferred shapes, so we add them back in
      # here.
      if output_shapes is not None:
        for ret_t, shape in zip(flat_values, flattened_shapes):
          ret_t.set_shape(shape)

      return nest.pack_sequence_as(output_types, flat_values)

    def finalize_fn(iterator_id_t):
      """Releases host-side state for the iterator with ID `iterator_id_t`."""

      def finalize_py_func(iterator_id):
        generator_state.iterator_completed(iterator_id)
        # We return a dummy value so that the `finalize_fn` has a valid
        # signature.
        # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
        # casting in `py_func()` will create an array of `np.int32` on Windows,
        # leading to a runtime error.
        return np.array(0, dtype=np.int64)

      return script_ops.py_func(
          finalize_py_func, [iterator_id_t], dtypes.int64, stateful=True)

    # This function associates each traversal of `generator` with a unique
    # iterator ID.
    def flat_map_fn(dummy_arg):
      # The `get_iterator_id_fn` gets a unique ID for the current instance of
      # of the generator.
      # The `generator_next_fn` gets the next element from the iterator with the
      # given ID, and raises StopIteration when that iterator contains no
      # more elements.
      return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn,
                               finalize_fn)

    # A single-element dataset that, each time it is evaluated, contains a
    # freshly-generated and unique (for the returned dataset) int64
    # ID that will be used to identify the appropriate Python state, which
    # is encapsulated in `generator_state`, and captured in
    # `get_iterator_id_map_fn`.
    dummy = 0
    id_dataset = Dataset.from_tensors(dummy)

    # A dataset that contains all of the elements generated by a
    # single iterator created from `generator`, identified by the
    # iterator ID contained in `id_dataset`. Lifting the iteration
    # into a flat_map here enables multiple repetitions and/or nested
    # versions of the returned dataset to be created, because it forces
    # the generation of a new ID for each version.
    return id_dataset.flat_map(flat_map_fn)

  @staticmethod
  def range(*args):
    """Creates a `Dataset` of a step-separated range of values.

    For example:

    ```python
    Dataset.range(5) == [0, 1, 2, 3, 4]
    Dataset.range(2, 5) == [2, 3, 4]
    Dataset.range(1, 5, 2) == [1, 3]
    Dataset.range(1, 5, -2) == []
    Dataset.range(5, 1) == []
    Dataset.range(5, 1, -2) == [5, 3]
    ```

    Args:
      *args: follow same semantics as python's xrange.
        len(args) == 1 -> start = 0, stop = args[0], step = 1
        len(args) == 2 -> start = args[0], stop = args[1], step = 1
        len(args) == 3 -> start = args[0], stop = args[1, stop = args[2]

    Returns:
      Dataset: A `RangeDataset`.

    Raises:
      ValueError: if len(args) == 0.
    """
    return RangeDataset(*args)

  @staticmethod
  def zip(datasets):
    """Creates a `Dataset` by zipping together the given datasets.

    This method has similar semantics to the built-in `zip()` function
    in Python, with the main difference being that the `datasets`
    argument can be an arbitrary nested structure of `Dataset` objects.
    For example:

    ```python
    # NOTE: The following examples use `{ ... }` to represent the
    # contents of a dataset.
    a = { 1, 2, 3 }
    b = { 4, 5, 6 }
    c = { (7, 8), (9, 10), (11, 12) }
    d = { 13, 14 }

    # The nested structure of the `datasets` argument determines the
    # structure of elements in the resulting dataset.
    Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) }
    Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) }

    # The `datasets` argument may contain an arbitrary number of
    # datasets.
    Dataset.zip((a, b, c)) == { (1, 4, (7, 8)),
                                (2, 5, (9, 10)),
                                (3, 6, (11, 12)) }

    # The number of elements in the resulting dataset is the same as
    # the size of the smallest dataset in `datasets`.
    Dataset.zip((a, d)) == { (1, 13), (2, 14) }
    ```

    Args:
      datasets: A nested structure of datasets.

    Returns:
      Dataset: A `Dataset`.
    """
    return ZipDataset(datasets)

  def concatenate(self, dataset):
    """Creates a `Dataset` by concatenating given dataset with this dataset.

    ```python
    # NOTE: The following examples use `{ ... }` to represent the
    # contents of a dataset.
    a = { 1, 2, 3 }
    b = { 4, 5, 6, 7 }

    # Input dataset and dataset to be concatenated should have same
    # nested structures and output types.
    # c = { (8, 9), (10, 11), (12, 13) }
    # d = { 14.0, 15.0, 16.0 }
    # a.concatenate(c) and a.concatenate(d) would result in error.

    a.concatenate(b) == { 1, 2, 3, 4, 5, 6, 7 }
    ```

    Args:
      dataset: `Dataset` to be concatenated.

    Returns:
      Dataset: A `Dataset`.
    """
    return ConcatenateDataset(self, dataset)

  def prefetch(self, buffer_size):
    """Creates a `Dataset` that prefetches elements from this dataset.

    Args:
      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
        maximum number of elements that will be buffered when prefetching.

    Returns:
      Dataset: A `Dataset`.
    """
    return PrefetchDataset(self, buffer_size)

  @staticmethod
  def list_files(file_pattern, shuffle=None, seed=None):
    """A dataset of all files matching a pattern.

    NOTE: The default behavior of this method is to return filenames in
    a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False`
    to get results in a deterministic order.

    Example:
      If we had the following files on our filesystem:
        - /path/to/dir/a.txt
        - /path/to/dir/b.py
        - /path/to/dir/c.py
      If we pass "/path/to/dir/*.py" as the directory, the dataset would
      produce:
        - /path/to/dir/b.py
        - /path/to/dir/c.py

    Args:
      file_pattern: A string or scalar string `tf.Tensor`, representing
        the filename pattern that will be matched.
      shuffle: (Optional.) If `True`, the file names will be shuffled randomly.
        Defaults to `True`.
      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
        random seed that will be used to create the distribution. See
        `tf.set_random_seed` for behavior.

    Returns:
     Dataset: A `Dataset` of strings corresponding to file names.
    """
    with ops.name_scope("list_files"):
      if shuffle is None:
        shuffle = True
      file_pattern = ops.convert_to_tensor(
          file_pattern, dtype=dtypes.string, name="file_pattern")
      matching_files = gen_io_ops.matching_files(file_pattern)

      # Raise an exception if `file_pattern` does not match any files.
      condition = math_ops.greater(array_ops.shape(matching_files)[0], 0,
                                   name="match_not_empty")

      message = math_ops.add(
          "No files matched pattern: ",
          string_ops.reduce_join(file_pattern, separator=", "), name="message")

      assert_not_empty = control_flow_ops.Assert(
          condition, [message], summarize=1, name="assert_not_empty")
      with ops.control_dependencies([assert_not_empty]):
        matching_files = array_ops.identity(matching_files)

      dataset = Dataset.from_tensor_slices(matching_files)
      if shuffle:
        # NOTE(mrry): The shuffle buffer size must be greater than zero, but the
        # list of files might be empty.
        buffer_size = math_ops.maximum(
            array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
        dataset = dataset.shuffle(buffer_size, seed=seed)
      return dataset

  def repeat(self, count=None):
    """Repeats this dataset `count` times.

    NOTE: If this dataset is a function of global state (e.g. a random number
    generator), then different repetitions may produce different elements.

    Args:
      count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
        number of times the dataset should be repeated. The default behavior
        (if `count` is `None` or `-1`) is for the dataset be repeated
        indefinitely.

    Returns:
      Dataset: A `Dataset`.
    """
    return RepeatDataset(self, count)

  def _enumerate(self, start=0):

    max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
    return Dataset.zip((Dataset.range(start, max_value), self))

  def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
    """Randomly shuffles the elements of this dataset.

    This dataset fills a buffer with `buffer_size` elements, then randomly
    samples elements from this buffer, replacing the selected elements with new
    elements. For perfect shuffling, a buffer size greater than or equal to the
    full size of the dataset is required.

    Args:
      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
        number of elements from this dataset from which the new
        dataset will sample.
      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
        random seed that will be used to create the distribution. See
        `tf.set_random_seed` for behavior.
      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
        that the dataset should be pseudorandomly reshuffled each time it is
        iterated over. (Defaults to `True`.)

    Returns:
      Dataset: A `Dataset`.
    """
    return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration)

  def cache(self, filename=""):
    """Caches the elements in this dataset.

    Args:
      filename: A `tf.string` scalar `tf.Tensor`, representing the name of a
        directory on the filesystem to use for caching tensors in this Dataset.
        If a filename is not provided, the dataset will be cached in memory.

    Returns:
      Dataset: A `Dataset`.
    """
    return CacheDataset(self, filename)

  def take(self, count):
    """Creates a `Dataset` with at most `count` elements from this dataset.

    Args:
      count: A `tf.int64` scalar `tf.Tensor`, representing the number of
        elements of this dataset that should be taken to form the new dataset.
        If `count` is -1, or if `count` is greater than the size of this
        dataset, the new dataset will contain all elements of this dataset.

    Returns:
      Dataset: A `Dataset`.
    """
    return TakeDataset(self, count)

  def skip(self, count):
    """Creates a `Dataset` that skips `count` elements from this dataset.

    Args:
      count: A `tf.int64` scalar `tf.Tensor`, representing the number
        of elements of this dataset that should be skipped to form the
        new dataset.  If `count` is greater than the size of this
        dataset, the new dataset will contain no elements.  If `count`
        is -1, skips the entire dataset.

    Returns:
      Dataset: A `Dataset`.
    """
    return SkipDataset(self, count)

  def shard(self, num_shards, index):
    """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.

    This dataset operator is very useful when running distributed training, as
    it allows each worker to read a unique subset.

    When reading a single input file, you can skip elements as follows:

    ```python
    d = tf.data.TFRecordDataset(FLAGS.input_file)
    d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
    d = d.repeat(FLAGS.num_epochs)
    d = d.shuffle(FLAGS.shuffle_buffer_size)
    d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
    ```

    Important caveats:

    - Be sure to shard before you use any randomizing operator (such as
      shuffle).
    - Generally it is best if the shard operator is used early in the dataset
      pipeline. For example, when reading from a set of TFRecord files, shard
      before converting the dataset to input samples. This avoids reading every
      file on every worker. The following is an example of an efficient
      sharding strategy within a complete pipeline:

    ```python
    d = Dataset.list_files(FLAGS.pattern)
    d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
    d = d.repeat(FLAGS.num_epochs)
    d = d.shuffle(FLAGS.shuffle_buffer_size)
    d = d.interleave(tf.data.TFRecordDataset,
                     cycle_length=FLAGS.num_readers, block_length=1)
    d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
    ```

    Args:
      num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
        shards operating in parallel.
      index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.

    Returns:
      Dataset: A `Dataset`.

    Raises:
      ValueError: if `num_shards` or `index` are illegal values. Note: error
        checking is done on a best-effort basis, and aren't guaranteed to be
        caught upon dataset creation. (e.g. providing in a placeholder tensor
        bypasses the early checking, and will instead result in an error during
        a session.run call.)
    """
    num_shards = ops.convert_to_tensor(
        num_shards, name="num_shards", dtype=dtypes.int64)
    num_shards_static = tensor_util.constant_value(num_shards)
    index = ops.convert_to_tensor(index, name="index", dtype=dtypes.int64)
    index_static = tensor_util.constant_value(index)

    if num_shards_static is not None and num_shards_static < 1:
      raise ValueError("num_shards must be >= 1; got: %s" % num_shards_static)
    if index_static is not None and index_static < 0:
      raise ValueError("index must be >= 0; got: %s" % index_static)
    if (index_static is not None and num_shards_static is not None and
        index_static >= num_shards_static):
      raise ValueError("index must be <= num_shards; %s is not < %s" %
                       (index_static, num_shards_static))

    def filter_fn(elem_index, _):
      mod_result = math_ops.mod(elem_index, num_shards)
      return math_ops.equal(mod_result, index)

    return self._enumerate().filter(filter_fn).map(lambda _, elem: elem)

  def batch(self, batch_size, drop_remainder=False):
    """Combines consecutive elements of this dataset into batches.

    The tensors in the resulting element will have an additional outer
    dimension, which will be `batch_size` (or `N % batch_size` for the last
    element if `batch_size` does not divide the number of input elements `N`
    evenly and `drop_remainder` is `False`). If your program depends on the
    batches having the same outer dimension, you should set the `drop_remainder`
    argument to `True` to prevent the smaller batch from being produced.

    Args:
      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
        consecutive elements of this dataset to combine in a single batch.
      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
        whether the last batch should be dropped in the case its has fewer than
        `batch_size` elements; the default behavior is not to drop the smaller
        batch.

    Returns:
      Dataset: A `Dataset`.
    """
    return BatchDataset(self, batch_size, drop_remainder)

  def padded_batch(self,
                   batch_size,
                   padded_shapes,
                   padding_values=None,
                   drop_remainder=False):
    """Combines consecutive elements of this dataset into padded batches.

    This transformation combines multiple consecutive elements of the input
    dataset into a single element.

    Like `tf.data.Dataset.batch`, the tensors in the resulting element will
    have an additional outer dimension, which will be `batch_size` (or
    `N % batch_size` for the last element if `batch_size` does not divide the
    number of input elements `N` evenly and `drop_remainder` is `False`). If
    your program depends on the batches having the same outer dimension, you
    should set the `drop_remainder` argument to `True` to prevent the smaller
    batch from being produced.

    Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
    different shapes, and this transformation will pad each component to the
    respective shape in `padding_shapes`. The `padding_shapes` argument
    determines the resulting shape for each dimension of each component in an
    output element:

    * If the dimension is a constant (e.g. `tf.Dimension(37)`), the component
      will be padded out to that length in that dimension.
    * If the dimension is unknown (e.g. `tf.Dimension(None)`), the component
      will be padded out to the maximum length of all elements in that
      dimension.

    See also `tf.data.experimental.dense_to_sparse_batch`, which combines
    elements that may have different shapes into a `tf.SparseTensor`.

    Args:
      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
        consecutive elements of this dataset to combine in a single batch.
      padded_shapes: A nested structure of `tf.TensorShape` or
        `tf.int64` vector tensor-like objects representing the shape
        to which the respective component of each input element should
        be padded prior to batching. Any unknown dimensions
        (e.g. `tf.Dimension(None)` in a `tf.TensorShape` or `-1` in a
        tensor-like object) will be padded to the maximum size of that
        dimension in each batch.
      padding_values: (Optional.) A nested structure of scalar-shaped
        `tf.Tensor`, representing the padding values to use for the
        respective components.  Defaults are `0` for numeric types and
        the empty string for string types.
      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
        whether the last batch should be dropped in the case its has fewer than
        `batch_size` elements; the default behavior is not to drop the smaller
        batch.

    Returns:
      Dataset: A `Dataset`.
    """
    return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values,
                              drop_remainder)

  def map(self, map_func, num_parallel_calls=None):
    """Maps `map_func` across the elements of this dataset.

    This transformation applies `map_func` to each element of this dataset, and
    returns a new dataset containing the transformed elements, in the same
    order as they appeared in the input.

    For example:

    ```python
    # NOTE: The following examples use `{ ... }` to represent the
    # contents of a dataset.
    a = { 1, 2, 3, 4, 5 }

    a.map(lambda x: x + 1) = { 2, 3, 4, 5, 6 }
    ```

    The input signature of `map_func` is determined by the structure of each
    element in this dataset. For example:

    ```python
    # Each element is a `tf.Tensor` object.
    a = { 1, 2, 3, 4, 5 }
    # `map_func` takes a single argument of type `tf.Tensor` with the same
    # shape and dtype.
    result = a.map(lambda x: ...)

    # Each element is a tuple containing two `tf.Tensor` objects.
    b = { (1, "foo"), (2, "bar"), (3, "baz") }
    # `map_func` takes two arguments of type `tf.Tensor`.
    result = b.map(lambda x_int, y_str: ...)

    # Each element is a dictionary mapping strings to `tf.Tensor` objects.
    c = { {"a": 1, "b": "foo"}, {"a": 2, "b": "bar"}, {"a": 3, "b": "baz"} }
    # `map_func` takes a single argument of type `dict` with the same keys as
    # the elements.
    result = c.map(lambda d: ...)
    ```

    The value or values returned by `map_func` determine the structure of each
    element in the returned dataset.

    ```python
    # `map_func` returns a scalar `tf.Tensor` of type `tf.float32`.
    def f(...):
      return tf.constant(37.0)
    result = dataset.map(f)
    result.output_classes == tf.Tensor
    result.output_types == tf.float32
    result.output_shapes == []  # scalar

    # `map_func` returns two `tf.Tensor` objects.
    def g(...):
      return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
    result = dataset.map(g)
    result.output_classes == (tf.Tensor, tf.Tensor)
    result.output_types == (tf.float32, tf.string)
    result.output_shapes == ([], [3])

    # Python primitives, lists, and NumPy arrays are implicitly converted to
    # `tf.Tensor`.
    def h(...):
      return 37.0, ["Foo", "Bar", "Baz"], np.array([1.0, 2.0] dtype=np.float64)
    result = dataset.map(h)
    result.output_classes == (tf.Tensor, tf.Tensor, tf.Tensor)
    result.output_types == (tf.float32, tf.string, tf.float64)
    result.output_shapes == ([], [3], [2])

    # `map_func` can return nested structures.
    def i(...):
      return {"a": 37.0, "b": [42, 16]}, "foo"
    result.output_classes == ({"a": tf.Tensor, "b": tf.Tensor}, tf.Tensor)
    result.output_types == ({"a": tf.float32, "b": tf.int32}, tf.string)
    result.output_shapes == ({"a": [], "b": [2]}, [])
    ```

    In addition to `tf.Tensor` objects, `map_func` can accept as arguments and
    return `tf.SparseTensor` objects.

    Args:
      map_func: A function mapping a nested structure of tensors (having
        shapes and types defined by `self.output_shapes` and
       `self.output_types`) to another nested structure of tensors.
      num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
        representing the number elements to process in parallel. If not
        specified, elements will be processed sequentially.

    Returns:
      Dataset: A `Dataset`.
    """
    if num_parallel_calls is None:
      return MapDataset(self, map_func)
    else:
      return ParallelMapDataset(self, map_func, num_parallel_calls)

  def flat_map(self, map_func):
    """Maps `map_func` across this dataset and flattens the result.

    Use `flat_map` if you want to make sure that the order of your dataset
    stays the same. For example, to flatten a dataset of batches into a
    dataset of their elements:

    ```python
    # NOTE: The following examples use `{ ... }` to represent the
    # contents of a dataset. '[...]' represents a tensor.
    a = {[1,2,3,4,5], [6,7,8,9], [10]}

    a.flat_map(lambda x: Dataset.from_tensor_slices(x)) ==
      {[1,2,3,4,5,6,7,8,9,10]}
    ```

    `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
    `flat_map` produces the same output as
    `tf.data.Dataset.interleave(cycle_length=1)`

    Args:
      map_func: A function mapping a nested structure of tensors (having shapes
        and types defined by `self.output_shapes` and `self.output_types`) to a
        `Dataset`.

    Returns:
      Dataset: A `Dataset`.
    """
    return FlatMapDataset(self, map_func)

  def interleave(self,
                 map_func,
                 cycle_length,
                 block_length=1,
                 num_parallel_calls=None):
    """Maps `map_func` across this dataset, and interleaves the results.

    For example, you can use `Dataset.interleave()` to process many input files
    concurrently:

    ```python
    # Preprocess 4 files concurrently, and interleave blocks of 16 records from
    # each file.
    filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...]
    dataset = (Dataset.from_tensor_slices(filenames)
               .interleave(lambda x:
                   TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
                   cycle_length=4, block_length=16))
    ```

    The `cycle_length` and `block_length` arguments control the order in which
    elements are produced. `cycle_length` controls the number of input elements
    that are processed concurrently. If you set `cycle_length` to 1, this
    transformation will handle one input element at a time, and will produce
    identical results to `tf.data.Dataset.flat_map`. In general,
    this transformation will apply `map_func` to `cycle_length` input elements,
    open iterators on the returned `Dataset` objects, and cycle through them
    producing `block_length` consecutive elements from each iterator, and
    consuming the next input element each time it reaches the end of an
    iterator.

    For example:

    ```python
    # NOTE: The following examples use `{ ... }` to represent the
    # contents of a dataset.
    a = { 1, 2, 3, 4, 5 }

    # NOTE: New lines indicate "block" boundaries.
    a.interleave(lambda x: Dataset.from_tensors(x).repeat(6),
                 cycle_length=2, block_length=4) == {
        1, 1, 1, 1,
        2, 2, 2, 2,
        1, 1,
        2, 2,
        3, 3, 3, 3,
        4, 4, 4, 4,
        3, 3,
        4, 4,
        5, 5, 5, 5,
        5, 5,
    }
    ```

    NOTE: The order of elements yielded by this transformation is
    deterministic, as long as `map_func` is a pure function. If
    `map_func` contains any stateful operations, the order in which
    that state is accessed is undefined.

    Args:
      map_func: A function mapping a nested structure of tensors (having shapes
        and types defined by `self.output_shapes` and `self.output_types`) to a
        `Dataset`.
      cycle_length: The number of elements from this dataset that will be
        processed concurrently.
      block_length: The number of consecutive elements to produce from each
        input element before cycling to another input element.
      num_parallel_calls: (Optional.) If specified, the implementation creates
        a threadpool, which is used to fetch inputs from cycle elements
        asynchronously and in parallel. The default behavior is to fetch inputs
        from cycle elements synchronously with no parallelism.

    Returns:
      Dataset: A `Dataset`.
    """
    if num_parallel_calls is None:
      return InterleaveDataset(self, map_func, cycle_length, block_length)
    else:
      return ParallelInterleaveDataset(self, map_func, cycle_length,
                                       block_length, num_parallel_calls)

  def filter(self, predicate):
    """Filters this dataset according to `predicate`.

    Args:
      predicate: A function mapping a nested structure of tensors (having shapes
        and types defined by `self.output_shapes` and `self.output_types`) to a
        scalar `tf.bool` tensor.

    Returns:
      Dataset: The `Dataset` containing the elements of this dataset for which
          `predicate` is `True`.
    """
    return FilterDataset(self, predicate)

  def apply(self, transformation_func):
    """Applies a transformation function to this dataset.

    `apply` enables chaining of custom `Dataset` transformations, which are
    represented as functions that take one `Dataset` argument and return a
    transformed `Dataset`.

    For example:

    ```
    dataset = (dataset.map(lambda x: x ** 2)
               .apply(group_by_window(key_func, reduce_func, window_size))
               .map(lambda x: x ** 3))
    ```

    Args:
      transformation_func: A function that takes one `Dataset` argument and
        returns a `Dataset`.

    Returns:
      Dataset: The `Dataset` returned by applying `transformation_func` to this
          dataset.
    """
    dataset = transformation_func(self)
    if not isinstance(dataset, Dataset):
      raise TypeError("`transformation_func` must return a Dataset.")
    dataset._input_datasets = [self]  # pylint: disable=protected-access
    return dataset

  def window(self, size, shift=None, stride=1, drop_remainder=False):
    """Combines input elements into a dataset of windows.

    Each window is a dataset itself and contains `size` elements (or
    possibly fewer if there are not enough input elements to fill the window
    and `drop_remainder` evaluates to false).

    The `stride` argument determines the stride of the input elements,
    and the `shift` argument determines the shift of the window.

    For example:
    - `tf.data.Dataset.range(7).window(2)` produces
      `{{0, 1}, {2, 3}, {4, 5}, {6}}`
    - `tf.data.Dataset.range(7).window(3, 2, 1, True)` produces
      `{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}`
    - `tf.data.Dataset.range(7).window(3, 1, 2, True)` produces
      `{{0, 2, 4}, {1, 3, 5}, {2, 4, 6}}`

    Args:
      size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements
        of the input dataset to combine into a window.
      shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
        forward shift of the sliding window in each iteration. Defaults to
        `size`.
      stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
        stride of the input elements in the sliding window.
      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
        whether a window should be dropped in case its size is smaller than
        `window_size`.

    Returns:
      Dataset: A `Dataset` of windows, each of which is a nested `Dataset` with
        the same structure as this dataset, but a finite subsequence of its
        elements.
    """
    if shift is None:
      shift = size
    return WindowDataset(self, size, shift, stride, drop_remainder)

  def reduce(self, initial_state, reduce_func):
    """Reduces the input dataset to a single element.

    The transformation calls `reduce_func` successively on every element of
    the input dataset until the dataset is exhausted, aggregating information in
    its internal state. The `initial_state` argument is used for the initial
    state and the final state is returned as the result.

    For example:
    - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1)`
      produces `5`
    - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y)`
      produces `10`

    Args:
      initial_state: A nested structure of tensors, representing the initial
        state of the transformation.
      reduce_func: A function that maps `(old_state, input_element)` to
        `new_state`. It must take two arguments and return a nested structure
        of tensors. The structure of `new_state` must match the structure of
        `initial_state`.

    Returns:
      A nested structure of `tf.Tensor` objects, corresponding to the final
      state of the transformation.

    """

    with ops.name_scope("initial_state"):
      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      initial_state = nest.pack_sequence_as(initial_state, [
          sparse_tensor_lib.SparseTensor.from_value(t)
          if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
              t, name="component_%d" % i)
          for i, t in enumerate(nest.flatten(initial_state))
      ])

    # Compute initial values for the state classes, shapes and types based on
    # the initial state.
    state_classes = sparse.get_classes(initial_state)
    state_shapes = nest.pack_sequence_as(
        initial_state, [t.get_shape() for t in nest.flatten(initial_state)])
    state_types = nest.pack_sequence_as(
        initial_state, [t.dtype for t in nest.flatten(initial_state)])

    # Iteratively rerun the reduce function until reaching a fixed point on
    # `self._state_shapes`.
    need_to_rerun = True
    while need_to_rerun:

      wrapped_func = StructuredFunctionWrapper(
          reduce_func,
          "reduce()",
          input_classes=(state_classes, self.output_classes),
          input_shapes=(state_shapes, self.output_shapes),
          input_types=(state_types, self.output_types),
          add_to_graph=False)

      # Extract and validate class information from the returned values.
      output_classes = wrapped_func.output_classes
      for new_state_class, state_class in zip(
          nest.flatten(output_classes), nest.flatten(state_classes)):
        if not issubclass(new_state_class, state_class):
          raise TypeError(
              "The element classes for the new state must match the initial "
              "state. Expected %s; got %s." % (state_classes,
                                               wrapped_func.output_classes))

      # Extract and validate type information from the returned values.
      output_types = wrapped_func.output_types
      for new_state_type, state_type in zip(
          nest.flatten(output_types), nest.flatten(state_types)):
        if new_state_type != state_type:
          raise TypeError(
              "The element types for the new state must match the initial "
              "state. Expected %s; got %s." % (state_types,
                                               wrapped_func.output_types))

      # Extract shape information from the returned values.
      output_shapes = wrapped_func.output_shapes
      flat_state_shapes = nest.flatten(state_shapes)
      flat_new_state_shapes = nest.flatten(output_shapes)
      weakened_state_shapes = [
          original.most_specific_compatible_shape(new)
          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
      ]

      need_to_rerun = False
      for original_shape, weakened_shape in zip(flat_state_shapes,
                                                weakened_state_shapes):
        if original_shape.ndims is not None and (
            weakened_shape.ndims is None or
            original_shape.as_list() != weakened_shape.as_list()):
          need_to_rerun = True
          break

      if need_to_rerun:
        state_shapes = nest.pack_sequence_as(state_shapes,
                                             weakened_state_shapes)

    reduce_func = wrapped_func.function
    reduce_func.add_to_graph(ops.get_default_graph())

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(
            output_types,
            gen_dataset_ops.reduce_dataset(
                self._as_variant_tensor(),  # pylint: disable=protected-access
                nest.flatten(sparse.serialize_sparse_tensors(initial_state)),
                reduce_func.captured_inputs,
                f=reduce_func,
                output_shapes=nest.flatten(
                    sparse.as_dense_shapes(output_shapes, output_classes)),
                output_types=nest.flatten(
                    sparse.as_dense_types(output_types, output_classes)))),
        output_types,
        output_shapes,
        output_classes)

  def with_options(self, options):
    """Returns a new `tf.data.Dataset` with the given options set.

    The options are "global" in the sense they apply to the entire input
    pipeline in which the `with_options` transformation is used. If options are
    set multiple times, they are merged if possible (see
    `tf.data.Options.merge()` for details).

    Args:
      options: A `tf.data.Options` that identifies the options the use.

    Returns:
      Dataset: A `Dataset` with the given options.

    Raises:
      ValueError: if options are set more than once
    """
    return _OptionsDataset(self, options)


@tf_export("data.Options")
class Options(object):
  """Represents options for tf.data.Dataset.

  An `Options` object can be for instance used to control which static
  optimizations to apply or whether to use performance modeling to dynamically
  tune the parallelism of operations such as `tf.data.Dataset.map` or
  `tf.data.Dataset.interleave`.
  """
  for _name, _ty, _docstring in [
      ("experimental_autotune", bool,
       "Whether to dynamically adjust the values of tunable parameters (e.g. "
       "degrees of parallelism)."),
      ("experimental_filter_fusion", bool,
       "Whether to fuse filter transformations."),
      ("experimental_hoist_random_uniform", bool,
       "Whether to hoist `tf.random_uniform()` ops out of map transformations."
      ),
      ("experimental_latency_all_edges", bool,
       "Whether to add latency measurements on all edges."),
      ("experimental_map_and_batch_fusion", bool,
       "Whether to fuse map and batch transformations."),
      ("experimental_map_and_filter_fusion", bool,
       "Whether to fuse map and filter transformations."),
      ("experimental_map_fusion", bool, "Whether to fuse map transformations."),
      ("experimental_map_parallelization", bool,
       "Whether to parallelize stateless map transformations."),
      ("experimental_map_vectorization", bool,
       "Whether to vectorize map transformations."),
      ("experimental_noop_elimination", bool,
       "Whether to eliminate no-op transformations."),
      ("experimental_shuffle_and_repeat_fusion", bool,
       "Whether to fuse shuffle and repeat transformations."),
      ("experimental_numa_aware", bool,
       "Whether to use NUMA-aware operations."),
  ]:

    def _make_getter(name):  # pylint: disable=no-self-argument

      def getter(self):
        return getattr(self, "_" + name)

      return getter

    def _make_setter(name, ty):  # pylint: disable=no-self-argument

      def setter(self, value):
        if not isinstance(value, ty):
          raise TypeError(
              "Attempting to set the option %s to incompatible value: %r" %
              (name, value))
        setattr(self, "_" + name, value)

      return setter

    vars()["_" + _name] = None
    vars()[_name] = property(
        _make_getter(_name), _make_setter(_name, _ty), None, _docstring)

  def __init__(self):
    pass

  def __eq__(self, other):
    if isinstance(other, self.__class__):
      return self.__dict__ == other.__dict__
    else:
      return False

  def __ne__(self, other):
    return not self.__eq__(other)

  def _static_optimizations(self):
    """Produces the list of enabled static optimizations."""
    experimental_optimizations = [
        "filter_fusion", "hoist_random_uniform", "latency_all_edges",
        "map_and_batch_fusion", "map_and_filter_fusion", "map_fusion",
        "map_parallelization", "map_vectorization", "noop_elimination",
        "shuffle_and_repeat_fusion"
    ]
    result = []
    for exp_opt in experimental_optimizations:
      if getattr(self, "experimental_" + exp_opt):
        result.append(exp_opt)

    if getattr(self, "experimental_numa_aware"):
      result.append("map_and_batch_numa_aware_replacement")
    return result

  def merge(self, options):
    """Merges itself with the given `tf.data.Options`.

    The given `tf.data.Options` can be merged as long as there does not exist an
    attribute that is set to different values in `self` and `options`.

    Args:
      options: a `tf.data.Options` to merge with

    Raises:
      ValueError: if the given `tf.data.Options` cannot be merged

    Returns:
      New `tf.data.Options()` object which is the result of merging self with
      the input `tf.data.Options`.
    """
    result = Options()
    for other in [self, options]:
      for name in [
          "experimental_autotune", "experimental_filter_fusion",
          "experimental_hoist_random_uniform", "experimental_latency_all_edges",
          "experimental_map_and_batch_fusion",
          "experimental_map_and_filter_fusion", "experimental_map_fusion",
          "experimental_map_parallelization", "experimental_map_vectorization",
          "experimental_noop_elimination",
          "experimental_shuffle_and_repeat_fusion", "experimental_numa_aware",
      ]:
        this = getattr(result, name)
        that = getattr(other, name)
        if that is not None:
          if this is None:
            setattr(result, name, that)
          elif this != that:
            raise ValueError(
                "Cannot merge incompatible values of option: %s" % (name))
    return result


class DatasetSource(Dataset):
  """Abstract class representing a dataset with no inputs."""

  def _inputs(self):
    return []


class UnaryDataset(Dataset):
  """Abstract class representing a dataset with one input."""

  def __init__(self, input_dataset):
    super(UnaryDataset, self).__init__()
    self._input_dataset = input_dataset

  def _inputs(self):
    return [self._input_dataset]


class TensorDataset(DatasetSource):
  """A `Dataset` with a single element, viz. a nested structure of tensors."""

  def __init__(self, tensors):
    """See `Dataset.from_tensors()` for details."""
    super(TensorDataset, self).__init__()
    with ops.name_scope("tensors"):
      tensors = nest.pack_sequence_as(tensors, [
          sparse_tensor_lib.SparseTensor.from_value(t)
          if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
              t, name="component_%d" % i)
          for i, t in enumerate(nest.flatten(tensors))
      ])

    self._tensors = sparse.serialize_sparse_tensors(tensors)
    self._output_classes = sparse.get_classes(tensors)
    self._output_shapes = nest.pack_sequence_as(
        tensors, [t.get_shape() for t in nest.flatten(tensors)])
    self._output_types = nest.pack_sequence_as(
        tensors, [t.dtype for t in nest.flatten(tensors)])

  def _as_variant_tensor(self):
    return gen_dataset_ops.tensor_dataset(
        nest.flatten(self._tensors),
        output_shapes=nest.flatten(
            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))

  @property
  def output_classes(self):
    return self._output_classes

  @property
  def output_shapes(self):
    return self._output_shapes

  @property
  def output_types(self):
    return self._output_types


class TensorSliceDataset(DatasetSource):
  """A `Dataset` of slices from a nested structure of tensors."""

  def __init__(self, tensors):
    """See `Dataset.from_tensor_slices()` for details."""
    super(TensorSliceDataset, self).__init__()
    with ops.name_scope("tensors"):
      tensors = nest.pack_sequence_as(tensors, [
          sparse_tensor_lib.SparseTensor.from_value(t)
          if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
              t, name="component_%d" % i)
          for i, t in enumerate(nest.flatten(tensors))
      ])
      flat_tensors = nest.flatten(tensors)

    batch_dim = flat_tensors[0].get_shape()[0]
    for t in flat_tensors[1:]:
      batch_dim.assert_is_compatible_with(t.get_shape()[0])
    self._tensors = sparse.serialize_many_sparse_tensors(tensors)
    self._output_classes = sparse.get_classes(tensors)
    self._output_shapes = nest.pack_sequence_as(
        tensors, [t.get_shape()[1:] for t in nest.flatten(tensors)])
    self._output_types = nest.pack_sequence_as(
        tensors, [t.dtype for t in nest.flatten(tensors)])

  def _as_variant_tensor(self):
    return gen_dataset_ops.tensor_slice_dataset(
        nest.flatten(self._tensors),
        output_shapes=nest.flatten(
            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))

  @property
  def output_classes(self):
    return self._output_classes

  @property
  def output_shapes(self):
    return self._output_shapes

  @property
  def output_types(self):
    return self._output_types


class SparseTensorSliceDataset(DatasetSource):
  """A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows."""

  def __init__(self, sparse_tensor):
    """See `Dataset.from_sparse_tensor_slices()` for details."""
    super(SparseTensorSliceDataset, self).__init__()
    if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor):
      raise TypeError("`sparse_tensor` must be a `tf.SparseTensor` object.")
    self._sparse_tensor = sparse_tensor

  def _as_variant_tensor(self):
    return gen_dataset_ops.sparse_tensor_slice_dataset(
        self._sparse_tensor.indices, self._sparse_tensor.values,
        self._sparse_tensor.dense_shape)

  @property
  def output_classes(self):
    return (ops.Tensor, ops.Tensor, ops.Tensor)

  @property
  def output_shapes(self):
    indices_shape = self._sparse_tensor.indices.get_shape()
    shape_shape = self._sparse_tensor.dense_shape.get_shape()
    rank = (indices_shape[1] - 1).merge_with(shape_shape[0] - 1)
    num_values = tensor_shape.Dimension(None)
    return (tensor_shape.TensorShape([num_values, rank]),
            tensor_shape.TensorShape([num_values]),
            tensor_shape.TensorShape([rank]))

  @property
  def output_types(self):
    return (dtypes.int64, self._sparse_tensor.dtype, dtypes.int64)


class _NestedDatasetComponent(object):
  """The structure of a `Dataset` nested in a component of another `Dataset`.

  A `StructuredFunctionWrapper` around a function that returns a `Dataset` as
  one of its components will have a `NestedDatasetComponent` in the
  corresponding position in the `output_classes`, `output_shapes`, and
  `output_types` properties.

  NOTE(mrry): This class is not currently exposed via the public API. Support
  for nested datasets can be enabled on a function-by-function basis by setting
  `experimental_nested_dataset_support=True` in the `StructuredFunctionWrapper`
  initializer.

  TODO(b/110122868): Add this class, or something equivalent, to the public API.
  We are considering revising the public API for accessing Dataset structure
  (`output_classes` etc.) based on experience with nested datasets and other
  custom component types.
  """

  def __init__(self,
               dataset=None,
               output_shapes=None,
               output_types=None,
               output_classes=None):
    if dataset is None:
      if (output_classes is None or output_shapes is None or
          output_types is None):
        raise ValueError(
            "Either `dataset`, or all of `output_classes`, "
            "`output_shapes`, and `output_types` must be specified.")
      self._output_classes = output_classes
      self._output_shapes = output_shapes
      self._output_types = output_types
    else:
      if not (output_classes is None and output_shapes is None and
              output_types is None):
        raise ValueError(
            "Either `dataset`, or all of `output_classes`, "
            "`output_shapes`, and `output_types` must be specified.")
      self._output_classes = dataset.output_classes
      self._output_shapes = dataset.output_shapes
      self._output_types = dataset.output_types

  @property
  def output_classes(self):
    return self._output_classes

  @property
  def output_shapes(self):
    return self._output_shapes

  @property
  def output_types(self):
    return self._output_types


class _VariantDataset(Dataset):
  """A Dataset wrapper around a `tf.variant`-typed function argument."""

  def __init__(self, dataset_variant, structure):
    super(_VariantDataset, self).__init__()
    self._dataset_variant = dataset_variant
    self._structure = structure

  def _as_variant_tensor(self):
    return self._dataset_variant

  def _inputs(self):
    return []

  @property
  def output_classes(self):
    return self._structure.output_classes

  @property
  def output_shapes(self):
    return self._structure.output_shapes

  @property
  def output_types(self):
    return self._structure.output_types


class StructuredFunctionWrapper(object):
  """A wrapper for `Defun` that supports structured arguments and return values.
  """

  def __init__(self, func, transformation_name, dataset=None,
               input_classes=None, input_shapes=None, input_types=None,
               add_to_graph=True, experimental_nested_dataset_support=False):
    """Creates a new `StructuredFunctionWrapper` for the given function.

    Args:
      func: A function from a nested structure to another nested structure.
      transformation_name: Human-readable name of the transformation in which
        this function is being instantiated, for error messages.
      dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this
        dataset will be assumed as the structure for `func` arguments; otherwise
        `input_classes`, `input_shapes`, and `input_types` must be defined.
      input_classes: (Optional.) A nested structure of `type`. If given, this
        argument defines the Python types for `func` arguments.
      input_shapes: (Optional.) A nested structure of `tf.TensorShape`. If
        given, this argument defines the shapes and structure for `func`
        arguments.
      input_types: (Optional.) A nested structure of `tf.DType`. If given, this
        argument defines the element types and structure for `func` arguments.
      add_to_graph: (Optional.) If `True`, the function will be added to the
        default graph.
      experimental_nested_dataset_support: (Optional.) If `True`, the function
        will support `tf.data.Dataset` objects as arguments and return values.

    Raises:
      ValueError: If an invalid combination of `dataset`, `input_classes`,
        `input_shapes`, and `input_types` is passed.
    """
    if dataset is None:
      if input_classes is None or input_shapes is None or input_types is None:
        raise ValueError("Either `dataset`, or all of `input_classes`, "
                         "`input_shapes`, and `input_types` must be specified.")
      self._input_shapes = input_shapes
      self._input_types = input_types
      self._input_classes = input_classes
    else:
      if not (input_classes is None and input_shapes is None and
              input_types is None):
        raise ValueError("Either `dataset`, or all of `input_classes`, "
                         "`input_shapes`, and `input_types` must be specified.")
      self._input_shapes = dataset.output_shapes
      self._input_types = dataset.output_types
      self._input_classes = dataset.output_classes

    self._transformation_name = transformation_name

    # TODO(b/110122868): Enable this support for all `tf.data` functions.
    self._nested_dataset_support = experimental_nested_dataset_support

    @function.Defun(*self._defun_args())
    def tf_data_structured_function_wrapper(*args):
      """Wrapper for passing nested structures to and from tf.data functions."""
      flat_args = []
      for arg, arg_class, arg_shape, arg_type in zip(
          args,
          nest.flatten(self._input_classes),
          nest.flatten(self._input_shapes),
          nest.flatten(self._input_types)):
        # TODO(b/110122868): Add a registration mechanism for new component
        # types.
        if arg_class is sparse_tensor_lib.SparseTensor:
          arg = sparse.deserialize_sparse_tensors(
              arg, arg_type, arg_shape, arg_class)
          arg.indices.set_shape([None, arg_shape.ndims])
          arg.dense_shape.set_shape([arg_shape.ndims])
        elif isinstance(arg_class, _NestedDatasetComponent):
          assert self._nested_dataset_support
          arg = _VariantDataset(arg, arg_class)
        else:
          arg.set_shape(arg_shape)
        flat_args.append(arg)
      nested_args = nest.pack_sequence_as(self._input_classes, flat_args)
      if not _should_unpack_args(nested_args):
        nested_args = (nested_args,)

      ret = func(*nested_args)
      # If `func` returns a list of tensors, `nest.flatten()` and
      # `ops.convert_to_tensor()` would conspire to attempt to stack
      # those tensors into a single tensor, because the customized
      # version of `nest.flatten()` does not recurse into lists. Since
      # it is more likely that the list arose from returning the
      # result of an operation (such as `tf.py_func()`) that returns a
      # list of not-necessarily-stackable tensors, we treat the
      # returned value is a `tuple` instead. A user wishing to pack
      # the return value into a single tensor can use an explicit
      # `tf.stack()` before returning.
      if isinstance(ret, list):
        ret = tuple(ret)

      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      flat_ret = []
      flat_classes = []
      flat_shapes = []
      flat_types = []
      for t in nest.flatten(ret):
        # TODO(b/110122868): Add a registration mechanism for new component
        # types.
        if sparse_tensor_lib.is_sparse(t):
          t = sparse_tensor_lib.SparseTensor.from_value(t)
          flat_ret.append(sparse.serialize_sparse_tensors(t))
          flat_classes.append(sparse_tensor_lib.SparseTensor)
          flat_shapes.append(t.get_shape())
          flat_types.append(t.dtype)
        elif isinstance(t, Dataset):
          if not self._nested_dataset_support:
            raise NotImplementedError(
                "The %s transformation does not currently support nested "
                "datasets as outputs." % self._transformation_name)

          flat_ret.append(t._as_variant_tensor())  # pylint: disable=protected-access
          component = _NestedDatasetComponent(t)
          flat_classes.append(component)
          flat_shapes.append(component)
          flat_types.append(component)
          if t.options() != Options():
            warnings.warn("Encountered a nested dataset with non-default "
                          "options. These options will not be propagated to "
                          "the outer dataset.")
        else:
          try:
            t = ops.convert_to_tensor(t)
          except (ValueError, TypeError):
            raise TypeError("Unsupported return value from function passed to "
                            "%s: %s." % (transformation_name, t))
          flat_ret.append(t)
          flat_classes.append(ops.Tensor)
          flat_shapes.append(t.get_shape())
          flat_types.append(t.dtype)

      ret = nest.pack_sequence_as(ret, flat_ret)
      self._output_classes = nest.pack_sequence_as(ret, flat_classes)
      self._output_shapes = nest.pack_sequence_as(ret, flat_shapes)
      self._output_types = nest.pack_sequence_as(ret, flat_types)

      _warn_if_collections(transformation_name)

      return flat_ret

    self._function = tf_data_structured_function_wrapper
    if add_to_graph:
      self._function.add_to_graph(ops.get_default_graph())
    else:
      # Use the private method that will execute
      # `tf_data_structured_function_wrapper` but delay adding it to the graph
      # in case (e.g.) we need to rerun the function.
      self._function._create_definition_if_needed()  # pylint: disable=protected-access

  def _defun_args(self):
    """Returns a flat list of `tf.DType` for the input element structure."""
    ret = []
    for input_type, input_class in zip(nest.flatten(self._input_types),
                                       nest.flatten(self._input_classes)):
      # TODO(b/110122868): Add a registration mechanism for new component types.
      if input_class is sparse_tensor_lib.SparseTensor:
        ret.append(dtypes.variant)
      elif isinstance(input_class, _NestedDatasetComponent):
        if not self._nested_dataset_support:
          raise NotImplementedError(
              "The %s transformation does not currently support nested "
              "datasets as inputs." % self._transformation_name)
        ret.append(dtypes.variant)
      else:
        assert isinstance(input_type, dtypes.DType)
        ret.append(input_type)
    return ret

  @property
  def output_classes(self):
    return self._output_classes

  @property
  def output_shapes(self):
    return self._output_shapes

  @property
  def output_types(self):
    return self._output_types

  @property
  def function(self):
    return self._function


def flat_structure(dataset):
  """Helper for setting `output_shapes` and `output_types` attrs of Dataset ops.

  Most Dataset op constructors expect `output_shapes` and `output_types`
  arguments that represent the flattened structure of an element. This helper
  function generates these attrs as a keyword argument dictionary, allowing
  `Dataset._as_variant_tensor()` implementations to pass
  `**flat_structure(self)` to the op constructor.

  Args:
    dataset: A `tf.data.Dataset`.

  Returns:
    A dictionary of keyword arguments that can be passed to many Dataset op
    constructors.
  """
  output_classes = []
  output_shapes = []
  output_types = []
  for output_class, output_shape, output_type in zip(
      nest.flatten(dataset.output_classes), nest.flatten(dataset.output_shapes),
      nest.flatten(dataset.output_types)):
    if isinstance(output_class, _NestedDatasetComponent):
      output_classes.append(output_class.output_classes)
      output_shapes.append(output_shape.output_shapes)
      output_types.append(output_type.output_types)
    else:
      output_classes.append(output_class)
      output_shapes.append(output_shape)
      output_types.append(output_type)

  output_classes = nest.pack_sequence_as(dataset.output_classes, output_classes)
  output_shapes = nest.pack_sequence_as(dataset.output_shapes, output_shapes)
  output_types = nest.pack_sequence_as(dataset.output_types, output_types)

  return {
      "output_shapes":
          nest.flatten(sparse.as_dense_shapes(output_shapes, output_classes)),
      "output_types":
          nest.flatten(sparse.as_dense_types(output_types, output_classes)),
  }


class _GeneratorDataset(DatasetSource):
  """A `Dataset` that generates elements by invoking a function."""

  def __init__(self, init_args, init_func, next_func, finalize_func):
    """Constructs a `_GeneratorDataset`.

    Args:
      init_args: A nested structure representing the arguments to `init_func`.
      init_func: A TensorFlow function that will be called on `init_args` each
        time a C++ iterator over this dataset is constructed. Returns a nested
        structure representing the "state" of the dataset.
      next_func: A TensorFlow function that will be called on the result of
        `init_func` to produce each element, and that raises `OutOfRangeError`
        to terminate iteration.
      finalize_func: A TensorFlow function that will be called on the result of
        `init_func` immediately before a C++ iterator over this dataset is
        destroyed. The return value is ignored.
    """
    super(_GeneratorDataset, self).__init__()
    # These members will be initialized by `tf_init_func`.
    self._state_classes = None
    self._state_shapes = None
    self._state_types = None

    self._init_args = init_args

    init_args_classes = sparse.get_classes(init_args)
    init_args_shapes = nest.pack_sequence_as(
        init_args, [t.get_shape() for t in nest.flatten(init_args)])
    init_args_types = nest.pack_sequence_as(
        init_args, [t.dtype for t in nest.flatten(init_args)])

    wrapped_init_func = StructuredFunctionWrapper(
        init_func, "GeneratorDataset", input_classes=init_args_classes,
        input_shapes=init_args_shapes, input_types=init_args_types)
    self._state_classes = wrapped_init_func.output_classes
    self._state_shapes = wrapped_init_func.output_shapes
    self._state_types = wrapped_init_func.output_types
    self._init_func = wrapped_init_func.function

    wrapped_next_func = StructuredFunctionWrapper(
        next_func, "GeneratorDataset", input_classes=self._state_classes,
        input_shapes=self._state_shapes, input_types=self._state_types)
    self._output_classes = wrapped_next_func.output_classes
    self._output_shapes = wrapped_next_func.output_shapes
    self._output_types = wrapped_next_func.output_types
    self._next_func = wrapped_next_func.function

    wrapped_finalize_func = StructuredFunctionWrapper(
        finalize_func, "GeneratorDataset", input_classes=self._state_classes,
        input_shapes=self._state_shapes, input_types=self._state_types)
    self._finalize_func = wrapped_finalize_func.function

  def _as_variant_tensor(self):
    return gen_dataset_ops.generator_dataset(
        nest.flatten(self._init_args) + self._init_func.captured_inputs,
        self._next_func.captured_inputs,
        self._finalize_func.captured_inputs,
        init_func=self._init_func,
        next_func=self._next_func,
        finalize_func=self._finalize_func,
        **flat_structure(self))

  @property
  def output_classes(self):
    return self._output_classes

  @property
  def output_shapes(self):
    return self._output_shapes

  @property
  def output_types(self):
    return self._output_types


class ZipDataset(Dataset):
  """A `Dataset` that zips its inputs together."""

  def __init__(self, datasets):
    """See `Dataset.zip()` for details."""
    super(ZipDataset, self).__init__()
    for ds in nest.flatten(datasets):
      if not isinstance(ds, Dataset):
        if isinstance(ds, list):
          message = ("The argument to `Dataset.zip()` must be a nested "
                     "structure of `Dataset` objects. Nested structures do not "
                     "support Python lists; please use a tuple instead.")
        else:
          message = ("The argument to `Dataset.zip()` must be a nested "
                     "structure of `Dataset` objects.")
        raise TypeError(message)
    self._datasets = datasets

  def _as_variant_tensor(self):
    # pylint: disable=protected-access
    return gen_dataset_ops.zip_dataset(
        [ds._as_variant_tensor() for ds in nest.flatten(self._datasets)],
        **flat_structure(self))
    # pylint: enable=protected-access

  def _inputs(self):
    return nest.flatten(self._datasets)

  @property
  def output_classes(self):
    return nest.pack_sequence_as(
        self._datasets,
        [ds.output_classes for ds in nest.flatten(self._datasets)])

  @property
  def output_shapes(self):
    return nest.pack_sequence_as(
        self._datasets,
        [ds.output_shapes for ds in nest.flatten(self._datasets)])

  @property
  def output_types(self):
    return nest.pack_sequence_as(
        self._datasets,
        [ds.output_types for ds in nest.flatten(self._datasets)])


class ConcatenateDataset(Dataset):
  """A `Dataset` that concatenates its input with given dataset."""

  def __init__(self, input_dataset, dataset_to_concatenate):
    """See `Dataset.concatenate()` for details."""
    super(ConcatenateDataset, self).__init__()
    self._input_dataset = input_dataset
    self._dataset_to_concatenate = dataset_to_concatenate
    if input_dataset.output_types != dataset_to_concatenate.output_types:
      raise TypeError(
          "Two datasets to concatenate have different types %s and %s" %
          (input_dataset.output_types, dataset_to_concatenate.output_types))
    if input_dataset.output_classes != dataset_to_concatenate.output_classes:
      raise TypeError(
          "Two datasets to concatenate have different classes %s and %s" %
          (input_dataset.output_classes, dataset_to_concatenate.output_classes))
    self._input_datasets = [input_dataset, dataset_to_concatenate]

  def _as_variant_tensor(self):
    # pylint: disable=protected-access
    return gen_dataset_ops.concatenate_dataset(
        self._input_dataset._as_variant_tensor(),
        self._dataset_to_concatenate._as_variant_tensor(),
        **flat_structure(self))
    # pylint: enable=protected-access

  def _inputs(self):
    return [self._input_dataset, self._dataset_to_concatenate]

  @property
  def output_classes(self):
    return self._input_dataset.output_classes

  @property
  def output_shapes(self):
    return nest.pack_sequence_as(self._input_dataset.output_shapes, [
        ts1.most_specific_compatible_shape(ts2)
        for (ts1, ts2) in zip(
            nest.flatten(self._input_dataset.output_shapes),
            nest.flatten(self._dataset_to_concatenate.output_shapes))
    ])

  @property
  def output_types(self):
    return self._input_dataset.output_types


class RepeatDataset(UnaryDataset):
  """A `Dataset` that repeats its input several times."""

  def __init__(self, input_dataset, count):
    """See `Dataset.repeat()` for details."""
    super(RepeatDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    if count is None:
      self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
    else:
      self._count = ops.convert_to_tensor(
          count, dtype=dtypes.int64, name="count")

  def _as_variant_tensor(self):
    return gen_dataset_ops.repeat_dataset(
        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
        count=self._count,
        **flat_structure(self))

  @property
  def output_classes(self):
    return self._input_dataset.output_classes

  @property
  def output_shapes(self):
    return self._input_dataset.output_shapes

  @property
  def output_types(self):
    return self._input_dataset.output_types


class RangeDataset(DatasetSource):
  """A `Dataset` of a step separated range of values."""

  def __init__(self, *args):
    """See `Dataset.range()` for details."""
    super(RangeDataset, self).__init__()
    self._parse_args(*args)

  def _parse_args(self, *args):
    """Parse arguments according to the same rules as the `range()` builtin."""
    if len(args) == 1:
      self._start = self._build_tensor(0, "start")
      self._stop = self._build_tensor(args[0], "stop")
      self._step = self._build_tensor(1, "step")
    elif len(args) == 2:
      self._start = self._build_tensor(args[0], "start")
      self._stop = self._build_tensor(args[1], "stop")
      self._step = self._build_tensor(1, "step")
    elif len(args) == 3:
      self._start = self._build_tensor(args[0], "start")
      self._stop = self._build_tensor(args[1], "stop")
      self._step = self._build_tensor(args[2], "step")
    else:
      raise ValueError("Invalid arguments to RangeDataset: %s" % str(args))

  def _build_tensor(self, int64_value, name):
    return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)

  def _as_variant_tensor(self):
    return gen_dataset_ops.range_dataset(
        start=self._start,
        stop=self._stop,
        step=self._step,
        **flat_structure(self))

  @property
  def output_classes(self):
    return ops.Tensor

  @property
  def output_shapes(self):
    return tensor_shape.scalar()

  @property
  def output_types(self):
    return dtypes.int64


class CacheDataset(UnaryDataset):
  """A `Dataset` that caches elements of its input."""

  def __init__(self, input_dataset, filename):
    """See `Dataset.cache()` for details."""
    super(CacheDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    self._filename = ops.convert_to_tensor(
        filename, dtype=dtypes.string, name="filename")

  def _as_variant_tensor(self):
    return gen_dataset_ops.cache_dataset(
        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
        filename=self._filename,
        **flat_structure(self))

  @property
  def output_classes(self):
    return self._input_dataset.output_classes

  @property
  def output_shapes(self):
    return self._input_dataset.output_shapes

  @property
  def output_types(self):
    return self._input_dataset.output_types


class ShuffleDataset(UnaryDataset):
  """A `Dataset` that randomly shuffles the elements of its input."""

  def __init__(self,
               input_dataset,
               buffer_size,
               seed=None,
               reshuffle_each_iteration=None):
    """Randomly shuffles the elements of this dataset.

    Args:
      input_dataset: The input dataset.
      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
        number of elements from this dataset from which the new
        dataset will sample.
      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
        random seed that will be used to create the distribution. See
        `tf.set_random_seed` for behavior.
      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
        that the dataset should be pseudorandomly reshuffled each time it is
        iterated over. (Defaults to `True`.)

    Returns:
      A `Dataset`.

    Raises:
      ValueError: if invalid arguments are provided.
    """
    super(ShuffleDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    self._buffer_size = ops.convert_to_tensor(
        buffer_size, dtype=dtypes.int64, name="buffer_size")
    self._seed, self._seed2 = random_seed.get_seed(seed)

    if reshuffle_each_iteration is None:
      self._reshuffle_each_iteration = True
    else:
      self._reshuffle_each_iteration = reshuffle_each_iteration

  def _as_variant_tensor(self):
    return gen_dataset_ops.shuffle_dataset(
        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
        buffer_size=self._buffer_size,
        seed=self._seed,
        seed2=self._seed2,
        reshuffle_each_iteration=self._reshuffle_each_iteration,
        **flat_structure(self))

  @property
  def output_classes(self):
    return self._input_dataset.output_classes

  @property
  def output_shapes(self):
    return self._input_dataset.output_shapes

  @property
  def output_types(self):
    return self._input_dataset.output_types


class TakeDataset(UnaryDataset):
  """A `Dataset` containing the first `count` elements from its input."""

  def __init__(self, input_dataset, count):
    """See `Dataset.take()` for details."""
    super(TakeDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")

  def _as_variant_tensor(self):
    return gen_dataset_ops.take_dataset(
        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
        count=self._count,
        **flat_structure(self))

  @property
  def output_classes(self):
    return self._input_dataset.output_classes

  @property
  def output_shapes(self):
    return self._input_dataset.output_shapes

  @property
  def output_types(self):
    return self._input_dataset.output_types


class SkipDataset(UnaryDataset):
  """A `Dataset` skipping the first `count` elements from its input."""

  def __init__(self, input_dataset, count):
    """See `Dataset.skip()` for details."""
    super(SkipDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")

  def _as_variant_tensor(self):
    return gen_dataset_ops.skip_dataset(
        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
        count=self._count,
        **flat_structure(self))

  @property
  def output_classes(self):
    return self._input_dataset.output_classes

  @property
  def output_shapes(self):
    return self._input_dataset.output_shapes

  @property
  def output_types(self):
    return self._input_dataset.output_types


class BatchDataset(UnaryDataset):
  """A `Dataset` that batches contiguous elements from its input."""

  def __init__(self, input_dataset, batch_size, drop_remainder):
    """See `Dataset.batch()` for details."""
    super(BatchDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    self._batch_size = ops.convert_to_tensor(
        batch_size, dtype=dtypes.int64, name="batch_size")
    self._drop_remainder = ops.convert_to_tensor(
        drop_remainder, dtype=dtypes.bool, name="drop_remainder")

  def _as_variant_tensor(self):
    # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018.
    if smart_cond.smart_constant_value(self._drop_remainder) is False:
      return gen_dataset_ops.batch_dataset(
          self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
          batch_size=self._batch_size,
          **flat_structure(self))
    else:
      return gen_dataset_ops.batch_dataset_v2(
          self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
          batch_size=self._batch_size,
          drop_remainder=self._drop_remainder,
          **flat_structure(self))

  @property
  def output_classes(self):
    return self._input_dataset.output_classes

  @property
  def output_shapes(self):
    input_shapes = self._input_dataset.output_shapes
    return nest.pack_sequence_as(input_shapes, [
        tensor_shape.vector(
            tensor_util.constant_value(self._batch_size) if smart_cond.
            smart_constant_value(self._drop_remainder) else None).concatenate(s)
        for s in nest.flatten(self._input_dataset.output_shapes)
    ])

  @property
  def output_types(self):
    return self._input_dataset.output_types


def _is_padded_shape_compatible_with(padded_shape, input_component_shape):
  """Returns `True` if `input_component_shape` can be padded to `padded_shape`.

  Args:
    padded_shape: A `tf.TensorShape`.
    input_component_shape: A `tf.TensorShape`.

  Returns:
    `True` if `input_component_shape` can be padded to `padded_shape`, otherwise
    `False`.
  """

  if padded_shape.dims is None or input_component_shape.dims is None:
    return True
  if len(padded_shape.dims) != len(input_component_shape.dims):
    return False
  for padded_dim, input_dim in zip(
      padded_shape.dims, input_component_shape.dims):
    if (padded_dim.value is not None and input_dim.value is not None
        and padded_dim.value < input_dim.value):
      return False
  return True


def _padded_shape_to_tensor(padded_shape, input_component_shape):
  """Converts `padded_shape` to a `tf.Tensor` representing that shape.

  Args:
    padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python
      sequence, or a 1-D `tf.Tensor` of `tf.int64` elements.
    input_component_shape: A `tf.TensorShape`, with which `padded_shape` must
      be compatible.

  Returns:
    A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`.

  Raises:
    ValueError: If `padded_shape` is not a shape or not compatible with
      `input_component_shape`.
    TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor.
  """
  try:
    # Try to convert the `padded_shape` to a `tf.TensorShape`
    padded_shape_as_shape = tensor_shape.as_shape(padded_shape)
    # We will return the "canonical" tensor representation, which uses
    # `-1` in place of `None`.
    ret = ops.convert_to_tensor(
        [dim if dim is not None else -1
         for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64)
  except (TypeError, ValueError):
    # The argument was not trivially convertible to a
    # `tf.TensorShape`, so fall back on the conversion to tensor
    # machinery.
    ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64)
    if ret.shape.dims is not None and len(ret.shape.dims) != 1:
      raise ValueError(
          "Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
          "shape was %s." % (padded_shape, ret.shape))
    if ret.dtype != dtypes.int64:
      raise TypeError(
          "Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
          "element type was %s." % (padded_shape, ret.dtype.name))
    padded_shape_as_shape = tensor_util.constant_value_as_shape(ret)

  if not _is_padded_shape_compatible_with(padded_shape_as_shape,
                                          input_component_shape):
    raise ValueError("The padded shape %s is not compatible with the "
                     "corresponding input component shape %s."
                     % (padded_shape_as_shape, input_component_shape))

  return ret


def _padding_value_to_tensor(value, output_type):
  """Converts the padding value to a tensor.

  Args:
    value: The padding value.
    output_type: Its expected dtype.

  Returns:
    A scalar `Tensor`.

  Raises:
    ValueError: if the padding value is not a scalar.
    TypeError: if the padding value's type does not match `output_type`.
  """
  value = ops.convert_to_tensor(value, name="padding_value")
  if not value.shape.is_compatible_with(tensor_shape.scalar()):
    raise ValueError("Padding value should be a scalar, but is not: %s" % value)
  if value.dtype != output_type:
    raise TypeError("Padding value tensor (%s) does not match output type: %s" %
                    (value, output_type))
  return value


def _default_padding(input_dataset):
  """Returns default padding tensors in a structure matching `input_dataset`."""
  def make_zero(t):
    if t.base_dtype == dtypes.string:
      return ""
    elif t.base_dtype == dtypes.variant:
      raise TypeError("Unable to create padding for field of type 'variant'")
    else:
      return np.zeros_like(t.as_numpy_dtype())

  return nest.map_structure(make_zero, input_dataset.output_types)


class PaddedBatchDataset(UnaryDataset):
  """A `Dataset` that batches and pads contiguous elements from its input."""

  def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
               drop_remainder):
    """See `Dataset.batch()` for details."""
    super(PaddedBatchDataset, self).__init__(input_dataset)
    if sparse.any_sparse(input_dataset.output_classes):
      # TODO(b/63669786): support batching of sparse tensors
      raise TypeError(
          "Batching of padded sparse tensors is not currently supported")
    self._input_dataset = input_dataset
    self._batch_size = ops.convert_to_tensor(
        batch_size, dtype=dtypes.int64, name="batch_size")
    padding_values = (
        padding_values
        if padding_values is not None else _default_padding(input_dataset))

    flat_padded_shapes = nest.flatten_up_to(input_dataset.output_shapes,
                                            padded_shapes)

    flat_padded_shapes_as_tensors = []

    for input_component_shape, padded_shape in zip(
        nest.flatten(input_dataset.output_shapes), flat_padded_shapes):
      flat_padded_shapes_as_tensors.append(
          _padded_shape_to_tensor(padded_shape, input_component_shape))

    self._padded_shapes = nest.pack_sequence_as(input_dataset.output_shapes,
                                                flat_padded_shapes_as_tensors)

    self._padding_values = nest.map_structure_up_to(
        input_dataset.output_shapes, _padding_value_to_tensor, padding_values,
        input_dataset.output_types)
    self._drop_remainder = ops.convert_to_tensor(
        drop_remainder, dtype=dtypes.bool, name="drop_remainder")

  def _as_variant_tensor(self):
    # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018.
    if smart_cond.smart_constant_value(self._drop_remainder) is False:
      return gen_dataset_ops.padded_batch_dataset(
          self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
          batch_size=self._batch_size,
          padded_shapes=[
              ops.convert_to_tensor(s, dtype=dtypes.int64)
              for s in nest.flatten(self._padded_shapes)
          ],
          padding_values=nest.flatten(self._padding_values),
          output_shapes=nest.flatten(
              sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
    else:
      return gen_dataset_ops.padded_batch_dataset_v2(
          self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
          batch_size=self._batch_size,
          padded_shapes=[
              ops.convert_to_tensor(s, dtype=dtypes.int64)
              for s in nest.flatten(self._padded_shapes)
          ],
          padding_values=nest.flatten(self._padding_values),
          drop_remainder=self._drop_remainder,
          output_shapes=nest.flatten(
              sparse.as_dense_shapes(self.output_shapes, self.output_classes)))

  @property
  def output_classes(self):
    return self._input_dataset.output_classes

  @property
  def output_shapes(self):

    def _padded_shape_to_batch_shape(s):
      return tensor_shape.vector(
          tensor_util.constant_value(self._batch_size) if smart_cond.
          smart_constant_value(self._drop_remainder) else None).concatenate(
              tensor_util.constant_value_as_shape(s))

    return nest.map_structure(_padded_shape_to_batch_shape, self._padded_shapes)

  @property
  def output_types(self):
    return self._input_dataset.output_types


def _should_unpack_args(args):
  """Returns `True` if `args` should be `*args` when passed to a callable."""
  return type(args) is tuple  # pylint: disable=unidiomatic-typecheck


def _warn_if_collections(transformation_name):
  """Prints warning message if the current graph uses common graph collections.

  NOTE(mrry): Currently a warning is only generated for lookup tables. Any
  variables created will be automatically hoisted out to the outermost scope
  using `init_scope()`. Some collections (such as for control-flow contexts)
  are benign and should not generate a warning.

  Args:
    transformation_name: A human-readable name for the transformation.
  """
  if ops.get_default_graph().get_collection(ops.GraphKeys.TABLE_INITIALIZERS):
    warnings.warn("Creating lookup tables inside a function passed to %s is not"
                  " supported. Create each table outside the function, and "
                  "capture it inside the function to use it."
                  % transformation_name)


class MapDataset(UnaryDataset):
  """A `Dataset` that maps a function over elements in its input."""

  def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
    """See `Dataset.map()` for details."""
    super(MapDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    self._use_inter_op_parallelism = use_inter_op_parallelism

    wrapped_func = StructuredFunctionWrapper(
        map_func, "Dataset.map()", input_dataset)
    self._output_classes = wrapped_func.output_classes
    self._output_shapes = wrapped_func.output_shapes
    self._output_types = wrapped_func.output_types
    self._map_func = wrapped_func.function

  def _as_variant_tensor(self):
    input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
    return gen_dataset_ops.map_dataset(
        input_t,
        self._map_func.captured_inputs,
        f=self._map_func,
        use_inter_op_parallelism=self._use_inter_op_parallelism,
        **flat_structure(self))

  @property
  def output_classes(self):
    return self._output_classes

  @property
  def output_shapes(self):
    return self._output_shapes

  @property
  def output_types(self):
    return self._output_types


class ParallelMapDataset(MapDataset):
  """A `Dataset` that maps a function over elements in its input in parallel."""

  def __init__(self,
               input_dataset,
               map_func,
               num_parallel_calls,
               use_inter_op_parallelism=True):
    """See `Dataset.map()` for details."""
    super(ParallelMapDataset, self).__init__(input_dataset, map_func,
                                             use_inter_op_parallelism)

    self._num_parallel_calls = ops.convert_to_tensor(
        num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")

  def _as_variant_tensor(self):
    input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
    # pylint: disable=protected-access
    return gen_dataset_ops.parallel_map_dataset(
        input_t,
        self._map_func.captured_inputs,
        f=self._map_func,
        num_parallel_calls=self._num_parallel_calls,
        use_inter_op_parallelism=self._use_inter_op_parallelism,
        **flat_structure(self))
    # pylint: enable=protected-access


class FlatMapDataset(UnaryDataset):
  """A `Dataset` that maps a function over its input and flattens the result."""

  def __init__(self, input_dataset, map_func):
    """See `Dataset.flat_map()` for details."""
    super(FlatMapDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset

    wrapped_func = StructuredFunctionWrapper(
        map_func, self._transformation_name(), input_dataset,
        experimental_nested_dataset_support=True)
    if not isinstance(wrapped_func.output_classes, _NestedDatasetComponent):
      raise TypeError("`map_func` must return a `Dataset` object.")
    self._output_classes = wrapped_func.output_classes.output_classes
    self._output_types = wrapped_func.output_types.output_types
    self._output_shapes = wrapped_func.output_shapes.output_shapes
    self._map_func = wrapped_func.function

  def _as_variant_tensor(self):
    return gen_dataset_ops.flat_map_dataset(
        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
        self._map_func.captured_inputs,
        f=self._map_func,
        **flat_structure(self))

  @property
  def output_classes(self):
    return self._output_classes

  @property
  def output_shapes(self):
    return self._output_shapes

  @property
  def output_types(self):
    return self._output_types

  def _transformation_name(self):
    return "Dataset.flat_map()"


class InterleaveDataset(FlatMapDataset):
  """A `Dataset` that maps a function over its input and interleaves the result.
  """

  def __init__(self, input_dataset, map_func, cycle_length, block_length):
    """See `Dataset.interleave()` for details."""
    super(InterleaveDataset, self).__init__(input_dataset, map_func)
    self._cycle_length = ops.convert_to_tensor(
        cycle_length, dtype=dtypes.int64, name="cycle_length")
    self._block_length = ops.convert_to_tensor(
        block_length, dtype=dtypes.int64, name="block_length")

  def _as_variant_tensor(self):
    return gen_dataset_ops.interleave_dataset(
        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
        self._map_func.captured_inputs,  # pylint: disable=protected-access
        self._cycle_length,
        self._block_length,
        f=self._map_func,  # pylint: disable=protected-access
        **flat_structure(self))

  def _transformation_name(self):
    return "Dataset.interleave()"


class ParallelInterleaveDataset(FlatMapDataset):
  """A `Dataset` that maps a function over its input and interleaves the result.

  """

  def __init__(self, input_dataset, map_func, cycle_length, block_length,
               num_parallel_calls):
    """See `Dataset.interleave()` for details."""
    super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func)
    self._cycle_length = ops.convert_to_tensor(
        cycle_length, dtype=dtypes.int64, name="cycle_length")
    self._block_length = ops.convert_to_tensor(
        block_length, dtype=dtypes.int64, name="block_length")
    self._num_parallel_calls = ops.convert_to_tensor(
        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")

  def _as_variant_tensor(self):
    return gen_dataset_ops.parallel_interleave_dataset_v2(
        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
        self._map_func.captured_inputs,  # pylint: disable=protected-access
        self._cycle_length,
        self._block_length,
        self._num_parallel_calls,
        f=self._map_func,  # pylint: disable=protected-access
        **flat_structure(self))

  def _transformation_name(self):
    return "Dataset.interleave()"


class FilterDataset(UnaryDataset):
  """A `Dataset` that filters its input according to a predicate function."""

  def __init__(self, input_dataset, predicate):
    """See `Dataset.filter()` for details."""
    super(FilterDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    wrapped_func = StructuredFunctionWrapper(
        predicate, "Dataset.filter()", input_dataset)
    if not (
        wrapped_func.output_types == dtypes.bool and
        wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
      raise ValueError("`predicate` must return a scalar boolean tensor.")
    self._predicate = wrapped_func.function

  def _as_variant_tensor(self):
    return gen_dataset_ops.filter_dataset(
        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
        other_arguments=self._predicate.captured_inputs,
        predicate=self._predicate,
        **flat_structure(self))

  @property
  def output_classes(self):
    return self._input_dataset.output_classes

  @property
  def output_shapes(self):
    return self._input_dataset.output_shapes

  @property
  def output_types(self):
    return self._input_dataset.output_types


class PrefetchDataset(UnaryDataset):
  """A `Dataset` that asynchronously prefetches its input."""

  def __init__(self, input_dataset, buffer_size):
    """See `Dataset.prefetch()` for details."""
    super(PrefetchDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    if buffer_size is None:
      buffer_size = -1  # This is the sentinel for auto-tuning.
    self._buffer_size = ops.convert_to_tensor(
        buffer_size, dtype=dtypes.int64, name="buffer_size")

  def _as_variant_tensor(self):
    return gen_dataset_ops.prefetch_dataset(
        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
        buffer_size=self._buffer_size,
        **flat_structure(self))

  @property
  def output_classes(self):
    return self._input_dataset.output_classes

  @property
  def output_shapes(self):
    return self._input_dataset.output_shapes

  @property
  def output_types(self):
    return self._input_dataset.output_types


class WindowDataset(UnaryDataset):
  """A dataset that creates window datasets from the input elements."""

  def __init__(self, input_dataset, size, shift, stride, drop_remainder):
    """See `window_dataset()` for more details."""
    super(WindowDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
    self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
    self._stride = ops.convert_to_tensor(
        stride, dtype=dtypes.int64, name="stride")
    self._drop_remainder = ops.convert_to_tensor(
        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
    self._output_classes = nest.pack_sequence_as(
        input_dataset.output_classes,
        [
            _NestedDatasetComponent(  # pylint: disable=protected-access
                output_classes=output_class,
                output_shapes=output_shape,
                output_types=output_type)
            for output_class, output_shape, output_type in zip(
                nest.flatten(input_dataset.output_classes),
                nest.flatten(input_dataset.output_shapes),
                nest.flatten(input_dataset.output_types))
        ])
    self._output_shapes = self._output_classes
    self._output_types = self._output_classes

  def _as_variant_tensor(self):
    return gen_dataset_ops.window_dataset(
        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
        self._size,
        self._shift,
        self._stride,
        self._drop_remainder,
        **flat_structure(self))

  @property
  def output_classes(self):
    return self._output_classes

  @property
  def output_shapes(self):
    return self._output_shapes

  @property
  def output_types(self):
    return self._output_types


class _OptionsDataset(UnaryDataset):
  """An identity `Dataset` that stores options."""

  def __init__(self, input_dataset, options):
    super(_OptionsDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    self._options = input_dataset.options()
    if self._options:
      self._options = self._options.merge(options)
    else:
      self._options = options

  def _as_variant_tensor(self):
    return self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access

  def options(self):
    return self._options

  @property
  def output_classes(self):
    return self._input_dataset.output_classes

  @property
  def output_shapes(self):
    return self._input_dataset.output_shapes

  @property
  def output_types(self):
    return self._input_dataset.output_types


class _ModelDataset(UnaryDataset):
  """A `Dataset` that acts as an identity, and models performance."""

  def __init__(self, input_dataset):
    """See `optimize()` for details."""
    super(_ModelDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset

  def _as_variant_tensor(self):
    return gen_dataset_ops.model_dataset(
        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
        **flat_structure(self))

  @property
  def output_classes(self):
    return self._input_dataset.output_classes

  @property
  def output_shapes(self):
    return self._input_dataset.output_shapes

  @property
  def output_types(self):
    return self._input_dataset.output_types


class _OptimizeDataset(UnaryDataset):
  """A `Dataset` that acts as an identity, and applies optimizations."""

  def __init__(self, input_dataset, optimizations):
    """See `optimize()` for details."""
    super(_OptimizeDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    if optimizations is None:
      optimizations = []
    self._optimizations = ops.convert_to_tensor(
        optimizations, dtype=dtypes.string, name="optimizations")

  def _as_variant_tensor(self):
    return gen_dataset_ops.optimize_dataset(
        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
        self._optimizations,
        **flat_structure(self))

  @property
  def output_classes(self):
    return self._input_dataset.output_classes

  @property
  def output_shapes(self):
    return self._input_dataset.output_shapes

  @property
  def output_types(self):
    return self._input_dataset.output_types