aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/keras/python/keras/engine/topology.py
blob: 3d9ed51a1c0b4522ec15dd19909f411d78357a90 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Base layer code and base model (Container) code.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import json
import os
import re
import warnings

import numpy as np
from six.moves import zip  # pylint: disable=redefined-builtin

from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras.utils import conv_utils
from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as tf_base_layers
from tensorflow.python.util import tf_inspect


# pylint: disable=g-import-not-at-top
try:
  import h5py
except ImportError:
  h5py = None

try:
  import yaml
except ImportError:
  yaml = None
# pylint: enable=g-import-not-at-top


class InputSpec(object):
  """Specifies the ndim, dtype and shape of every input to a layer.

  Every layer should expose (if appropriate) an `input_spec` attribute:
  a list of instances of InputSpec (one per input tensor).

  A None entry in a shape is compatible with any dimension,
  a None shape is compatible with any shape.

  Arguments:
      dtype: Expected datatype of the input.
      shape: Shape tuple, expected shape of the input
          (may include None for unchecked axes).
      ndim: Integer, expected rank of the input.
      max_ndim: Integer, maximum rank of the input.
      min_ndim: Integer, minimum rank of the input.
      axes: Dictionary mapping integer axes to
          a specific dimension value.
  """

  def __init__(self,
               dtype=None,
               shape=None,
               ndim=None,
               max_ndim=None,
               min_ndim=None,
               axes=None):
    self.dtype = dtype
    self.shape = shape
    if shape is not None:
      self.ndim = len(shape)
    else:
      self.ndim = ndim
    self.max_ndim = max_ndim
    self.min_ndim = min_ndim
    self.axes = axes or {}


class Node(object):
  """A `Node` describes the connectivity between two layers.

  Each time a layer is connected to some new input,
  a node is added to `layer.inbound_nodes`.
  Each time the output of a layer is used by another layer,
  a node is added to `layer.outbound_nodes`.

  Arguments:
      outbound_layer: the layer that takes
          `input_tensors` and turns them into `output_tensors`
          (the node gets created when the `call`
          method of the layer was called).
      inbound_layers: a list of layers, the same length as `input_tensors`,
          the layers from where `input_tensors` originate.
      node_indices: a list of integers, the same length as `inbound_layers`.
          `node_indices[i]` is the origin node of `input_tensors[i]`
          (necessary since each inbound layer might have several nodes,
          e.g. if the layer is being shared with a different data stream).
      tensor_indices: a list of integers,
          the same length as `inbound_layers`.
          `tensor_indices[i]` is the index of `input_tensors[i]` within the
          output of the inbound layer
          (necessary since each inbound layer might
          have multiple tensor outputs, with each one being
          independently manipulable).
      input_tensors: list of input tensors.
      output_tensors: list of output tensors.
      input_masks: list of input masks (a mask can be a tensor, or None).
      output_masks: list of output masks (a mask can be a tensor, or None).
      arguments: dictionary of keyword arguments that were passed to the
          `call` method of the layer at the call that created the node.

  `node_indices` and `tensor_indices` are basically fine-grained coordinates
  describing the origin of the `input_tensors`.

  A node from layer A to layer B is added to:
      A.outbound_nodes
      B.inbound_nodes
  """

  def __init__(self,
               outbound_layer,
               inbound_layers,
               node_indices,
               tensor_indices,
               input_tensors,
               output_tensors,
               input_masks,
               output_masks,
               arguments=None):
    # Layer instance (NOT a list).
    # this is the layer that takes a list of input tensors
    # and turns them into a list of output tensors.
    # the current node will be added to
    # the inbound_nodes of outbound_layer.
    self.outbound_layer = outbound_layer

    # The following 3 properties describe where
    # the input tensors come from: which layers,
    # and for each layer, which node and which
    # tensor output of each node.

    # List of layer instances.
    self.inbound_layers = inbound_layers
    # List of integers, 1:1 mapping with inbound_layers.
    self.node_indices = node_indices
    # List of integers, 1:1 mapping with inbound_layers.
    self.tensor_indices = tensor_indices

    # Following 2 properties:
    # tensor inputs and outputs of outbound_layer.

    # List of tensors. 1:1 mapping with inbound_layers.
    self.input_tensors = input_tensors
    # List of tensors, created by outbound_layer.call().
    self.output_tensors = output_tensors

    # Following 2 properties: input and output masks.
    # List of tensors, 1:1 mapping with input_tensor.
    self.input_masks = input_masks
    # List of tensors, created by outbound_layer.compute_mask().
    self.output_masks = output_masks

    # Following 2 properties: input and output shapes.

    # List of shape tuples, shapes of input_tensors.
    self.input_shapes = [K.int_shape(x) for x in input_tensors]
    # List of shape tuples, shapes of output_tensors.
    self.output_shapes = [K.int_shape(x) for x in output_tensors]

    # Optional keyword arguments to layer's `call`.
    self.arguments = arguments

    # Add nodes to all layers involved.
    for layer in inbound_layers:
      if layer is not None:
        layer.outbound_nodes.append(self)
    outbound_layer.inbound_nodes.append(self)

  def get_config(self):
    inbound_names = []
    for layer in self.inbound_layers:
      if layer:
        inbound_names.append(layer.name)
      else:
        inbound_names.append(None)
    return {
        'outbound_layer':
            self.outbound_layer.name if self.outbound_layer else None,
        'inbound_layers':
            inbound_names,
        'node_indices':
            self.node_indices,
        'tensor_indices':
            self.tensor_indices
    }


