diff options
Diffstat (limited to 'tensorflow/python/autograph/converters/side_effect_guards.py')
-rw-r--r-- | tensorflow/python/autograph/converters/side_effect_guards.py | 183 |
1 files changed, 183 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/converters/side_effect_guards.py b/tensorflow/python/autograph/converters/side_effect_guards.py new file mode 100644 index 0000000000..6e48e57bde --- /dev/null +++ b/tensorflow/python/autograph/converters/side_effect_guards.py @@ -0,0 +1,183 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Adds guards against function calls with side effects. + +Only standalone calls are guarded. + +WARNING: This mechanism is incomplete. Particularly, it only guards the +arguments passed to functions, and does not account for indirectly modified +state. + +Example: + y = tf.layers.dense(x) # Creates TF variable 'foo' + loss = loss(y) + opt.minimize(loss) # indirectly affects 'foo' + z = tf.get_variable('foo') # Indirectly affects `loss` and 'foo' + # Here, `loss` can be guarded. But `z` cannot. + +# TODO(mdan): We should probably define a safe mode where we guard everything. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.python.autograph.core import converter +from tensorflow.python.autograph.pyct import anno +from tensorflow.python.autograph.pyct import ast_util +from tensorflow.python.autograph.pyct import qual_names +from tensorflow.python.autograph.pyct import templates +from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno + + +class SymbolNamer(object): + """Describes the interface for SideEffectGuardTransformer's namer.""" + + def new_symbol(self, name_root, reserved_locals): + """Generate a new unique function_name. + + Args: + name_root: String, used as stem in the new name. + reserved_locals: Set(string), additional local symbols that are reserved. + Returns: + String. + """ + raise NotImplementedError() + + +class SideEffectGuardTransformer(converter.Base): + """Adds control dependencies to functions with side effects.""" + + def _visit_and_reindent(self, nodes): + new_nodes = [] + current_dest = new_nodes + alias_map = {} + reindent_requested = False + for n in nodes: + n = self.visit(n) + # NOTE: the order in which these statements execute is important; in + # particular, watch out for ending up with cycles in the AST. + if alias_map: + n = ast_util.rename_symbols(n, alias_map) + if isinstance(n, (list, tuple)): + current_dest.extend(n) + else: + current_dest.append(n) + if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): + reindent_requested = True + new_dest, new_alias_map = anno.getanno( + n, anno.Basic.INDENT_BLOCK_REMAINDER) + anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) + new_alias_map.update(alias_map) + alias_map = new_alias_map + current_dest = new_dest + if reindent_requested and not current_dest: + # TODO(mdan): There may still be something that could be done. + raise ValueError('Unable to insert statement into the computation flow: ' + 'it is not followed by any computation which ' + 'the statement could gate.') + return new_nodes + + def visit_FunctionDef(self, node): + node.body = self._visit_and_reindent(node.body) + return node + + def visit_With(self, node): + node.body = self._visit_and_reindent(node.body) + return node + + def visit_If(self, node): + node.body = self._visit_and_reindent(node.body) + node.orelse = self._visit_and_reindent(node.orelse) + return node + + def visit_While(self, node): + node.body = self._visit_and_reindent(node.body) + node.orelse = self._visit_and_reindent(node.orelse) + return node + + def visit_Expr(self, node): + self.generic_visit(node) + if isinstance(node.value, gast.Call): + # Patterns of single function calls, like: + # opt.minimize(loss) + # or: + # tf.py_func(...) + + # First, attempt to gate future evaluation of args. If that's not + # possible, gate all remaining statements (and that may fail too, see + # _visit_and_reindent. + args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE) + # NOTE: We can't guard object attributes because they may not be writable. + # In addition, avoid renaming well-known names. + # TODO(mdan): Move these names into config. + unguarded_names = (qual_names.QN('self'), qual_names.QN('tf')) + guarded_args = tuple(s for s in args_scope.used + if not s.is_composite() and s not in unguarded_names) + + # TODO(mdan): Include all arguments which depended on guarded_args too. + # For example, the following will still cause a race: + # tf.assign(a, a + 1) + # b = a + 1 + # tf.assign(a, a + 1) # Control deps here should include `b` + # c = b + 1 + # Or maybe we should just raise an "unsafe assign" error? + + if guarded_args: + # The aliases may need new names to avoid incorrectly making them local. + # TODO(mdan): This is brutal. It will even rename modules - any fix? + need_alias = tuple( + s for s in guarded_args if s not in args_scope.parent.modified) + aliased_new_names = tuple( + qual_names.QN( + self.ctx.namer.new_symbol( + s.ssf(), args_scope.parent.referenced)) for s in need_alias) + alias_map = dict(zip(need_alias, aliased_new_names)) + if len(guarded_args) == 1: + s, = guarded_args + aliased_guarded_args = alias_map.get(s, s) + else: + aliased_guarded_args = gast.Tuple( + [alias_map.get(s, s).ast() for s in guarded_args], None) + + template = """ + with ag__.utils.control_dependency_on_returns(call): + aliased_guarded_args = ag__.utils.alias_tensors(guarded_args) + """ + control_deps_guard = templates.replace( + template, + call=node.value, + aliased_guarded_args=aliased_guarded_args, + guarded_args=guarded_args)[-1] + else: + alias_map = {} + + template = """ + with ag__.utils.control_dependency_on_returns(call): + pass + """ + control_deps_guard = templates.replace(template, call=node.value)[-1] + control_deps_guard.body = [] + + node = control_deps_guard + anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER, + (node.body, alias_map)) + return node + + +def transform(node, ctx): + return SideEffectGuardTransformer(ctx).visit(node) |