aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/function.py
blob: f1a63adce14e3a2b5672fa75a9dab39bf2bb141f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=unidiomatic-typecheck
"""Defun decorator for defining graph-mode functions."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import functools
import re
import sys
import threading
import weakref

import numpy as np
import six

from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import autograph
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect

# This is to avoid a circular dependency with cond_v2_impl
# (function -> gradients_impl -> control_flow_ops -> cond_v2_impl).
cond_v2_impl._function = sys.modules[__name__]  # pylint: disable=protected-access

# This is to avoid a circular dependency with gradients_impl
gradients_impl._function = sys.modules[__name__]  # pylint: disable=protected-access

FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"

# TODO(scottzhu): Update this to allow arbitrary attribute names in future.
WHITELIST_FUNCTION_ATTRIBUTE_REGEX = [
    "experimental_.*",
    FORWARD_FUNCTION_ATTRIBUTE_NAME,
    BACKWARD_FUNCTION_ATTRIBUTE_NAME
]


def _create_substitute_placeholder(value, name=None, dtype=None):
  """Creates a placeholder for `value` and propagates shape info to it."""
  # Note: setting ops.control_dependencies(None) ensures we always put
  # capturing placeholders outside of any control flow context.
  with ops.control_dependencies(None):
    placeholder = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
  custom_gradient.copy_handle_data(value, placeholder)
  return placeholder


def _get_device_functions(ctx, graph):
  """Returns a tuple of device functions representing the device stack."""
  if ctx.executing_eagerly():
    return (pydev.merge_device(ctx.device_name),)
  else:
    return tuple(graph._device_functions_outer_to_inner)  # pylint: disable=protected-access


def _parse_func_attrs(attributes):
  """Convert the keyword arguments into function_def attributes.

  Currently only support primitive types: bool, int, float and string.

  Args:
    attributes: the dictionary of attributes.
  Returns:
    A dict of attributes where the key is the name of attribute and the value
      is the AttrValue proto.
  Raises:
    ValueError: If the kwargs contains unwhitelisted name or unsupported value
      types.
  """
  attrs = {}
  for key, value in attributes.items():
    if not any([re.match(reg, key)
                for reg in WHITELIST_FUNCTION_ATTRIBUTE_REGEX]):
      raise ValueError("Attribute name is not whitelisted. "
                       "Whitelisted: prefix %s, got: %s" %
                       (WHITELIST_FUNCTION_ATTRIBUTE_REGEX, key))

    if isinstance(value, attr_value_pb2.AttrValue):
      attrs[key] = value
    # bool type check has to happen before int since bool is a subclass of int.
    elif isinstance(value, bool):
      attrs[key] = attr_value_pb2.AttrValue(b=value)
    elif isinstance(value, int):
      attrs[key] = attr_value_pb2.AttrValue(i=value)
    elif isinstance(value, float):
      attrs[key] = attr_value_pb2.AttrValue(f=value)
    elif isinstance(value, (str, bytes)):
      attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
    else:
      raise ValueError("Unsupported attribute type for %s with type %s" %
                       (key, type(value)))
  return attrs


class FuncGraph(ops.Graph):
  """Graph representing a function body.

  Attributes:
    name: The name of the function.
    inputs: Placeholder tensors representing the inputs to this function. The
      tensors are in this FuncGraph. This represents "regular" inputs as well as
      captured inputs (i.e. the values of self.captures), with the regular
      inputs coming first.
    outputs: Tensors that will be returned by this function. The tensors are in
      this FuncGraph.
    structured_outputs: A possibly-nested python object which will be returned
      by this function. The Tensors in this structure are the same as those of
      self.outputs. Note that this structure might contain Python `None`s.
    variables: Variables that should be watched during function execution.
    outer_graph: The graph this function is defined in. May be another FuncGraph
      or the global default Graph.
    captures: Maps external tensor -> internal tensor (i.e. input placeholder).
      The entries are in the order they were captured.
    seed: The graph-level random seed.
  """

  def __init__(self, name):
    """Construct a new FuncGraph.

    The graph will inherit its graph key, collections, seed, device stack, and
    distribution strategy stack from the current context or graph.

    Args:
      name: the name of the function.
    """
    super(FuncGraph, self).__init__()

    self.name = name
    self.inputs = []
    self.outputs = []
    self.structured_outputs = None
    self._weak_variables = []
    self.outer_graph = ops.get_default_graph()
    self.captures = collections.OrderedDict()

    self._building_function = True
    # Map from resource tensor name to last op (in program order) which uses
    # this tensor. Used to enforce that execution order matches program order
    # for resource tensors.
    self._last_op_using_resource_tensor = {}

    graph = self.outer_graph

    if context.executing_eagerly():
      self.seed = context.global_seed()
      self._xla_compile = (context.context().device_spec.device_type == "TPU")
      self._add_device_to_stack(context.context().device_name)
    else:
      self.seed = graph.seed
      self._xla_compile = getattr(graph, "_xla_compile", False)
      self._device_function_stack = graph._device_function_stack.copy()  # pylint: disable=protected-access
      self._colocation_stack = graph._colocation_stack.copy()  # pylint: disable=protected-access

    # TODO(b/112165328, b/112906995): summaries depend on inheriting collections
    # from the default graph even in eager mode. It'd be nice to not have a
    # default graph with eager execution, so hopefully this will go away when we
    # remove collections.
    # pylint: disable=protected-access
    self._collections = graph._collections
    # TODO(b/112906995): distribution strategy depends on inheriting this stack
    # from the default graph even in eager mode. Maybe it should be part of the
    # eager context?
    self._distribution_strategy_stack = graph._distribution_strategy_stack
    self._variable_creator_stack = graph._variable_creator_stack
    # Inherit the graph key, since this is used for matching variables in
    # optimizers.
    self._graph_key = graph._graph_key
    # pylint: enable=protected-access

  @property
  def variables(self):
    """A list of variables accessed by this FuncGraph.

    Note that functions keep only weak references to variables. Calling the
    function after a variable it accesses has been deleted is an error.

    Yields:
      Strong references to variables accessed by this FuncGraph.
    """
    for weak_v in self._weak_variables:
      v = weak_v()
      if v is None:
        raise AssertionError(
            "Called a function referencing variables which have been deleted. "
            "This likely means that function-local variables were created and "
            "not referenced elsewhere in the program. This is generally a "
            "mistake; consider storing variables in an object attribute on "
            "first call.")
      yield v

  @variables.setter
  def variables(self, var_list):
    self._weak_variables = [weakref.ref(v) for v in var_list]

  def create_op(
      self,
      op_type,
      inputs,
      dtypes,
      input_types=None,
      name=None,
      attrs=None,
      op_def=None,
      compute_shapes=True,
      compute_device=True):
    """Like Graph.create_op, except handles external input tensors.

    This overload adds functionality to create_op to "capture" any external
    input tensors, i.e. tensors from the eager context or outer function graphs
    if this is a nested function. See `capture` for more information.

    Args:
      op_type: The `Operation` type to create. This corresponds to the
        `OpDef.name` field for the proto that defines the operation.
      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
      dtypes: A list of `DType` objects that will be the types of the tensors
        that the operation produces.
      input_types: (Optional.) A list of `DType`s that will be the types of
        the tensors that the operation consumes. By default, uses the base
        `DType` of each input in `inputs`. Operations that expect
        reference-typed inputs must specify `input_types` explicitly.
      name: (Optional.) A string name for the operation. If not specified, a
        name is generated based on `op_type`.
      attrs: (Optional.) A dictionary where the key is the attribute name (a
        string) and the value is the respective `attr` attribute of the
        `NodeDef` proto that will represent the operation (an `AttrValue`
        proto).
      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
        the operation will have.
      compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
        computed).
      compute_device: (Optional.) If True, device functions will be executed
        to compute the device property of the Operation.

    Returns:
      An `Operation` object.
    """
    # This capturing logic interacts poorly with control flow contexts which
    # want to replace inputs of ops far too late in the process. This can lead
    # the context to get confused and try to create an Enter for an Enter. We
    # can detect this here and skip the additional Enter which can confuse loop
    # validation logic.
    if op_type == "Enter" and inputs[0].op.type == "Enter":
      if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
        return inputs[0].op
    # Calling AddValue on the control flow contexts to force creation of the
    # backward accumulators in the original graph before we create placeholders
    # to capture the inputs.
    ctxt = ops.get_default_graph()._control_flow_context  # pylint: disable=protected-access
    for i, inp in enumerate(inputs):
      # TPU Estimator defines a control flow context with no AddValue method.
      if ctxt is not None and hasattr(ctxt, "AddValue"):
        inp = ctxt.AddValue(inp)
      inp = self.capture(inp)
      inputs[i] = inp
    return super(FuncGraph, self).create_op(
        op_type, inputs, dtypes, input_types, name, attrs, op_def,
        compute_device=compute_device)

  def capture(self, tensor, name=None):
    """Captures `tensor` if it's external to this graph.

    If `tensor` is from a different graph, returns a placeholder for it.
    `tensor` and the placeholder will appear in self.captures, and the
    placeholder will appear in self.inputs.  Multiple calls to this method with
    the same `tensor` argument will return the same placeholder. If `tensor` is
    from this graph, returns `tensor`.

    Args:
      tensor: Tensor. May be from this FuncGraph or a different graph.
      name: Optional name if a placeholder is created.

    Returns:
      Tensor from this FuncGraph.
    """
    if isinstance(tensor, ops.EagerTensor):
      if name is None:
        name = str(ops.uid())
      return self._capture_helper(tensor, name)
    if tensor.graph is not self:
      if name is None:
        name = tensor.op.name
      return self._capture_helper(tensor, name)
    return tensor

  def _capture_helper(self, tensor, name):
    captured_tensor = self.captures.get(tensor, None)
    if captured_tensor is None:
      captured_tensor = _create_substitute_placeholder(tensor, name=name,
                                                       dtype=tensor.dtype)
      self.captures[tensor] = captured_tensor
      self.inputs.append(captured_tensor)
    tape.record_operation("captured_value", [captured_tensor], [tensor],
                          lambda x: [x])
    return captured_tensor

  @property
  def external_captures(self):
    """External tensors captured by this function."""
    return list(self.captures.keys())

  @property
  def internal_captures(self):
    """Placeholders in this function corresponding captured tensors."""
    return list(self.captures.values())


def _forward_name(n):
  """The name of a generated forward defun named n."""
  return "__forward_%s_%s" % (n, ops.uid())


def _backward_name(n):
  """The name of a generated backward defun named n."""
  return "__backward_%s_%s" % (n, ops.uid())


def _inference_name(n):
  """The name of a forward-but-no-gradient defun named n."""
  return "__inference_%s_%s" % (n, ops.uid())


def _register(fn):
  """Registers the function `fn`."""
  context.context().add_function(fn)


# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
# so it doesn't have the definition-generating logic and is just a container for
# an already-defined function.
class _EagerDefinedFunction(object):
  """Callable with the interface of `framework.function._DefinedFunction.`

  `_EagerDefinedFunction` encapsulates a function definition and its properties,
  and it provides a method for calling the encapsulated function. Some Ops
  take functions as attributes, which have type `func`; an instance of this
  class may be provided as the value of these `func` attributes.
  """

  def __init__(self, name, graph, inputs, outputs, attrs):
    """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
      attrs: dict mapping names of attributes to their AttrValue values
    """
    operations = [
        op for op in graph.get_operations()
        if op not in set(arg.op for arg in inputs)
    ]
    fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
        graph._c_graph,  # pylint: disable=protected-access
        compat.as_str(name),
        False,
        [o._c_op for o in operations],  # pylint: disable=protected-access
        [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
        [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
        [],
        None,
        compat.as_str(""))

    for name, attr_value in attrs.items():
      serialized = attr_value.SerializeToString()
      # TODO(iga): this creates and deletes a new TF_Status for every attr.
      # It might be worth creating a convenient way to re-use status.
      pywrap_tensorflow.TF_FunctionSetAttrValueProto(
          fn, compat.as_str(name), serialized)

    # TODO(apassos) avoid creating a FunctionDef (specially to grab the
    # signature, but also in general it's nice not to depend on it.
    with c_api_util.tf_buffer() as buffer_:
      pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
      proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
    function_def = function_pb2.FunctionDef()
    function_def.ParseFromString(compat.as_bytes(proto_data))
    if context.executing_eagerly():
      _register(fn)
    self.definition = function_def
    self.name = compat.as_bytes(function_def.signature.name)
    self.signature = function_def.signature
    self._num_outputs = len(self.signature.output_arg)
    self._output_types = [o.type for o in self.signature.output_arg]
    self._output_shapes = [o.shape for o in outputs]
    self._func_graph_outputs = outputs
    self.grad_func_name = None
    self.python_grad_func = None
    self._c_func = c_api_util.ScopedTFFunction(fn)
    self._grad_func = None
    self._graph = graph
    self._stateful_ops = tuple(op for op in operations if op.op_def.is_stateful)

  def add_to_graph(self, g):
    # pylint: disable=protected-access
    if self.name not in g._functions:
      g._add_function(self)
    for f in self._graph._functions.values():
      if f.name not in g._functions:
        g._add_function(f)
    # pylint: enable=protected-access

  @property
  def stateful_ops(self):
    return self._stateful_ops

  def call(self, ctx, args):
    """Calls this function with `args` as inputs.

    Function execution respects device annotations only if the function won't
    be compiled with xla.

    Args:
      ctx: a Context object
      args: a list of arguments to supply this function with.

    Returns:
      The outputs of the function call.

    Raises:
      ValueError: if the number of arguments is incorrect.
    """

    executing_eagerly = ctx.executing_eagerly()

    if self._graph._xla_compile:  # pylint: disable=protected-access
      # XLA compilation relies upon a custom kernel creator to run functions.
      signature = self.signature
      if executing_eagerly:
        outputs = execute.execute(
            str(signature.name),
            num_outputs=self._num_outputs,
            inputs=args,
            attrs=None,
            ctx=ctx)
      else:
        g = ops.get_default_graph()
        self.add_to_graph(g)
        op = g.create_op(
            signature.name,
            [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args],
            tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
            op_def=signature,
            name="FunctionCall",
            compute_shapes=False)
        outputs = op.outputs
        if not outputs:
          return op
        outputs = [outputs] if isinstance(
            outputs, (ops.Tensor, type(None))) else list(outputs)
    else:
      # TODO(akshayka): Either remove this if the FunctionLibraryRuntime
      # creates `PartitionedCallOp` kernels by default, or remove the previous
      # branch if a TPU kernel is registered for `PartitionedCall`.
      if len(args) != len(self.signature.input_arg):
        raise ValueError(
            "Arguments and signature arguments do not match: %s %s " %
            (len(args), len(list(self.signature.input_arg))))
      outputs = functional_ops.partitioned_call(
          args=args,
          f=self,
          tout=self._output_types,
          executing_eagerly=executing_eagerly)

    if executing_eagerly:
      return outputs
    else:
      for i, shape in enumerate(self._output_shapes):
        outputs[i].set_shape(shape)
      for i, func_graph_output in enumerate(self._func_graph_outputs):
        custom_gradient.copy_handle_data(func_graph_output, outputs[i])
      return outputs


def _flatten(sequence):
  """A wrapper around `nest.flatten` that also unpacks `IndexedSlices`."""
  # TODO(akshayka): Support `SparseTensor` in a similar fashion.
  flat_sequence = nest.flatten(sequence)
  outputs = []
  for item in flat_sequence:
    if isinstance(item, ops.IndexedSlices):
      if item.dense_shape is not None:
        outputs.extend([item.values, item.indices, item.dense_shape])
      else:
        outputs.extend([item.values, item.indices])
    else:
      outputs.append(item)
  return outputs


class Function(object):
  """Callable object encapsulating a function definition and its gradient.

  `Function` is a callable that encapsulates a function definition and
  is differentiable under `tf.GradientTape` objects.
  """

  def __init__(self, func_graph, attrs=None):
    """Initialize a Function.

    Args:
      func_graph: An instance of FuncGraph: the function body to wrap.
      attrs: (optional) dict mapping names of attributes to their AttrValue
        values. Attributes in `attrs` will be included in this function's
        definition.

    Raises:
      ValueError: If number of input_placeholders is not equal to the number
        of function inputs.
    """
    self._func_graph = func_graph
    self._captured_inputs = list(self._func_graph.captures.keys())
    self._num_outputs = len(self._func_graph.outputs)
    self._output_shapes = tuple(
        output.shape for output in self._func_graph.outputs)
    self._attrs = _parse_func_attrs(attrs or {})
    self._device_functions = tuple(
        self._func_graph._device_functions_outer_to_inner)  # pylint: disable=protected-access

    self._inference_function = _EagerDefinedFunction(
        _inference_name(self._func_graph.name), self._func_graph,
        self._func_graph.inputs, self._func_graph.outputs, self._attrs)
    self._backward_graph_function = None

  def __call__(self, *args):
    """Executes the wrapped function.

    Args:
      *args: a list of Tensors or Variables.

    Returns:
      The result of applying the TF function to `args`.

    Raises:
      ValueError: If the current device stack does not match the device stack
        under which the function was created, or if `args` contains anything
        other than Tensors or Variables.
    """
    ctx = context.context()
    device_functions = _get_device_functions(ctx, ops.get_default_graph())
    if device_functions != self._device_functions:
      raise ValueError(
          "The current device stack does not match the device stack under "
          "which the TensorFlow function '%s' was created.\n"
          "Current device stack: %s\n%s device stack: %s" %
          (self._inference_function.name, device_functions,
           self._inference_function.name, self._device_functions))

    for v in self._func_graph.variables:
      if v.trainable:
        tape.variable_accessed(v)

    tensor_inputs = []
    for i, arg in enumerate(nest.flatten(args)):
      if isinstance(arg, resource_variable_ops.ResourceVariable):
        if arg.trainable:
          tape.variable_accessed(arg)
        tensor_inputs.append(arg.handle)
      elif isinstance(arg, ops.Tensor):
        tensor_inputs.append(arg)
      else:
        raise ValueError("All inputs to `Function`s must be Tensors; "
                         "on invocation of %s, the %d-th input (%s) was not a "
                         "Tensor." % (self._func_graph.name, i, str(arg)))
    args = tensor_inputs + self._captured_inputs

    if (tape.should_record(tensor_inputs) or
        tape.should_record(self._captured_inputs)):
      return self._backprop_call(args)

    # Only need to override the gradient in graph mode and when we have outputs.
    if context.executing_eagerly() or not self.outputs:
      outputs = self._inference_function.call(ctx, args)
    else:
      name = "PartitionedCall-%s" % ops.uid()

      @ops.RegisterGradient(name)
      def grad_fn(op, *doutputs):  # pylint: disable=unused-variable
        """Gradients of this function."""
        if op.graph is not ops.get_default_graph():
          # TODO(apassos) this will still emit SymbolicGradient ops when
          # nested defuns are being differentiated. We need to somehow figure
          # out a way to update the FunctionDef corresponding to the calling
          # function when mutating a call to the forward pass.
          return gradients_impl._SymGrad(op, list(doutputs))  # pylint: disable=protected-access
        if self._backward_graph_function is None:
          self._construct_backprop_function()
        self._forward_function.add_to_graph(op.graph)
        func = attr_value_pb2.AttrValue(
            func=attr_value_pb2.NameAttrList(
                name=self._forward_function.name))
        # pylint: disable=protected-access
        op._set_attr("f", func)
        types = attr_value_pb2.AttrValue.ListValue(
            type=self._forward_function._output_types)
        op._set_attr("Tout", attr_value_pb2.AttrValue(list=types))
        for i in range(
            len(outputs), len(self._forward_function._output_types)):
          t = ops.Tensor(op, i, self._forward_function._output_types[i])
          t.set_shape(self._forward_function._output_shapes[i])
          func_graph_output = self._forward_function._func_graph_outputs[i]
          custom_gradient.copy_handle_data(func_graph_output, t)
          op._outputs.append(t)
        # pylint: enable=protected-access
        side_outputs = op.outputs[len(outputs):]
        return self._backward_graph_function(
            *(list(doutputs) + list(side_outputs)))

      with ops.get_default_graph().gradient_override_map(
          {"PartitionedCall": name}):
        outputs = self._inference_function.call(ctx, args)

    return self._build_call_outputs(outputs)

  @property
  def name(self):
    """Function name."""
    return self._inference_function.name

  @property
  def graph(self):
    """Returns the graph from which this function was constructed."""
    return self._func_graph

  @property
  def inputs(self):
    """Returns tensors in `self.graph` corresponding to arguments."""
    return self._func_graph.inputs

  @property
  def outputs(self):
    """Returns tensors in `self.graph` corresponding to return values."""
    return self._func_graph.outputs

  @property
  def captured_inputs(self):
    """Returns external Tensors captured by this function.

    self.__call__(*args) passes `args + self.captured_inputs` to the function.
    """
    return self._captured_inputs

  @property
  def function_def(self):
    """Returns a `FunctionDef` object representing this function."""
    return self._inference_function.definition

  @property
  def output_shapes(self):
    """The function's output shapes."""
    # TODO(ebrevdo): Should we only keep the output shapes associated
    # with len(self._python_returns) outputs?
    # TODO(akshayka): Consider removing this.
    outputs_list = nest.flatten(self._func_graph.structured_outputs)
    j = 0
    for i, o in enumerate(outputs_list):
      if o is not None:
        if isinstance(o, ops.IndexedSlices):
          # Extract the shape of the `IndexedSlices` object's `values` field.
          outputs_list[i] = self._output_shapes[j]  # the `values` shape
          if o.dense_shape is not None:
            j += 3  # skip over shapes for `values`, `indices`, `dense_shape`
          else:
            j += 2  # skip over shapes for `values`, `indices`
        else:
          outputs_list[i] = self._output_shapes[j]
          j += 1
    return nest.pack_sequence_as(self._func_graph.structured_outputs,
                                 outputs_list)

  @property
  def output_dtypes(self):
    # TODO(akshayka): Consider removing this.
    return nest.map_structure(lambda x: x.dtype if x is not None else None,
                              self._func_graph.structured_outputs)

  def add_to_graph(self, g):
    """Adds this function into the graph g."""
    return self._inference_function.add_to_graph(g)

  def _construct_backprop_function(self):
    """Constructs the backprop function object for this function."""
    backwards_graph = FuncGraph(_backward_name(self._func_graph.name))
    forward_function_name = _forward_name(self._func_graph.name)
    with backwards_graph.as_default():
      gradients_wrt_outputs = [
          graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs
      ]
      gradients_wrt_inputs = gradients_impl._GradientsHelper(  # pylint: disable=protected-access
          self._func_graph.outputs,
          self._func_graph.inputs,
          grad_ys=gradients_wrt_outputs,
          src_graph=self._func_graph)

    backwards_graph_captures = list(backwards_graph.captures.keys())

    backward_function_attr = _parse_func_attrs(
        {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
    backward_function_attr.update(self._attrs)

    # The ordering of `backwards_graph.inputs` is important: inputs of
    # `self._backward_graph_function` correspond to outputs of
    # `self._forward_function`.
    backwards_graph.inputs = gradients_wrt_outputs + list(
        backwards_graph.captures.values())
    # Clear captures, since we pass them in as inputs.
    backwards_graph.captures = {}
    backwards_graph.outputs.extend(
        grad for grad in _flatten(gradients_wrt_inputs) if grad is not None)
    backwards_graph.structured_outputs = gradients_wrt_inputs
    self._backward_graph_function = Function(
        backwards_graph, attrs=backward_function_attr)

    forward_function_attr = _parse_func_attrs({
        BACKWARD_FUNCTION_ATTRIBUTE_NAME:
            self._backward_graph_function._inference_function.name})  # pylint: disable=protected-access
    forward_function_attr.update(self._attrs)
    self._forward_function = _EagerDefinedFunction(
        forward_function_name, self._func_graph, self._func_graph.inputs,
        self._func_graph.outputs + backwards_graph_captures,
        forward_function_attr)

  def _backprop_call(self, args):
    """Calls the forward function and records the result on a tape.

    (Only records results on a tape if the function has outputs)

    Args:
      args: All inputs to the function, including resolved captured inputs

    Returns:
      The call output.
    """
    if self._backward_graph_function is None:
      self._construct_backprop_function()

    ctx = context.context()
    outputs = self._forward_function.call(ctx, args)
    if isinstance(outputs, ops.Operation) or outputs is None:
      return outputs

    # `real_outputs` are the actual outputs of the inference graph function;
    # `side_outputs` are the intermediate Tensors that were added as outputs to
    # the forward graph function so that we can compute its gradient.
    real_outputs = outputs[:self._num_outputs]
    side_outputs = outputs[self._num_outputs:]

    def backward_function(*args):
      return self._backward_graph_function(*(list(args) + side_outputs))  # pylint: disable=not-callable

    tape.record_operation(self._forward_function.signature.name, real_outputs,
                          args, backward_function)
    return self._build_call_outputs(real_outputs)

  def _build_call_outputs(self, result):
    """Maps the fdef output list to actual output structure.

    Args:
      result: Output lists defined by FunctionDef.
    Returns:
      The actual call output.
    """
    if self._func_graph.structured_outputs is None:
      return result

    # Use `nest.flatten` instead of `_flatten` in order to preserve any
    # IndexedSlices in `self._func_graph.structured_outputs`.
    outputs_list = nest.flatten(self._func_graph.structured_outputs)
    j = 0
    for i, o in enumerate(outputs_list):
      if o is not None:
        if isinstance(o, ops.IndexedSlices):
          # Repack Tensors for IndexedSlices.
          if o.dense_shape is not None:
            outputs_list[i] = ops.IndexedSlices(
                values=result[j],
                indices=result[j + 1],
                dense_shape=result[j + 2])
            j += 3
          else:
            outputs_list[i] = ops.IndexedSlices(
                values=result[j], indices=result[j + 1])
            j += 2
        else:
          outputs_list[i] = result[j]
          j += 1
    ret = nest.pack_sequence_as(self._func_graph.structured_outputs,
                                outputs_list)
    return ret


def _get_defun_inputs_from_args(args):
  """Maps python function args to graph-construction inputs."""
  function_inputs = [
      graph_placeholder(arg.dtype, arg.shape)
      if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec))
      else arg for arg in nest.flatten(args)
  ]
  return nest.pack_sequence_as(args, function_inputs)