class Layer(tf_base_layers.Layer):
  """Abstract base layer class.

  # Properties
      name: String, must be unique within a model.
      input_spec: List of InputSpec class instances
          each entry describes one required input:
              - ndim
              - dtype
          A layer with `n` input tensors must have
          an `input_spec` of length `n`.
      trainable: Boolean, whether the layer weights
          will be updated during training.
      uses_learning_phase: Whether any operation
          of the layer uses `K.in_training_phase()`
          or `K.in_test_phase()`.
      input_shape: Shape tuple. Provided for convenience,
          but note that there may be cases in which this
          attribute is ill-defined (e.g. a shared layer
          with multiple input shapes), in which case
          requesting `input_shape` will raise an Exception.
          Prefer using `layer.get_input_shape_for(input_shape)`,
          or `layer.get_input_shape_at(node_index)`.
      output_shape: Shape tuple. See above.
      inbound_nodes: List of nodes.
      outbound_nodes: List of nodes.
      input, output: Input/output tensor(s). Note that if the layer is used
          more than once (shared layer), this is ill-defined
          and will raise an exception. In such cases, use
          `layer.get_input_at(node_index)`.
      input_mask, output_mask: Same as above, for masks.
      trainable_weights: List of variables.
      non_trainable_weights: List of variables.
      weights: The concatenation of the lists trainable_weights and
          non_trainable_weights (in this order).
      constraints: Dict mapping weights to constraints.

  # Methods
      call(x, mask=None): Where the layer's logic lives.
      __call__(x, mask=None): Wrapper around the layer logic (`call`).
          If x is a Keras tensor:
              - Connect current layer with last layer from tensor:
                  `self._add_inbound_node(last_layer)`
              - Add layer to tensor history
          If layer is not built:
              - Build from inputs shape
      get_weights()
      set_weights(weights)
      get_config()
      count_params()
      _compute_output_shape(input_shape)
      compute_mask(x, mask)
      get_input_at(node_index)
      get_output_at(node_index)
      get_input_shape_at(node_index)
      get_output_shape_at(node_index)
      get_input_mask_at(node_index)
      get_output_mask_at(node_index)

  # Class Methods
      from_config(config)

  # Internal methods:
      build(input_shape)
      _add_inbound_node(layer, index=0)
      assert_input_compatibility()
  """

  def __init__(self, **kwargs):
    # These properties should be set by the user via keyword arguments.
    # note that 'dtype', 'input_shape' and 'batch_input_shape'
    # are only applicable to input layers: do not pass these keywords
    # to non-input layers.
    allowed_kwargs = {
        'input_shape',
        'batch_input_shape',
        'batch_size',
        'dtype',
        'name',
        'trainable',
        'weights',
    }
    # Validate optional keyword arguments.
    for kwarg in kwargs:
      if kwarg not in allowed_kwargs:
        raise TypeError('Keyword argument not understood:', kwarg)

    # Get layer name.
    name = kwargs.get('name')

    # Get `trainable` status.
    trainable = kwargs.get('trainable', True)

    # Get `dtype`.
    dtype = kwargs.get('dtype')
    if dtype is None:
      dtype = K.floatx()

    # Call super, which will set all properties common to Keras layers
    # and core TF layers.
    super(Layer, self).__init__(name=name, dtype=dtype, trainable=trainable)

    # Add properties that are Keras-only for now.
    self.input_spec = None
    self.supports_masking = False
    self._constraints = {}  # dict {tensor: constraint instance}

    # These lists will be filled via successive calls
    # to self._add_inbound_node().
    self.inbound_nodes = []
    self.outbound_nodes = []

    # Manage input shape information if passed.
    if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
      # In this case we will later create an input layer
      # to insert before the current layer
      if 'batch_input_shape' in kwargs:
        batch_input_shape = tuple(kwargs['batch_input_shape'])
      elif 'input_shape' in kwargs:
        if 'batch_size' in kwargs:
          batch_size = kwargs['batch_size']
        else:
          batch_size = None
        batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
      self.batch_input_shape = batch_input_shape

    # Manage initial weight values if passed.
    if 'weights' in kwargs:
      self._initial_weights = kwargs['weights']
    else:
      self._initial_weights = None

  @property
  def constraints(self):
    return self._constraints

  @constraints.setter
  def constraints(self, constraints):
    self._constraints = constraints

  def add_weight(self,
                 name,
                 shape,
                 dtype=None,
                 initializer=None,
                 regularizer=None,
                 trainable=True,
                 constraint=None):
    """Adds a weight variable to the layer.

    Arguments:
        name: String, the name for the weight variable.
        shape: The shape tuple of the weight.
        dtype: The dtype of the weight.
        initializer: An Initializer instance (callable).
        regularizer: An optional Regularizer instance.
        trainable: A boolean, whether the weight should
            be trained via backprop or not (assuming
            that the layer itself is also trainable).
        constraint: An optional Constraint instance.

    Returns:
        The created weight variable.
    """
    if dtype is None:
      dtype = K.floatx()
    weight = self.add_variable(
        name, shape, dtype=dtype,
        initializer=initializer, regularizer=regularizer, trainable=trainable)
    if constraint is not None:
      self.constraints[weight] = constraint
    return weight

  def assert_input_compatibility(self, inputs):
    """Checks compatibility between the layer and provided inputs.

    This checks that the tensor(s) `input`
    verify the input assumptions of the layer
    (if any). If not, exceptions are raised.

    Arguments:
        inputs: input tensor or list of input tensors.

    Raises:
        ValueError: in case of mismatch between
            the provided inputs and the expectations of the layer.
    """
    if not self.input_spec:
      return
    if not isinstance(self.input_spec, (list, tuple)):
      input_spec = _to_list(self.input_spec)
    else:
      input_spec = self.input_spec
    inputs = _to_list(inputs)
    if len(inputs) != len(input_spec):
      raise ValueError('Layer ' + self.name + ' expects ' +
                       str(len(input_spec)) + ' inputs, '
                       'but it received ' + str(len(inputs)) +
                       ' input tensors. Input received: ' + str(inputs))
    for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
      if spec is None:
        continue

      # Check ndim.
      if spec.ndim is not None:
        if K.ndim(x) != spec.ndim:
          raise ValueError('Input ' + str(input_index) +
                           ' is incompatible with layer ' + self.name +
                           ': expected ndim=' + str(spec.ndim) + ', found ndim='
                           + str(K.ndim(x)))
      if spec.max_ndim is not None:
        ndim = K.ndim(x)
        if ndim is not None and ndim > spec.max_ndim:
          raise ValueError('Input ' + str(input_index) +
                           ' is incompatible with layer ' + self.name +
                           ': expected max_ndim=' + str(spec.max_ndim) +
                           ', found ndim=' + str(K.ndim(x)))
      if spec.min_ndim is not None:
        ndim = K.ndim(x)
        if ndim is not None and ndim < spec.min_ndim:
          raise ValueError('Input ' + str(input_index) +
                           ' is incompatible with layer ' + self.name +
                           ': expected min_ndim=' + str(spec.min_ndim) +
                           ', found ndim=' + str(K.ndim(x)))
      # Check dtype.
      if spec.dtype is not None:
        if K.dtype(x) != spec.dtype:
          raise ValueError('Input ' + str(input_index) +
                           ' is incompatible with layer ' + self.name +
                           ': expected dtype=' + str(spec.dtype) +
                           ', found dtype=' + str(K.dtype(x)))
      # Check specific shape axes.
      if spec.axes:
        try:
          x_shape = K.int_shape(x)
        except TypeError:
          x_shape = None
        if x_shape is not None:
          for axis, value in spec.axes.items():
            if hasattr(value, 'value'):
              value = value.value
            if value is not None and x_shape[int(axis)] not in {value, None}:
              raise ValueError(
                  'Input ' + str(input_index) + ' is incompatible with layer ' +
                  self.name + ': expected axis ' + str(axis) +
                  ' of input shape to have '
                  'value ' + str(value) + ' but got shape ' + str(x_shape))
      # Check shape.
      if spec.shape is not None:
        try:
          x_shape = K.int_shape(x)
        except TypeError:
          x_shape = None
        if x_shape is not None:
          for spec_dim, dim in zip(spec.shape, x_shape):
            if hasattr(spec_dim, 'value'):
              spec_dim = spec_dim.value
            if spec_dim is not None and dim is not None:
              if spec_dim != dim:
                raise ValueError('Input ' + str(input_index) +
                                 ' is incompatible with layer ' + self.name +
                                 ': expected shape=' + str(spec.shape) +
                                 ', found shape=' + str(x_shape))

  def call(self, inputs, **kwargs):  # pylint: disable=unused-argument
    """This is where the layer's logic lives.

    Arguments:
        inputs: Input tensor, or list/tuple of input tensors.
        **kwargs: Additional keyword arguments.

    Returns:
        A tensor or list/tuple of tensors.
    """
    return inputs

  def __call__(self, inputs, **kwargs):
    """Wrapper around self.call(), for handling internal references.

    If a Keras tensor is passed:
        - We call self._add_inbound_node().
        - If necessary, we `build` the layer to match
            the shape of the input(s).
        - We update the _keras_history of the output tensor(s)
            with the current layer.
            This is done as part of _add_inbound_node().

    Arguments:
        inputs: Can be a tensor or list/tuple of tensors.
        **kwargs: Additional keyword arguments to be passed to `call()`.

    Returns:
        Output of the layer's `call` method.

    Raises:
        ValueError: in case the layer is missing shape information
            for its `build` call.
    """
    if isinstance(inputs, list):
      inputs = inputs[:]

    # Raise exceptions in case the input is not compatible
    # with the input_spec set at build time.
    # TODO(fchollet): call after the layer is built, too.
    self.assert_input_compatibility(inputs)

    # Handle mask propagation.
    previous_mask = _collect_previous_mask(inputs)
    user_kwargs = copy.copy(kwargs)
    if not _is_all_none(previous_mask):
      # The previous layer generated a mask.
      if 'mask' in tf_inspect.getargspec(self.call).args:
        if 'mask' not in kwargs:
          # If mask is explicitly passed to __call__,
          # we should override the default mask.
          kwargs['mask'] = previous_mask

    # Actually call the layer (optionally building it).
    output = super(Layer, self).__call__(inputs, **kwargs)

    # Handle mask computation.
    with K.name_scope(self.name):
      output_mask = self.compute_mask(inputs, previous_mask)

    # Add an inbound node to the layer, so that it keeps track
    # of the call and of all new variables created during the call.
    # This also updates the layer history of the output tensor(s).
    # If the input tensor(s) had not previous Keras history,
    # this does nothing.
    self._add_inbound_node(
        input_tensors=inputs,
        output_tensors=output,
        input_masks=previous_mask,
        output_masks=output_mask,
        arguments=user_kwargs)

    # Optionally load weight values that were specified at layer instantiation.
    if hasattr(self, '_initial_weights') and self._initial_weights is not None:
      self.set_weights(self._initial_weights)
      del self._initial_weights
    return output

  def _add_inbound_node(self,
                        input_tensors,
                        output_tensors,
                        input_masks,
                        output_masks,
                        arguments=None):
    """Internal method to create an inbound node for the layer.

    Arguments:
        input_tensors: list of input tensors.
        output_tensors: list of output tensors.
        input_masks: list of input masks (a mask can be a tensor, or None).
        output_masks: list of output masks (a mask can be a tensor, or None).
        arguments: dictionary of keyword arguments that were passed to the
            `call` method of the layer at the call that created the node.
    """
    input_tensors = _to_list(input_tensors)
    output_tensors = _to_list(output_tensors)
    input_masks = _to_list(input_masks)
    output_masks = _to_list(output_masks)

    # Collect input tensor(s) coordinates.
    inbound_layers = []
    node_indices = []
    tensor_indices = []
    for x in input_tensors:
      if hasattr(x, '_keras_history'):
        inbound_layer, node_index, tensor_index = x._keras_history
        inbound_layers.append(inbound_layer)
        node_indices.append(node_index)
        tensor_indices.append(tensor_index)
      else:
        inbound_layers.append(None)
        node_indices.append(None)
        tensor_indices.append(None)

    # Create node, add it to inbound nodes.
    Node(
        self,
        inbound_layers=inbound_layers,
        node_indices=node_indices,
        tensor_indices=tensor_indices,
        input_tensors=input_tensors,
        output_tensors=output_tensors,
        input_masks=input_masks,
        output_masks=output_masks,
        arguments=arguments)

    # Update tensor history and `_uses_learning_phase`.
    for i in range(len(output_tensors)):
      uses_lp = any(
          [getattr(x, '_uses_learning_phase', False) for x in input_tensors])
      uses_lp = getattr(self, 'uses_learning_phase', False) or uses_lp
      output_tensors[i]._uses_learning_phase = getattr(
          output_tensors[i], '_uses_learning_phase', False) or uses_lp
      output_tensors[i]._keras_history = (self, len(self.inbound_nodes) - 1, i)

  def _compute_output_shape(self, input_shape):
    """Computes the output shape of the layer.

    Assumes that the layer will be built
    to match that input shape provided.

    Arguments:
        input_shape: Shape tuple (tuple of integers)
            or list of shape tuples (one per output tensor of the layer).
            Shape tuples can include None for free dimensions,
            instead of an integer.

    Returns:
        An input shape tuple.
    """
    if isinstance(input_shape, list):
      return [tensor_shape.TensorShape(shape) for shape in input_shape]
    else:
      return tensor_shape.TensorShape(input_shape)

  def compute_mask(self, inputs, mask=None):  # pylint: disable=unused-argument
    """Computes an output mask tensor.

    Arguments:
        inputs: Tensor or list of tensors.
        mask: Tensor or list of tensors.

    Returns:
        None or a tensor (or list of tensors,
            one per output tensor of the layer).
    """
    if not self.supports_masking:
      if mask is not None:
        if isinstance(mask, list):
          if any(m is not None for m in mask):
            raise TypeError('Layer ' + self.name + ' does not support masking, '
                            'but was passed an input_mask: ' + str(mask))
        else:
          raise TypeError('Layer ' + self.name + ' does not support masking, '
                          'but was passed an input_mask: ' + str(mask))
      # masking not explicitly supported: return None as mask
      return None
    # if masking is explicitly supported, by default
    # carry over the input mask
    return mask

  def build(self, input_shape):  # pylint: disable=unused-argument
    """Creates the layer weights.

    Must be implemented on all layers that have weights.

    Arguments:
        input_shape: Keras tensor (future input to layer)
            or list/tuple of Keras tensors to reference
            for weight shape computations.
    """
    self.built = True

  def _get_node_attribute_at_index(self, node_index, attr, attr_name):
    """Retrieves an attribute (e.g. input_tensors) from a node.

    This is used to implement the methods:
        - get_input_shape_at
        - get_output_shape_at
        - get_input_at
        etc...

    Arguments:
        node_index: Integer index of the node from which
            to retrieve the attribute.
        attr: Exact node attribute name.
        attr_name: Human-readable attribute name, for error messages.

    Returns:
        The layer's attribute `attr` at the node of index `node_index`.

    Raises:
        RuntimeError: If the layer has no inbound nodes.
        ValueError: If the index is does not match any node.
    """
    if not self.inbound_nodes:
      raise RuntimeError('The layer has never been called '
                         'and thus has no defined ' + attr_name + '.')
    if not len(self.inbound_nodes) > node_index:
      raise ValueError('Asked to get ' + attr_name + ' at node ' +
                       str(node_index) + ', but the layer has only ' +
                       str(len(self.inbound_nodes)) + ' inbound nodes.')
    values = getattr(self.inbound_nodes[node_index], attr)
    if len(values) == 1:
      return values[0]
    else:
      return values

  def get_input_shape_at(self, node_index):
    """Retrieves the input shape(s) of a layer at a given node.

    Arguments:
        node_index: Integer, index of the node
            from which to retrieve the attribute.
            E.g. `node_index=0` will correspond to the
            first time the layer was called.

    Returns:
        A shape tuple
        (or list of shape tuples if the layer has multiple inputs).
    """
    return self._get_node_attribute_at_index(node_index, 'input_shapes',
                                             'input shape')

  def get_output_shape_at(self, node_index):
    """Retrieves the output shape(s) of a layer at a given node.

    Arguments:
        node_index: Integer, index of the node
            from which to retrieve the attribute.
            E.g. `node_index=0` will correspond to the
            first time the layer was called.

    Returns:
        A shape tuple
        (or list of shape tuples if the layer has multiple outputs).
    """
    return self._get_node_attribute_at_index(node_index, 'output_shapes',
                                             'output shape')

  def get_input_at(self, node_index):
    """Retrieves the input tensor(s) of a layer at a given node.

    Arguments:
        node_index: Integer, index of the node
            from which to retrieve the attribute.
            E.g. `node_index=0` will correspond to the
            first time the layer was called.

    Returns:
        A tensor (or list of tensors if the layer has multiple inputs).
    """
    return self._get_node_attribute_at_index(node_index, 'input_tensors',
                                             'input')

  def get_output_at(self, node_index):
    """Retrieves the output tensor(s) of a layer at a given node.

    Arguments:
        node_index: Integer, index of the node
            from which to retrieve the attribute.
            E.g. `node_index=0` will correspond to the
            first time the layer was called.

    Returns:
        A tensor (or list of tensors if the layer has multiple outputs).
    """
    return self._get_node_attribute_at_index(node_index, 'output_tensors',
                                             'output')

  def get_input_mask_at(self, node_index):
    """Retrieves the input mask tensor(s) of a layer at a given node.

    Arguments:
        node_index: Integer, index of the node
            from which to retrieve the attribute.
            E.g. `node_index=0` will correspond to the
            first time the layer was called.

    Returns:
        A mask tensor
        (or list of tensors if the layer has multiple inputs).
    """
    return self._get_node_attribute_at_index(node_index, 'input_masks',
                                             'input mask')

  def get_output_mask_at(self, node_index):
    """Retrieves the output mask tensor(s) of a layer at a given node.

    Arguments:
        node_index: Integer, index of the node
            from which to retrieve the attribute.
            E.g. `node_index=0` will correspond to the
            first time the layer was called.

    Returns:
        A mask tensor
        (or list of tensors if the layer has multiple outputs).
    """
    return self._get_node_attribute_at_index(node_index, 'output_masks',
                                             'output mask')

  @property
  def input(self):
    """Retrieves the input tensor(s) of a layer.

    Only applicable if the layer has exactly one inbound node,
    i.e. if it is connected to one incoming layer.

    Returns:
        Input tensor or list of input tensors.

    Raises:
        AttributeError: if the layer is connected to
        more than one incoming layers.
    """
    if len(self.inbound_nodes) > 1:
      raise AttributeError('Layer ' + self.name +
                           ' has multiple inbound nodes, '
                           'hence the notion of "layer input" '
                           'is ill-defined. '
                           'Use `get_input_at(node_index)` instead.')
    elif not self.inbound_nodes:
      raise AttributeError('Layer ' + self.name +
                           ' is not connected, no input to return.')
    return self._get_node_attribute_at_index(0, 'input_tensors', 'input')

  @property
  def output(self):
    """Retrieves the output tensor(s) of a layer.

    Only applicable if the layer has exactly one inbound node,
    i.e. if it is connected to one incoming layer.

    Returns:
        Output tensor or list of output tensors.

    Raises:
        AttributeError: if the layer is connected to
        more than one incoming layers.
    """
    if not self.inbound_nodes:
      raise AttributeError('Layer ' + self.name + ' has no inbound nodes.')
    if len(self.inbound_nodes) > 1:
      raise AttributeError('Layer ' + self.name +
                           ' has multiple inbound nodes, '
                           'hence the notion of "layer output" '
                           'is ill-defined. '
                           'Use `get_output_at(node_index)` instead.')
    return self._get_node_attribute_at_index(0, 'output_tensors', 'output')

  @property
  def input_mask(self):
    """Retrieves the input mask tensor(s) of a layer.

    Only applicable if the layer has exactly one inbound node,
    i.e. if it is connected to one incoming layer.

    Returns:
        Input mask tensor (potentially None) or list of input
        mask tensors.

    Raises:
        AttributeError: if the layer is connected to
        more than one incoming layers.
    """
    if len(self.inbound_nodes) != 1:
      raise AttributeError('Layer ' + self.name +
                           ' has multiple inbound nodes, ' +
                           'hence the notion of "layer input mask" '
                           'is ill-defined. '
                           'Use `get_input_mask_at(node_index)` '
                           'instead.')
    return self._get_node_attribute_at_index(0, 'input_masks', 'input mask')

  @property
  def output_mask(self):
    """Retrieves the output mask tensor(s) of a layer.

    Only applicable if the layer has exactly one inbound node,
    i.e. if it is connected to one incoming layer.

    Returns:
        Output mask tensor (potentially None) or list of output
        mask tensors.

    Raises:
        AttributeError: if the layer is connected to
        more than one incoming layers.
    """
    if len(self.inbound_nodes) != 1:
      raise AttributeError('Layer ' + self.name +
                           ' has multiple inbound nodes, '
                           'hence the notion of "layer output mask" '
                           'is ill-defined. '
                           'Use `get_output_mask_at(node_index)` '
                           'instead.')
    return self._get_node_attribute_at_index(0, 'output_masks', 'output mask')

  @property
  def input_shape(self):
    """Retrieves the input shape(s) of a layer.

    Only applicable if the layer has exactly one inbound node,
    i.e. if it is connected to one incoming layer.

    Returns:
        Input shape, as `TensorShape`
        (or list of `TensorShape`, one tuple per input tensor).

    Raises:
        AttributeError: if the layer is connected to
        more than one incoming layers.
    """
    if not self.inbound_nodes:
      raise AttributeError('The layer has never been called '
                           'and thus has no defined input shape.')
    all_input_shapes = set(
        [str(node.input_shapes) for node in self.inbound_nodes])
    if len(all_input_shapes) == 1:
      input_shapes = self.inbound_nodes[0].input_shapes
      if len(input_shapes) == 1:
        return tuple(tensor_shape.TensorShape(input_shapes[0]).as_list())
      else:
        return [
            tuple(tensor_shape.TensorShape(shape).as_list())
            for shape in input_shapes
        ]
    else:
      raise AttributeError('The layer "' + str(self.name) +
                           ' has multiple inbound nodes, '
                           'with different input shapes. Hence '
                           'the notion of "input shape" is '
                           'ill-defined for the layer. '
                           'Use `get_input_shape_at(node_index)` '
                           'instead.')

  @property
  def output_shape(self):
    """Retrieves the output shape(s) of a layer.

    Only applicable if the layer has one inbound node,
    or if all inbound nodes have the same output shape.

    Returns:
        Output shape, as `TensorShape`
        (or list of `TensorShape`, one tuple per output tensor).

    Raises:
        AttributeError: if the layer is connected to
        more than one incoming layers.
    """
    if not self.inbound_nodes:
      raise AttributeError('The layer has never been called '
                           'and thus has no defined output shape.')
    all_output_shapes = set(
        [str(node.output_shapes) for node in self.inbound_nodes])
    if len(all_output_shapes) == 1:
      output_shapes = self.inbound_nodes[0].output_shapes
      if len(output_shapes) == 1:
        return tuple(tensor_shape.TensorShape(output_shapes[0]).as_list())
      else:
        return [
            tuple(tensor_shape.TensorShape(shape).as_list())
            for shape in output_shapes
        ]
    else:
      raise AttributeError('The layer "' + str(self.name) +
                           ' has multiple inbound nodes, '
                           'with different output shapes. Hence '
                           'the notion of "output shape" is '
                           'ill-defined for the layer. '
                           'Use `get_output_shape_at(node_index)` '
                           'instead.')

  def set_weights(self, weights):
    """Sets the weights of the layer, from Numpy arrays.

    Arguments:
        weights: a list of Numpy arrays. The number
            of arrays and their shape must match
            number of the dimensions of the weights
            of the layer (i.e. it should match the
            output of `get_weights`).

    Raises:
        ValueError: If the provided weights list does not match the
            layer's specifications.
    """
    params = self.weights
    if len(params) != len(weights):
      raise ValueError('You called `set_weights(weights)` on layer "' +
                       self.name + '" with a  weight list of length ' +
                       str(len(weights)) + ', but the layer was expecting ' +
                       str(len(params)) + ' weights. Provided weights: ' +
                       str(weights)[:50] + '...')
    if not params:
      return
    weight_value_tuples = []
    param_values = K.batch_get_value(params)
    for pv, p, w in zip(param_values, params, weights):
      if pv.shape != w.shape:
        raise ValueError('Layer weight shape ' + str(pv.shape) +
                         ' not compatible with '
                         'provided weight shape ' + str(w.shape))
      weight_value_tuples.append((p, w))
    K.batch_set_value(weight_value_tuples)

  def get_weights(self):
    """Returns the current weights of the layer.

    Returns:
        Weights values as a list of numpy arrays.
    """
    params = self.weights
    return K.batch_get_value(params)

  def get_config(self):
    """Returns the config of the layer.

    A layer config is a Python dictionary (serializable)
    containing the configuration of a layer.
    The same layer can be reinstantiated later
    (without its trained weights) from this configuration.

    The config of a layer does not include connectivity
    information, nor the layer class name. These are handled
    by `Container` (one layer of abstraction above).

    Returns:
        Python dictionary.
    """
    config = {'name': self.name, 'trainable': self.trainable}
    if hasattr(self, 'batch_input_shape'):
      config['batch_input_shape'] = self.batch_input_shape
    if hasattr(self, 'dtype'):
      config['dtype'] = self.dtype
    return config

  @classmethod
  def from_config(cls, config):
    """Creates a layer from its config.

    This method is the reverse of `get_config`,
    capable of instantiating the same layer from the config
    dictionary. It does not handle layer connectivity
    (handled by Container), nor weights (handled by `set_weights`).

    Arguments:
        config: A Python dictionary, typically the
            output of get_config.

    Returns:
        A layer instance.
    """
    return cls(**config)

  def count_params(self):
    """Count the total number of scalars composing the weights.

    Returns:
        An integer count.

    Raises:
        RuntimeError: if the layer isn't yet built
            (in which case its weights aren't yet defined).
    """
    if not self.built:
      if self.__class__.__name__ == 'Sequential':
        self.build()  # pylint: disable=no-value-for-parameter
      else:
        raise RuntimeError('You tried to call `count_params` on ' + self.name +
                           ', but the layer isn\'t built. '
                           'You can build it manually via: `' + self.name +
                           '.build(batch_input_shape)`.')
    return sum([K.count_params(p) for p in self.weights])


