aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-10-01 10:36:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 10:43:59 -0700
commita6478312ef296ba9684931135851e9c7bb460444 (patch)
treef6f99111a57121f9baa3369640773af06c93d8d2 /tensorflow/python/autograph
parenta5fc8b064884b926ade9f7973dc096c0677a14e0 (diff)
Replace the tf.name_scope call with an internal context manager that can contain additional boilerplate later on. Unfortunately it could not be extended to include the error handling.
PiperOrigin-RevId: 215238369
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r--tensorflow/python/autograph/converters/BUILD6
-rw-r--r--tensorflow/python/autograph/converters/function_scopes.py (renamed from tensorflow/python/autograph/converters/name_scopes.py)32
-rw-r--r--tensorflow/python/autograph/converters/function_scopes_test.py (renamed from tensorflow/python/autograph/converters/name_scopes_test.py)40
-rw-r--r--tensorflow/python/autograph/core/BUILD12
-rw-r--r--tensorflow/python/autograph/core/converter_testing.py2
-rw-r--r--tensorflow/python/autograph/core/function_wrapping.py30
-rw-r--r--tensorflow/python/autograph/core/function_wrapping_test.py34
-rw-r--r--tensorflow/python/autograph/impl/conversion.py6
8 files changed, 122 insertions, 40 deletions
diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD
index 7b029de8ed..f06dc78f0e 100644
--- a/tensorflow/python/autograph/converters/BUILD
+++ b/tensorflow/python/autograph/converters/BUILD
@@ -27,10 +27,10 @@ py_library(
"decorators.py",
"directives.py",
"error_handlers.py",
+ "function_scopes.py",
"list_comprehensions.py",
"lists.py",
"logical_expressions.py",
- "name_scopes.py",
"return_statements.py",
"side_effect_guards.py",
"slices.py",
@@ -157,8 +157,8 @@ py_test(
)
py_test(
- name = "name_scopes_test",
- srcs = ["name_scopes_test.py"],
+ name = "function_scopes_test",
+ srcs = ["function_scopes_test.py"],
deps = [
":converters",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/autograph/converters/name_scopes.py b/tensorflow/python/autograph/converters/function_scopes.py
index a9c55ccff0..284b5b3519 100644
--- a/tensorflow/python/autograph/converters/name_scopes.py
+++ b/tensorflow/python/autograph/converters/function_scopes.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Wraps a function body with a `name_scope` of the function name."""
+"""Wraps the body of a converted function with auxiliary constructs."""
from __future__ import absolute_import
from __future__ import division
@@ -24,8 +24,8 @@ from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import templates
-class FunctionNameScopeTransformer(converter.Base):
- """Wrap a function body with a `name_scope` of the function name."""
+class FunctionBodyTransformer(converter.Base):
+ """Wraps function bodies around autograph-specific boilerplate."""
def _name_for_current_scope(self):
innermost = self.enclosing_entities[-1]
@@ -49,26 +49,28 @@ class FunctionNameScopeTransformer(converter.Base):
def visit_FunctionDef(self, node):
node = self.generic_visit(node)
- unscoped_body = []
- scoped_body = node.body
- if scoped_body:
- first = scoped_body[0]
- if isinstance(first, gast.Expr) and isinstance(first.value, gast.Str):
- # Skip any docstring.
- unscoped_body = scoped_body[:1]
- scoped_body = scoped_body[1:]
+ final_body = []
+ indented_body = node.body
+ if node.body:
+ first_statement = node.body[0]
+ # Skip the docstring, if any.
+ if (isinstance(first_statement, gast.Expr) and
+ isinstance(first_statement.value, gast.Str)):
+ indented_body = indented_body[1:]
+ final_body.append(first_statement)
template = """
- with tf.name_scope(scope_name):
+ with ag__.function_scope(scope_name):
body
"""
scoped_body = templates.replace(
template,
scope_name=gast.Str(self._name_for_current_scope()),
- body=scoped_body)
- node.body = unscoped_body + scoped_body
+ body=indented_body)
+ final_body.extend(scoped_body)
+ node.body = final_body
return node
def transform(node, ctx):
- return FunctionNameScopeTransformer(ctx).visit(node)
+ return FunctionBodyTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/name_scopes_test.py b/tensorflow/python/autograph/converters/function_scopes_test.py
index 73933c1c4f..e5ce03a109 100644
--- a/tensorflow/python/autograph/converters/name_scopes_test.py
+++ b/tensorflow/python/autograph/converters/function_scopes_test.py
@@ -12,51 +12,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for for_canonicalization module."""
+"""Tests for function_scopes module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.autograph.converters import name_scopes
+from tensorflow.python.autograph.converters import function_scopes
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
-class FunctionNameScopeTransformer(converter_testing.TestCase):
+class FunctionBodyTransformerTest(converter_testing.TestCase):
def test_basic(self):
def test_fn(l):
- """This should stay here."""
+ """Docstring."""
a = 1
l += a
return l
- with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
+ with self.converted(test_fn, function_scopes, {}) as result:
result_op = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', result_op.op.name)
- self.assertEqual('This should stay here.', result.test_fn.__doc__)
+ self.assertEqual('Docstring.', result.test_fn.__doc__)
- def test_long_docstring(self):
+ def test_multiline_docstring(self):
- def test_fn(l):
- """Multi-line docstring.
+ tf = None
+
+ def test_fn():
+ """First sentence.
- Args:
- l: A thing.
- Returns:
- l
+ Second sentence.
"""
- return l + 1
+ return tf.constant(1)
- with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
- result_op = result.test_fn(constant_op.constant(1))
+ with self.converted(test_fn, function_scopes, {},
+ constant_op.constant) as result:
+ result_op = result.test_fn()
self.assertIn('test_fn/', result_op.op.name)
- self.assertIn('Multi-line docstring.', result.test_fn.__doc__)
- self.assertIn('Returns:', result.test_fn.__doc__)
+ self.assertIn('First sentence.', result.test_fn.__doc__)
+ self.assertIn('Second sentence.', result.test_fn.__doc__)
def test_nested_functions(self):
@@ -68,7 +68,7 @@ class FunctionNameScopeTransformer(converter_testing.TestCase):
l += 1
return l, inner_fn(l)
- with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
+ with self.converted(test_fn, function_scopes, {}, ops.name_scope) as result:
first, second = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', first.op.name)
self.assertNotIn('inner_fn', first.op.name)
@@ -88,7 +88,7 @@ class FunctionNameScopeTransformer(converter_testing.TestCase):
ns = {'TestClass': TestClass}
node, ctx = self.prepare(TestClass, ns, owner_type=TestClass)
- node = name_scopes.transform(node, ctx)
+ node = function_scopes.transform(node, ctx)
with self.compiled(node, {}, ops.name_scope) as result:
first, second = result.TestClass().test_fn(constant_op.constant(1))
diff --git a/tensorflow/python/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD
index 85fecf084d..843e381f31 100644
--- a/tensorflow/python/autograph/core/BUILD
+++ b/tensorflow/python/autograph/core/BUILD
@@ -20,11 +20,13 @@ py_library(
"config.py",
"converter.py",
"errors.py",
+ "function_wrapping.py",
"naming.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
+ "//tensorflow/python:framework_ops",
"//tensorflow/python/autograph/pyct",
"//tensorflow/python/autograph/pyct/static_analysis",
"//tensorflow/python/autograph/utils",
@@ -47,6 +49,16 @@ py_test(
)
py_test(
+ name = "function_wrapping_test",
+ srcs = ["function_wrapping_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
name = "naming_test",
srcs = ["naming_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
index 7ce1b7c4c5..dc2d419d34 100644
--- a/tensorflow/python/autograph/core/converter_testing.py
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -29,6 +29,7 @@ from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.core import function_wrapping
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer
@@ -112,6 +113,7 @@ class TestCase(test.TestCase):
fake_ag.__dict__['utils'] = utils
fake_ag.__dict__['rewrite_graph_construction_error'] = (
errors.rewrite_graph_construction_error)
+ fake_ag.__dict__['function_scope'] = function_wrapping.function_scope
result.__dict__['ag__'] = fake_ag
for k, v in namespace.items():
result.__dict__[k] = v
diff --git a/tensorflow/python/autograph/core/function_wrapping.py b/tensorflow/python/autograph/core/function_wrapping.py
new file mode 100644
index 0000000000..21b66eff02
--- /dev/null
+++ b/tensorflow/python/autograph/core/function_wrapping.py
@@ -0,0 +1,30 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Support for wrapping converted functions bodies with auxiliary logic."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+
+from tensorflow.python.framework import ops
+
+
+@contextlib.contextmanager
+def function_scope(function_name):
+ """Returns a context manager for the converted body of a function."""
+ with ops.name_scope(function_name):
+ yield
diff --git a/tensorflow/python/autograph/core/function_wrapping_test.py b/tensorflow/python/autograph/core/function_wrapping_test.py
new file mode 100644
index 0000000000..5e217055c7
--- /dev/null
+++ b/tensorflow/python/autograph/core/function_wrapping_test.py
@@ -0,0 +1,34 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for function_wrapping module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import function_wrapping
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+
+
+class FunctionWrappingTest(test.TestCase):
+
+ def test_function_scope_name(self):
+ with function_wrapping.function_scope('test_name'):
+ t = constant_op.constant(1)
+ self.assertIn('test_name', t.name)
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
index a0d13c82a8..52abd40626 100644
--- a/tensorflow/python/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -34,15 +34,16 @@ from tensorflow.python.autograph.converters import control_flow
from tensorflow.python.autograph.converters import decorators
from tensorflow.python.autograph.converters import directives
from tensorflow.python.autograph.converters import error_handlers
+from tensorflow.python.autograph.converters import function_scopes
from tensorflow.python.autograph.converters import lists
from tensorflow.python.autograph.converters import logical_expressions
-from tensorflow.python.autograph.converters import name_scopes
from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.converters import side_effect_guards
from tensorflow.python.autograph.converters import slices
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.core import function_wrapping
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import origin_info
@@ -257,6 +258,7 @@ def _add_self_references(namespace, autograph_module):
ag_internal.converted_call = autograph_module.converted_call
ag_internal.ConversionOptions = autograph_module.ConversionOptions
ag_internal.utils = utils
+ ag_internal.function_scope = function_wrapping.function_scope
ag_internal.rewrite_graph_construction_error = (
errors.rewrite_graph_construction_error)
# TODO(mdan): Add safeguards against name clashes.
@@ -346,7 +348,7 @@ def node_to_graph(node, context, rewrite_errors=True):
node = converter.apply_(node, context, conditional_expressions)
node = converter.apply_(node, context, logical_expressions)
node = converter.apply_(node, context, side_effect_guards)
- node = converter.apply_(node, context, name_scopes)
+ node = converter.apply_(node, context, function_scopes)
if rewrite_errors:
node = converter.apply_(node, context, error_handlers)
return node