def func_graph_from_py_func(name,
                            python_func,
                            args,
                            kwargs,
                            signature=None,
                            func_graph=None,
                            experimental_autograph=False):
  """Returns a `FuncGraph` generated from `python_func`.

  Args:
    name: an identifier for the function.
    python_func: the Python function to trace.
    args: the positional args with which the Python function should be called;
      ignored if a signature is provided.
    kwargs: the keyword args with which the Python function should be called;
      ignored if a signature is provided.
    signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
      and dtypes of the arguments. When a signature is provided, `args` and
      `kwargs` are ignored, and `python_func` is traced with Tensors conforming
      to `signature`. If `None`, the shapes and dtypes are inferred from the
      inputs.
    func_graph: Optional. An instance of FuncGraph. If provided, we will use
      this graph else a new one is built and returned.
    experimental_autograph: whether to use autograph to compile `python_func`.
      See https://www.tensorflow.org/guide/autograph for more information.

  Returns:
    A FuncGraph.

  Raises:
    TypeError: If any of `python_func`'s return values is neither `None` nor a
      `Tensor`.
  """
  if func_graph is None:
    func_graph = FuncGraph(name)
  assert isinstance(func_graph, FuncGraph)
  with func_graph.as_default(), AutomaticControlDependencies() as a:
    variable_scope.get_variable_scope().set_use_resource(True)

    if signature is not None:
      args = signature
      kwargs = {}

    func_args = _get_defun_inputs_from_args(args)
    func_kwargs = _get_defun_inputs_from_args(kwargs)

    # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
    # Variables to help check whether mutation happens in calling the function
    # Copy the recursive list, tuple and map structure, but not base objects
    func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args))
    func_kwargs_before = nest.pack_sequence_as(
        func_kwargs, nest.flatten(func_kwargs))

    def convert(x):
      """Converts an argument to a Tensor."""
      if x is None:
        return None
      try:
        x = ops.convert_to_tensor_or_indexed_slices(x)
      except (ValueError, TypeError):
        raise TypeError(
            "To be compatible with tf.contrib.eager.defun, Python functions "
            "must return zero or more Tensors; in compilation of %s, found "
            "return value of type %s, which is not a Tensor." %
            (str(python_func), type(x)))
      x = a.mark_as_return(x)
      return x

    this_tape = tape.push_new_tape()
    try:
      if experimental_autograph:
        func_outputs = autograph.converted_call(
            python_func,
            autograph.ConversionOptions(
                verbose=True,
                recursive=True,
                force_conversion=False,
                strip_decorators=(defun,),
                arg_types={}), *func_args, **func_kwargs)
      else:
        func_outputs = python_func(*func_args, **func_kwargs)
      # invariant: `func_outputs` contains only Tensors and `None`s.
      func_outputs = nest.map_structure(convert, func_outputs)

      def check_mutation(n1, n2):
        """Check if two list of arguments are exactly the same."""
        errmsg = ("Function to be traced should not modify structure of input "
                  "arguments. Check if your function has list and dictionary "
                  "operations that alter input arguments, "
                  "such as `list.pop`, `list.append`")
        try:
          nest.assert_same_structure(n1, n2)
        except ValueError:
          raise ValueError(errmsg)

        for arg1, arg2 in zip(nest.flatten(n1), nest.flatten(n2)):
          if arg1 is not arg2:
            raise ValueError(errmsg)

      check_mutation(func_args_before, func_args)
      check_mutation(func_kwargs_before, func_kwargs)
    finally:
      tape.pop_tape(this_tape)

    # Variables in `func_args`, `func_kwargs` should be explicit inputs
    # to the function, not captured inputs.
    tape_variables = this_tape.watched_variables()
    arg_variables = set()
    inputs = []
    for arg in nest.flatten(func_args) + nest.flatten(func_kwargs):
      if isinstance(arg, resource_variable_ops.ResourceVariable):
        try:
          resource_placeholder = func_graph.captures.pop(arg.handle)
          arg_variables.add(arg)
        except KeyError:
          # This case occurs if a Variable among the inputs is not actually
          # used by the function; we still add an explicit input for it
          # because the user should presumably pass the Variable as an input
          # to the corresponding graph function.
          resource_placeholder = _create_substitute_placeholder(arg.handle)
        inputs.append(resource_placeholder)
      elif isinstance(arg, ops.Tensor):
        inputs.append(arg)
    variables = [v for v in tape_variables if v not in arg_variables]
    func_graph.inputs = inputs + list(func_graph.captures.values())

    func_graph.structured_outputs = func_outputs
    # Returning a closed-over tensor does not trigger convert_to_tensor.
    func_graph.outputs.extend(
        func_graph.capture(x)
        for x in _flatten(func_graph.structured_outputs)
        if x is not None)

    func_graph.variables = variables

  # Register any other functions defined in the graph.
  if context.executing_eagerly():
    for f in func_graph._functions.values():  # pylint: disable=protected-access
      # TODO(ashankar): What about the gradient registry?
      _register(f._c_func.func)  # pylint: disable=protected-access

  return func_graph