class InputLayer(Layer):
  """Layer to be used as an entry point into a graph.

  It can either wrap an existing tensor (pass an `input_tensor` argument)
  or create its a placeholder tensor (pass arguments `input_shape`
  or `batch_input_shape` as well as `dtype`).

  Arguments:
      input_shape: Shape tuple, not including the batch axis.
      batch_size: Optional input batch size (integer or None).
      batch_input_shape: Shape tuple, including the batch axis.
      dtype: Datatype of the input.
      input_tensor: Optional tensor to use as layer input
          instead of creating a placeholder.
      sparse: Boolean, whether the placeholder created
          is meant to be sparse.
      name: Name of the layer (string).
  """

  def __init__(self,
               input_shape=None,
               batch_size=None,
               batch_input_shape=None,
               dtype=None,
               input_tensor=None,
               sparse=False,
               name=None):
    if not name:
      prefix = 'input'
      name = prefix + '_' + str(K.get_uid(prefix))
    if not dtype:
      if input_tensor is None:
        dtype = K.floatx()
      else:
        dtype = K.dtype(input_tensor)
    super(InputLayer, self).__init__(dtype=dtype, name=name)
    self.built = True
    self.sparse = sparse

    if input_shape and batch_input_shape:
      raise ValueError('Only provide the input_shape OR '
                       'batch_input_shape argument to '
                       'InputLayer, not both at the same time.')
    if input_tensor is not None:
      # Attempt automatic input shape inference.
      try:
        batch_input_shape = K.int_shape(input_tensor)
      except TypeError:
        if not input_shape and not batch_input_shape:
          raise ValueError('InputLayer was provided '
                           'an input_tensor argument, '
                           'but its input shape cannot be '
                           'automatically inferred. '
                           'You should pass an input_shape or '
                           'batch_input_shape argument.')
    if not batch_input_shape:
      if not input_shape:
        raise ValueError('An Input layer should be passed either '
                         'a `batch_input_shape` or an `input_shape`.')
      else:
        batch_input_shape = (batch_size,) + tuple(input_shape)
    else:
      batch_input_shape = tuple(batch_input_shape)
    self.batch_input_shape = batch_input_shape

    if input_tensor is None:
      self.is_placeholder = True
      input_tensor = K.placeholder(
          shape=batch_input_shape,
          dtype=dtype,
          sparse=self.sparse,
          name=self.name)
    else:
      self.is_placeholder = False
    # Create an input node to add to self.outbound_node
    # and set output_tensors' _keras_history.
    input_tensor._uses_learning_phase = False
    input_tensor._keras_history = (self, 0, 0)
    Node(
        self,
        inbound_layers=[],
        node_indices=[],
        tensor_indices=[],
        input_tensors=[input_tensor],
        output_tensors=[input_tensor],
        input_masks=[None],
        output_masks=[None])

  def get_config(self):
    config = {
        'batch_input_shape': self.batch_input_shape,
        'dtype': self.dtype,
        'sparse': self.sparse,
        'name': self.name
    }
    return config


