aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/converters/side_effect_guards.py
blob: 6e48e57bde0fffab96db40efe840bf067bf11300 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adds guards against function calls with side effects.

Only standalone calls are guarded.

WARNING: This mechanism is incomplete. Particularly, it only guards the
arguments passed to functions, and does not account for indirectly modified
state.

Example:
  y = tf.layers.dense(x)       # Creates TF variable 'foo'
  loss = loss(y)
  opt.minimize(loss)           # indirectly affects 'foo'
  z = tf.get_variable('foo')   # Indirectly affects `loss` and 'foo'
  # Here, `loss` can be guarded. But `z` cannot.

# TODO(mdan): We should probably define a safe mode where we guard everything.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gast

from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno


class SymbolNamer(object):
  """Describes the interface for SideEffectGuardTransformer's namer."""

  def new_symbol(self, name_root, reserved_locals):
    """Generate a new unique function_name.

    Args:
      name_root: String, used as stem in the new name.
      reserved_locals: Set(string), additional local symbols that are reserved.
    Returns:
      String.
    """
    raise NotImplementedError()


class SideEffectGuardTransformer(converter.Base):
  """Adds control dependencies to functions with side effects."""

  def _visit_and_reindent(self, nodes):
    new_nodes = []
    current_dest = new_nodes
    alias_map = {}
    reindent_requested = False
    for n in nodes:
      n = self.visit(n)
      # NOTE: the order in which these statements execute is important; in
      # particular, watch out for ending up with cycles in the AST.
      if alias_map:
        n = ast_util.rename_symbols(n, alias_map)
      if isinstance(n, (list, tuple)):
        current_dest.extend(n)
      else:
        current_dest.append(n)
      if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
        reindent_requested = True
        new_dest, new_alias_map = anno.getanno(
            n, anno.Basic.INDENT_BLOCK_REMAINDER)
        anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
        new_alias_map.update(alias_map)
        alias_map = new_alias_map
        current_dest = new_dest
    if reindent_requested and not current_dest:
      # TODO(mdan): There may still be something that could be done.
      raise ValueError('Unable to insert statement into the computation flow: '
                       'it is not followed by any computation which '
                       'the statement could gate.')
    return new_nodes

  def visit_FunctionDef(self, node):
    node.body = self._visit_and_reindent(node.body)
    return node

  def visit_With(self, node):
    node.body = self._visit_and_reindent(node.body)
    return node

  def visit_If(self, node):
    node.body = self._visit_and_reindent(node.body)
    node.orelse = self._visit_and_reindent(node.orelse)
    return node

  def visit_While(self, node):
    node.body = self._visit_and_reindent(node.body)
    node.orelse = self._visit_and_reindent(node.orelse)
    return node

  def visit_Expr(self, node):
    self.generic_visit(node)
    if isinstance(node.value, gast.Call):
      # Patterns of single function calls, like:
      #   opt.minimize(loss)
      # or:
      #   tf.py_func(...)

      # First, attempt to gate future evaluation of args. If that's not
      # possible, gate all remaining statements (and that may fail too, see
      # _visit_and_reindent.
      args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
      # NOTE: We can't guard object attributes because they may not be writable.
      # In addition, avoid renaming well-known names.
      # TODO(mdan): Move these names into config.
      unguarded_names = (qual_names.QN('self'), qual_names.QN('tf'))
      guarded_args = tuple(s for s in args_scope.used
                           if not s.is_composite() and s not in unguarded_names)

      # TODO(mdan): Include all arguments which depended on guarded_args too.
      # For example, the following will still cause a race:
      #   tf.assign(a, a + 1)
      #   b = a + 1
      #   tf.assign(a, a + 1)  # Control deps here should include `b`
      #   c = b + 1
      # Or maybe we should just raise an "unsafe assign" error?

      if guarded_args:
        # The aliases may need new names to avoid incorrectly making them local.
        # TODO(mdan): This is brutal. It will even rename modules - any fix?
        need_alias = tuple(
            s for s in guarded_args if s not in args_scope.parent.modified)
        aliased_new_names = tuple(
            qual_names.QN(
                self.ctx.namer.new_symbol(
                    s.ssf(), args_scope.parent.referenced)) for s in need_alias)
        alias_map = dict(zip(need_alias, aliased_new_names))
        if len(guarded_args) == 1:
          s, = guarded_args
          aliased_guarded_args = alias_map.get(s, s)
        else:
          aliased_guarded_args = gast.Tuple(
              [alias_map.get(s, s).ast() for s in guarded_args], None)

        template = """
          with ag__.utils.control_dependency_on_returns(call):
            aliased_guarded_args = ag__.utils.alias_tensors(guarded_args)
        """
        control_deps_guard = templates.replace(
            template,
            call=node.value,
            aliased_guarded_args=aliased_guarded_args,
            guarded_args=guarded_args)[-1]
      else:
        alias_map = {}

        template = """
          with ag__.utils.control_dependency_on_returns(call):
            pass
        """
        control_deps_guard = templates.replace(template, call=node.value)[-1]
        control_deps_guard.body = []

      node = control_deps_guard
      anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER,
                   (node.body, alias_map))
    return node


def transform(node, ctx):
  return SideEffectGuardTransformer(ctx).visit(node)