aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/converters/side_effect_guards.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/converters/side_effect_guards.py')
-rw-r--r--tensorflow/python/autograph/converters/side_effect_guards.py183
1 files changed, 183 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/converters/side_effect_guards.py b/tensorflow/python/autograph/converters/side_effect_guards.py
new file mode 100644
index 0000000000..6e48e57bde
--- /dev/null
+++ b/tensorflow/python/autograph/converters/side_effect_guards.py
@@ -0,0 +1,183 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Adds guards against function calls with side effects.
+
+Only standalone calls are guarded.
+
+WARNING: This mechanism is incomplete. Particularly, it only guards the
+arguments passed to functions, and does not account for indirectly modified
+state.
+
+Example:
+ y = tf.layers.dense(x) # Creates TF variable 'foo'
+ loss = loss(y)
+ opt.minimize(loss) # indirectly affects 'foo'
+ z = tf.get_variable('foo') # Indirectly affects `loss` and 'foo'
+ # Here, `loss` can be guarded. But `z` cannot.
+
+# TODO(mdan): We should probably define a safe mode where we guard everything.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+
+
+class SymbolNamer(object):
+ """Describes the interface for SideEffectGuardTransformer's namer."""
+
+ def new_symbol(self, name_root, reserved_locals):
+ """Generate a new unique function_name.
+
+ Args:
+ name_root: String, used as stem in the new name.
+ reserved_locals: Set(string), additional local symbols that are reserved.
+ Returns:
+ String.
+ """
+ raise NotImplementedError()
+
+
+class SideEffectGuardTransformer(converter.Base):
+ """Adds control dependencies to functions with side effects."""
+
+ def _visit_and_reindent(self, nodes):
+ new_nodes = []
+ current_dest = new_nodes
+ alias_map = {}
+ reindent_requested = False
+ for n in nodes:
+ n = self.visit(n)
+ # NOTE: the order in which these statements execute is important; in
+ # particular, watch out for ending up with cycles in the AST.
+ if alias_map:
+ n = ast_util.rename_symbols(n, alias_map)
+ if isinstance(n, (list, tuple)):
+ current_dest.extend(n)
+ else:
+ current_dest.append(n)
+ if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
+ reindent_requested = True
+ new_dest, new_alias_map = anno.getanno(
+ n, anno.Basic.INDENT_BLOCK_REMAINDER)
+ anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
+ new_alias_map.update(alias_map)
+ alias_map = new_alias_map
+ current_dest = new_dest
+ if reindent_requested and not current_dest:
+ # TODO(mdan): There may still be something that could be done.
+ raise ValueError('Unable to insert statement into the computation flow: '
+ 'it is not followed by any computation which '
+ 'the statement could gate.')
+ return new_nodes
+
+ def visit_FunctionDef(self, node):
+ node.body = self._visit_and_reindent(node.body)
+ return node
+
+ def visit_With(self, node):
+ node.body = self._visit_and_reindent(node.body)
+ return node
+
+ def visit_If(self, node):
+ node.body = self._visit_and_reindent(node.body)
+ node.orelse = self._visit_and_reindent(node.orelse)
+ return node
+
+ def visit_While(self, node):
+ node.body = self._visit_and_reindent(node.body)
+ node.orelse = self._visit_and_reindent(node.orelse)
+ return node
+
+ def visit_Expr(self, node):
+ self.generic_visit(node)
+ if isinstance(node.value, gast.Call):
+ # Patterns of single function calls, like:
+ # opt.minimize(loss)
+ # or:
+ # tf.py_func(...)
+
+ # First, attempt to gate future evaluation of args. If that's not
+ # possible, gate all remaining statements (and that may fail too, see
+ # _visit_and_reindent.
+ args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
+ # NOTE: We can't guard object attributes because they may not be writable.
+ # In addition, avoid renaming well-known names.
+ # TODO(mdan): Move these names into config.
+ unguarded_names = (qual_names.QN('self'), qual_names.QN('tf'))
+ guarded_args = tuple(s for s in args_scope.used
+ if not s.is_composite() and s not in unguarded_names)
+
+ # TODO(mdan): Include all arguments which depended on guarded_args too.
+ # For example, the following will still cause a race:
+ # tf.assign(a, a + 1)
+ # b = a + 1
+ # tf.assign(a, a + 1) # Control deps here should include `b`
+ # c = b + 1
+ # Or maybe we should just raise an "unsafe assign" error?
+
+ if guarded_args:
+ # The aliases may need new names to avoid incorrectly making them local.
+ # TODO(mdan): This is brutal. It will even rename modules - any fix?
+ need_alias = tuple(
+ s for s in guarded_args if s not in args_scope.parent.modified)
+ aliased_new_names = tuple(
+ qual_names.QN(
+ self.ctx.namer.new_symbol(
+ s.ssf(), args_scope.parent.referenced)) for s in need_alias)
+ alias_map = dict(zip(need_alias, aliased_new_names))
+ if len(guarded_args) == 1:
+ s, = guarded_args
+ aliased_guarded_args = alias_map.get(s, s)
+ else:
+ aliased_guarded_args = gast.Tuple(
+ [alias_map.get(s, s).ast() for s in guarded_args], None)
+
+ template = """
+ with ag__.utils.control_dependency_on_returns(call):
+ aliased_guarded_args = ag__.utils.alias_tensors(guarded_args)
+ """
+ control_deps_guard = templates.replace(
+ template,
+ call=node.value,
+ aliased_guarded_args=aliased_guarded_args,
+ guarded_args=guarded_args)[-1]
+ else:
+ alias_map = {}
+
+ template = """
+ with ag__.utils.control_dependency_on_returns(call):
+ pass
+ """
+ control_deps_guard = templates.replace(template, call=node.value)[-1]
+ control_deps_guard.body = []
+
+ node = control_deps_guard
+ anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER,
+ (node.body, alias_map))
+ return node
+
+
+def transform(node, ctx):
+ return SideEffectGuardTransformer(ctx).visit(node)