def Input(  # pylint: disable=invalid-name
    shape=None,
    batch_shape=None,
    name=None,
    dtype=K.floatx(),
    sparse=False,
    tensor=None):
  """`Input()` is used to instantiate a Keras tensor.

  A Keras tensor is a tensor object from the underlying backend
  (Theano or TensorFlow), which we augment with certain
  attributes that allow us to build a Keras model
  just by knowing the inputs and outputs of the model.

  For instance, if a, b and c and Keras tensors,
  it becomes possible to do:
  `model = Model(input=[a, b], output=c)`

  The added Keras attribute is:
      `_keras_history`: Last layer applied to the tensor.
          the entire layer graph is retrievable from that layer,
          recursively.

  Arguments:
      shape: A shape tuple (integer), not including the batch size.
          For instance, `shape=(32,)` indicates that the expected input
          will be batches of 32-dimensional vectors.
      batch_shape: A shape tuple (integer), including the batch size.
          For instance, `batch_shape=(10, 32)` indicates that
          the expected input will be batches of 10 32-dimensional vectors.
          `batch_shape=(None, 32)` indicates batches of an arbitrary number
          of 32-dimensional vectors.
      name: An optional name string for the layer.
          Should be unique in a model (do not reuse the same name twice).
          It will be autogenerated if it isn't provided.
      dtype: The data type expected by the input, as a string
          (`float32`, `float64`, `int32`...)
      sparse: A boolean specifying whether the placeholder
          to be created is sparse.
      tensor: Optional existing tensor to wrap into the `Input` layer.
          If set, the layer will not create a placeholder tensor.

  Returns:
      A tensor.

  Example:

      ```python
      # this is a logistic regression in Keras
      x = Input(shape=(32,))
      y = Dense(16, activation='softmax')(x)
      model = Model(x, y)
      ```
  """
  if not batch_shape and tensor is None:
    assert shape, ('Please provide to Input either a `shape`'
                   ' or a `batch_shape` argument. Note that '
                   '`shape` does not include the batch '
                   'dimension.')
  if shape and not batch_shape:
    batch_shape = (None,) + tuple(shape)
  input_layer = InputLayer(
      batch_input_shape=batch_shape,
      name=name,
      dtype=dtype,
      sparse=sparse,
      input_tensor=tensor)
  # Return tensor including `_keras_history`.
  # Note that in this case train_output and test_output are the same pointer.
  outputs = input_layer.inbound_nodes[0].output_tensors
  if len(outputs) == 1:
    return outputs[0]
  else:
    return outputs


