diff options
author | Dan Moldovan <mdan@google.com> | 2018-10-01 10:36:07 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 10:43:59 -0700 |
commit | a6478312ef296ba9684931135851e9c7bb460444 (patch) | |
tree | f6f99111a57121f9baa3369640773af06c93d8d2 /tensorflow/python/autograph | |
parent | a5fc8b064884b926ade9f7973dc096c0677a14e0 (diff) |
Replace the tf.name_scope call with an internal context manager that can contain additional boilerplate later on. Unfortunately it could not be extended to include the error handling.
PiperOrigin-RevId: 215238369
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r-- | tensorflow/python/autograph/converters/BUILD | 6 | ||||
-rw-r--r-- | tensorflow/python/autograph/converters/function_scopes.py (renamed from tensorflow/python/autograph/converters/name_scopes.py) | 32 | ||||
-rw-r--r-- | tensorflow/python/autograph/converters/function_scopes_test.py (renamed from tensorflow/python/autograph/converters/name_scopes_test.py) | 40 | ||||
-rw-r--r-- | tensorflow/python/autograph/core/BUILD | 12 | ||||
-rw-r--r-- | tensorflow/python/autograph/core/converter_testing.py | 2 | ||||
-rw-r--r-- | tensorflow/python/autograph/core/function_wrapping.py | 30 | ||||
-rw-r--r-- | tensorflow/python/autograph/core/function_wrapping_test.py | 34 | ||||
-rw-r--r-- | tensorflow/python/autograph/impl/conversion.py | 6 |
8 files changed, 122 insertions, 40 deletions
diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD index 7b029de8ed..f06dc78f0e 100644 --- a/tensorflow/python/autograph/converters/BUILD +++ b/tensorflow/python/autograph/converters/BUILD @@ -27,10 +27,10 @@ py_library( "decorators.py", "directives.py", "error_handlers.py", + "function_scopes.py", "list_comprehensions.py", "lists.py", "logical_expressions.py", - "name_scopes.py", "return_statements.py", "side_effect_guards.py", "slices.py", @@ -157,8 +157,8 @@ py_test( ) py_test( - name = "name_scopes_test", - srcs = ["name_scopes_test.py"], + name = "function_scopes_test", + srcs = ["function_scopes_test.py"], deps = [ ":converters", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/autograph/converters/name_scopes.py b/tensorflow/python/autograph/converters/function_scopes.py index a9c55ccff0..284b5b3519 100644 --- a/tensorflow/python/autograph/converters/name_scopes.py +++ b/tensorflow/python/autograph/converters/function_scopes.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Wraps a function body with a `name_scope` of the function name.""" +"""Wraps the body of a converted function with auxiliary constructs.""" from __future__ import absolute_import from __future__ import division @@ -24,8 +24,8 @@ from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.pyct import templates -class FunctionNameScopeTransformer(converter.Base): - """Wrap a function body with a `name_scope` of the function name.""" +class FunctionBodyTransformer(converter.Base): + """Wraps function bodies around autograph-specific boilerplate.""" def _name_for_current_scope(self): innermost = self.enclosing_entities[-1] @@ -49,26 +49,28 @@ class FunctionNameScopeTransformer(converter.Base): def visit_FunctionDef(self, node): node = self.generic_visit(node) - unscoped_body = [] - scoped_body = node.body - if scoped_body: - first = scoped_body[0] - if isinstance(first, gast.Expr) and isinstance(first.value, gast.Str): - # Skip any docstring. - unscoped_body = scoped_body[:1] - scoped_body = scoped_body[1:] + final_body = [] + indented_body = node.body + if node.body: + first_statement = node.body[0] + # Skip the docstring, if any. + if (isinstance(first_statement, gast.Expr) and + isinstance(first_statement.value, gast.Str)): + indented_body = indented_body[1:] + final_body.append(first_statement) template = """ - with tf.name_scope(scope_name): + with ag__.function_scope(scope_name): body """ scoped_body = templates.replace( template, scope_name=gast.Str(self._name_for_current_scope()), - body=scoped_body) - node.body = unscoped_body + scoped_body + body=indented_body) + final_body.extend(scoped_body) + node.body = final_body return node def transform(node, ctx): - return FunctionNameScopeTransformer(ctx).visit(node) + return FunctionBodyTransformer(ctx).visit(node) diff --git a/tensorflow/python/autograph/converters/name_scopes_test.py b/tensorflow/python/autograph/converters/function_scopes_test.py index 73933c1c4f..e5ce03a109 100644 --- a/tensorflow/python/autograph/converters/name_scopes_test.py +++ b/tensorflow/python/autograph/converters/function_scopes_test.py @@ -12,51 +12,51 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for for_canonicalization module.""" +"""Tests for function_scopes module.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.autograph.converters import name_scopes +from tensorflow.python.autograph.converters import function_scopes from tensorflow.python.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.platform import test -class FunctionNameScopeTransformer(converter_testing.TestCase): +class FunctionBodyTransformerTest(converter_testing.TestCase): def test_basic(self): def test_fn(l): - """This should stay here.""" + """Docstring.""" a = 1 l += a return l - with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result: + with self.converted(test_fn, function_scopes, {}) as result: result_op = result.test_fn(constant_op.constant(1)) self.assertIn('test_fn/', result_op.op.name) - self.assertEqual('This should stay here.', result.test_fn.__doc__) + self.assertEqual('Docstring.', result.test_fn.__doc__) - def test_long_docstring(self): + def test_multiline_docstring(self): - def test_fn(l): - """Multi-line docstring. + tf = None + + def test_fn(): + """First sentence. - Args: - l: A thing. - Returns: - l + Second sentence. """ - return l + 1 + return tf.constant(1) - with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result: - result_op = result.test_fn(constant_op.constant(1)) + with self.converted(test_fn, function_scopes, {}, + constant_op.constant) as result: + result_op = result.test_fn() self.assertIn('test_fn/', result_op.op.name) - self.assertIn('Multi-line docstring.', result.test_fn.__doc__) - self.assertIn('Returns:', result.test_fn.__doc__) + self.assertIn('First sentence.', result.test_fn.__doc__) + self.assertIn('Second sentence.', result.test_fn.__doc__) def test_nested_functions(self): @@ -68,7 +68,7 @@ class FunctionNameScopeTransformer(converter_testing.TestCase): l += 1 return l, inner_fn(l) - with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result: + with self.converted(test_fn, function_scopes, {}, ops.name_scope) as result: first, second = result.test_fn(constant_op.constant(1)) self.assertIn('test_fn/', first.op.name) self.assertNotIn('inner_fn', first.op.name) @@ -88,7 +88,7 @@ class FunctionNameScopeTransformer(converter_testing.TestCase): ns = {'TestClass': TestClass} node, ctx = self.prepare(TestClass, ns, owner_type=TestClass) - node = name_scopes.transform(node, ctx) + node = function_scopes.transform(node, ctx) with self.compiled(node, {}, ops.name_scope) as result: first, second = result.TestClass().test_fn(constant_op.constant(1)) diff --git a/tensorflow/python/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD index 85fecf084d..843e381f31 100644 --- a/tensorflow/python/autograph/core/BUILD +++ b/tensorflow/python/autograph/core/BUILD @@ -20,11 +20,13 @@ py_library( "config.py", "converter.py", "errors.py", + "function_wrapping.py", "naming.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ + "//tensorflow/python:framework_ops", "//tensorflow/python/autograph/pyct", "//tensorflow/python/autograph/pyct/static_analysis", "//tensorflow/python/autograph/utils", @@ -47,6 +49,16 @@ py_test( ) py_test( + name = "function_wrapping_test", + srcs = ["function_wrapping_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":core", + "//tensorflow/python:client_testlib", + ], +) + +py_test( name = "naming_test", srcs = ["naming_test.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index 7ce1b7c4c5..dc2d419d34 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -29,6 +29,7 @@ from tensorflow.python.autograph import utils from tensorflow.python.autograph.core import config from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import errors +from tensorflow.python.autograph.core import function_wrapping from tensorflow.python.autograph.pyct import compiler from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import pretty_printer @@ -112,6 +113,7 @@ class TestCase(test.TestCase): fake_ag.__dict__['utils'] = utils fake_ag.__dict__['rewrite_graph_construction_error'] = ( errors.rewrite_graph_construction_error) + fake_ag.__dict__['function_scope'] = function_wrapping.function_scope result.__dict__['ag__'] = fake_ag for k, v in namespace.items(): result.__dict__[k] = v diff --git a/tensorflow/python/autograph/core/function_wrapping.py b/tensorflow/python/autograph/core/function_wrapping.py new file mode 100644 index 0000000000..21b66eff02 --- /dev/null +++ b/tensorflow/python/autograph/core/function_wrapping.py @@ -0,0 +1,30 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Support for wrapping converted functions bodies with auxiliary logic.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib + +from tensorflow.python.framework import ops + + +@contextlib.contextmanager +def function_scope(function_name): + """Returns a context manager for the converted body of a function.""" + with ops.name_scope(function_name): + yield diff --git a/tensorflow/python/autograph/core/function_wrapping_test.py b/tensorflow/python/autograph/core/function_wrapping_test.py new file mode 100644 index 0000000000..5e217055c7 --- /dev/null +++ b/tensorflow/python/autograph/core/function_wrapping_test.py @@ -0,0 +1,34 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for function_wrapping module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.autograph.core import function_wrapping +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import test + + +class FunctionWrappingTest(test.TestCase): + + def test_function_scope_name(self): + with function_wrapping.function_scope('test_name'): + t = constant_op.constant(1) + self.assertIn('test_name', t.name) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index a0d13c82a8..52abd40626 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -34,15 +34,16 @@ from tensorflow.python.autograph.converters import control_flow from tensorflow.python.autograph.converters import decorators from tensorflow.python.autograph.converters import directives from tensorflow.python.autograph.converters import error_handlers +from tensorflow.python.autograph.converters import function_scopes from tensorflow.python.autograph.converters import lists from tensorflow.python.autograph.converters import logical_expressions -from tensorflow.python.autograph.converters import name_scopes from tensorflow.python.autograph.converters import return_statements from tensorflow.python.autograph.converters import side_effect_guards from tensorflow.python.autograph.converters import slices from tensorflow.python.autograph.core import config from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import errors +from tensorflow.python.autograph.core import function_wrapping from tensorflow.python.autograph.pyct import ast_util from tensorflow.python.autograph.pyct import inspect_utils from tensorflow.python.autograph.pyct import origin_info @@ -257,6 +258,7 @@ def _add_self_references(namespace, autograph_module): ag_internal.converted_call = autograph_module.converted_call ag_internal.ConversionOptions = autograph_module.ConversionOptions ag_internal.utils = utils + ag_internal.function_scope = function_wrapping.function_scope ag_internal.rewrite_graph_construction_error = ( errors.rewrite_graph_construction_error) # TODO(mdan): Add safeguards against name clashes. @@ -346,7 +348,7 @@ def node_to_graph(node, context, rewrite_errors=True): node = converter.apply_(node, context, conditional_expressions) node = converter.apply_(node, context, logical_expressions) node = converter.apply_(node, context, side_effect_guards) - node = converter.apply_(node, context, name_scopes) + node = converter.apply_(node, context, function_scopes) if rewrite_errors: node = converter.apply_(node, context, error_handlers) return node |