diff options
author | 2018-04-12 06:22:30 -0700 | |
---|---|---|
committer | 2018-04-12 06:24:54 -0700 | |
commit | cf542ae4174d954ad21ab255bc0fdb81326e4443 (patch) | |
tree | 4c4ddeb8ad1a096eaed3f2f9cb3e20df47fb6ed7 | |
parent | e688642372893d9e51be4119342f787560d8e644 (diff) |
Special-case the name scoping for operator methods. TensorFlow disallows top-level name scopes to begin with underscores. Also use the transformer scope information to get to the enclosing function name.
PiperOrigin-RevId: 192600256
-rw-r--r-- | tensorflow/contrib/autograph/converters/name_scopes.py | 38 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/converters/name_scopes_test.py | 55 |
2 files changed, 65 insertions, 28 deletions
diff --git a/tensorflow/contrib/autograph/converters/name_scopes.py b/tensorflow/contrib/autograph/converters/name_scopes.py index 2a3f474360..280bc4c314 100644 --- a/tensorflow/contrib/autograph/converters/name_scopes.py +++ b/tensorflow/contrib/autograph/converters/name_scopes.py @@ -28,22 +28,34 @@ from tensorflow.contrib.autograph.pyct import transformer class FunctionNameScopeTransformer(transformer.Base): """Wrap a function body with a `name_scope` of the function name.""" - def __init__(self, context): - super(FunctionNameScopeTransformer, self).__init__(context) - self._function_level = 0 + def _name_for_current_scope(self): + innermost = self.enclosing_entities[-1] + if len(self.enclosing_entities) > 1: + parent = self.enclosing_entities[-2] + if isinstance(parent, gast.ClassDef): + # Methods also take the name of their class. + name = '%s/%s' % (parent.name, innermost.name) + else: + name = innermost.name + else: + name = innermost.name + + # Sanitize the name. + # See https://www.tensorflow.org/api_docs/python/tf/Graph#name_scope + # TensorFlow doesn't like leading underscores at the top level. + while name[0] == '_': + name = name[1:] + return name def visit_FunctionDef(self, node): - self._function_level += 1 - try: - self.generic_visit(node) - finally: - self._function_level -= 1 - scope_name = node.name - if self._function_level == 0 and self.context.owner_type is not None: - scope_name = '{}/{}'.format(self.context.owner_type.__name__, scope_name) + self.generic_visit(node) + template = """ + with tf.name_scope(scope_name): + body + """ node.body = templates.replace( - 'with tf.name_scope(scope_name): body', - scope_name=gast.Str(scope_name), + template, + scope_name=gast.Str(self._name_for_current_scope()), body=node.body) return node diff --git a/tensorflow/contrib/autograph/converters/name_scopes_test.py b/tensorflow/contrib/autograph/converters/name_scopes_test.py index 61e5db2af8..2c2b6bbbec 100644 --- a/tensorflow/contrib/autograph/converters/name_scopes_test.py +++ b/tensorflow/contrib/autograph/converters/name_scopes_test.py @@ -38,29 +38,29 @@ class FunctionNameScopeTransformer(converter_test_base.TestCase): node = name_scopes.transform(node, self.ctx) with self.compiled(node, ops.name_scope) as result: - result_op = result.test_fn(constant_op.constant([1, 2, 3])) + result_op = result.test_fn(constant_op.constant(1)) self.assertIn('test_fn/', result_op.op.name) def test_nested_name(self): def test_fn(l): - def body(i): - return i**2 + def inner_fn(i): + return i ** 2 - l += [4] - return body(l) + l += 4 + return inner_fn(l) node = self.parse_and_analyze(test_fn, {}) node = name_scopes.transform(node, self.ctx) with self.compiled(node, ops.name_scope) as result: - result_op = result.test_fn(constant_op.constant([1, 2, 3])) + result_op = result.test_fn(constant_op.constant(1)) first_result_input_name = result_op.op.inputs[0].name second_result_input_name = result_op.op.inputs[1].name self.assertIn('test_fn/', first_result_input_name) - self.assertNotIn('body/', first_result_input_name) - self.assertIn('test_fn/body/', second_result_input_name) + self.assertNotIn('inner_fn', first_result_input_name) + self.assertIn('test_fn/inner_fn/', second_result_input_name) def test_class_name(self): @@ -68,11 +68,11 @@ class FunctionNameScopeTransformer(converter_test_base.TestCase): def test_fn(self, l): - def body(i): - return i**2 + def inner_fn(i): + return i ** 2 - l += [4] - return body(l) + l += 4 + return inner_fn(l) # Note that 'TestClass' was needed in the namespace here. node = self.parse_and_analyze( @@ -80,12 +80,37 @@ class FunctionNameScopeTransformer(converter_test_base.TestCase): node = name_scopes.transform(node, self.ctx) with self.compiled(node, ops.name_scope) as result: - result_op = result.TestClass().test_fn(constant_op.constant([1, 2, 3])) + result_op = result.TestClass().test_fn(constant_op.constant(1)) first_result_input_name = result_op.op.inputs[0].name second_result_input_name = result_op.op.inputs[1].name self.assertIn('TestClass/test_fn/', first_result_input_name) - self.assertNotIn('body/', first_result_input_name) - self.assertIn('TestClass/test_fn/body/', second_result_input_name) + self.assertNotIn('inner_fn', first_result_input_name) + self.assertIn('TestClass/test_fn/inner_fn/', second_result_input_name) + + def test_special_name(self): + + class TestClass(object): + + def __call__(self, l): + + def inner_fn(i): + return i ** 2 + + l += 4 + return inner_fn(l) + + # Note that 'TestClass' was needed in the namespace here. + node = self.parse_and_analyze( + TestClass.__call__, {'TestClass': TestClass}, owner_type=TestClass) + node = name_scopes.transform(node, self.ctx) + + with self.compiled(node, ops.name_scope) as result: + result_op = result.__call__(TestClass(), constant_op.constant(1)) + first_result_input_name = result_op.op.inputs[0].name + second_result_input_name = result_op.op.inputs[1].name + self.assertIn('call__/', first_result_input_name) + self.assertNotIn('inner_fn', first_result_input_name) + self.assertIn('call__/inner_fn/', second_result_input_name) if __name__ == '__main__': |