class Container(Layer):
  """A Container is a directed acyclic graph of layers.

  It is the topological form of a "model". A Model
  is simply a Container with added training routines.

  # Properties
      name
      inputs
      outputs
      input_layers
      output_layers
      input_spec (list of class instances)
          each entry describes one required input:
              - ndim
              - dtype
      trainable (boolean)
      input_shape
      output_shape
      inbound_nodes: list of nodes
      outbound_nodes: list of nodes
      trainable_weights (list of variables)
      non_trainable_weights (list of variables)
      constraints (list of tuples (weight, constraint))

  # Methods
      summary
      get_layer
      get_weights
      set_weights
      get_config
      compute_output_shape

  # Class Methods
      from_config
  """

  def __init__(self, inputs, outputs, name=None):  # pylint: disable=super-init-not-called
    # Handle `name` argument.
    if not name:
      prefix = self.__class__.__name__.lower()
      name = prefix + '_' + str(K.get_uid(prefix))
    self.name = name
    self.supports_masking = False
    self.trainable = True
    self._per_input_losses = {}
    self._per_input_updates = {}

    # The following properties are not actually used by Keras;
    # they exist for compatibility with TF.
    self._updates = []
    self._scope = None
    self._reuse = None
    self._base_name = name
    self._graph = ops.get_default_graph()

    # Container-specific properties.
    if isinstance(inputs, (list, tuple)):
      self.inputs = list(inputs)  # Tensor or list of tensors.
    else:
      self.inputs = [inputs]
    if isinstance(outputs, (list, tuple)):
      self.outputs = list(outputs)
    else:
      self.outputs = [outputs]

    # Check for redundancy in inputs.
    inputs_set = set(self.inputs)
    if len(inputs_set) != len(self.inputs):
      raise ValueError('The list of inputs passed to the model '
                       'is redundant. '
                       'All inputs should only appear once.'
                       ' Found: ' + str(self.inputs))

    # List of initial layers (1 to 1 mapping with self.inputs,
    # hence the same layer might appear twice)
    self.input_layers = []
    self.input_layers_node_indices = []
    self.input_layers_tensor_indices = []
    # list of layers (1 to 1 mapping with self.inputs,
    # hence the same layer might appear twice)
    self.output_layers = []
    self.output_layers_node_indices = []
    self.output_layers_tensor_indices = []
    # all layers in order of horizontal graph traversal.
    # Entries are unique. Includes input and output layers.
    self.layers = []

    # This is for performance optimization
    # when calling the Container on new inputs.
    # every time the Container is called on a set on input tensors,
    # we compute the output tensors,
    # output masks and output shapes in one pass,
    # then cache them here. When of of these output is queried later,
    # we retrieve it from there instead of recomputing it.
    self._output_mask_cache = {}
    self._output_tensor_cache = {}
    self._output_shape_cache = {}

    # User-provided arguments validation.
    for x in self.inputs:
      # Check that x is a Keras tensor.
      if not hasattr(x, '_keras_history'):
        cls_name = self.__class__.__name__
        raise TypeError('Input tensors to a ' + cls_name + ' ' +
                        'must be Keras tensors. Found: ' + str(x) +
                        ' (missing Keras metadata).')
      # Check that x is an input tensor.
      layer, node_index, tensor_index = x._keras_history
      if len(layer.inbound_nodes) > 1 or (
          layer.inbound_nodes and layer.inbound_nodes[0].inbound_layers):
        cls_name = self.__class__.__name__
        warnings.warn(cls_name + ' inputs must come from '
                      'a Keras Input layer, '
                      'they cannot be the output of '
                      'a previous non-Input layer. '
                      'Here, a tensor specified as '
                      'input to "' + self.name + '" was not an Input tensor, '
                      'it was generated by layer ' + layer.name + '.\n'
                      'Note that input tensors are '
                      'instantiated via `tensor = Input(shape)`.\n'
                      'The tensor that caused the issue was: ' + str(x.name))
    for x in self.outputs:
      if not hasattr(x, '_keras_history'):
        cls_name = self.__class__.__name__
        raise TypeError('Output tensors to a ' + cls_name + ' must be '
                        'Keras tensors. Found: ' + str(x))
    # Build self.output_layers:
    for x in self.outputs:
      layer, node_index, tensor_index = x._keras_history
      self.output_layers.append(layer)
      self.output_layers_node_indices.append(node_index)
      self.output_layers_tensor_indices.append(tensor_index)

    # Fill in the output mask cache.
    masks = []
    for x in self.inputs:
      layer, node_index, tensor_index = x._keras_history
      node = layer.inbound_nodes[node_index]
      mask = node.output_masks[tensor_index]
      masks.append(mask)
    mask_cache_key = ','.join([str(id(x)) for x in self.inputs])
    mask_cache_key += '_' + ','.join([str(id(x)) for x in masks])
    masks = []
    for x in self.outputs:
      layer, node_index, tensor_index = x._keras_history
      node = layer.inbound_nodes[node_index]
      mask = node.output_masks[tensor_index]
      masks.append(mask)
    if len(masks) == 1:
      mask = masks[0]
    else:
      mask = masks
    self._output_mask_cache[mask_cache_key] = mask

    # Build self.input_layers:
    for x in self.inputs:
      layer, node_index, tensor_index = x._keras_history
      # It's supposed to be an input layer, so only one node
      # and one tensor output.
      assert node_index == 0
      assert tensor_index == 0
      self.input_layers.append(layer)
      self.input_layers_node_indices.append(node_index)
      self.input_layers_tensor_indices.append(tensor_index)

    # Build self.input_names and self.output_names.
    self.input_names = []
    self.output_names = []
    self._feed_input_names = []
    self._feed_inputs = []
    self._feed_input_shapes = []
    for i, layer in enumerate(self.input_layers):
      self.input_names.append(layer.name)
      if layer.is_placeholder:
        self._feed_input_names.append(layer.name)
        self._feed_inputs.append(layer.input)
        self._feed_input_shapes.append(K.int_shape(self.inputs[i]))
    for layer in self.output_layers:
      self.output_names.append(layer.name)

    self.internal_input_shapes = [K.int_shape(x) for x in self.inputs]
    self.internal_output_shapes = [K.int_shape(x) for x in self.outputs]

    # Container_nodes: set of nodes included in the graph
    # (not all nodes included in the layers
    # are relevant to the current graph).
    container_nodes = set()  # ids of all nodes relevant to the Container
    nodes_depths = {}  # dict {node: depth value}
    layers_depths = {}  # dict {layer: depth value}
    layer_indices = {}  # dict {layer: index in traversal}

    def make_node_marker(node, depth):
      return str(id(node)) + '-' + str(depth)

    def build_map_of_graph(tensor,
                           seen_nodes=None,
                           depth=0,
                           layer=None,
                           node_index=None,
                           tensor_index=None):
      """Builds a map of the graph of layers.

      This recursively updates the maps `nodes_depths`,
      `layers_depths` and the set `container_nodes`.

      Does not try to detect cycles in the graph.

      Arguments:
          tensor: Some tensor in a graph.
          seen_nodes: Set of node ids ("{layer.name}_ib-{node_index}")
              of nodes seen so far. Useful to prevent infinite loops.
          depth: Current depth in the graph (0 = last output).
          layer: Layer from which `tensor` comes from. If not provided,
              will be obtained from `tensor._keras_history`.
          node_index: Node index from which `tensor` comes from.
          tensor_index: Tensor_index from which `tensor` comes from.
      """
      seen_nodes = seen_nodes or set()
      if not layer or node_index is None or tensor_index is None:
        layer, node_index, tensor_index = tensor._keras_history
      node = layer.inbound_nodes[node_index]

      # Prevent cycles.
      seen_nodes.add(make_node_marker(node, depth))

      node_key = layer.name + '_ib-' + str(node_index)
      # Update container_nodes.
      container_nodes.add(node_key)
      # Update nodes_depths.
      node_depth = nodes_depths.get(node)
      if node_depth is None:
        nodes_depths[node] = depth
      else:
        nodes_depths[node] = max(depth, node_depth)
      # Update layers_depths.
      previously_seen_depth = layers_depths.get(layer)
      if previously_seen_depth is None:
        current_depth = depth
      else:
        current_depth = max(depth, previously_seen_depth)
      layers_depths[layer] = current_depth
      if layer not in layer_indices:
        layer_indices[layer] = len(layer_indices)

      # Propagate to all previous tensors connected to this node.
      for i in range(len(node.inbound_layers)):
        x = node.input_tensors[i]
        layer = node.inbound_layers[i]
        node_index = node.node_indices[i]
        tensor_index = node.tensor_indices[i]
        next_node = layer.inbound_nodes[node_index]
        # use node_marker to prevent cycles
        node_marker = make_node_marker(next_node, current_depth + 1)
        if node_marker not in seen_nodes:
          build_map_of_graph(x, seen_nodes, current_depth + 1, layer,
                             node_index, tensor_index)

    for x in self.outputs:
      seen_nodes = set()
      build_map_of_graph(x, seen_nodes, depth=0)

    # Build a dict {depth: list of nodes with this depth}
    nodes_by_depth = {}
    for node, depth in nodes_depths.items():
      if depth not in nodes_by_depth:
        nodes_by_depth[depth] = []
      nodes_by_depth[depth].append(node)

    # Build a dict {depth: list of layers with this depth}
    layers_by_depth = {}
    for layer, depth in layers_depths.items():
      if depth not in layers_by_depth:
        layers_by_depth[depth] = []
      layers_by_depth[depth].append(layer)

    # Get sorted list of layer depths.
    depth_keys = list(layers_by_depth.keys())
    depth_keys.sort(reverse=True)

    # Set self.layers and self.layers_by_depth.
    layers = []
    for depth in depth_keys:
      layers_for_depth = layers_by_depth[depth]
      # Container.layers needs to have a deterministic order:
      # here we order them by traversal order.
      layers_for_depth.sort(key=lambda x: layer_indices[x])
      for layer in layers_for_depth:
        layers.append(layer)
    self.layers = layers
    self.layers_by_depth = layers_by_depth

    # Get sorted list of node depths.
    depth_keys = list(nodes_by_depth.keys())
    depth_keys.sort(reverse=True)

    # Check that all tensors required are computable.
    # computable_tensors: all tensors in the graph
    # that can be computed from the inputs provided.
    computable_tensors = []
    for x in self.inputs:
      computable_tensors.append(x)

    layers_with_complete_input = []  # To provide a better error msg.
    for depth in depth_keys:
      for node in nodes_by_depth[depth]:
        layer = node.outbound_layer
        if layer:
          for x in node.input_tensors:
            if x not in computable_tensors:
              raise RuntimeError('Graph disconnected: '
                                 'cannot obtain value for tensor ' + str(x) +
                                 ' at layer "' + layer.name + '". '
                                 'The following previous layers '
                                 'were accessed without issue: ' +
                                 str(layers_with_complete_input))
          for x in node.output_tensors:
            computable_tensors.append(x)
          layers_with_complete_input.append(layer.name)

    # Set self.nodes and self.nodes_by_depth.
    self.container_nodes = container_nodes
    self.nodes_by_depth = nodes_by_depth

    # Ensure name unicity, which will be crucial for serialization
    # (since serialized nodes refer to layers by their name).
    all_names = [layer.name for layer in self.layers]
    for name in all_names:
      if all_names.count(name) != 1:
        raise RuntimeError('The name "' + name + '" is used ' +
                           str(all_names.count(name)) + ' times in the model. '
                           'All layer names should be unique.')

    # Layer parameters.
    # The new container starts with a single inbound node
    # for its inputs, and no outbound nodes.
    self.outbound_nodes = []  # Will be appended to by future calls to __call__
    self.inbound_nodes = [
    ]  # Will be appended to below, and by future calls to __call__
    # Create the node linking internal inputs to internal outputs.
    Node(
        outbound_layer=self,
        inbound_layers=[],
        node_indices=[],
        tensor_indices=[],
        input_tensors=self.inputs,
        output_tensors=self.outputs,
        # No container-level masking for now.
        input_masks=[None for _ in self.inputs],
        output_masks=[None for _ in self.outputs])
    self.built = True

    # The following are implemented as property functions:
    # self.constraints
    # self.trainable_weights
    # self.non_trainable_weights
    # self.input_spec

  def get_layer(self, name=None, index=None):
    """Retrieves a layer based on either its name (unique) or index.

    Indices are based on order of horizontal graph traversal (bottom-up).

    Arguments:
        name: String, name of layer.
        index: Integer, index of layer.

    Returns:
        A layer instance.

    Raises:
        ValueError: In case of invalid layer name or index.
    """
    # It would be unreliable to build a dictionary
    # based on layer names, because names can potentially
    # be changed at any point by the user
    # without the container being notified of it.
    if index is not None:
      if len(self.layers) <= index:
        raise ValueError('Was asked to retrieve layer at index ' + str(index) +
                         ' but model only has ' + str(len(self.layers)) +
                         ' layers.')
      else:
        return self.layers[index]
    else:
      if not name:
        raise ValueError('Provide either a layer name or layer index.')
    layer = None
    for layer in self.layers:
      if layer.name == name:
        return layer
    if not layer:
      raise ValueError('No such layer: ' + name)

  @property
  def updates(self):
    """Retrieve the model's updates.

    Will only include updates that are either
    inconditional, or conditional on inputs to this model
    (e.g. will not include updates that depend on tensors
    that aren't inputs to this model).

    Returns:
        A list of update ops.
    """
    updates = []
    for layer in self.layers:
      if hasattr(layer, 'updates'):
        # Collect updates that are dependent on inputs
        # that are part of the model.
        for node_index, node in enumerate(layer.inbound_nodes):
          node_key = layer.name + '_ib-' + str(node_index)
          if node_key in self.container_nodes:
            # The model owns this layer node.
            inputs = node.input_tensors
            updates += layer.get_updates_for(inputs)
        # Collect unconditional updates.
        updates += layer.get_updates_for(None)
    return updates

  @property
  def losses(self):
    """Retrieve the model's losses.

    Will only include losses that are either
    inconditional, or conditional on inputs to this model
    (e.g. will not include losses that depend on tensors
    that aren't inputs to this model).

    Returns:
        A list of loss tensors.
    """
    losses = []
    # Retrieve losses for all internal layers.
    for layer in self.layers:
      if hasattr(layer, 'losses'):
        # Collect losses that are dependent on inputs
        # that are part of the model.
        for node_index, node in enumerate(layer.inbound_nodes):
          node_key = layer.name + '_ib-' + str(node_index)
          if node_key in self.container_nodes:
            # The model owns this layer node.
            inputs = node.input_tensors
            losses += layer.get_losses_for(inputs)
        # Collect unconditional losses.
        losses += layer.get_losses_for(None)
    # Add any potential unconditional model-level loss.
    losses += self.get_losses_for(None)
    return losses

  @property
  def uses_learning_phase(self):
    return any([x._uses_learning_phase for x in self.outputs])

  @property
  def stateful(self):
    return any([(hasattr(layer, 'stateful') and layer.stateful)
                for layer in self.layers])

  def reset_states(self):
    for layer in self.layers:
      if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
        layer.reset_states()

  @property
  def state_updates(self):
    """Returns the `updates` from all layers that are stateful.

    This is useful for separating training updates and
    state updates, e.g. when we need to update a layer's internal state
    during prediction.

    Returns:
        A list of update ops.
    """
    state_updates = []
    for layer in self.layers:
      if getattr(layer, 'stateful', False):
        if hasattr(layer, 'updates'):
          state_updates += layer.updates
    return state_updates

  @property
  def constraints(self):
    cons = {}
    for layer in self.layers:
      for key, value in layer.constraints.items():
        if key in cons and cons[key] != value:
          raise ValueError('Received multiple constraints '
                           'for one weight tensor: ' + str(key))
        cons[key] = value
    return cons

  @property
  def trainable_weights(self):
    if not self.trainable:
      return []
    weights = []
    for layer in self.layers:
      weights += layer.trainable_weights
    return weights

  @property
  def non_trainable_weights(self):
    weights = []
    for layer in self.layers:
      weights += layer.non_trainable_weights
    if not self.trainable:
      trainable_weights = []
      for layer in self.layers:
        trainable_weights += layer.trainable_weights
      return trainable_weights + weights
    return weights

  def get_weights(self):
    """Retrieves the weights of the model.

    Returns:
        A flat list of Numpy arrays.
    """
    weights = []
    for layer in self.layers:
      weights += layer.weights
    return K.batch_get_value(weights)

  def set_weights(self, weights):
    """Sets the weights of the model.

    Arguments:
        weights: A list of Numpy arrays with shapes and types matching
            the output of `model.get_weights()`.
    """
    tuples = []
    for layer in self.layers:
      num_param = len(layer.weights)
      layer_weights = weights[:num_param]
      for sw, w in zip(layer.weights, layer_weights):
        tuples.append((sw, w))
      weights = weights[num_param:]
    K.batch_set_value(tuples)

  @property
  def input_spec(self):
    """Gets the model's input specs.

    Returns:
        A list of `InputSpec` instances (one per input to the model)
            or a single instance if the model has only one input.
    """
    specs = []
    for layer in getattr(self, 'input_layers', []):
      if layer.input_spec is None:
        specs.append(None)
      else:
        if not isinstance(layer.input_spec, list):
          raise TypeError('Layer ' + layer.name +
                          ' has an input_spec attribute that '
                          'is not a list. We expect a list. '
                          'Found input_spec = ' + str(layer.input_spec))
        specs += layer.input_spec
    if len(specs) == 1:
      return specs[0]
    return specs

  def call(self, inputs, mask=None):
    """Call the model on new inputs.

    In this case `call` just reapplies
    all ops in the graph to the new inputs
    (e.g. build a new computational graph from the provided inputs).

    A model is callable on non-Keras tensors.

    Arguments:
        inputs: A tensor or list of tensors.
        mask: A mask or list of masks. A mask can be
            either a tensor or None (no mask).

    Returns:
        A tensor if there is a single output, or
        a list of tensors if there are more than one outputs.
    """
    inputs = _to_list(inputs)
    if mask is None:
      masks = [None for _ in range(len(inputs))]
    else:
      masks = _to_list(mask)
    cache_key = ','.join([str(id(x)) for x in inputs])
    cache_key += '_' + ','.join([str(id(x)) for x in masks])
    if cache_key in self._output_tensor_cache:
      return self._output_tensor_cache[cache_key]
    else:
      output_tensors, _, _ = self.run_internal_graph(inputs, masks)
      return output_tensors

  def compute_mask(self, inputs, mask):
    inputs = _to_list(inputs)
    if mask is None:
      masks = [None for _ in range(len(inputs))]
    else:
      masks = _to_list(mask)
    cache_key = ','.join([str(id(x)) for x in inputs])
    cache_key += '_' + ','.join([str(id(x)) for x in masks])
    if cache_key in self._output_mask_cache:
      return self._output_mask_cache[cache_key]
    else:
      _, output_masks, _ = self.run_internal_graph(inputs, masks)
      return output_masks

  def _compute_output_shape(self, input_shape):
    if isinstance(input_shape, list):
      input_shapes = []
      for shape in input_shape:
        if shape is not None:
          input_shapes.append(tuple(tensor_shape.TensorShape(shape).as_list()))
        else:
          input_shapes.append(None)
    else:
      if input_shape is not None:
        input_shapes = [tuple(tensor_shape.TensorShape(input_shape).as_list())]
      else:
        input_shapes = [None]

    if len(input_shapes) != len(self.input_layers):
      raise ValueError('Invalid input_shape argument ' + str(input_shape) +
                       ': model has ' + str(len(self.input_layers)) +
                       ' tensor inputs.')

    cache_key = ','.join([str(x) for x in input_shapes])
    if cache_key in self._output_shape_cache:
      output_shapes = self._output_shape_cache[cache_key]
      if isinstance(output_shapes, list):
        if len(output_shapes) == 1:
          return tensor_shape.TensorShape(output_shapes[0])
        else:
          return [tensor_shape.TensorShape(shape) for shape in output_shapes]
      else:
        return tensor_shape.TensorShape(output_shapes)
    else:
      # Bad luck, we have to run the graph manually.
      layers_to_output_shapes = {}
      for i in range(len(input_shapes)):
        layer = self.input_layers[i]
        input_shape = input_shapes[i]
        # It's an input layer: compute_output_shape is identity,
        # and there is only one node and one tensor output.
        shape_key = layer.name + '_0_0'
        layers_to_output_shapes[shape_key] = input_shape

      depth_keys = list(self.nodes_by_depth.keys())
      depth_keys.sort(reverse=True)
      # Iterate over nodes, by depth level.
      if len(depth_keys) > 1:
        for depth in depth_keys:
          nodes = self.nodes_by_depth[depth]
          for node in nodes:
            # This is always a single layer, never a list.
            layer = node.outbound_layer
            if layer in self.input_layers:
              # We've already covered the input layers
              # a few lines above.
              continue
            # Potentially redundant list,
            # same size of node.input_tensors.
            input_shapes = []
            for j in range(len(node.inbound_layers)):
              inbound_layer = node.inbound_layers[j]
              node_index = node.node_indices[j]
              tensor_index = node.tensor_indices[j]
              shape_key = inbound_layer.name + '_%s_%s' % (node_index,
                                                           tensor_index)
              input_shape = layers_to_output_shapes[shape_key]
              input_shapes.append(input_shape)

            if len(input_shapes) == 1:
              output_shape = layer._compute_output_shape(input_shapes[0])
            else:
              output_shape = layer._compute_output_shape(input_shapes)
            if isinstance(output_shape, list):
              output_shapes = [
                  tuple(tensor_shape.TensorShape(shape).as_list())
                  for shape in output_shape
              ]
            else:
              output_shapes = [
                  tuple(tensor_shape.TensorShape(output_shape).as_list())
              ]

            node_index = layer.inbound_nodes.index(node)
            for j in range(len(output_shapes)):
              shape_key = layer.name + '_%s_%s' % (node_index, j)
              layers_to_output_shapes[shape_key] = output_shapes[j]

      # Read final output shapes from layers_to_output_shapes.
      output_shapes = []
      output_shape_keys = []
      for i in range(len(self.output_layers)):
        layer = self.output_layers[i]
        node_index = self.output_layers_node_indices[i]
        tensor_index = self.output_layers_tensor_indices[i]
        shape_key = layer.name + '_%s_%s' % (node_index, tensor_index)
        output_shape_keys.append(shape_key)

      for i, key in enumerate(output_shape_keys):
        assert key in layers_to_output_shapes
        output_shapes.append(layers_to_output_shapes[key])
      # Store in cache.
      self._output_shape_cache[cache_key] = output_shapes
      if isinstance(output_shapes, list):
        if len(output_shapes) == 1:
          return tensor_shape.TensorShape(output_shapes[0])
        else:
          return [tensor_shape.TensorShape(shape) for shape in output_shapes]
      else:
        return tensor_shape.TensorShape(output_shapes)

  def run_internal_graph(self, inputs, masks=None):
    """Computes output tensors for new inputs.

    # Note:
        - Expects `inputs` to be a list (potentially with 1 element).
        - Can be run on non-Keras tensors.

    Arguments:
        inputs: List of tensors
        masks: List of masks (tensors or None).

    Returns:
        Three lists: output_tensors, output_masks, output_shapes
    """
    if masks is None:
      masks = [None for _ in range(len(inputs))]

    # Dictionary mapping reference tensors to tuples
    # (computed tensor, compute mask)
    # we assume a 1:1 mapping from tensor to mask
    # TODO(fchollet): raise exception when a `.compute_mask()` call
    # does not return a list the same size as `call`
    tensor_map = {}
    for x, y, mask in zip(self.inputs, inputs, masks):
      tensor_map[str(id(x))] = (y, mask)

    depth_keys = list(self.nodes_by_depth.keys())
    depth_keys.sort(reverse=True)
    for depth in depth_keys:
      nodes = self.nodes_by_depth[depth]
      for node in nodes:
        # This is always a single layer, never a list.
        layer = node.outbound_layer

        reference_input_tensors = node.input_tensors
        reference_output_tensors = node.output_tensors

        # If all previous input tensors are available in tensor_map,
        # then call node.inbound_layer on them.
        computed_data = []  # List of tuples (input, mask).
        for x in reference_input_tensors:
          if str(id(x)) in tensor_map:
            computed_data.append(tensor_map[str(id(x))])

        if len(computed_data) == len(reference_input_tensors):
          # call layer
          with K.name_scope(layer.name):
            if node.arguments:
              kwargs = node.arguments
            else:
              kwargs = {}
            if len(computed_data) == 1:
              computed_tensor, computed_mask = computed_data[0]
              if 'mask' in tf_inspect.getargspec(layer.call).args:
                if 'mask' not in kwargs:
                  kwargs['mask'] = computed_mask
              output_tensors = _to_list(layer.call(computed_tensor, **kwargs))
              output_masks = _to_list(
                  layer.compute_mask(computed_tensor, computed_mask))
              computed_tensors = [computed_tensor]
              computed_masks = [computed_mask]
            else:
              computed_tensors = [x[0] for x in computed_data]
              computed_masks = [x[1] for x in computed_data]
              if 'mask' in tf_inspect.getargspec(layer.call).args:
                if 'mask' not in kwargs:
                  kwargs['mask'] = computed_masks
              output_tensors = _to_list(layer.call(computed_tensors, **kwargs))
              output_masks = _to_list(
                  layer.compute_mask(computed_tensors, computed_masks))

            # Apply activity regularizer if any:
            if hasattr(layer, 'activity_regularizer'
                      ) and layer.activity_regularizer is not None:
              regularization_losses = [
                  layer.activity_regularizer(x) for x in computed_tensors
              ]
              layer.add_loss(regularization_losses, computed_tensors)

          # Update model updates and losses:
          # Keep track of updates that depend on the inputs
          # (e.g. BN updates).
          self.add_update(layer.get_updates_for(computed_tensors), inputs)
          # Keep track of unconditional updates (e.g. a counter).
          self.add_update(layer.get_updates_for(None), None)
          # Keep track of losses that depend on the inputs
          # (e.g. activity regularizers).
          self.add_loss(layer.get_losses_for(computed_tensors), inputs)
          # Keep track of unconditional losses
          # (e.g. weight regularizers).
          self.add_loss(layer.get_losses_for(None), None)

          # Update `_uses_learning_phase`.
          if len(computed_tensors) == 1:
            uses_learning_phase = getattr(computed_tensors[0],
                                          '_uses_learning_phase', False)
          else:
            uses_learning_phase = any([
                getattr(x, '_uses_learning_phase', False)
                for x in computed_tensors
            ])
          for x in output_tensors:
            x._uses_learning_phase = getattr(x, '_uses_learning_phase',
                                             False) or uses_learning_phase

          # Update tensor_map.
          for x, y, mask in zip(reference_output_tensors, output_tensors,
                                output_masks):
            tensor_map[str(id(x))] = (y, mask)

    output_tensors = []
    output_masks = []
    output_shapes = []
    for x in self.outputs:
      assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x)
      tensor, mask = tensor_map[str(id(x))]
      output_shapes.append(K.int_shape(x))
      output_tensors.append(tensor)
      output_masks.append(mask)

    # Update cache;
    # keys are based on ids on input tensors and inputs masks.
    cache_key = ','.join([str(id(x)) for x in inputs])
    cache_key += '_' + ','.join([str(id(x)) for x in masks])

    if len(output_tensors) == 1:
      output_tensors = output_tensors[0]
      self._output_tensor_cache[cache_key] = output_tensors
    else:
      self._output_tensor_cache[cache_key] = output_tensors

    if len(output_masks) == 1:
      output_masks = output_masks[0]
      self._output_mask_cache[cache_key] = output_masks
    else:
      self._output_mask_cache[cache_key] = output_masks

    if output_shapes is not None:
      input_shapes = [K.int_shape(x) for x in inputs]
      cache_key = ','.join([str(x) for x in input_shapes])
      if len(output_shapes) == 1:
        output_shapes = output_shapes[0]
        self._output_shape_cache[cache_key] = output_shapes
      else:
        self._output_shape_cache[cache_key] = output_shapes
    return output_tensors, output_masks, output_shapes

  def get_config(self):
    config = {
        'name': self.name,
    }
    node_conversion_map = {}
    for layer in self.layers:
      if issubclass(layer.__class__, Container):
        # Containers start with a pre-existing node
        # linking their input to output.
        kept_nodes = 1
      else:
        kept_nodes = 0
      for original_node_index, node in enumerate(layer.inbound_nodes):
        node_key = layer.name + '_ib-' + str(original_node_index)
        if node_key in self.container_nodes:
          node_conversion_map[node_key] = kept_nodes
          kept_nodes += 1
    layer_configs = []
    for layer in self.layers:  # From the earliest layers on.
      layer_class_name = layer.__class__.__name__
      layer_config = layer.get_config()
      filtered_inbound_nodes = []
      for original_node_index, node in enumerate(layer.inbound_nodes):
        node_key = layer.name + '_ib-' + str(original_node_index)
        if node_key in self.container_nodes:
          # The node is relevant to the model:
          # add to filtered_inbound_nodes.
          if node.arguments:
            try:
              json.dumps(node.arguments)
              kwargs = node.arguments
            except TypeError:
              warnings.warn('Layer ' + layer.name +
                            ' was passed non-serializable keyword arguments: ' +
                            str(node.arguments) + '. They will not be included '
                            'in the serialized model (and thus will be missing '
                            'at deserialization time).')
              kwargs = {}
          else:
            kwargs = {}
          if node.inbound_layers:
            node_data = []
            for i in range(len(node.inbound_layers)):
              inbound_layer = node.inbound_layers[i]
              node_index = node.node_indices[i]
              tensor_index = node.tensor_indices[i]
              node_key = inbound_layer.name + '_ib-' + str(node_index)
              new_node_index = node_conversion_map.get(node_key, 0)
              node_data.append(
                  [inbound_layer.name, new_node_index, tensor_index, kwargs])
            filtered_inbound_nodes.append(node_data)
      layer_configs.append({
          'name': layer.name,
          'class_name': layer_class_name,
          'config': layer_config,
          'inbound_nodes': filtered_inbound_nodes,
      })
    config['layers'] = layer_configs

    # Gather info about inputs and outputs.
    model_inputs = []
    for i in range(len(self.input_layers)):
      layer = self.input_layers[i]
      node_index = self.input_layers_node_indices[i]
      node_key = layer.name + '_ib-' + str(node_index)
      new_node_index = node_conversion_map[node_key]
      tensor_index = self.input_layers_tensor_indices[i]
      model_inputs.append([layer.name, new_node_index, tensor_index])
    config['input_layers'] = model_inputs
    model_outputs = []
    for i in range(len(self.output_layers)):
      layer = self.output_layers[i]
      node_index = self.output_layers_node_indices[i]
      node_key = layer.name + '_ib-' + str(node_index)
      new_node_index = node_conversion_map[node_key]
      tensor_index = self.output_layers_tensor_indices[i]
      model_outputs.append([layer.name, new_node_index, tensor_index])
    config['output_layers'] = model_outputs
    return copy.deepcopy(config)

  @classmethod
  def from_config(cls, config, custom_objects=None):
    """Instantiates a Model from its config (output of `get_config()`).

    Arguments:
        config: Model config dictionary.
        custom_objects: Optional dictionary mapping names
            (strings) to custom classes or functions to be
            considered during deserialization.

    Returns:
        A model instance.

    Raises:
        ValueError: In case of improperly formatted config dict.
    """
    # layer instances created during
    # the graph reconstruction process
    created_layers = {}

    def process_layer(layer_data):
      """Deserialize a layer, then call it on appropriate inputs.

      Arguments:
          layer_data: layer config dict.

      Raises:
          ValueError: In case of improperly formatted `layer_data` dict.
      """
      layer_name = layer_data['name']

      # Instantiate layer.
      from tensorflow.contrib.keras.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
      layer = deserialize_layer(layer_data, custom_objects=custom_objects)
      created_layers[layer_name] = layer

      # Gather layer inputs.
      inbound_nodes_data = layer_data['inbound_nodes']
      for node_data in inbound_nodes_data:
        input_tensors = []
        for input_data in node_data:
          inbound_layer_name = input_data[0]
          inbound_node_index = input_data[1]
          inbound_tensor_index = input_data[2]
          if len(input_data) == 3:
            kwargs = {}
          elif len(input_data) == 4:
            kwargs = input_data[3]
          else:
            raise ValueError('Improperly formatted model config.')
          if inbound_layer_name not in created_layers:
            raise ValueError('Missing layer: ' + inbound_layer_name)
          inbound_layer = created_layers[inbound_layer_name]
          inbound_node = inbound_layer.inbound_nodes[inbound_node_index]
          input_tensors.append(
              inbound_node.output_tensors[inbound_tensor_index])
        # Call layer on its inputs, thus creating the node
        # and building the layer if needed.
        if input_tensors:
          if len(input_tensors) == 1:
            layer(input_tensors[0], **kwargs)
          else:
            layer(input_tensors, **kwargs)

    for layer_data in config['layers']:
      process_layer(layer_data)

    name = config.get('name')
    input_tensors = []
    output_tensors = []
    for layer_data in config['input_layers']:
      layer_name, node_index, tensor_index = layer_data
      assert layer_name in created_layers
      layer = created_layers[layer_name]
      layer_output_tensors = layer.inbound_nodes[node_index].output_tensors
      input_tensors.append(layer_output_tensors[tensor_index])
    for layer_data in config['output_layers']:
      layer_name, node_index, tensor_index = layer_data
      assert layer_name in created_layers
      layer = created_layers[layer_name]
      layer_output_tensors = layer.inbound_nodes[node_index].output_tensors
      output_tensors.append(layer_output_tensors[tensor_index])
    return cls(inputs=input_tensors, outputs=output_tensors, name=name)

  def save(self, filepath, overwrite=True, include_optimizer=True):
    """Save the model to a single HDF5 file.

    The savefile includes:
        - The model architecture, allowing to re-instantiate the model.
        - The model weights.
        - The state of the optimizer, allowing to resume training
            exactly where you left off.

    This allows you to save the entirety of the state of a model
    in a single file.

    Saved models can be reinstantiated via `keras.models.load_model`.
    The model returned by `load_model`
    is a compiled model ready to be used (unless the saved model
    was never compiled in the first place).

    Arguments:
        filepath: String, path to the file to save the weights to.
        overwrite: Whether to silently overwrite any existing file at the
            target location, or provide the user with a manual prompt.
        include_optimizer: If True, save optimizer's state together.

    Example:

    ```python
    from keras.models import load_model

    model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'
    del model  # deletes the existing model

    # returns a compiled model
    # identical to the previous one
    model = load_model('my_model.h5')
    ```
    """
    from tensorflow.contrib.keras.python.keras.models import save_model  # pylint: disable=g-import-not-at-top
    save_model(self, filepath, overwrite, include_optimizer)

  def save_weights(self, filepath, overwrite=True):
    """Dumps all layer weights to a HDF5 file.

    The weight file has:
        - `layer_names` (attribute), a list of strings
            (ordered names of model layers).
        - For every layer, a `group` named `layer.name`
            - For every such layer group, a group attribute `weight_names`,
                a list of strings
                (ordered names of weights tensor of the layer).
            - For every weight in the layer, a dataset
                storing the weight value, named after the weight tensor.

    Arguments:
        filepath: String, path to the file to save the weights to.
        overwrite: Whether to silently overwrite any existing file at the
            target location, or provide the user with a manual prompt.

    Raises:
        ImportError: If h5py is not available.
    """
    if h5py is None:
      raise ImportError('`save_weights` requires h5py.')
    # If file exists and should not be overwritten:
    if not overwrite and os.path.isfile(filepath):
      proceed = ask_to_proceed_with_overwrite(filepath)
      if not proceed:
        return
    f = h5py.File(filepath, 'w')
    save_weights_to_hdf5_group(f, self.layers)
    f.flush()
    f.close()

  def load_weights(self, filepath, by_name=False):
    """Loads all layer weights from a HDF5 save file.

    If `by_name` is False (default) weights are loaded
    based on the network's topology, meaning the architecture
    should be the same as when the weights were saved.
    Note that layers that don't have weights are not taken
    into account in the topological ordering, so adding or
    removing layers is fine as long as they don't have weights.

    If `by_name` is True, weights are loaded into layers
    only if they share the same name. This is useful
    for fine-tuning or transfer-learning models where
    some of the layers have changed.

    Arguments:
        filepath: String, path to the weights file to load.
        by_name: Boolean, whether to load weights by name
            or by topological order.

    Raises:
        ImportError: If h5py is not available.
    """
    if h5py is None:
      raise ImportError('`load_weights` requires h5py.')
    f = h5py.File(filepath, mode='r')
    if 'layer_names' not in f.attrs and 'model_weights' in f:
      f = f['model_weights']
    if by_name:
      load_weights_from_hdf5_group_by_name(f, self.layers)
    else:
      load_weights_from_hdf5_group(f, self.layers)

    if hasattr(f, 'close'):
      f.close()

  def _updated_config(self):
    """Util hared between different serialization methods.

    Returns:
        Model config with Keras version information added.
    """
    from tensorflow.contrib.keras.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top

    config = self.get_config()
    model_config = {
        'class_name': self.__class__.__name__,
        'config': config,
        'keras_version': keras_version,
        'backend': K.backend()
    }
    return model_config

  def to_json(self, **kwargs):
    """Returns a JSON string containing the network configuration.

    To load a network from a JSON save file, use
    `keras.models.model_from_json(json_string, custom_objects={})`.

    Arguments:
        **kwargs: Additional keyword arguments
            to be passed to `json.dumps()`.

    Returns:
        A JSON string.
    """

    def get_json_type(obj):
      # If obj is any numpy type
      if type(obj).__module__ == np.__name__:
        return obj.item()

      # If obj is a python 'type'
      if type(obj).__name__ == type.__name__:
        return obj.__name__

      raise TypeError('Not JSON Serializable:', obj)

    model_config = self._updated_config()
    return json.dumps(model_config, default=get_json_type, **kwargs)

  def to_yaml(self, **kwargs):
    """Returns a yaml string containing the network configuration.

    To load a network from a yaml save file, use
    `keras.models.model_from_yaml(yaml_string, custom_objects={})`.

    `custom_objects` should be a dictionary mapping
    the names of custom losses / layers / etc to the corresponding
    functions / classes.

    Arguments:
        **kwargs: Additional keyword arguments
            to be passed to `yaml.dump()`.

    Returns:
        A YAML string.

    Raises:
        ImportError: if yaml module is not found.
    """
    if yaml is None:
      raise ImportError('Requires yaml module installed.')
    return yaml.dump(self._updated_config(), **kwargs)

  def summary(self, line_length=None, positions=None):
    print_layer_summary(self, line_length=line_length, positions=positions)


