aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-12 06:22:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-12 06:24:54 -0700
commitcf542ae4174d954ad21ab255bc0fdb81326e4443 (patch)
tree4c4ddeb8ad1a096eaed3f2f9cb3e20df47fb6ed7
parente688642372893d9e51be4119342f787560d8e644 (diff)
Special-case the name scoping for operator methods. TensorFlow disallows top-level name scopes to begin with underscores. Also use the transformer scope information to get to the enclosing function name.
PiperOrigin-RevId: 192600256
-rw-r--r--tensorflow/contrib/autograph/converters/name_scopes.py38
-rw-r--r--tensorflow/contrib/autograph/converters/name_scopes_test.py55
2 files changed, 65 insertions, 28 deletions
diff --git a/tensorflow/contrib/autograph/converters/name_scopes.py b/tensorflow/contrib/autograph/converters/name_scopes.py
index 2a3f474360..280bc4c314 100644
--- a/tensorflow/contrib/autograph/converters/name_scopes.py
+++ b/tensorflow/contrib/autograph/converters/name_scopes.py
@@ -28,22 +28,34 @@ from tensorflow.contrib.autograph.pyct import transformer
class FunctionNameScopeTransformer(transformer.Base):
"""Wrap a function body with a `name_scope` of the function name."""
- def __init__(self, context):
- super(FunctionNameScopeTransformer, self).__init__(context)
- self._function_level = 0
+ def _name_for_current_scope(self):
+ innermost = self.enclosing_entities[-1]
+ if len(self.enclosing_entities) > 1:
+ parent = self.enclosing_entities[-2]
+ if isinstance(parent, gast.ClassDef):
+ # Methods also take the name of their class.
+ name = '%s/%s' % (parent.name, innermost.name)
+ else:
+ name = innermost.name
+ else:
+ name = innermost.name
+
+ # Sanitize the name.
+ # See https://www.tensorflow.org/api_docs/python/tf/Graph#name_scope
+ # TensorFlow doesn't like leading underscores at the top level.
+ while name[0] == '_':
+ name = name[1:]
+ return name
def visit_FunctionDef(self, node):
- self._function_level += 1
- try:
- self.generic_visit(node)
- finally:
- self._function_level -= 1
- scope_name = node.name
- if self._function_level == 0 and self.context.owner_type is not None:
- scope_name = '{}/{}'.format(self.context.owner_type.__name__, scope_name)
+ self.generic_visit(node)
+ template = """
+ with tf.name_scope(scope_name):
+ body
+ """
node.body = templates.replace(
- 'with tf.name_scope(scope_name): body',
- scope_name=gast.Str(scope_name),
+ template,
+ scope_name=gast.Str(self._name_for_current_scope()),
body=node.body)
return node
diff --git a/tensorflow/contrib/autograph/converters/name_scopes_test.py b/tensorflow/contrib/autograph/converters/name_scopes_test.py
index 61e5db2af8..2c2b6bbbec 100644
--- a/tensorflow/contrib/autograph/converters/name_scopes_test.py
+++ b/tensorflow/contrib/autograph/converters/name_scopes_test.py
@@ -38,29 +38,29 @@ class FunctionNameScopeTransformer(converter_test_base.TestCase):
node = name_scopes.transform(node, self.ctx)
with self.compiled(node, ops.name_scope) as result:
- result_op = result.test_fn(constant_op.constant([1, 2, 3]))
+ result_op = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', result_op.op.name)
def test_nested_name(self):
def test_fn(l):
- def body(i):
- return i**2
+ def inner_fn(i):
+ return i ** 2
- l += [4]
- return body(l)
+ l += 4
+ return inner_fn(l)
node = self.parse_and_analyze(test_fn, {})
node = name_scopes.transform(node, self.ctx)
with self.compiled(node, ops.name_scope) as result:
- result_op = result.test_fn(constant_op.constant([1, 2, 3]))
+ result_op = result.test_fn(constant_op.constant(1))
first_result_input_name = result_op.op.inputs[0].name
second_result_input_name = result_op.op.inputs[1].name
self.assertIn('test_fn/', first_result_input_name)
- self.assertNotIn('body/', first_result_input_name)
- self.assertIn('test_fn/body/', second_result_input_name)
+ self.assertNotIn('inner_fn', first_result_input_name)
+ self.assertIn('test_fn/inner_fn/', second_result_input_name)
def test_class_name(self):
@@ -68,11 +68,11 @@ class FunctionNameScopeTransformer(converter_test_base.TestCase):
def test_fn(self, l):
- def body(i):
- return i**2
+ def inner_fn(i):
+ return i ** 2
- l += [4]
- return body(l)
+ l += 4
+ return inner_fn(l)
# Note that 'TestClass' was needed in the namespace here.
node = self.parse_and_analyze(
@@ -80,12 +80,37 @@ class FunctionNameScopeTransformer(converter_test_base.TestCase):
node = name_scopes.transform(node, self.ctx)
with self.compiled(node, ops.name_scope) as result:
- result_op = result.TestClass().test_fn(constant_op.constant([1, 2, 3]))
+ result_op = result.TestClass().test_fn(constant_op.constant(1))
first_result_input_name = result_op.op.inputs[0].name
second_result_input_name = result_op.op.inputs[1].name
self.assertIn('TestClass/test_fn/', first_result_input_name)
- self.assertNotIn('body/', first_result_input_name)
- self.assertIn('TestClass/test_fn/body/', second_result_input_name)
+ self.assertNotIn('inner_fn', first_result_input_name)
+ self.assertIn('TestClass/test_fn/inner_fn/', second_result_input_name)
+
+ def test_special_name(self):
+
+ class TestClass(object):
+
+ def __call__(self, l):
+
+ def inner_fn(i):
+ return i ** 2
+
+ l += 4
+ return inner_fn(l)
+
+ # Note that 'TestClass' was needed in the namespace here.
+ node = self.parse_and_analyze(
+ TestClass.__call__, {'TestClass': TestClass}, owner_type=TestClass)
+ node = name_scopes.transform(node, self.ctx)
+
+ with self.compiled(node, ops.name_scope) as result:
+ result_op = result.__call__(TestClass(), constant_op.constant(1))
+ first_result_input_name = result_op.op.inputs[0].name
+ second_result_input_name = result_op.op.inputs[1].name
+ self.assertIn('call__/', first_result_input_name)
+ self.assertNotIn('inner_fn', first_result_input_name)
+ self.assertIn('call__/inner_fn/', second_result_input_name)
if __name__ == '__main__':