pywrap_tensorflow.RegisterType("Tensor", ops.Tensor)
pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices)


def _deterministic_dict_values(dictionary):
  return tuple(dictionary[key] for key in sorted(dictionary))


class PolymorphicFunction(object):
  """Wrapper class for the graph functions defined for a Python function.

  See the documentation for `defun` for more information on the semantics of
  defined functions.

  PolymorphicFunction class is thread-compatible meaning that minimal
  usage of defuns (defining and calling) is thread-safe, but if users call other
  methods or invoke the base `python_function` themselves, external
  synchronization is necessary.
  """

  def __init__(self,
               python_function,
               name,
               input_signature=None,
               attributes=None,
               experimental_autograph=False):
    """Initializes a polymorphic function.

    Args:
      python_function: the function to be wrapped.
      name: the name given to it.
      input_signature: a possibly nested sequence of `TensorSpec` objects
        specifying the input signature of this function. If `None`, a separate
        function is instantiated for each inferred input signature.
      attributes: dict, extra keyword arguments that will be added as attribute
        of the function.
      experimental_autograph: whether to use autograph to compile
        `python_function`. See https://www.tensorflow.org/guide/autograph for
        more information.

    Raises:
      ValueError: if `input_signature` is not None and the `python_function`'s
        argspec has keyword arguments.
    """

    if isinstance(python_function, functools.partial):
      self._python_function = python_function.func
      self._args_to_prepend = python_function.args or tuple()
      self._kwargs_to_include = python_function.keywords or {}
    else:
      self._python_function = python_function
      self._args_to_prepend = tuple()
      self._kwargs_to_include = {}
    self._name = name
    self._experimental_autograph = experimental_autograph
    self._function_cache = collections.OrderedDict()
    self._function_attributes = attributes or {}

    self._lock = threading.Lock()

    fullargspec = tf_inspect.getfullargspec(self._python_function)
    if tf_inspect.ismethod(self._python_function):
      # Remove `self`: default arguments shouldn't be matched to it.
      args = fullargspec.args[1:]
    else:
      args = fullargspec.args

    # A cache mapping from argument name to index, for canonicalizing
    # arguments that are called in a keyword-like fashion.
    self._args_to_indices = {arg: i for i, arg in enumerate(args)}
    # A cache mapping from arg index to default value, for canonicalization.
    offset = len(args) - len(fullargspec.defaults or [])
    self._arg_indices_to_default_values = {
        offset + index: default
        for index, default in enumerate(fullargspec.defaults or [])
    }
    self._default_values = fullargspec.defaults
    self._default_values_start_index = offset
    if input_signature is None:
      self._input_signature = None
    else:
      if fullargspec.varkw is not None or fullargspec.kwonlyargs:
        raise ValueError("Cannot define a TensorFlow function from a Python "
                         "function with keyword arguments when "
                         "input_signature is provided.")

      if not isinstance(input_signature, (tuple, list)):
        raise TypeError("input_signature must be either a tuple or a "
                        "list, received " + str(type(input_signature)))

      self._input_signature = tuple(input_signature)
      self._flat_input_signature = tuple(nest.flatten(input_signature))

  def __call__(self, *args, **kwargs):
    """Calls a graph function specialized to the inputs."""
    graph_function, inputs = self._maybe_define_function(args, kwargs)
    return graph_function(*inputs)

  @property
  def python_function(self):
    """Returns the wrapped Python function."""
    return self._python_function

  def get_concrete_function(self, *args, **kwargs):
    """Returns a `Function` object specialized to inputs and execution context.

    `args` and `kwargs` are ignored if this `PolymorphicFunction` was created
    with an `input_signature`.

    Args:
      *args: inputs to specialize on.
      **kwargs: inputs to specialize on.
    """
    if self._input_signature:
      args, kwargs = None, None
    graph_function, _ = self._maybe_define_function(args, kwargs)
    return graph_function

  def __get__(self, instance, owner):
    """Makes it possible to defun instance methods."""
    del owner
    # `instance` here is the instance that this `PolymorphicFunction` was
    # accessed through; e.g., for
    #
    #   class Foo(object):
    #
    #     @function.defun
    #     def bar(self):
    #       ...
    #
    #   foo = Foo()
    #   foo.bar()  # `foo.bar` is a `PolymorphicFunction` instance
    #
    # then `instance` will be `foo` (and `owner` will be `Foo`).
    return functools.partial(self.__call__, instance)

  def _cache_key(self, args, kwargs):
    """Computes the cache key given inputs and execution context."""
    if self._input_signature is None:
      inputs = (args, kwargs) if kwargs else args
      cache_key = pywrap_tensorflow.TFE_Py_EncodeArg(inputs)
    else:
      del args, kwargs
      cache_key = self._flat_input_signature

    ctx = context.context()
    with ops.init_scope():
      # The graph, or whether we're executing eagerly, should be a part of the
      # cache key so we don't improperly capture tensors such as variables.
      executing_eagerly = ctx.executing_eagerly()
      execution_context = executing_eagerly or ops.get_default_graph()

    if executing_eagerly:
      device_functions = (pydev.merge_device(ctx.device_name),)
      colocation_stack = ()
    else:
      default_graph = ops.get_default_graph()
      # Putting the device in the cache key ensures that call-site device
      # annotations are respected.
      device_functions = tuple(default_graph._device_functions_outer_to_inner)  # pylint: disable=protected-access
      colocation_stack = tuple(default_graph._colocation_stack.peek_objs())  # pylint: disable=protected-access

    return (cache_key, execution_context, device_functions, colocation_stack)

  def _canonicalize_function_inputs(self, *args, **kwargs):
    """Canonicalizes `args` and `kwargs`.

    Canonicalize the inputs to the Python function using its fullargspec. In
    particular, we parse the varags and kwargs that this
    `PolymorphicFunction` was called with into a tuple corresponding to the
    Python function's positional (named) arguments and a dictionary
    corresponding to its kwargs.

    Args:
      *args: The varargs this object was called with.
      **kwargs: The keyword args this function was called with.

    Returns:
      A canonicalized ordering of the inputs.

    Raises:
      ValueError: If a keyword in `kwargs` cannot be matched with a positional
        argument when an input signature is specified, or when the inputs
        do not conform to the input signature.
    """
    args = self._args_to_prepend + args
    kwargs = dict(kwargs, **self._kwargs_to_include)
    if not kwargs:
      if self._default_values:
        inputs = args + self._default_values[len(args) -
                                             self._default_values_start_index:]
      else:
        inputs = args
    else:
      # Maps from index of arg to its corresponding value, according to `args`
      # and `kwargs`; seeded with the default values for the named args that
      # aren't in `args`.
      arg_indices_to_values = {
          index: default for index, default in six.iteritems(
              self._arg_indices_to_default_values) if index >= len(args)
      }
      consumed_args = []
      for arg, value in six.iteritems(kwargs):
        index = self._args_to_indices.get(arg, None)
        if index is not None:
          arg_indices_to_values[index] = value
          consumed_args.append(arg)
        elif self._input_signature is not None:
          raise ValueError("Cannot define a TensorFlow function from a Python "
                           "function with keyword arguments when "
                           "input_signature is provided.")
      for arg in consumed_args:
        # After this loop, `kwargs` will only contain true keyword arguments, as
        # opposed to named arguments called in a keyword-like fashion.
        kwargs.pop(arg)
      inputs = args + _deterministic_dict_values(arg_indices_to_values)
    flat_inputs = nest.flatten(inputs)

    # Check for NumPy arrays in arguments and convert them to Tensors.
    # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
    # finding a way to store them directly in the cache key (currently not
    # possible since ndarrays are not hashable).
    need_packing = False
    for index, value in enumerate(flat_inputs):
      if type(value) == np.ndarray:
        flat_inputs[index] = constant_op.constant(value)
        need_packing = True
    if need_packing:
      inputs = nest.pack_sequence_as(structure=inputs,
                                     flat_sequence=flat_inputs)
    if self._input_signature is None:
      return inputs, kwargs
    else:
      assert not kwargs
      try:
        nest.assert_same_structure(self._input_signature, inputs)
      except (ValueError, TypeError):
        raise ValueError("Structure of Python function inputs does not match "
                         "input_signature.")
      if any(not isinstance(arg, ops.Tensor) for arg in flat_inputs):
        raise ValueError("When input_signature is provided, all inputs to "
                         "the Python function must be Tensors.")
      tensor_specs = [
          tensor_spec.TensorSpec.from_tensor(tensor) for tensor in flat_inputs
      ]
      if any(not spec.is_compatible_with(other)
             for spec, other in zip(self._flat_input_signature, tensor_specs)):
        raise ValueError("Python inputs incompatible with input_signature: "
                         "inputs (%s), input_signature (%s)" %
                         (str(inputs), str(self._input_signature)))
      return inputs, {}

  def _maybe_define_function(self, args, kwargs):
    """Gets a function for these inputs, defining it if necessary.

    `args` and `kwargs` can be None if this `PolymorphicFunction` was created
    with an `input_signature`.

    Args:
      args: The varargs for the Python function.
      kwargs: The keyword args for the Python function.

    Returns:
      A graph function corresponding to the input signature implied by args and
      kwargs, as well as the inputs that the object should be called with.

    Raises:
      ValueError: If inputs are incompatible with the input signature.
      TypeError: If the function inputs include non-hashable objects
    """
    if self._input_signature is None or args is not None or kwargs is not None:
      args, kwargs = self._canonicalize_function_inputs(*args, **kwargs)
    cache_key = self._cache_key(args, kwargs)
    with self._lock:
      try:
        graph_function = self._function_cache.get(cache_key, None)
      except TypeError:
        raise TypeError("Arguments supplied to `defun`-generated functions "
                        "must be hashable.")

      if graph_function is None:
        graph_function = Function(
            func_graph_from_py_func(
                self._name,
                self._python_function,
                args,
                kwargs,
                self._input_signature,
                experimental_autograph=self._experimental_autograph),
            self._function_attributes)
        self._function_cache[cache_key] = graph_function
      return graph_function, [
          t for t in nest.flatten((args, kwargs))
          if isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable))
      ]