def get_source_inputs(tensor, layer=None, node_index=None):
  """Returns the list of input tensors necessary to compute `tensor`.

  Output will always be a list of tensors
  (potentially with 1 element).

  Arguments:
      tensor: The tensor to start from.
      layer: Origin layer of the tensor. Will be
          determined via tensor._keras_history if not provided.
      node_index: Origin node index of the tensor.

  Returns:
      List of input tensors.
  """
  if not hasattr(tensor, '_keras_history'):
    return tensor

  if layer is None or node_index:
    layer, node_index, _ = tensor._keras_history
  if not layer.inbound_nodes:
    return [tensor]
  else:
    node = layer.inbound_nodes[node_index]
    if not node.inbound_layers:
      # Reached an Input layer, stop recursion.
      return node.input_tensors
    else:
      source_tensors = []
      for i in range(len(node.inbound_layers)):
        x = node.input_tensors[i]
        layer = node.inbound_layers[i]
        node_index = node.node_indices[i]
        previous_sources = get_source_inputs(x, layer, node_index)
        # Avoid input redundancy.
        for x in previous_sources:
          if x not in source_tensors:
            source_tensors.append(x)
      return source_tensors


def _to_list(x):
  """Normalizes a list/tensor into a list.

  If a tensor is passed, we return
  a list of size 1 containing the tensor.

  Arguments:
      x: target object to be normalized.

  Returns:
      A list.
  """
  if isinstance(x, list):
    return x
  return [x]


