aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/base.py
blob: db608aa79affa36db8d2f52ec2c4663bcf448832 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

# pylint: disable=unused-import,g-bad-import-order
"""Contains the base Layer class, from which all layers inherit."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import copy
import re
import weakref

import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest


class Layer(object):
  """Base layer class.

  This is the class from which all layers inherit, implementing common
  infrastructure functionality.

  A layer is a class implementing common neural networks operations, such
  as convolution, batch norm, etc. These operations require managing variables,
  losses, and updates, as well as applying TensorFlow ops to input tensors.

  Users will just instantiate it and then treat it as a callable.

  We recommend that descendants of Layer implement the following methods:
  * `__init__()`: Save configuration in member variables
  * `build()`: Called once from `__call__`, when we know the shapes of inputs
    and `dtype`. Should have the calls to `add_variable()`, and then
    call the super's `build()` (which sets `self.built = True`, which is
    nice in case the user wants to call `build()` manually before the
    first `__call__`).
  * `call()`: Called in `__call__` after making sure `build()` has been called
    once. Should actually perform the logic of applying the layer to the
    input tensors (which should be passed in as the first argument).

  Read-only properties:
    `name`: The name of the layer (string).
    `dtype`: Default dtype of the layer (default of `None` means use the
      type of the first input).
    `trainable_variables`: List of trainable variables.
    `non_trainable_variables`: List of non-trainable variables.
    `variables`: List of all variables of this layer, trainable and
      non-trainable.
    `updates`: List of update ops of this layer.
    `losses`: List of losses added by this layer.

  Mutable properties:
    `trainable`: Whether the layer should be trained (boolean).
    `input_spec`: Optional (list of) `InputSpec` object(s) specifying the
      constraints on inputs that can be accepted by the layer.
  """

  def __init__(self, trainable=True, name=None, dtype=None,
               activity_regularizer=None, **kwargs):
    # We use a kwargs dict here because these kwargs only exist
    # for compatibility reasons.
    # The list of kwargs is subject to changes in the future.
    # We do not want to commit to it or to expose the list to users at all.
    # Note this is exactly as safe as defining kwargs in the function signature,
    # the only difference being that the list of valid kwargs is defined
    # below rather rather in the signature, and default values are defined
    # in calls to kwargs.get().
    allowed_kwargs = {
        '_scope',
        '_reuse',
        'input_shape',  # For compatibility with Keras `Sequential` model.
        'batch_size',  # For compatibility with Keras `Sequential` model.
    }
    for kwarg in kwargs:
      if kwarg not in allowed_kwargs:
        raise TypeError('Keyword argument not understood:', kwarg)

    # Mutable properties
    self.trainable = trainable
    self.built = False
    self.input_spec = None

    self._activity_regularizer = activity_regularizer
    self._trainable_weights = []
    self._non_trainable_weights = []
    self._updates = []
    self._losses = []
    self._reuse = kwargs.get('_reuse')
    self._graph = ops.get_default_graph()
    self._per_input_losses = {}
    self._per_input_updates = {}
    self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
    call_fn_args = estimator_util.fn_args(self.call)
    self._compute_previous_mask = ('mask' in call_fn_args or
                                   hasattr(self, 'compute_mask'))
    self._call_has_scope_arg = 'scope' in call_fn_args

    # These lists will be filled via successive calls
    # to self._add_inbound_node().
    self._inbound_nodes = []
    self._outbound_nodes = []

    self._init_set_name(name)

    # Determine variable scope.
    scope = kwargs.get('_scope')
    if scope:
      with vs.variable_scope(scope) as captured_scope:
        self._scope = captured_scope
    else:
      self._scope = None

    # Set `_batch_input_shape` attribute
    # for compatibility with Keras `Sequential` model.
    if 'input_shape' in kwargs:
      batch_size = kwargs.get('batch_size')
      self._batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])

  def _init_set_name(self, name):
    # Determine layer name (non-unique).
    if isinstance(name, vs.VariableScope):
      base_name = name.name
    else:
      base_name = name
      self._name = name
    if not name:
      self._name, base_name = self._make_unique_name()
    self._base_name = base_name

  @property
  def dtype(self):
    return self._dtype

  @property
  def name(self):
    return self._name

  @property
  def activity_regularizer(self):
    """Optional regularizer function for the output of this layer."""
    return self._activity_regularizer

  @property
  def scope_name(self):
    if not self._scope:
      raise ValueError('No name available for layer scope because the layer "' +
                       self._name + '" has not been used yet. The scope name ' +
                       ' is determined the first time the layer instance is ' +
                       'called. You must therefore call the layer before ' +
                       'querying `scope_name`.')
    return self._scope.name

  @property
  def trainable_weights(self):
    return self._trainable_weights if self.trainable else []

  @property
  def non_trainable_weights(self):
    if self.trainable:
      return self._non_trainable_weights
    else:
      return self._trainable_weights + self._non_trainable_weights

  @property
  def trainable_variables(self):
    return self.trainable_weights

  @property
  def non_trainable_variables(self):
    return self.non_trainable_weights

  @property
  def weights(self):
    """Returns the list of all layer variables/weights.

    Returns:
      A list of variables.
    """
    return self.trainable_weights + self.non_trainable_weights

  @property
  def variables(self):
    """Returns the list of all layer variables/weights.

    Returns:
      A list of variables.
    """
    return self.weights

  @property
  def updates(self):
    if context.in_eager_mode():
      raise RuntimeError('Layer.updates not supported in Eager mode.')
    return self._updates

  def add_update(self, updates, inputs=None):
    """Add update op(s), potentially dependent on layer inputs.

    Weight updates (for instance, the updates of the moving mean and variance
    in a BatchNormalization layer) may be dependent on the inputs passed
    when calling a layer. Hence, when reusing the same layer on
    different inputs `a` and `b`, some entries in `layer.updates` may be
    dependent on `a` and some on `b`. This method automatically keeps track
    of dependencies.

    The `get_updates_for` method allows to retrieve the updates relevant to a
    specific set of inputs.

    This call is ignored in Eager mode.

    Arguments:
      updates: Update op, or list/tuple of update ops.
      inputs: Optional input tensor(s) that the update(s) depend on. Must
        match the `inputs` argument passed to the `__call__` method at the time
        the updates are created. If `None` is passed, the updates are assumed
        to be unconditional, and will apply across all dataflows of the layer.
    """
    if context.in_eager_mode():
      return  # Updates already applied when in eager mode.
    updates = _to_list(updates)
    if not updates:
      return
    self._updates += updates
    if inputs is not None:
      inputs = nest.flatten(inputs)
    if not inputs:
      inputs = None
    if inputs is not None:
      # We compute an ID that uniquely identifies the list of tensors.
      # This ID is order-sensitive.
      inputs_hash = _object_list_uid(inputs)
    else:
      inputs_hash = None
    if inputs_hash not in self._per_input_updates:
      self._per_input_updates[inputs_hash] = []
    self._per_input_updates[inputs_hash] += updates

  def get_updates_for(self, inputs):
    """Retrieves updates relevant to a specific set of inputs.

    Arguments:
      inputs: Input tensor or list/tuple of input tensors.
        Must match the `inputs` argument passed to the `__call__` method
        at the time the updates were created.
        If you pass `inputs=None`, unconditional updates are returned.

    Returns:
      List of update ops of the layer that depend on `inputs`.

    Raises:
      RuntimeError: If called in Eager mode.
    """
    if context.in_eager_mode():
      raise RuntimeError('Layer.get_updates_for not supported in Eager mode.')
    if inputs is not None:
      inputs = nest.flatten(inputs)
    if not inputs:
      inputs = None
    if inputs is not None:
      inputs_hash = _object_list_uid(inputs)
    else:
      inputs_hash = None
    return self._per_input_updates.get(inputs_hash, [])

  @property
  def losses(self):
    if context.in_eager_mode():
      raise RuntimeError('Layer.losses not supported in Eager mode.')
    return self._losses

  def add_loss(self, losses, inputs=None):
    """Add loss tensor(s), potentially dependent on layer inputs.

    Some losses (for instance, activity regularization losses) may be dependent
    on the inputs passed when calling a layer. Hence, when reusing the same
    layer on different inputs `a` and `b`, some entries in `layer.losses` may
    be dependent on `a` and some on `b`. This method automatically keeps track
    of dependencies.

    The `get_losses_for` method allows to retrieve the losses relevant to a
    specific set of inputs.

    Arguments:
      losses: Loss tensor, or list/tuple of tensors.
      inputs: Optional input tensor(s) that the loss(es) depend on. Must
        match the `inputs` argument passed to the `__call__` method at the time
        the losses are created. If `None` is passed, the losses are assumed
        to be unconditional, and will apply across all dataflows of the layer
        (e.g. weight regularization losses).

    Raises:
      RuntimeError: If called in Eager mode.
    """
    if context.in_eager_mode():
      raise RuntimeError('Layer.add_loss not supported in Eager mode.')
    losses = _to_list(losses)
    if not losses:
      return
    self._losses += losses
    if inputs is not None:
      inputs = nest.flatten(inputs)
    if not inputs:
      inputs = None
    if inputs is not None:
      # We compute an ID that uniquely identifies the list of tensors.
      # This ID is order-sensitive.
      inputs_hash = _object_list_uid(inputs)
    else:
      inputs_hash = None
    if inputs_hash not in self._per_input_losses:
      self._per_input_losses[inputs_hash] = []
    self._per_input_losses[inputs_hash] += losses
    _add_elements_to_collection(losses, ops.GraphKeys.REGULARIZATION_LOSSES)

  def get_losses_for(self, inputs):
    """Retrieves losses relevant to a specific set of inputs.

    Arguments:
      inputs: Input tensor or list/tuple of input tensors.
        Must match the `inputs` argument passed to the `__call__`
        method at the time the losses were created.
        If you pass `inputs=None`, unconditional losses are returned,
        such as weight regularization losses.

    Returns:
      List of loss tensors of the layer that depend on `inputs`.

    Raises:
      RuntimeError: If called in Eager mode.
    """
    if context.in_eager_mode():
      raise RuntimeError('Layer.get_losses_for not supported in Eager mode.')
    if inputs is not None:
      inputs = nest.flatten(inputs)
    if not inputs:
      inputs = None
    if inputs is not None:
      inputs_hash = _object_list_uid(inputs)
    else:
      inputs_hash = None
    return self._per_input_losses.get(inputs_hash, [])

  def build(self, _):
    """Creates the variables of the layer."""
    self.built = True

  def call(self, inputs, **kwargs):  # pylint: disable=unused-argument
    """The logic of the layer lives here.

    Arguments:
      inputs: input tensor(s).
      **kwargs: additional keyword arguments.

    Returns:
      Output tensor(s).
    """
    return inputs

  def _compute_output_shape(self, input_shape):
    """Computes the output shape of the layer given the input shape.

    Assumes that the layer will be built to match that input shape.
    If this method is not implemented by child classes, the default
    assumption will be that the layer does not alter the shape of the tensors
    passing through it.

    Args:
      input_shape: A (possibly nested tuple of) `TensorShape`.  It need not
        be fully defined (e.g. the batch size may be unknown).

    Returns:
      A (possibly nested tuple of) `TensorShape`.

    Raises:
      TypeError: if `input_shape` is not a (possibly nested tuple of)
        `TensorShape`.
      ValueError: if `input_shape` is incomplete or is incompatible with the
        the layer.
    """
    return input_shape

  def _make_unique_name(self, name_uid_map=None, avoid_names=None):
    base_name = _to_snake_case(self.__class__.__name__)
    name = _unique_layer_name(base_name, name_uid_map=name_uid_map,
                              avoid_names=avoid_names)
    return (name, base_name)

  def _set_scope(self, scope=None):
    if self._scope is None:
      # If constructed with _scope=None, lazy setting of scope.
      if self._reuse:
        with vs.variable_scope(
            scope if scope is not None else self._base_name) as captured_scope:
          self._scope = captured_scope
      else:
        with vs.variable_scope(
            scope, default_name=self._base_name) as captured_scope:
          self._scope = captured_scope

  def add_variable(self, name, shape, dtype=None,
                   initializer=None, regularizer=None,
                   trainable=True, constraint=None,
                   partitioner=None):
    """Adds a new variable to the layer, or gets an existing one; returns it.

    Arguments:
      name: variable name.
      shape: variable shape.
      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
      initializer: initializer instance (callable).
      regularizer: regularizer instance (callable).
      trainable: whether the variable should be part of the layer's
        "trainable_variables" (e.g. variables, biases)
        or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
        Note, if the current variable scope is marked as non-trainable
        then this parameter is ignored and any added variables are also
        marked as non-trainable.
      constraint: constraint instance (callable).
      partitioner: (optional) partitioner instance (callable).  If
        provided, when the requested variable is created it will be split
        into multiple partitions according to `partitioner`.  In this case,
        an instance of `PartitionedVariable` is returned.  Available
        partitioners include `tf.fixed_size_partitioner` and
        `tf.variable_axis_size_partitioner`.  For more details, see the
        documentation of `tf.get_variable` and the  "Variable Partitioners
        and Sharding" section of the API guide.

    Returns:
      The created variable.  Usually either a `Variable` or `ResourceVariable`
      instance.  If `partitioner` is not `None`, a `PartitionedVariable`
      instance is returned.

    Raises:
      RuntimeError: If called in Eager mode with regularizers.
    """
    # Note that we currently don't support variable regularization in Eager
    # mode. An alternative is for users to directly compute these losses before
    # performing a backward pass.
    if context.in_graph_mode():
      existing_variables = set(tf_variables.global_variables())
    else:
      existing_variables = []
      if regularizer is not None:
        raise RuntimeError('Variable regularization not supported in Eager '
                           'mode.')
    if dtype is None:
      dtype = self.dtype or dtypes.float32

    self._set_scope(None)
    with vs.variable_scope(
        self._scope, reuse=(self.built or self._reuse)) as scope:
      with ops.name_scope(scope.original_name_scope):
        variable = vs.get_variable(name,
                                   shape=shape,
                                   initializer=initializer,
                                   dtype=dtypes.as_dtype(dtype),
                                   constraint=constraint,
                                   trainable=trainable and self.trainable,
                                   partitioner=partitioner)
        if (context.in_graph_mode() and trainable and self.trainable
            and variable not in tf_variables.trainable_variables()):
          # A custom getter / variable scope overrode the trainable flag.
          trainable = False
        if variable in existing_variables:
          return variable
        if regularizer:
          # To match the behavior of tf.get_variable(), we only
          # apply regularization if the variable is newly created.
          if isinstance(variable, tf_variables.PartitionedVariable):
            for v in variable:
              with ops.colocate_with(v.op):
                with ops.name_scope(name + '/Regularizer'):
                  regularization = regularizer(v)
              if regularization is not None:
                self.add_loss(regularization)
          else:
            with ops.colocate_with(variable.op):
              with ops.name_scope(name + '/Regularizer'):
                regularization = regularizer(variable)
            if regularization is not None:
              self.add_loss(regularization)
    if trainable:
      self._trainable_weights.append(variable)
    else:
      self._non_trainable_weights.append(variable)
    return variable

  def __call__(self, inputs, *args, **kwargs):
    """Wraps `call`, applying pre- and post-processing steps.

    Arguments:
      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.

    Returns:
      Output tensor(s).

    Note:
      - If the layer's `call` method takes a `scope` keyword argument,
        this argument will be automatically set to the current variable scope.
      - If the layer's `call` method takes a `mask` argument (as some Keras
        layers do), its default value will be set to the mask generated
        for `inputs` by the previous layer (if `input` did come from
        a layer that generated a corresponding mask, i.e. if it came from
        a Keras layer with masking support.

    Raises:
      ValueError: if the layer's `call` method returns None (an invalid value).
    """
    self._set_scope(kwargs.pop('scope', None))
    input_list = nest.flatten(inputs)

    in_graph_mode = context.in_graph_mode()
    in_deferred_mode = isinstance(input_list[0], _DeferredTensor)
    # Ensure the Layer, if being reused, is working with inputs from
    # the same graph as where it was created.
    if in_graph_mode:
      try:
        ops._get_graph_from_inputs(input_list, graph=self.graph)  # pylint: disable=protected-access
      except ValueError as e:
        raise ValueError('Input graph and Layer graph are not the same: %s' % e)
    if in_graph_mode or in_deferred_mode:
      user_kwargs = copy.copy(kwargs)

    # Handle Keras mask propagation from previous layer to current layer.
    previous_mask = None
    if (not hasattr(self, '_compute_previous_mask') or
        self._compute_previous_mask):
      previous_mask = _collect_previous_mask(inputs)
      if ('mask' in estimator_util.fn_args(self.call) and
          'mask' not in kwargs and
          not _is_all_none(previous_mask)):
        # The previous layer generated a mask, and mask was not explicitly pass
        # to __call__, hence we set previous_mask as the default value.
        kwargs['mask'] = previous_mask

    if self.built:
      try:
        # Some classes which inherit from Layer do not use its constructor, so
        # rather than initializing to None we check for an AttributeError.
        scope_context_manager = self._always_reuse_variable_scope
      except AttributeError:
        # From this point we will always set reuse=True, so create a "final"
        # variable scope with this setting. We avoid re-creating variable scopes
        # after this point as an optimization.
        self._always_reuse_variable_scope = vs.variable_scope(
            self._scope, reuse=True)
        scope_context_manager = self._always_reuse_variable_scope
    else:
      scope_context_manager = vs.variable_scope(
          self._scope, reuse=self._reuse)
    with scope_context_manager as scope:
      with ops.name_scope(scope.original_name_scope):
        if not self.built:
          if not in_graph_mode:
            # Activity regularization is currently unsupported in Eager mode.
            if self._activity_regularizer:
              raise ValueError('activity_regularizer currently unsupported in '
                               'Eager mode. Found an activity_regularizer in '
                               '%s(%s).' % (self.__class__.__name__, self))
          if not in_graph_mode and not in_deferred_mode:
            # TODO(agarwal): support _keras_history in Eager mode.
            for x in input_list:
              if hasattr(x, '_keras_history'):
                raise ValueError('_keras_history currently unsupported in '
                                 'Eager mode. Found _keras_history in %s while '
                                 'executing __call__ for %s(%s)' %
                                 (x, self.__class_.__name__, self))

          # Check input assumptions set before layer building, e.g. input rank.
          self._assert_input_compatibility(inputs)
          if input_list and self._dtype is None:
            try:
              self._dtype = input_list[0].dtype.name
            except AttributeError:
              pass
          input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
          self.build(input_shapes)
        try:
          # Note: not all sub-classes of Layer call Layer.__init__ (especially
          # the ones under tensorflow/python/keras). Hence we recompute this
          # attribute here if it is not set.
          # TODO(agarwal): Fix the sub-classes and avoid this complexity.
          call_has_scope_arg = self._call_has_scope_arg
        except AttributeError:
          call_has_scope_arg = 'scope' in estimator_util.fn_args(self.call)
        if call_has_scope_arg:
          kwargs['scope'] = scope
        # Check input assumptions set after layer building, e.g. input shape.
        if in_graph_mode or in_deferred_mode:
          self._assert_input_compatibility(inputs)

        if not in_deferred_mode:
          outputs = self.call(inputs, *args, **kwargs)
          if outputs is None:
            raise ValueError('A layer\'s `call` method should return a Tensor '
                             'or a list of Tensors, not None.')
        else:
          # Deferred mode behavior: use `_compute_output_shape` to
          # infer the number of outputs of the layer and their shapes.
          output_shapes = self._compute_output_shape(input_shapes)
          output_shapes = nest.flatten(output_shapes)
          outputs = [
              # TODO(fchollet): name the deferred tensors?
              _DeferredTensor(shape=shape, dtype=self._dtype)
              for shape in output_shapes
          ]
          if len(outputs) == 1:
            outputs = outputs[0]

        if in_graph_mode:
          # Apply activity regularization.
          # Note that it should be applied every time the layer creates a new
          # output, since it is output-specific.
          if self._activity_regularizer:
            output_list = nest.flatten(outputs)
            for output in output_list:
              with ops.name_scope('ActivityRegularizer'):
                activity_regularization = self._activity_regularizer(output)
              self.add_loss(activity_regularization)

        if not in_deferred_mode:
          # TODO(fchollet): consider how masking will work with deferred mode.
          # Handle mask computation and propagation to the next layer.
          if hasattr(self, 'compute_mask'):
            output_mask = self.compute_mask(inputs, previous_mask)
            if isinstance(outputs, list):
              if output_mask is None:
                output_mask = [None for _ in range(len(outputs))]
              for x, m in zip(outputs, output_mask):
                x._keras_mask = m  # pylint: disable=protected-access
            else:
              outputs._keras_mask = output_mask  # pylint: disable=protected-access

    if in_graph_mode:
      # If all input tensors have history metadata,
      # we update the output tensors
      # with corresponding history metadata, thus eventually allowing to use
      # these tensors to instantiate a Network.
      if _have_all_keras_metadata(inputs):
        # If the layer returns tensors from its inputs, unmodified,
        # we copy them to avoid loss of tensor metadata.
        output_ls = nest.flatten(outputs)
        output_ls_copy = []
        for x in output_ls:
          if x in input_list:
            with ops.name_scope(scope.original_name_scope):
              x = array_ops.identity(x)
          output_ls_copy.append(x)
        if len(output_ls_copy) == 1:
          outputs = output_ls_copy[0]
        else:
          outputs = output_ls_copy

      # Update global default collections.
      _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)

    if in_deferred_mode or in_graph_mode:
      if _have_all_keras_metadata(inputs):
        # Add an inbound node to the layer, so it can keep track of this call.
        # This updates the layer history of the output tensor(s).
        self._add_inbound_node(
            input_tensors=inputs, output_tensors=outputs, arguments=user_kwargs)

    self.built = True
    return outputs

  @property
  def graph(self):
    if context.in_eager_mode():
      raise RuntimeError('Layer.graph not supported in Eager mode.')
    return self._graph

  def __deepcopy__(self, memo):
    no_copy = set(['_graph'])
    shallow_copy = set(['_scope', '_always_reuse_variable_scope'])
    cls = self.__class__
    result = cls.__new__(cls)
    memo[id(self)] = result
    for k, v in self.__dict__.items():
      if k in no_copy:
        setattr(result, k, v)
      elif k in shallow_copy:
        setattr(result, k, copy.copy(v))
      elif _is_tensor_or_tensor_list(v):
        setattr(result, k, v)
      else:
        setattr(result, k, copy.deepcopy(v, memo))
    return result

  def apply(self, inputs, *args, **kwargs):
    """Apply the layer on a input.

    This simply wraps `self.__call__`.

    Arguments:
      inputs: Input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.

    Returns:
      Output tensor(s).
    """
    return self.__call__(inputs, *args, **kwargs)

  def _add_inbound_node(self,
                        input_tensors,
                        output_tensors,
                        arguments=None):
    """Internal method to create an inbound node for the layer.

    Arguments:
        input_tensors: list of input tensors.
        output_tensors: list of output tensors.
        arguments: dictionary of keyword arguments that were passed to the
            `call` method of the layer at the call that created the node.
    """
    input_tensors = nest.flatten(input_tensors)
    output_tensors = nest.flatten(output_tensors)

    # Collect input tensor(s) coordinates.
    inbound_layers = []
    node_indices = []
    tensor_indices = []
    for x in input_tensors:
      assert hasattr(x, '_keras_history')
      inbound_layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
      inbound_layers.append(inbound_layer)
      node_indices.append(node_index)
      tensor_indices.append(tensor_index)

    # Create node, add it to inbound nodes.
    Node(
        self,
        inbound_layers=inbound_layers,
        node_indices=node_indices,
        tensor_indices=tensor_indices,
        input_tensors=input_tensors,
        output_tensors=output_tensors,
        arguments=arguments)

    # Update tensor history metadata.
    for i in range(len(output_tensors)):
      # The metadata attribute consists of 1) a layer instance
      # 2) a node index for the layer, 3) a tensor index for the node.
      # The allows layer reuse (multiple nodes per layer) and multi-output
      # or multi-input layers (e.g. a layer can return multiple tensors,
      # and each can be sent to a different layer).
      output_tensors[i]._keras_history = (self, len(self._inbound_nodes) - 1, i)  # pylint: disable=protected-access

  def _get_node_attribute_at_index(self, node_index, attr, attr_name):
    """Private utility to retrieves an attribute (e.g. inputs) from a node.

    This is used to implement the methods:
        - get_input_shape_at
        - get_output_shape_at
        - get_input_at
        etc...

    Arguments:
        node_index: Integer index of the node from which
            to retrieve the attribute.
        attr: Exact node attribute name.
        attr_name: Human-readable attribute name, for error messages.

    Returns:
        The layer's attribute `attr` at the node of index `node_index`.

    Raises:
        RuntimeError: If the layer has no inbound nodes, or if called in Eager
        mode.
        ValueError: If the index provided does not match any node.
    """
    assert context.in_graph_mode()
    if not self._inbound_nodes:
      raise RuntimeError('The layer has never been called '
                         'and thus has no defined ' + attr_name + '.')
    if not len(self._inbound_nodes) > node_index:
      raise ValueError('Asked to get ' + attr_name + ' at node ' +
                       str(node_index) + ', but the layer has only ' +
                       str(len(self._inbound_nodes)) + ' inbound nodes.')
    values = getattr(self._inbound_nodes[node_index], attr)
    if len(values) == 1:
      return values[0]
    else:
      return values

  def get_input_shape_at(self, node_index):
    """Retrieves the input shape(s) of a layer at a given node.

    Arguments:
        node_index: Integer, index of the node
            from which to retrieve the attribute.
            E.g. `node_index=0` will correspond to the
            first time the layer was called.

    Returns:
        A shape tuple
        (or list of shape tuples if the layer has multiple inputs).

    Raises:
      RuntimeError: If called in Eager mode.
    """
    if context.in_eager_mode():
      raise RuntimeError(
          'Layer.get_input_shape_at not supported in Eager mode.')
    return self._get_node_attribute_at_index(node_index, 'input_shapes',
                                             'input shape')

  def get_output_shape_at(self, node_index):
    """Retrieves the output shape(s) of a layer at a given node.

    Arguments:
        node_index: Integer, index of the node
            from which to retrieve the attribute.
            E.g. `node_index=0` will correspond to the
            first time the layer was called.

    Returns:
        A shape tuple
        (or list of shape tuples if the layer has multiple outputs).

    Raises:
      RuntimeError: If called in Eager mode.
    """
    if context.in_eager_mode():
      raise RuntimeError(
          'Layer.get_output_shape_at not supported in Eager mode.')
    return self._get_node_attribute_at_index(node_index, 'output_shapes',
                                             'output shape')

  def get_input_at(self, node_index):
    """Retrieves the input tensor(s) of a layer at a given node.

    Arguments:
        node_index: Integer, index of the node
            from which to retrieve the attribute.
            E.g. `node_index=0` will correspond to the
            first time the layer was called.

    Returns:
        A tensor (or list of tensors if the layer has multiple inputs).

    Raises:
      RuntimeError: If called in Eager mode.
    """
    if context.in_eager_mode():
      raise RuntimeError('Layer.get_input_at not supported in Eager mode.')
    return self._get_node_attribute_at_index(node_index, 'input_tensors',
                                             'input')

  def get_output_at(self, node_index):
    """Retrieves the output tensor(s) of a layer at a given node.

    Arguments:
        node_index: Integer, index of the node
            from which to retrieve the attribute.
            E.g. `node_index=0` will correspond to the
            first time the layer was called.

    Returns:
        A tensor (or list of tensors if the layer has multiple outputs).

    Raises:
      RuntimeError: If called in Eager mode.
    """
    if context.in_eager_mode():
      raise RuntimeError('Layer.get_output_at not supported in Eager mode.')
    return self._get_node_attribute_at_index(node_index, 'output_tensors',
                                             'output')

  @property
  def input(self):
    """Retrieves the input tensor(s) of a layer.

    Only applicable if the layer has exactly one input,
    i.e. if it is connected to one incoming layer.

    Returns:
        Input tensor or list of input tensors.

    Raises:
        AttributeError: if the layer is connected to
        more than one incoming layers.

    Raises:
      RuntimeError: If called in Eager mode.
      AttributeError: If no inbound nodes are found.
    """
    if context.in_eager_mode():
      raise RuntimeError('Layer.input not supported in Eager mode.')
    if not self._inbound_nodes:
      raise AttributeError('Layer ' + self.name +
                           ' is not connected, no input to return.')
    return self._get_node_attribute_at_index(0, 'input_tensors', 'input')

  @property
  def output(self):
    """Retrieves the output tensor(s) of a layer.

    Only applicable if the layer has exactly one output,
    i.e. if it is connected to one incoming layer.

    Returns:
      Output tensor or list of output tensors.

    Raises:
      AttributeError: if the layer is connected to more than one incoming
        layers.
      RuntimeError: if called in Eager mode.
    """
    if context.in_eager_mode():
      raise RuntimeError('Layer.output not supported in Eager mode.')
    if not self._inbound_nodes:
      raise AttributeError('Layer ' + self.name + ' has no inbound nodes.')
    return self._get_node_attribute_at_index(0, 'output_tensors', 'output')

  @property
  def input_shape(self):
    """Retrieves the input shape(s) of a layer.

    Only applicable if the layer has exactly one input,
    i.e. if it is connected to one incoming layer, or if all inputs
    have the same shape.

    Returns:
        Input shape, as an integer shape tuple
        (or list of shape tuples, one tuple per input tensor).

    Raises:
        AttributeError: if the layer has no defined input_shape.
        RuntimeError: if called in Eager mode.
    """
    if context.in_eager_mode():
      raise RuntimeError('Layer.input_shape not supported in Eager mode.')
    if not self._inbound_nodes:
      raise AttributeError('The layer has never been called '
                           'and thus has no defined input shape.')
    all_input_shapes = set(
        [str(node.input_shapes) for node in self._inbound_nodes])
    if len(all_input_shapes) == 1:
      input_shapes = self._inbound_nodes[0].input_shapes
      if len(input_shapes) == 1:
        return tuple(tensor_shape.TensorShape(input_shapes[0]).as_list())
      else:
        return [
            tuple(tensor_shape.TensorShape(shape).as_list())
            for shape in input_shapes
        ]
    else:
      raise AttributeError('The layer "' + str(self.name) +
                           ' has multiple inbound nodes, '
                           'with different input shapes. Hence '
                           'the notion of "input shape" is '
                           'ill-defined for the layer. '
                           'Use `get_input_shape_at(node_index)` '
                           'instead.')

  def count_params(self):
    """Count the total number of scalars composing the weights.

    Returns:
        An integer count.

    Raises:
        ValueError: if the layer isn't yet built
          (in which case its weights aren't yet defined).
    """
    if not self.built:
      if self.__class__.__name__ == 'Sequential':
        self.build()  # pylint: disable=no-value-for-parameter
      else:
        raise ValueError('You tried to call `count_params` on ' + self.name +
                         ', but the layer isn\'t built. '
                         'You can build it manually via: `' + self.name +
                         '.build(batch_input_shape)`.')
    weight_shapes = [w.get_shape().as_list() for w in self.weights]
    return int(sum([np.prod(w) for w in weight_shapes]))

  @property
  def output_shape(self):
    """Retrieves the output shape(s) of a layer.

    Only applicable if the layer has one output,
    or if all outputs have the same shape.

    Returns:
        Output shape, as an integer shape tuple
        (or list of shape tuples, one tuple per output tensor).

    Raises:
        AttributeError: if the layer has no defined output shape.
        RuntimeError: if called in Eager mode.
    """
    if context.in_eager_mode():
      raise RuntimeError('Layer.output_shape not supported in Eager mode.')
    if not self._inbound_nodes:
      raise AttributeError('The layer has never been called '
                           'and thus has no defined output shape.')
    all_output_shapes = set(
        [str(node.output_shapes) for node in self._inbound_nodes])
    if len(all_output_shapes) == 1:
      output_shapes = self._inbound_nodes[0].output_shapes
      if len(output_shapes) == 1:
        return tuple(tensor_shape.TensorShape(output_shapes[0]).as_list())
      else:
        return [
            tuple(tensor_shape.TensorShape(shape).as_list())
            for shape in output_shapes
        ]
    else:
      raise AttributeError('The layer "%s"'
                           ' has multiple inbound nodes, '
                           'with different output shapes. Hence '
                           'the notion of "output shape" is '
                           'ill-defined for the layer. '
                           'Use `get_output_shape_at(node_index)` '
                           'instead.' % self.name)

  @property
  def inbound_nodes(self):
    """Deprecated, do NOT use! Only for compatibility with external Keras."""
    return self._inbound_nodes

  @property
  def outbound_nodes(self):
    """Deprecated, do NOT use! Only for compatibility with external Keras."""
    return self._outbound_nodes

  def _assert_input_compatibility(self, inputs):
    """Checks compatibility between the layer and provided inputs.

    This checks that the tensor(s) `inputs` verify the input assumptions
    of the layer (if any). If not, a clear and actional exception gets raised.

    Arguments:
        inputs: input tensor or list of input tensors.

    Raises:
        ValueError: in case of mismatch between
            the provided inputs and the expectations of the layer.
    """
    if not self.input_spec:
      return
    if not isinstance(self.input_spec, (list, tuple)):
      input_spec = nest.flatten(self.input_spec)
    else:
      input_spec = self.input_spec
    inputs = nest.flatten(inputs)
    if len(inputs) != len(input_spec):
      raise ValueError('Layer ' + self.name + ' expects ' +
                       str(len(input_spec)) + ' inputs, '
                       'but it received ' + str(len(inputs)) +
                       ' input tensors. Inputs received: ' + str(inputs))
    for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
      if spec is None:
        continue

      if (spec.ndim is not None or
          spec.min_ndim is not None or
          spec.max_ndim is not None):
        if x.get_shape().ndims is None:
          raise ValueError('Input ' + str(input_index) + ' of layer ' +
                           self.name + ' is incompatible with the layer: '
                           'its rank is undefined, but the layer requires a '
                           'defined rank.')

      # Check ndim.
      if spec.ndim is not None:
        ndim = x.get_shape().ndims
        if ndim != spec.ndim:
          raise ValueError('Input ' + str(input_index) + ' of layer ' +
                           self.name + ' is incompatible with the layer: '
                           'expected ndim=' + str(spec.ndim) + ', found ndim=' +
                           str(ndim) + '. Full shape received: ' +
                           str(x.get_shape().as_list()))
      if spec.max_ndim is not None:
        ndim = x.get_shape().ndims
        if ndim is not None and ndim > spec.max_ndim:
          raise ValueError('Input ' + str(input_index) + ' of layer ' +
                           self.name + ' is incompatible with the layer: '
                           'expected max_ndim=' + str(spec.max_ndim) +
                           ', found ndim=' + str(ndim))
      if spec.min_ndim is not None:
        ndim = x.get_shape().ndims
        if ndim is not None and ndim < spec.min_ndim:
          raise ValueError('Input ' + str(input_index) + ' of layer ' +
                           self.name + ' is incompatible with the layer: '
                           ': expected min_ndim=' + str(spec.min_ndim) +
                           ', found ndim=' + str(ndim) +
                           '. Full shape received: ' +
                           str(x.get_shape().as_list()))
      # Check dtype.
      if spec.dtype is not None:
        if x.dtype != spec.dtype:
          raise ValueError('Input ' + str(input_index) + ' of layer ' +
                           self.name + ' is incompatible with the layer: '
                           'expected dtype=' + str(spec.dtype) +
                           ', found dtype=' + str(x.dtype))
      # Check specific shape axes.
      if spec.axes:
        shape = x.get_shape().as_list()
        if shape is not None:
          for axis, value in spec.axes.items():
            if hasattr(value, 'value'):
              value = value.value
            if value is not None and shape[int(axis)] not in {value, None}:
              raise ValueError(
                  'Input ' + str(input_index) + ' of layer ' + self.name + ' is'
                  ' incompatible with the layer: expected axis ' + str(axis) +
                  ' of input shape to have value ' + str(value) +
                  ' but received input with shape ' + str(shape))
      # Check shape.
      if spec.shape is not None:
        shape = x.get_shape().as_list()
        if shape is not None:
          for spec_dim, dim in zip(spec.shape, shape):
            if spec_dim is not None and dim is not None:
              if spec_dim != dim:
                raise ValueError('Input ' + str(input_index) +
                                 ' is incompatible with layer ' + self.name +
                                 ': expected shape=' + str(spec.shape) +
                                 ', found shape=' + str(shape))


class InputSpec(object):
  """Specifies the ndim, dtype and shape of every input to a layer.

  Every layer should expose (if appropriate) an `input_spec` attribute:
  a list of instances of InputSpec (one per input tensor).

  A None entry in a shape is compatible with any dimension,
  a None shape is compatible with any shape.

  Arguments:
      dtype: Expected DataType of the input.
      shape: Shape tuple, expected shape of the input
          (may include None for unchecked axes).
      ndim: Integer, expected rank of the input.
      max_ndim: Integer, maximum rank of the input.
      min_ndim: Integer, minimum rank of the input.
      axes: Dictionary mapping integer axes to
          a specific dimension value.
  """

  def __init__(self,
               dtype=None,
               shape=None,
               ndim=None,
               max_ndim=None,
               min_ndim=None,
               axes=None):
    self.dtype = dtype
    self.shape = shape
    if shape is not None:
      self.ndim = len(shape)
    else:
      self.ndim = ndim
    self.max_ndim = max_ndim
    self.min_ndim = min_ndim
    self.axes = axes or {}


class Node(object):
  """A `Node` describes the connectivity between two layers.

  Each time a layer is connected to some new input,
  a node is added to `layer._inbound_nodes`.
  Each time the output of a layer is used by another layer,
  a node is added to `layer._outbound_nodes`.

  Arguments:
      outbound_layer: the layer that takes
          `input_tensors` and turns them into `output_tensors`
          (the node gets created when the `call`
          method of the layer was called).
      inbound_layers: a list of layers, the same length as `input_tensors`,
          the layers from where `input_tensors` originate.
      node_indices: a list of integers, the same length as `inbound_layers`.
          `node_indices[i]` is the origin node of `input_tensors[i]`
          (necessary since each inbound layer might have several nodes,
          e.g. if the layer is being shared with a different data stream).
      tensor_indices: a list of integers,
          the same length as `inbound_layers`.
          `tensor_indices[i]` is the index of `input_tensors[i]` within the
          output of the inbound layer
          (necessary since each inbound layer might
          have multiple tensor outputs, with each one being
          independently manipulable).
      input_tensors: list of input tensors.
      output_tensors: list of output tensors.
      arguments: dictionary of keyword arguments that were passed to the
          `call` method of the layer at the call that created the node.

  `node_indices` and `tensor_indices` are basically fine-grained coordinates
  describing the origin of the `input_tensors`.

  A node from layer A to layer B is added to:
    - A._outbound_nodes
    - B._inbound_nodes
  """

  def __init__(self,
               outbound_layer,
               inbound_layers,
               node_indices,
               tensor_indices,
               input_tensors,
               output_tensors,
               arguments=None):
    # Layer instance (NOT a list).
    if isinstance(outbound_layer, list):
      raise ValueError(
          '`outbound_layer` should be a layer instance, not a list.')
    # this is the layer that takes a list of input tensors
    # and turns them into a list of output tensors.
    # the current node will be added to
    # the inbound_nodes of outbound_layer.
    self.outbound_layer = outbound_layer

    # The following 3 properties describe where
    # the input tensors come from: which layers,
    # and for each layer, which node and which
    # tensor output of each node.

    # List of layer instances.
    self.inbound_layers = inbound_layers
    # List of integers, 1:1 mapping with inbound_layers.
    self.node_indices = node_indices
    # List of integers, 1:1 mapping with inbound_layers.
    self.tensor_indices = tensor_indices

    # Following 2 properties:
    # tensor inputs and outputs of outbound_layer.

    # List of tensors. 1:1 mapping with inbound_layers.
    self.input_tensors = input_tensors
    # List of tensors, created by outbound_layer.call().
    self.output_tensors = output_tensors

    # Following 2 properties: input and output shapes.

    # List of shape tuples, shapes of input_tensors.
    self.input_shapes = [_static_shape(x) for x in input_tensors]
    # List of shape tuples, shapes of output_tensors.
    self.output_shapes = [_static_shape(x) for x in output_tensors]

    # Optional keyword arguments to layer's `call`.
    self.arguments = arguments

    # Add nodes to all layers involved.
    for layer in inbound_layers:
      if layer is not None:
        # For compatibility with external Keras, we use the deprecated
        # accessor here.
        layer.outbound_nodes.append(self)
    # For compatibility with external Keras, we use the deprecated
    # accessor here.
    outbound_layer.inbound_nodes.append(self)

  def get_config(self):
    inbound_names = []
    for layer in self.inbound_layers:
      if layer:
        inbound_names.append(layer.name)
      else:
        inbound_names.append(None)
    return {
        'outbound_layer': self.outbound_layer.name,
        'inbound_layers': inbound_names,
        'node_indices': self.node_indices,
        'tensor_indices': self.tensor_indices
    }


class _DeferredTensor(object):
  """Tensor-like object used to build graphs of layers in Eager mode.

  When calling a layer on a DeferredTensor, the layer will not perform any
  computation and will simply perfom shape inference to return new
  DeferredTensors with appropriate shape information. Thus DeferredTensor
  behaves like a graph-mode Tensor when manipulated by layers.
  """

  def __init__(self, shape, dtype, name=None):
    self.shape = tensor_shape.TensorShape(shape)
    self.dtype = dtypes.as_dtype(dtype)
    self.name = name

  def get_shape(self):
    return self.shape

  def __str__(self):
    return "DeferredTensor('%s', shape=%s, dtype=%s)" % (self.name,
                                                         self.get_shape(),
                                                         self.dtype.name)

  def __repr__(self):
    return "<_DeferredTensor '%s' shape=%s dtype=%s>" % (self.name,
                                                         self.get_shape(),
                                                         self.dtype.name)


class InputLayer(Layer):
  """Layer to be used as an entry point into a Network (a graph of layers).

  It can either wrap an existing tensor (pass an `input_tensor` argument)
  or create its a placeholder tensor (pass arguments `input_shape`
  as well as `dtype`).

  It is generally recommend to use the functional layer API via `Input`,
  (which creates an `InputLayer`) without directly using `InputLayer`.

  Arguments:
      input_shape: Shape tuple (not including the batch axis), or `TensorShape`
        instance (not including the batch axis).
      batch_size: Optional input batch size (integer or None).
      dtype: Datatype of the input.
      input_tensor: Optional tensor to use as layer input
          instead of creating a placeholder.
      sparse: Boolean, whether the placeholder created
          is meant to be sparse.
      name: Name of the layer (string).

    Raises:
      RuntimeError: If created in Eager mode.
  """

  def __init__(self,
               input_shape=None,
               batch_size=None,
               dtype=dtypes.float32,
               input_tensor=None,
               sparse=False,
               name=None):
    super(InputLayer, self).__init__(dtype=dtype, name=name)
    self.built = True
    self.sparse = sparse
    self.batch_size = batch_size

    if isinstance(input_shape, tensor_shape.TensorShape):
      input_shape = tuple(input_shape.as_list())

    if input_tensor is None:
      if input_shape is not None:
        batch_input_shape = (batch_size,) + tuple(input_shape)
      else:
        batch_input_shape = None

      if context.in_eager_mode():
        # In eager mode, create a temporary placeholder to call the layer on.
        input_tensor = _DeferredTensor(
            shape=batch_input_shape,
            dtype=dtype,
            name=self.name)
      else:
        # In graph mode, create a graph placeholder to call the layer on.
        if sparse:
          input_tensor = array_ops.sparse_placeholder(
              shape=batch_input_shape,
              dtype=dtype,
              name=self.name)
        else:
          input_tensor = array_ops.placeholder(
              shape=batch_input_shape,
              dtype=dtype,
              name=self.name)

      # For compatibility with Keras API.
      self.is_placeholder = True
      self._batch_input_shape = batch_input_shape
    else:
      # For compatibility with Keras API.
      self.is_placeholder = False
      self._batch_input_shape = tuple(input_tensor.get_shape().as_list())

    # Create an input node to add to self.outbound_node
    # and set output_tensors' _keras_history.
    input_tensor._keras_history = (self, 0, 0)  # pylint: disable=protected-access
    Node(
        self,
        inbound_layers=[],
        node_indices=[],
        tensor_indices=[],
        input_tensors=[input_tensor],
        output_tensors=[input_tensor])


def Input(  # pylint: disable=invalid-name
    shape=None,
    batch_size=None,
    name=None,
    dtype=dtypes.float32,
    sparse=False,
    tensor=None):
  """`Input()` is used to instantiate an input tensor for use with a `Network`.

  For instance, if a, b and c are tensors created via `Input`,
  it becomes possible to do:

  `network = Network(inputs=[a, b], outputs=c)`

  Example:

      ```python
      # This is a logistic regression
      x = tf.layers.Input(shape=(32,))
      y = tf.layers.Dense(16, activation='softmax')(x)
      network = tf.layers.Network(x, y)
      ```

  Arguments:
      shape: A shape tuple (integer), not including the batch size.
          For instance, `shape=(32,)` indicates that the expected input
          will be batches of 32-dimensional vectors.
      batch_size: Optional input batch size (integer or None).
      name: An optional name string for the layer.
          Should be unique in a model (do not reuse the same name twice).
          It will be autogenerated if it isn't provided.
      dtype: The data type expected by the input, as a string
          (`float32`, `float64`, `int32`...)
      sparse: A boolean specifying whether the placeholder
          to be created is sparse.
      tensor: Optional existing tensor to wrap into the `Input` layer.
          If set, the layer will not create a placeholder tensor.

  Returns:
      A tensor: either a new placeholder (with history metadata) or
      `tensor` (if passed), with added history metadata.

  Raises:
    RuntimeError: If called in Eager mode.
  """
  input_layer = InputLayer(
      input_shape=shape,
      batch_size=batch_size,
      name=name,
      dtype=dtype,
      sparse=sparse,
      input_tensor=tensor)
  # Return tensor including `_keras_history` metadata.
  # Note that in this case train_output and test_output are the same pointer.
  outputs = input_layer._inbound_nodes[0].output_tensors  # pylint: disable=protected-access
  if len(outputs) == 1:
    return outputs[0]
  else:
    return outputs


class Network(Layer):
  """A Network is a directed acyclic graph of layers.

  It is the topological form of a "model".
  A Model is simply a Network with added training/evaluation routines.

  A Network instance implements the full Layer API. In particular, a network
  can be called on new inputs.

  Example:

      ```python
      # This is a logistic regression
      x = tf.layers.Input(shape=(32,))
      y = tf.layers.Dense(16, activation='softmax')(x)
      network = tf.layers.Network(x, y)

      # It is then possible to call the network on compatible inputs:
      z = tf.layers.Input(shape=(32,))
      w = network(z)

      # It is possible to retrieve the same properties as a layer:
      weights = network.trainable_weights
      ```

  Arguments:
      inputs: Input tensor or list of input tensors.
        Must come from `tf.layers.Input`.
      output: Output tensor or list of output tensors. Must come from
        tf.layers Layers or Keras layers.
      name: Optional name of the model (string).

  Attributes:
    Network has the same attributes as Layer. On top of it, it also has:
      - layers: a list of the children layers of the network,
        a list of layer instances, ordered from "earlier in the graph"
        to "later in the graph".

  Methods:
    Network has the same methods as Layer. On top of it, it also has:
      - get_layer: retrieves a child layer by name or index in the graph.

  Raises:
    RuntimeError: If created in Eager mode.
  """

  def __init__(self, inputs, outputs, name=None):  # pylint: disable=super-init-not-called
    if context.in_eager_mode():
      # TODO(fchollet): check that all inputs and outputs are DeferredTensors.
      pass

    self._init_set_name(name)
    self._activity_regularizer = None
    with vs.variable_scope(
        None, default_name=self._base_name) as captured_scope:
      self._scope = captured_scope
    call_fn_args = estimator_util.fn_args(self.call)
    self._compute_previous_mask = ('mask' in call_fn_args or
                                   hasattr(self, 'compute_mask'))
    self._call_has_scope_arg = 'scope' in call_fn_args

    # This acts just like the `trainable` attribute of any layer instance.
    # It does not affect users of the underlying layers, only users of the
    # Network instance.
    self.trainable = True
    # A Network does not create weights of its own, thus it is already built.
    self.built = True
    # A Network does not create weights of its own, thus has no dtype.
    self._dtype = None
    # The following are implemented as property functions:
    # self.trainable_weights
    # self.non_trainable_weights
    # self.input_spec

    # Private attributes to implement compatibility with Layer.
    self._per_input_losses = {}
    self._per_input_updates = {}
    self._updates = []
    self._losses = []
    self._scope = None
    self._reuse = None
    self._graph = ops.get_default_graph()

    # Network-specific properties.
    if isinstance(inputs, (list, tuple)):
      self.inputs = list(inputs)  # Tensor or list of tensors.
    else:
      self.inputs = [inputs]
    if isinstance(outputs, (list, tuple)):
      self.outputs = list(outputs)
    else:
      self.outputs = [outputs]
    # All layers in order of horizontal graph traversal.
    # Entries are unique. Includes input and output layers.
    self.layers = []

    # Check for redundancy in inputs.
    if len(set(self.inputs)) != len(self.inputs):
      raise ValueError('The list of inputs passed to the model '
                       'is redundant. '
                       'All inputs should only appear once.'
                       ' Found: ' + str(self.inputs))

    # # List of initial layers (1 to 1 mapping with self.inputs,
    # # hence the same layer might appear twice)
    # self._input_layers = []
    # self._input_layers_node_indices = []
    # self._input_layers_tensor_indices = []
    # # list of layers (1 to 1 mapping with self.inputs,
    # # hence the same layer might appear twice)
    # self._output_layers = []
    # self._output_layers_node_indices = []
    # self._output_layers_tensor_indices = []

    self._input_layers = []
    self._output_layers = []
    self._input_coordinates = []
    self._output_coordinates = []

    # This is for performance optimization
    # when calling the Network on new inputs.
    # every time the Network is called on a set on input tensors,
    # we compute the output tensors,
    # output masks and output shapes in one pass,
    # then cache them here. When any of these outputs is queried later,
    # we retrieve it from there instead of recomputing it.
    self._output_mask_cache = {}
    self._output_tensor_cache = {}
    self._output_shape_cache = {}

    # User-provided arguments validation.
    for x in self.inputs:
      # Check that x has appropriate `_keras_history` metadata.
      if not hasattr(x, '_keras_history'):
        cls_name = self.__class__.__name__
        raise ValueError('Input tensors to a ' + cls_name + ' ' +
                         'must come from `tf.layers.Input`. '
                         'Received: ' + str(x) +
                         ' (missing previous layer metadata).')
      # Check that x is an input tensor.
      # pylint: disable=protected-access
      layer, node_index, tensor_index = x._keras_history
      if len(layer._inbound_nodes) > 1 or (
          layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers):
        cls_name = self.__class__.__name__
        logging.warning(cls_name + ' inputs must come from '
                        '`tf.layers.Input` (thus holding past layer metadata), '
                        'they cannot be the output of '
                        'a previous non-Input layer. '
                        'Here, a tensor specified as '
                        'input to "' + self.name + '" was not an Input tensor, '
                        'it was generated by layer ' + layer.name + '.\n'
                        'Note that input tensors are '
                        'instantiated via `tensor = tf.layers.Input(shape)`.\n'
                        'The tensor that caused the issue was: ' + str(x.name))
      # pylint: enable=protected-access
    for x in self.outputs:
      if not hasattr(x, '_keras_history'):
        cls_name = self.__class__.__name__
        raise ValueError('Output tensors to a ' + cls_name + ' must be '
                         'the output of a TensorFlow `Layer` '
                         '(thus holding past layer metadata). Found: ' + str(x))

    # Build self._output_layers:
    for x in self.outputs:
      layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
      self._output_layers.append(layer)
      self._output_coordinates.append((layer, node_index, tensor_index))

    # Build self._input_layers:
    for x in self.inputs:
      layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
      # It's supposed to be an input layer, so only one node
      # and one tensor output.
      assert node_index == 0
      assert tensor_index == 0
      self._input_layers.append(layer)
      self._input_coordinates.append((layer, node_index, tensor_index))

    # Network_nodes: set of nodes included in the graph
    # (not all nodes included in the layers
    # are relevant to the current graph).
    network_nodes = set()  # ids of all nodes relevant to the Network
    nodes_depths = {}  # dict {node: depth value}
    layers_depths = {}  # dict {layer: depth value}
    layer_indices = {}  # dict {layer: index in traversal}
    nodes_in_decreasing_depth = []

    def build_map_of_graph(tensor,
                           finished_nodes,
                           nodes_in_progress,
                           layer,
                           node_index,
                           tensor_index):
      """Builds a map of the graph of layers.

      This recursively updates the map `layer_indices`,
      the list `nodes_in_decreasing_depth` and the set `network_nodes`.

      Arguments:
          tensor: Some tensor in a graph.
          finished_nodes: Set of nodes whose subgraphs have been traversed
              completely. Useful to prevent duplicated work.
          nodes_in_progress: Set of nodes that are currently active on the
              recursion stack. Useful to detect cycles.
          layer: Layer from which `tensor` comes from. If not provided,
              will be obtained from `tensor._keras_history`.
          node_index: Node index from which `tensor` comes from.
          tensor_index: Tensor_index from which `tensor` comes from.

      Raises:
          ValueError: if a cycle is detected.
      """
      node = layer._inbound_nodes[node_index]  # pylint: disable=protected-access

      # Prevent cycles.
      if node in nodes_in_progress:
        raise ValueError('The tensor ' + str(tensor) + ' at layer "' +
                         layer.name + '" is part of a cycle.')

      # Don't repeat work for shared subgraphs
      if node in finished_nodes:
        return

      node_key = _make_node_key(layer.name, node_index)
      # Update network_nodes.
      network_nodes.add(node_key)

      # Store the traversal order for layer sorting.
      if layer not in layer_indices:
        layer_indices[layer] = len(layer_indices)

      nodes_in_progress.add(node)

      # Propagate to all previous tensors connected to this node.
      for i in range(len(node.inbound_layers)):
        x = node.input_tensors[i]
        layer = node.inbound_layers[i]
        node_index = node.node_indices[i]
        tensor_index = node.tensor_indices[i]
        build_map_of_graph(x, finished_nodes, nodes_in_progress, layer,
                           node_index, tensor_index)

      finished_nodes.add(node)
      nodes_in_progress.remove(node)
      nodes_in_decreasing_depth.append(node)

    finished_nodes = set()
    nodes_in_progress = set()
    for x in self.outputs:
      layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
      build_map_of_graph(x, finished_nodes, nodes_in_progress,
                         layer=layer,
                         node_index=node_index,
                         tensor_index=tensor_index)

    for node in reversed(nodes_in_decreasing_depth):
      # If the depth is not set, the node has no outbound nodes (depth 0).
      depth = nodes_depths.setdefault(node, 0)

      # Update the depth of the corresponding layer
      previous_depth = layers_depths.get(node.outbound_layer, 0)
      # If we've seen this layer before at a higher depth,
      # we should use that depth instead of the node depth.
      # This is necessary for shared layers that have inputs at different
      # depth levels in the graph.
      depth = max(depth, previous_depth)
      layers_depths[node.outbound_layer] = depth
      nodes_depths[node] = depth

      # Update the depth of inbound nodes.
      # The "depth" of a node is the max of the depths
      # of all layers it is connected to.
      for i in range(len(node.inbound_layers)):
        inbound_layer = node.inbound_layers[i]
        node_index = node.node_indices[i]
        inbound_node = inbound_layer._inbound_nodes[node_index]  # pylint: disable=protected-access
        previous_depth = nodes_depths.get(inbound_node, 0)
        nodes_depths[inbound_node] = max(depth + 1, previous_depth)

    # Build a dict {depth: list of nodes with this depth}
    nodes_by_depth = {}
    for node, depth in nodes_depths.items():
      if depth not in nodes_by_depth:
        nodes_by_depth[depth] = []
      nodes_by_depth[depth].append(node)

    # Build a dict {depth: list of layers with this depth}
    layers_by_depth = {}
    for layer, depth in layers_depths.items():
      if depth not in layers_by_depth:
        layers_by_depth[depth] = []
      layers_by_depth[depth].append(layer)

    # Get sorted list of layer depths.
    depth_keys = list(layers_by_depth.keys())
    depth_keys.sort(reverse=True)

    # Set self.layers and self._layers_by_depth.
    layers = []
    for depth in depth_keys:
      layers_for_depth = layers_by_depth[depth]
      # Network.layers needs to have a deterministic order:
      # here we order them by traversal order.
      layers_for_depth.sort(key=lambda x: layer_indices[x])
      layers.extend(layers_for_depth)
    self.layers = layers
    self._layers_by_depth = layers_by_depth

    # Get sorted list of node depths.
    depth_keys = list(nodes_by_depth.keys())
    depth_keys.sort(reverse=True)

    # Check that all tensors required are computable.
    # computable_tensors: all tensors in the graph
    # that can be computed from the inputs provided.
    computable_tensors = []
    for x in self.inputs:
      computable_tensors.append(x)

    layers_with_complete_input = []  # To provide a better error msg.
    for depth in depth_keys:
      for node in nodes_by_depth[depth]:
        layer = node.outbound_layer
        if layer:
          for x in node.input_tensors:
            if x not in computable_tensors:
              raise ValueError('Graph disconnected: '
                               'cannot obtain value for tensor ' + str(x) +
                               ' at layer "' + layer.name + '". '
                               'The following previous layers '
                               'were accessed without issue: ' +
                               str(layers_with_complete_input))
          for x in node.output_tensors:
            computable_tensors.append(x)
          layers_with_complete_input.append(layer.name)

    # Keep track of the network's nodes.
    self._network_nodes = network_nodes
    self._nodes_by_depth = nodes_by_depth

    # Ensure name unicity, which will be crucial for serialization
    # (since serialized nodes refer to layers by their name).
    all_names = [layer.name for layer in self.layers]
    for name in all_names:
      if all_names.count(name) != 1:
        raise ValueError('The name "' + name + '" is used ' +
                         str(all_names.count(name)) + ' times in the model. '
                         'All layer names should be unique.')

    # Layer parameters.
    # The new network starts with a single inbound node
    # for its inputs, and no outbound nodes.
    self._outbound_nodes = []  # Will be appended to by future calls to __call__
    self._inbound_nodes = [
    ]  # Will be appended to below, and by future calls to __call__
    # Create the node linking internal inputs to internal outputs.
    Node(
        outbound_layer=self,
        inbound_layers=[],
        node_indices=[],
        tensor_indices=[],
        input_tensors=self.inputs,
        output_tensors=self.outputs)

  def get_layer(self, name=None, index=None):
    """Retrieves a layer based on either its name (unique) or index.

    Indices are based on order of horizontal graph traversal (bottom-up).

    Arguments:
        name: String, name of layer.
        index: Integer, index of layer.

    Returns:
        A layer instance.

    Raises:
        ValueError: In case of invalid layer name or index.
    """
    # TODO(fchollet): We could build a dictionary based on layer names
    # since they are constant, but we have not done that yet.
    if index is not None:
      if len(self.layers) <= index:
        raise ValueError('Was asked to retrieve layer at index ' + str(index) +
                         ' but model only has ' + str(len(self.layers)) +
                         ' layers.')
      else:
        return self.layers[index]
    else:
      if not name:
        raise ValueError('Provide either a layer name or layer index.')
    for layer in self.layers:
      if layer.name == name:
        return layer
    raise ValueError('No such layer: ' + name)

  @property
  def updates(self):
    """Retrieve the network's updates.

    Will only include updates that are either
    unconditional, or conditional on inputs to this model
    (e.g. will not include updates that depend on tensors
    that aren't inputs to this model).

    Returns:
        A list of update ops.
    """
    updates = []
    for layer in self.layers:
      if hasattr(layer, 'updates'):
        # Collect updates that are dependent on inputs
        # that are part of the model.
        for node_index, node in enumerate(layer._inbound_nodes):  # pylint: disable=protected-access
          node_key = _make_node_key(layer.name, node_index)
          if node_key in self._network_nodes:
            # The model owns this layer node.
            inputs = node.input_tensors
            updates += layer.get_updates_for(inputs)
        # Collect unconditional updates.
        updates += layer.get_updates_for(None)
    return updates

  @property
  def losses(self):
    """Retrieve the network's losses.

    Will only include losses that are either
    unconditional, or conditional on inputs to this model
    (e.g. will not include losses that depend on tensors
    that aren't inputs to this model).

    Returns:
        A list of loss tensors.
    """
    losses = []
    # Retrieve losses for all internal layers.
    for layer in self.layers:
      if hasattr(layer, 'losses'):
        # Collect losses that are dependent on inputs
        # that are part of the model.
        for node_index, node in enumerate(layer._inbound_nodes):  # pylint: disable=protected-access
          node_key = _make_node_key(layer.name, node_index)
          if node_key in self._network_nodes:
            # The model owns this layer node.
            inputs = node.input_tensors
            losses += layer.get_losses_for(inputs)
        # Collect unconditional losses.
        losses += layer.get_losses_for(None)
    # Add any potential unconditional model-level loss.
    losses += self.get_losses_for(None)
    return losses

  @property
  def trainable_weights(self):
    if not self.trainable:
      return []
    weights = []
    for layer in self.layers:
      weights += layer.trainable_weights
    return weights

  @property
  def non_trainable_weights(self):
    weights = []
    for layer in self.layers:
      weights += layer.non_trainable_weights
    if not self.trainable:
      trainable_weights = []
      for layer in self.layers:
        trainable_weights += layer.trainable_weights
      return trainable_weights + weights
    return weights

  @property
  def input_spec(self):
    """Gets the network's input specs.

    Returns:
        A list of `InputSpec` instances (one per input to the model)
            or a single instance if the model has only one input.
    """
    specs = []
    for layer in self._input_layers:
      if layer.input_spec is None:
        specs.append(None)
      else:
        if not isinstance(layer.input_spec, list):
          raise TypeError('Layer ' + layer.name +
                          ' has an input_spec attribute that '
                          'is not a list. We expect a list. '
                          'Found input_spec = ' + str(layer.input_spec))
        specs += layer.input_spec
    if len(specs) == 1:
      return specs[0]
    return specs

  def call(self, inputs, mask=None):
    """Call the model on new inputs.

    In this case `call` just reapplies
    all ops in the graph to the new inputs
    (e.g. build a new computational graph from the provided inputs).

    Arguments:
        inputs: A tensor or list of tensors.
        mask: A mask or list of masks. A mask can be
            either a tensor or None (no mask).

    Returns:
        A tensor if there is a single output, or
        a list of tensors if there are more than one outputs.
    """
    inputs = nest.flatten(inputs)
    if mask is None:
      masks = [None for _ in range(len(inputs))]
    else:
      masks = nest.flatten(mask)

    if context.in_graph_mode():
      # Try to retrieve cached outputs if the layer has already been called
      # on these exact inputs.
      cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks)
      if cache_key in self._output_tensor_cache:
        # Cache hit.
        return self._output_tensor_cache[cache_key]
    # Actually apply the network graph to the new inputs.
    outputs, _ = self._run_internal_graph(inputs, masks)
    return outputs

  def _compute_output_shape(self, input_shape):
    if isinstance(input_shape, list):
      input_shapes = []
      for shape in input_shape:
        if shape is not None:
          input_shapes.append(tuple(tensor_shape.TensorShape(shape).as_list()))
        else:
          input_shapes.append(None)
    else:
      if input_shape is not None:
        input_shapes = [tuple(tensor_shape.TensorShape(input_shape).as_list())]
      else:
        input_shapes = [None]

    if len(input_shapes) != len(self._input_layers):
      raise ValueError('Invalid input_shape argument ' + str(input_shape) +
                       ': model has ' + str(len(self._input_layers)) +
                       ' tensor inputs.')

    cache_key = _object_list_uid(input_shapes)
    if cache_key not in self._output_shape_cache:
      # Cache miss. We have to run the network graph manually (recursive calls
      # to `_compute_output_shape`).
      layers_to_output_shapes = {}
      for i in range(len(input_shapes)):
        layer = self._input_layers[i]
        input_shape = input_shapes[i]
        # It's an input layer: then `_compute_output_shape` is identity,
        # and there is only one node and one tensor output.
        shape_key = layer.name + '_0_0'
        layers_to_output_shapes[shape_key] = input_shape

      depth_keys = list(self._nodes_by_depth.keys())
      depth_keys.sort(reverse=True)
      # Iterate over nodes, by depth level.
      if len(depth_keys) > 1:
        for depth in depth_keys:
          nodes = self._nodes_by_depth[depth]
          for node in nodes:
            # This is always a single layer, never a list.
            layer = node.outbound_layer
            if layer in self._input_layers:
              # We've already covered the input layers
              # a few lines above.
              continue
            # Potentially redundant list,
            # same size as node.input_tensors.
            input_shapes = []
            for j in range(len(node.inbound_layers)):
              inbound_layer = node.inbound_layers[j]
              node_index = node.node_indices[j]
              tensor_index = node.tensor_indices[j]
              shape_key = inbound_layer.name + '_%s_%s' % (node_index,
                                                           tensor_index)
              input_shape = layers_to_output_shapes[shape_key]
              input_shapes.append(input_shape)

            if len(input_shapes) == 1:
              output_shape = layer._compute_output_shape(input_shapes[0])  # pylint: disable=protected-access
            else:
              output_shape = layer._compute_output_shape(input_shapes)  # pylint: disable=protected-access
            if isinstance(output_shape, list):
              output_shapes = [
                  tuple(tensor_shape.TensorShape(shape).as_list())
                  for shape in output_shape
              ]
            else:
              output_shapes = [
                  tuple(tensor_shape.TensorShape(output_shape).as_list())
              ]

            node_index = layer._inbound_nodes.index(node)  # pylint: disable=protected-access
            for j in range(len(output_shapes)):
              shape_key = layer.name + '_%s_%s' % (node_index, j)
              layers_to_output_shapes[shape_key] = output_shapes[j]

        # Read final output shapes from layers_to_output_shapes.
        output_shapes = []
        for i in range(len(self._output_layers)):
          layer, node_index, tensor_index = self._output_coordinates[i]
          shape_key = layer.name + '_%s_%s' % (node_index, tensor_index)
          output_shapes.append(layers_to_output_shapes[shape_key])

        # Store in cache.
        self._output_shape_cache[cache_key] = output_shapes
      else:
        # Cache hit.
        output_shapes = self._output_shape_cache[cache_key]

      if isinstance(output_shapes, list):
        if len(output_shapes) == 1:
          return tensor_shape.TensorShape(output_shapes[0])
        else:
          return [tensor_shape.TensorShape(shape) for shape in output_shapes]
      else:
        return tensor_shape.TensorShape(output_shapes)

  def _run_internal_graph(self, inputs, masks=None):
    """Computes output tensors for new inputs.

    # Note:
        - Expects `inputs` to be a list (potentially with 1 element).
        - Can be run on non-Keras tensors.

    Arguments:
        inputs: List of tensors
        masks: List of masks (tensors or None).

    Returns:
        Three lists: output_tensors, output_masks, output_shapes
    """
    # Note: masking support is relevant mainly for Keras.
    # It cannot be factored out without having the fully reimplement the
    # network calling logic on the Keras side. We choose to incorporate it
    # in Network because 1) it may be useful to fully support in tf.layers in
    # the future and 2) Keras is a major user of Network.
    # If you don't use masking, it does not interfere with regular behavior
    # at all and you can ignore it.
    if masks is None:
      masks = [None for _ in range(len(inputs))]

    # Dictionary mapping reference tensors to tuples
    # (computed tensor, compute mask)
    # we assume a 1:1 mapping from tensor to mask
    # TODO(fchollet): raise exception when a `.compute_mask()` call
    # does not return a list the same size as `call`
    tensor_map = {}
    for x, y, mask in zip(self.inputs, inputs, masks):
      tensor_map[str(id(x))] = (y, mask)

    depth_keys = list(self._nodes_by_depth.keys())
    depth_keys.sort(reverse=True)
    for depth in depth_keys:
      nodes = self._nodes_by_depth[depth]
      for node in nodes:
        # This is always a single layer, never a list.
        layer = node.outbound_layer

        reference_input_tensors = node.input_tensors
        reference_output_tensors = node.output_tensors

        # If all previous input tensors are available in tensor_map,
        # then call node.inbound_layer on them.
        computed_data = []  # List of tuples (input, mask).
        for x in reference_input_tensors:
          if str(id(x)) in tensor_map:
            computed_data.append(tensor_map[str(id(x))])

        if len(computed_data) == len(reference_input_tensors):
          # Call layer (reapplying ops to new inputs).
          with ops.name_scope(layer.name):
            if node.arguments:
              kwargs = node.arguments
            else:
              kwargs = {}
            if len(computed_data) == 1:
              computed_tensor, computed_mask = computed_data[0]
              # Ensure mask propagation if applicable.
              if 'mask' in estimator_util.fn_args(layer.call):
                if 'mask' not in kwargs:
                  kwargs['mask'] = computed_mask

              output_tensors = nest.flatten(
                  layer.call(computed_tensor, **kwargs))
              if hasattr(layer, 'compute_mask'):
                output_masks = nest.flatten(
                    layer.compute_mask(computed_tensor, computed_mask))
              else:
                output_masks = [None for _ in range(len(output_tensors))]
              computed_tensors = [computed_tensor]
              computed_masks = [computed_mask]
            else:
              computed_tensors = [x[0] for x in computed_data]
              computed_masks = [x[1] for x in computed_data]
              if 'mask' in estimator_util.fn_args(layer.call):
                if 'mask' not in kwargs:
                  kwargs['mask'] = computed_masks
              output_tensors = nest.flatten(
                  layer.call(computed_tensors, **kwargs))
              if hasattr(layer, 'compute_mask'):
                output_masks = nest.flatten(
                    layer.compute_mask(computed_tensors, computed_masks))
              else:
                output_masks = [None for _ in range(len(output_tensors))]

            # Apply activity regularizer if any:
            if layer.activity_regularizer is not None:
              regularization_losses = [
                  layer.activity_regularizer(x) for x in computed_tensors
              ]
              layer.add_loss(regularization_losses, computed_tensors)

          if context.in_graph_mode():
            # Update model updates and losses:
            # Keep track of updates that depend on the inputs
            # (e.g. BN updates).
            self.add_update(layer.get_updates_for(computed_tensors), inputs)
            # Keep track of unconditional updates (e.g. a counter).
            self.add_update(layer.get_updates_for(None), None)
            # Keep track of losses that depend on the inputs
            # (e.g. activity regularizers).
            self.add_loss(layer.get_losses_for(computed_tensors), inputs)
            # Keep track of unconditional losses
            # (e.g. weight regularizers).
            self.add_loss(layer.get_losses_for(None), None)

          # Update tensor_map.
          for x, y, mask in zip(reference_output_tensors, output_tensors,
                                output_masks):
            tensor_map[str(id(x))] = (y, mask)

    output_tensors = []
    output_masks = []
    output_shapes = []
    for x in self.outputs:
      assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x)
      tensor, mask = tensor_map[str(id(x))]
      output_shapes.append(_static_shape(x))
      output_tensors.append(tensor)
      output_masks.append(mask)

    if len(output_tensors) == 1:
      output_tensors = output_tensors[0]
      if output_shapes is not None:
        output_shapes = output_shapes[0]
      if output_masks is not None:
        output_masks = output_masks[0]

    if context.in_graph_mode():
      # Update cache;
      # keys are based on ids on input tensors and inputs masks.
      cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks)
      self._output_tensor_cache[cache_key] = output_tensors
      if output_masks is not None:
        self._output_mask_cache[cache_key] = output_masks
      if output_shapes is not None:
        input_shapes = [_static_shape(x) for x in inputs]
        cache_key = _object_list_uid(input_shapes)
        self._output_shape_cache[cache_key] = output_shapes

    return output_tensors, output_masks


def _is_tensor_or_tensor_list(v):
  v = nest.flatten(v)
  if v and isinstance(v[0], ops.Tensor):
    return True
  else:
    return False


def _to_snake_case(name):
  intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
  insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
  # If the class is private the name starts with "_" which is not secure
  # for creating scopes. We prefix the name with "private" in this case.
  if insecure[0] != '_':
    return insecure
  return 'private' + insecure


def _to_list(x):
  """This normalizes a list/tuple or single element into a list.

  If a single element is passed, we return
  a list of size 1 containing the element.

  Arguments:
    x: list or tuple or single element.

  Returns:
    A list.
  """
  if isinstance(x, (list, tuple)):
    return list(x)
  return [x]


def _add_elements_to_collection(elements, collection_list):
  if context.in_eager_mode():
    raise RuntimeError('Using collections from Layers not supported in Eager '
                       'mode. Tried to add %s to %s' % (elements,
                                                        collection_list))
  elements = nest.flatten(elements)
  collection_list = nest.flatten(collection_list)
  for name in collection_list:
    collection = ops.get_collection_ref(name)
    collection_set = set(collection)
    for element in elements:
      if element not in collection_set:
        collection.append(element)


def _object_list_uid(object_list):
  object_list = nest.flatten(object_list)
  return ', '.join([str(abs(id(x))) for x in object_list])


def _make_node_key(layer_name, node_index):
  return layer_name + '_ib-' + str(node_index)


def _static_shape(x):
  if x is None:
    return None
  try:
    return tuple(x.get_shape().as_list())
  except ValueError:
    return None


def _is_all_none(iterable_or_element):
  if not isinstance(iterable_or_element, (list, tuple)):
    iterable = [iterable_or_element]
  else:
    iterable = iterable_or_element
  # We cannot use Python's `any` because the iterable may return Tensors.
  for element in iterable:
    if element is not None:
      return False
  return True


def _have_all_keras_metadata(iterable_or_element):
  if not isinstance(iterable_or_element, (list, tuple)):
    iterable = [iterable_or_element]
  else:
    iterable = iterable_or_element
  return all([hasattr(x, '_keras_history') for x in iterable])


def _collect_previous_mask(input_tensors):
  """Retrieves the output mask(s) of the previous node.

  Arguments:
      input_tensors: A tensor or list of tensors.

  Returns:
      A mask tensor or list of mask tensors.
  """
  input_tensors = nest.flatten(input_tensors)
  masks = []
  for x in input_tensors:
    if hasattr(x, '_keras_mask'):
      mask = x._keras_mask  # pylint: disable=protected-access
      masks.append(mask)
    else:
      masks.append(None)
  if len(masks) == 1:
    return masks[0]
  return masks


# A global dictionary mapping graph objects to an index of counters used
# for various layer names in each graph.
# Allows to give unique autogenerated names to layers, in a graph-specific way.
PER_GRAPH_LAYER_NAME_UIDS = weakref.WeakKeyDictionary()


def _get_default_graph_uid_map():
  graph = ops.get_default_graph()
  name_uid_map = PER_GRAPH_LAYER_NAME_UIDS.get(graph, None)
  if name_uid_map is None:
    name_uid_map = collections.defaultdict(int)
    PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map
  return name_uid_map


def _unique_layer_name(name, name_uid_map=None, avoid_names=None):
  """Makes a layer name (or arbitrary string) unique within a TensorFlow graph.

  Arguments:
    name: String name to make unique.
    name_uid_map: An optional defaultdict(int) to use when creating unique
      names. If None (default), uses a per-Graph dictionary.
    avoid_names: An optional set or dict with names which should not be used. If
      None (default) does not avoid any names.

  Returns:
    Unique string name.

  Example:

  ```python
  _unique_layer_name('dense')  # dense_1
  _unique_layer_name('dense')  # dense_2
  ```
  """
  if name_uid_map is None:
    name_uid_map = _get_default_graph_uid_map()
  if avoid_names is None:
    avoid_names = set()
  proposed_name = None
  while proposed_name is None or proposed_name in avoid_names:
    name_uid_map[name] += 1
    proposed_name = name + '_' + str(name_uid_map[name])
  return proposed_name