def register(func, *args, **kwargs):
  """Register the defun function into the graph.

  This won't actually call the function with the inputs, and only put the
  function definition into graph. Register function with different input param
  will result into multiple version of functions registered in graph.

  Also, `args` and `kwargs` are ignored if this `PolymorphicFunction` was
  created with an `input_signature`.

  Args:
    func: the PolymorphicFunction instance that generated by a @defun
    *args: input arguments for the Python function.
    **kwargs: input keyword arguments for the Python function.

  Returns:
    a `Function` object specialized to inputs and execution context.

  Raises:
    ValueError: When the input function is not a defun wrapped python function.
  """
  if not isinstance(func, PolymorphicFunction):
    raise ValueError("Only defun function is allowed to be registered. "
                     "Got type: %s" % type(func))
  concrete_func = func.get_concrete_function(*args, **kwargs)
  graph = ops.get_default_graph()

  # There are two situations for the actual call of a defun:
  # 1. If none of the input args are resource variables or watch by any tape,
  #   it will run the _inference_function of concrete_func for forward pass, and
  #   the gradient will be generated by standard mechanism.
  # 2. Otherwise, defun will create two functions, one for forward pass, and the
  #   backward pass will be created via tape.
  # When registering the function, we put both cases into graph.
  # pylint: disable=protected-access
  concrete_func._inference_function.add_to_graph(graph)

  if concrete_func._backward_graph_function is None:
    concrete_func._construct_backprop_function()
  forward_function = concrete_func._forward_function
  backward_function = concrete_func._backward_graph_function._inference_function
  forward_function.add_to_graph(graph)
  backward_function.add_to_graph(graph)
  # pylint: enable=protected-access

  return concrete_func