def _object_list_uid(object_list):
  object_list = _to_list(object_list)
  return ', '.join([str(abs(id(x))) for x in object_list])


def _is_all_none(iterable_or_element):
  if not isinstance(iterable_or_element, (list, tuple)):
    iterable = [iterable_or_element]
  else:
    iterable = iterable_or_element
  for element in iterable:
    if element is not None:
      return False
  return True


def _collect_previous_mask(input_tensors):
  """Retrieves the output mask(s) of the previous node.

  Arguments:
      input_tensors: A tensor or list of tensors.

  Returns:
      A mask tensor or list of mask tensors.
  """
  input_tensors = _to_list(input_tensors)
  masks = []
  for x in input_tensors:
    if hasattr(x, '_keras_history'):
      inbound_layer, node_index, tensor_index = x._keras_history
      node = inbound_layer.inbound_nodes[node_index]
      mask = node.output_masks[tensor_index]
      masks.append(mask)
    else:
      masks.append(None)
  if len(masks) == 1:
    return masks[0]
  return masks


def _to_snake_case(name):
  intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
  insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
  # If the class is private the name starts with "_" which is not secure
  # for creating scopes. We prefix the name with "private" in this case.
  if insecure[0] != '_':
    return insecure
  return 'private' + insecure


def _collect_input_shape(input_tensors):
  """Collects the output shape(s) of a list of Keras tensors.

  Arguments:
      input_tensors: list of input tensors (or single input tensor).

  Returns:
      List of shape tuples (or single tuple), one tuple per input.
  """
  input_tensors = _to_list(input_tensors)
  shapes = []
  for x in input_tensors:
    shapes.append(K.int_shape(x))
  if len(shapes) == 1:
    return shapes[0]
  return shapes