def _validate_signature(signature):
  if any(not isinstance(arg, tensor_spec.TensorSpec)
         for arg in nest.flatten(signature)):
    raise TypeError("Invalid input_signature %s; input_signature must be "
                    "a possibly nested sequence of TensorSpec objects.")


def defun(func=None, input_signature=None, experimental_autograph=False):
  """Compiles a Python function into a callable TensorFlow graph.

  `defun` (short for "define function") trace-compiles a Python function
  composed of TensorFlow operations into a callable that executes a `tf.Graph`
  containing those operations. The callable produced by `defun` contains only
  the subgraph of TensorFlow operations that were executed when the Python
  function was called with a particular input signature, defined as a list
  of the shapes and dtypes of the Python function's Tensor-valued arguments and
  the values of its non-Tensor Python objects. In particular, `defun` is _not_ a
  compiler for arbitrary Python code.

  When eager execution is enabled, the ability to create graphs from Python
  functions makes it possible to incrementally trade off debugability and
  interactivity for performance.  Functions compiled with `defun` cannot be
  inspected with `pdb` and `print` statements; however, executing a graph
  generated by `defun` sometimes takes less time and memory than eagerly
  executing the corresponding Python function, since specifying computations as
  graphs allows for optimizations like automatic buffer reuse and
  parallelization among ops. Note that executing a `defun`-compiled function
  incurs a small constant overhead, so eagerly executing sufficiently small
  Python functions might take less time than executing their corresponding
  `defun`-generated graphs.

  For a Python function to be compatible with `defun`, all of its arguments must
  be hashable Python objects or lists thereof. The function itself may not
  modify the list/map structure of its arguments. Additionally, it must return
  zero or more `tf.Tensor` objects. If the Python function returns
  a `tf.Variable`, its compiled version will return the value of that variable
  as a `tf.Tensor`.

  Executing a graph generated by `defun` respects device annotations (i.e.,
  all `with tf.device` directives present in a Python function will also be
  present in its corresponding graph), but it is not yet possible to execute the
  generated graphs across multiple machines.

  _Example Usage_

  ```python
  import tensorflow as tf

  tf.enable_eager_execution()

  # A simple example.
  def f(x, y):
    return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)

  g = tf.contrib.eager.defun(f)

  x = tf.constant([[2.0, 3.0]])
  y = tf.constant([[3.0, -2.0]])

  # `f` and `g` will return the same value, but `g` will be executed as a
  # TensorFlow graph.
  assert f(x, y).numpy() == g(x, y).numpy()

  # `defun` is capable of compiling Python functions that close over Python
  # objects, including Tensors and Variables.
  @tf.contrib.eager.defun
  def h():
    return f(x, y)

  assert (h().numpy() == f(x, y).numpy()).all()

  # `defun` automatically lifts variables out of the graphs it creates,
  # allowing you to compile the `call` methods of `tf.keras.layers.Layer` and
  # `tf.keras.Model` objects.
  class MyModel(tf.keras.Model):

    def __init__(self, keep_probability=0.2):
      super(MyModel, self).__init__()
      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
      self.keep_probability = keep_probability

    @tf.contrib.eager.defun
    def call(self, inputs, training=True):
      x = self.dense2(self.dense1(inputs))
      if training:
        return tf.nn.dropout(x, self.keep_probability)
      else:
        return x

  model = MyModel()
  model(x, training=True)  # executes a graph, with dropout
  model(x, training=False) # executes a graph, without dropout

  # `defun`-compiled functions are differentiable.
  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
  with tf.GradientTape() as tape:
    outputs = model(x)
  gradient = tape.gradient(outputs, model.trainable_variables)
  optimizer.apply_gradients((grad, var) for grad, var in zip(gradient,
                            model.trainable_variables))
  ```

  When using `defun`, there are subtleties regarding inputs, Python control
  flow, and variable creation that one should be aware of. For concreteness, let
  `f` be a Python function that returns zero or more `tf.Tensor` objects and
  let `F = defun(f)`. `F` builds a graph for each unique input signature it
  sees, Python control flow is baked into graphs, and operations related to
  variable initialization are automatically lifted out of the graphs that `F`
  generates and placed in the eager context if executing eagerly or into an
  outer graph otherwise.

  _Input Signatures_
  By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph
  for every unique sequence of the shapes and dtypes of Tensor arguments and
  the values of Python objects it is invoked with. For example, calling
  `F(tf.random_uniform([2])` will execute a different graph than
  `F(tf.random_uniform([3])` because the two inputs have different shapes.
  The first time that `F(*args, **kwargs)` is called with a particular sequence
  of Tensor shapes and dtypes and Python values, it constructs a graph by
  tracing the execution of `f(*args, **kwargs)`; this graph is bound to an
  input signature inferred from `(*args, **kwargs)` and cached for future reuse.

  NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects
  before being passed to `f`, and are treated as Tensors for caching. This
  allows a function to be called multiple times with NumPy arrays having
  different values but the same shape and dtype without re-tracing each time.

  `tf.contrib.eager.defun` caches graphs for your convenience, letting you
  define TensorFlow functions without explicitly specifying their signatures.
  However, this policy is conservative and potentially expensive; for example,
  when different invocations of your function have differently-shaped Tensor
  inputs, this policy might generate more graph functions than necessary. To
  eliminate such costs, `tf.contrib.eager.defun` allows you to supply an
  optional `input_signature` argument specifying the shapes and dtypes of the
  inputs. In particular, the shapes may be partially unspecified, with `None`s
  in the unknown dimensions.  When an input signature is provided,
  `tf.contrib.eager.defun` will only instantiate a single graph for the
  decorated Python function. The following is an example:

  ```python
  import tensorflow as tf

  # The first `TensorSpec` below describes the shape and dtype of `words`,
  # and the second describes the shape and dtype of `another_tensor`. Note that
  # the last dimension of the `words` `TensorSpec` is left unspecified.
  @tf.contrib.eager.defun(input_signature=[
    tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32),
    tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32)
  ])
  def my_sequence_model(words, another_tensor):
    ...

  # Note how the third dimension of the first input can vary freely.
  words = tf.random_uniform(([50, 300, 10])
  second_input = tf.random_uniform([300, 100])
  my_sequence_model(words, second_input)

  words = tf.random_uniform(([50, 300, 20])
  my_sequence_model(words, second_input)

  # Passing an input with an incompatible shape will raise an error.
  words = tf.random_uniform(([50, 100, 20])
  my_sequence_model(words, second_input)  # <---- This will raise an error.

  ```

  Python functions that are compiled with an `input_signature` must only accept
  Tensors as arguments and must not take unnamed keyword arguments (**kwargs).

  _Tracing_
  Be aware that because `F` only logs TensorFlow operations, all the other
  Python code that `f` executes will only shape the _construction_ of the graphs
  that `F` executes: the Python code won't be executed when the graphs
  themselves are executed, though it will be executed every time the Python
  function is traced (and a given Python function might be traced multiple
  times, once for each input signature it is invoked with). For example, whereas
  the Python function

  ```python
  import tensorflow as tf
  import numpy as np

  tf.enable_eager_execution()

  def add_noise():
    return tf.eye(5) + np.random.randn(5, 5)
  ```

  will return a different output everytime it is invoked, the compiled function
  `compiled = tf.contrib.eager.defun(add_noise)` will return the same value
  every time it is called, since a particular random offset generated by NumPy
  will be inserted into the graph as a TensorFlow constant. The solution is to
  replace the call to `np.random.randn` with `tf.random_normal((5, 5))`.

  _Python Side-Effects_
  A corollary of the previous discussion on tracing is the following: If a
  Python function `f` has Python side-effects, then executing `f` multiple times
  will not necessarily be semantically equivalent to executing `F =
  tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact
  that `defun` only captures the subgraph of TensorFlow operations that is
  constructed when `f` is called in a graph-building context.

  _Python Control Flow_.
  The structure of many machine learning computations depend upon whether one is
  training or validating, and it is common to nest specialized logic under `if
  training:` blocks. By mapping each input signature to a unique graph, `defun`
  lets users transparently compile such code, as the following code snippet
  demonstrates:

  ```python
  import tensorflow as tf

  tf.enable_eager_execution()

  @tf.contrib.eager.defun
  def lossy_matmul(W, x, training=True):
    outputs = tf.matmul(W, x)
    if training:
      outputs = tf.nn.dropout(outputs, keep_probability=0.2)
    return outputs

  W = tf.random_normal((3, 5))
  x = tf.random_normal((5, 1))

  # Executes a graph that applies dropout.
  lossy_outputs = lossy_matmul(W, x, training=True)

  # Executes a graph that does not apply dropout.
  exact_outputs = lossy_matmul(W, x, training=False)
  ```

  On the other hand, because `defun` generates graphs by tracing and not by
  source code analysis, it fully unrolls Python `for` and `while` loops,
  potentially creating large graphs. If your Python function has native loops
  that run for many iterations, consider replacing them with `tf.while_loop`
  operations.

  When constructing graphs, `tf.Tensor` objects cannot be used as Python
  `bool` objects. This means, for example, that you should replace code in `f`
  resembling

  ```python

  if tensor < 10:
    true_fn()
  else:
    false_fn()
  ```

  with `tf.cond(tensor < 10, true_fn, false_fn)`.

  _Variables_
  TensorFlow operations related to variable creation and initialization are
  automatically lifted out of the graphs generated by `defun`. In practice, this
  implies that variable creation and initialization only happen the first time
  `F` is called, and that variables are reused every time thereafter. Many
  TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the
  first time they are called and reuse them thereafter. Automatic variable
  lifting makes it possible to compile these APIs without extra effort, at the
  cost of introducing a discrepancy between the semantics of executing Python
  functions and their corresponding compiled functions. For example:

  ```python
  import tensorflow as tf

  tf.enable_eager_execution()

  def fn():
    x = tf.Variable(0.0)
    x.assign_add(1.0)
    return x.read_value()

  # `fn` is a Python function, so x is created, initialized, and destroyed upon
  # every invocation
  assert fn().numpy() == fn().numpy() == 1.0

  compiled = tf.contrib.eager.defun(fn)

  # Compiling `fn` with `defun` hoists all variables outside of the generated
  # graph, so initialization happens exactly once.
  assert compiled().numpy() == 1.0
  assert compiled().numpy() == 2.0
  ```

  Finally, because each input signature is bound to a unique graph, if your
  Python function constructs `tf.Variable` objects, then each graph constructed
  for that Python function will reference a unique set of variables. To
  circumvent this problem, we recommend against compiling Python functions that
  create `tf.Variable` objects. Instead, Python functions should either
  lexically close over `tf.Variable` objects or accept them as arguments,
  preferably encapsulated in an object-oriented container. If you must create
  variables inside your Python function and you want each graph generated for it
  to reference the same set of variables, add logic to your Python function that
  ensures that variables are only created the first time it is called and are
  reused for every subsequent invocation; note that this is precisely what
  `tf.keras.layers.Layer` objects do, so we recommend using them to represent
  variable-bearing computations whenever possible.

  Args:
    func: function to be compiled. If `func` is None, returns a
      decorator that can be invoked with a single argument - `func`. The
      end result is equivalent to providing all the arguments up front.
      In other words, defun(input_signature=...)(func) is equivalent to
      defun(func, input_signature=...). The former allows
      the following use case:
        @tf.contrib.eager.defun(input_signature=...)
        def foo(...):
          ...

    input_signature: A possibly nested sequence of
      `tf.contrib.eager.TensorSpec` objects specifying the shapes and dtypes of
      the Tensors that will be supplied to this function. If `None`, a separate
      function is instantiated for each inferred input signature.  If a
      signature is specified, every input to `func` must be a `Tensor`, and
      `func` cannot accept `**kwargs`.
    experimental_autograph: Whether `func` should be compiled before
      constructing the graph. See https://www.tensorflow.org/guide/autograph
      for more information.


  Returns:
     If `func` is not None, returns a callable that will execute the compiled
     function (and return zero or more `tf.Tensor` objects).
     If `func` is None, returns a decorator that, when invoked with a single
     `func` argument, returns a callable equivalent to the case above.

  Raises:
    TypeError: If `input_signature` is neither `None` nor a sequence of
      `tf.contrib.eager.TensorSpec` objects.
  """
  return defun_with_attributes(
      func=func,
      input_signature=input_signature,
      experimental_autograph=experimental_autograph)


def defun_with_attributes(func=None,
                          input_signature=None,
                          attributes=None,
                          experimental_autograph=False):
  """Compiles a Python function into a callable TensorFlow graph.

  This function supports adding extra function attributes. See detailed
  documentation in defun(). Currently this is not exposed in public API since we
  don't expect user to directly use attributes, and attribute won't work by
  itself. This assumption might change in future.

  Args:
    func: function to be compiled.
    input_signature: same as defun()'s input_signature.
    attributes: A dictionary of arguments which will be added to function def as
      attributes. Currently only support primitive types as value, and only
      whitelisted attribute name is allowed. Unwhitelisted attribute name or
      unsupported value will result into ValueError.
    experimental_autograph: same as defun()'s experimental_autograph.

  Returns:
    Same as the return value of defun, with attributes added to the function in
    graph.
  """
  if input_signature is not None:
    _validate_signature(input_signature)

  # TODO(apassos): deal with captured global state. Deal with control flow.
  def decorated(function):
    try:
      name = function.__name__
    except AttributeError:
      name = "function"
    return tf_decorator.make_decorator(
        function,
        PolymorphicFunction(
            function,
            name,
            input_signature=input_signature,
            attributes=attributes,
            experimental_autograph=experimental_autograph))

  # This code path is for the `foo = tfe.defun(foo, ...)` use case
  if func is not None:
    return decorated(func)

  # This code path is for the
  #
  # @tfe.defun(...)
  # def foo(...):
  #    ...
  #
  # use case, which is equivalent to `foo = tfe.defun(...)(foo)`
  return decorated