def save_weights_to_hdf5_group(f, layers):
  from tensorflow.contrib.keras.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top

  f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in layers]
  f.attrs['backend'] = K.backend().encode('utf8')
  f.attrs['keras_version'] = str(keras_version).encode('utf8')

  for layer in layers:
    g = f.create_group(layer.name)
    symbolic_weights = layer.weights
    weight_values = K.batch_get_value(symbolic_weights)
    weight_names = []
    for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)):
      if hasattr(w, 'name') and w.name:
        name = str(w.name)
      else:
        name = 'param_' + str(i)
      weight_names.append(name.encode('utf8'))
    g.attrs['weight_names'] = weight_names
    for name, val in zip(weight_names, weight_values):
      param_dset = g.create_dataset(name, val.shape, dtype=val.dtype)
      if not val.shape:
        # scalar
        param_dset[()] = val
      else:
        param_dset[:] = val


def preprocess_weights_for_loading(layer,
                                   weights,
                                   original_keras_version=None,
                                   original_backend=None):
  """Converts layers weights from Keras 1 format to Keras 2.

  Arguments:
      layer: Layer instance.
      weights: List of weights values (Numpy arrays).
      original_keras_version: Keras version for the weights, as a string.
      original_backend: Keras backend the weights were trained with,
          as a string.

  Returns:
      A list of weights values (Numpy arrays).
  """
  if original_keras_version == '1':
    if layer.__class__.__name__ == 'Conv1D':
      shape = weights[0].shape
      # Handle Keras 1.1 format
      if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters:
        # Legacy shape:
        # (filters, input_dim, filter_length, 1)
        assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0],
                                                           1)
        weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
      weights[0] = weights[0][:, 0, :, :]

    if layer.__class__.__name__ == 'Conv2D':
      if layer.data_format == 'channels_first':
        # old: (filters, stack_size, kernel_rows, kernel_cols)
        # new: (kernel_rows, kernel_cols, stack_size, filters)
        weights[0] = np.transpose(weights[0], (2, 3, 1, 0))

    if layer.__class__.__name__ == 'Conv2DTranspose':
      if layer.data_format == 'channels_last':
        # old: (kernel_rows, kernel_cols, stack_size, filters)
        # new: (kernel_rows, kernel_cols, filters, stack_size)
        weights[0] = np.transpose(weights[0], (0, 1, 3, 2))
      if layer.data_format == 'channels_first':
        # old: (filters, stack_size, kernel_rows, kernel_cols)
        # new: (kernel_rows, kernel_cols, filters, stack_size)
        weights[0] = np.transpose(weights[0], (2, 3, 0, 1))

    if layer.__class__.__name__ == 'Conv3D':
      if layer.data_format == 'channels_first':
        # old: (filters, stack_size, ...)
        # new: (..., stack_size, filters)
        weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0))

    if layer.__class__.__name__ == 'GRU':
      if len(weights) == 9:
        kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1)
        recurrent_kernel = np.concatenate(
            [weights[1], weights[4], weights[7]], axis=-1)
        bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1)
        weights = [kernel, recurrent_kernel, bias]

    if layer.__class__.__name__ == 'LSTM':
      if len(weights) == 12:
        # old: i, c, f, o
        # new: i, f, c, o
        kernel = np.concatenate(
            [weights[0], weights[6], weights[3], weights[9]], axis=-1)
        recurrent_kernel = np.concatenate(
            [weights[1], weights[7], weights[4], weights[10]], axis=-1)
        bias = np.concatenate(
            [weights[2], weights[8], weights[5], weights[11]], axis=-1)
        weights = [kernel, recurrent_kernel, bias]

    if layer.__class__.__name__ == 'ConvLSTM2D':
      if len(weights) == 12:
        kernel = np.concatenate(
            [weights[0], weights[6], weights[3], weights[9]], axis=-1)
        recurrent_kernel = np.concatenate(
            [weights[1], weights[7], weights[4], weights[10]], axis=-1)
        bias = np.concatenate(
            [weights[2], weights[8], weights[5], weights[11]], axis=-1)
        if layer.data_format == 'channels_first':
          # old: (filters, stack_size, kernel_rows, kernel_cols)
          # new: (kernel_rows, kernel_cols, stack_size, filters)
          kernel = np.transpose(kernel, (2, 3, 1, 0))
          recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0))
        weights = [kernel, recurrent_kernel, bias]

  if original_backend and K.backend() != original_backend:
    conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose']
    if layer.__class__.__name__ in conv_layers:
      weights[0] = conv_utils.convert_kernel(weights[0])
    if layer.__class__.__name__ == 'ConvLSTM2D':
      weights[0] = conv_utils.convert_kernel(weights[0])
      weights[1] = conv_utils.convert_kernel(weights[1])
  return weights


def load_weights_from_hdf5_group(f, layers):
  """Implements topological (order-based) weight loading.

  Arguments:
      f: A pointer to a HDF5 group.
      layers: a list of target layers.

  Raises:
      ValueError: in case of mismatch between provided layers
          and weights file.
  """
  if 'keras_version' in f.attrs:
    original_keras_version = f.attrs['keras_version'].decode('utf8')
  else:
    original_keras_version = '1'
  if 'backend' in f.attrs:
    original_backend = f.attrs['backend'].decode('utf8')
  else:
    original_backend = None

  filtered_layers = []
  for layer in layers:
    weights = layer.weights
    if weights:
      filtered_layers.append(layer)

  layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
  filtered_layer_names = []
  for name in layer_names:
    g = f[name]
    weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
    if weight_names:
      filtered_layer_names.append(name)
  layer_names = filtered_layer_names
  if len(layer_names) != len(filtered_layers):
    raise ValueError('You are trying to load a weight file '
                     'containing ' + str(len(layer_names)) +
                     ' layers into a model with ' + str(len(filtered_layers)) +
                     ' layers.')

  # We batch weight value assignments in a single backend call
  # which provides a speedup in TensorFlow.
  weight_value_tuples = []
  for k, name in enumerate(layer_names):
    g = f[name]
    weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
    weight_values = [g[weight_name] for weight_name in weight_names]
    layer = filtered_layers[k]
    symbolic_weights = layer.weights
    weight_values = preprocess_weights_for_loading(
        layer, weight_values, original_keras_version, original_backend)
    if len(weight_values) != len(symbolic_weights):
      raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
                       '" in the current model) was found to '
                       'correspond to layer ' + name + ' in the save file. '
                       'However the new layer ' + layer.name + ' expects ' +
                       str(len(symbolic_weights)) +
                       ' weights, but the saved weights have ' +
                       str(len(weight_values)) + ' elements.')
    weight_value_tuples += zip(symbolic_weights, weight_values)
  K.batch_set_value(weight_value_tuples)


def load_weights_from_hdf5_group_by_name(f, layers):
  """Implements name-based weight loading.

  (instead of topological weight loading).

  Layers that have no matching name are skipped.

  Arguments:
      f: A pointer to a HDF5 group.
      layers: a list of target layers.

  Raises:
      ValueError: in case of mismatch between provided layers
          and weights file.
  """
  if 'keras_version' in f.attrs:
    original_keras_version = f.attrs['keras_version'].decode('utf8')
  else:
    original_keras_version = '1'
  if 'backend' in f.attrs:
    original_backend = f.attrs['backend'].decode('utf8')
  else:
    original_backend = None

  # New file format.
  layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]

  # Reverse index of layer name to list of layers with name.
  index = {}
  for layer in layers:
    if layer.name:
      index.setdefault(layer.name, []).append(layer)

  # We batch weight value assignments in a single backend call
  # which provides a speedup in TensorFlow.
  weight_value_tuples = []
  for k, name in enumerate(layer_names):
    g = f[name]
    weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
    weight_values = [g[weight_name] for weight_name in weight_names]

    for layer in index.get(name, []):
      symbolic_weights = layer.weights
      weight_values = preprocess_weights_for_loading(
          layer, weight_values, original_keras_version, original_backend)
      if len(weight_values) != len(symbolic_weights):
        raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
                         '") expects ' + str(len(symbolic_weights)) +
                         ' weight(s), but the saved weights' + ' have ' +
                         str(len(weight_values)) + ' element(s).')
      # Set values.
      for i in range(len(weight_values)):
        weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
  K.batch_set_value(weight_value_tuples)