class AutomaticControlDependencies(object):
  """Context manager to automatically add control dependencies.

  Code under this context manager will act as if a sensible set of control
  dependencies were present. More specifically:
    1. All stateful ops in the scope will execute
    2. Stateful ops which modify the same resource will execute in program order

  Note: creating variables in an automatic control dependencies context is not
  supported (the value of the variables will never change as they will keep
  getting reinitialized).

  NOT THREAD SAFE
  """

  def __init__(self):
    self._returned_tensors = set()

  def mark_as_return(self, tensor):
    """Acts like identity but marks the `Tensor` as a return value.

    This will possibly return a copy of the `Tensor`. Usage:

    ```
      with AutomaticControlDependencies() as a:
       ...
       t = a.mark_as_return(t)
      _ = ...(t...)  # i.e. it's safe to use t here
    ```

    Args:
      tensor: the `Tensor` to be marked

    Returns:
      a copy of the `Tensor`.
    """
    if isinstance(tensor, ops.IndexedSlices):
      values = array_ops.identity(tensor.values)
      indices = array_ops.identity(tensor.indices)
      self._returned_tensors.add(indices)
      self._returned_tensors.add(values)
      return ops.IndexedSlices(values, indices, dense_shape=tensor.dense_shape)
    # We want to make the return values depend on the stateful operations, but
    # we don't want to introduce a cycle, so we make the return value the result
    # of a new identity operation that the stateful operations definitely don't
    # depend on.
    tensor = array_ops.identity(tensor)
    self._returned_tensors.add(tensor)
    return tensor

  def __enter__(self):
    if context.executing_eagerly():
      return self
    # This code assumes no other thread is adding ops to the graph while
    # we're adding ops to the graph.
    # TODO(apassos): Fix this by locking the graph or using a temporary
    # graph (but that would mess up devices and collections at least,
    # probably other things as well).
    self._graph = ops.get_default_graph()
    self._n_operations = len(self._graph.get_operations())
    return self

  def _process_switch(self, switch_op, ops_which_must_run,
                      last_op_using_resource_tensor, merge_for_resource):
    """Processes a switch node for a resource input.

    When tensorflow creates a cond, it creates a control flow context for each
    branch of the cond. Each external tensor accessed by that branch is routed
    through a switch op, which gets created in the graph _after_ the op which
    uses that tensor get created.

    If the resource comes from another switch op we process that one first.

    _process_switch creates a corresponding merge node for the switch node. This
    merge node is added to the outer control flow context of the switch
    node. We also ensure that:

      1. The switch node executes after the previous op which used the resource
         tensor

      2. Any op which uses a resource output of the switch node executes before
         the merge for the switch node.

      3. The next op which uses the input resource to the switch node (which
         might be another switch node for the other branch of the conditional)
         will execute after the merge node is done.

      4. The merge node is marked as must_run so it will run even if no
         subsequent operation uses the resource.

    Args:
      switch_op: the switch op to be processed
      ops_which_must_run: the set of ops which must run
      last_op_using_resource_tensor: map from resource tensor to last op using
        it
      merge_for_resource: map from resource tensor to merge which must follow
        all usages of it.
    """
    inp = switch_op.inputs[0]
    if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
      self._process_switch(inp.op, ops_which_must_run,
                           last_op_using_resource_tensor, merge_for_resource)
    if switch_op.outputs[0] in merge_for_resource:
      return
    new_merge = control_flow_ops.merge(switch_op.outputs,
                                       name="artificial_merge")
    new_merge[0].op._control_flow_context = (  # pylint: disable=protected-access
        switch_op._control_flow_context.outer_context)  # pylint: disable=protected-access
    # Ensures the merge always runs
    ops_which_must_run.add(new_merge[0].op)
    if inp in last_op_using_resource_tensor:
      # Ensures the switch executes after the previous op using the resource.
      switch_op._add_control_input(last_op_using_resource_tensor[inp])  # pylint: disable=protected-access
    # Ensure the next op outside the cond happens after the merge.
    last_op_using_resource_tensor[inp] = new_merge[0].op
    if inp in merge_for_resource:
      merge_for_resource[inp]._add_control_input(new_merge[0].op)  # pylint: disable=protected-access
    for o in switch_op.outputs:
      # Ensures the merge will execute after all ops inside the cond
      merge_for_resource[o] = new_merge[0].op

  def __exit__(self, unused_type, unused_value, unused_traceback):
    if context.executing_eagerly():
      return

    if self._graph is not ops.get_default_graph():
      raise RuntimeError(
          "Graph changed while trying to add control dependencies.")

    # map from resource tensor to the last op which used it
    last_op_using_resource_tensor = {}
    # set of conditional and loop exits
    ops_which_must_run = set()
    # merge which must depend on ops which use this resource
    merge_for_resource = {}

    new_operations = self._graph.get_operations()[self._n_operations:]

    # Ensures that uses of resource tensors get serialized properly and all
    # execute. This is done by keeping a map from resource tensor to the last op
    # in graph-construction order which used it (last_op_using_resource_tensor).
    #
    # Conditionals are written in TensorFlow such that every external tensor
    # accessed in the conditional goes through a switch op and every return
    # tensor (it's guaranteed that there will be at least one) goes through a
    # merge op.
    #
    # To handle conditionals, switches are handled in a special way (see
    # comments for _process_switch). Merge nodes created by TF's conditional
    # logic (as opposed to by _process_switch) are forced to run and also get a
    # control dependency added to them to ensure all stateful ops inside their
    # control flow context run.
    #
    # We also ensure that if an op is using a resource output by a switch node
    # (that is, a resource tensor for which there's a value in
    # merge_for_resource) this op will run before the merge for that resource.
    #
    # We try to add control inputs to nodes respecting their control flow
    # contexts to avoid dead nodes propagating everywhere and leading to
    # "retval[0] doesn't have value" errors. If a node gets a control dependency
    # on a dead node (i.e. a note from an untaken control flow branch) that node
    # will be marked as dead unless it's a merge node.
    #
    # TODO(apassos): serialize non-resource-taking stateful ops as well, and
    # test that it works. Support while loops. Support init_scope escaping from
    # this.
    for op in new_operations:
      # TODO(apassos) make this code safely support while loops.
      if isinstance(op._control_flow_context, control_flow_ops.WhileContext):  # pylint: disable=protected-access
        continue
      control_inputs = set()
      # Ensure stateful ops run
      if (op.type not in self._graph._registered_ops  # pylint: disable=protected-access
          or self._graph._registered_ops[op.type].is_stateful):  # pylint: disable=protected-access
        ops_which_must_run.add(op)
      # Ignore switches (they're handled separately)
      if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
        continue
      # Make merges trigger all other computation which must run
      if op.type == "Merge":
        for o in ops_which_must_run:
          op._add_control_input(o)  # pylint: disable=protected-access
          for inp in o.inputs:
            if inp in last_op_using_resource_tensor:
              last_op_using_resource_tensor[inp] = op
        ops_which_must_run = set([op])
        continue
      found_resource = False
      for inp in op.inputs:
        if inp.dtype == dtypes_module.resource:
          found_resource = True
          # Deal with switches, finally.
          if inp.op.type == "Switch":
            self._process_switch(inp.op, ops_which_must_run,
                                 last_op_using_resource_tensor,
                                 merge_for_resource)
          # Ensure uses of resources are serialized
          if inp in last_op_using_resource_tensor:
            if (last_op_using_resource_tensor[inp]._control_flow_context  # pylint: disable=protected-access
                is op._control_flow_context):  # pylint: disable=protected-access
              control_inputs.add(last_op_using_resource_tensor[inp])
          # Ensure merges happen after the closing of a cond block
          if inp in merge_for_resource:
            merge_for_resource[inp]._add_control_input(op)  # pylint: disable=protected-access
          last_op_using_resource_tensor[inp] = op
      if (op.op_def.is_stateful and not found_resource
          and op._control_flow_context is None):  # pylint: disable=protected-access
        if None in last_op_using_resource_tensor:
          op._add_control_input(last_op_using_resource_tensor[None])  # pylint: disable=protected-access
        last_op_using_resource_tensor[None] = op
      control_inputs = [c for c in control_inputs
                        if c._control_flow_context is op._control_flow_context]  # pylint: disable=protected-access
      op._add_control_inputs(control_inputs)  # pylint: disable=protected-access

    # Ensure all ops which must run do run
    for r in self._returned_tensors:
      if ops_which_must_run:
        r.op._add_control_inputs(  # pylint: disable=protected-access
            [o for o in ops_which_must_run
             if o._control_flow_context is r.op._control_flow_context])  # pylint: disable=protected-access


def automatic_control_dependencies(f):
  """Wraps f to automatically insert control dependencies.

  The inserted dependencies ensure that:
    1. All stateful ops in f run when the result of f runs
    2. Updates to the same resources happen in order.

  Args:
    f: the function to be wrapped.

  Returns:
    The wrapped function.
  """

  def wrapper(*args, **kwargs):
    with AutomaticControlDependencies() as a:
      result = f(*args, **kwargs)
      result_flat = [a.mark_as_return(t) for t in nest.flatten(result)]
      return nest.pack_sequence_as(result, result_flat)

  return tf_decorator.make_decorator(f, wrapper)