aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/converters/name_scopes_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/converters/name_scopes_test.py')
-rw-r--r--tensorflow/contrib/autograph/converters/name_scopes_test.py90
1 files changed, 26 insertions, 64 deletions
diff --git a/tensorflow/contrib/autograph/converters/name_scopes_test.py b/tensorflow/contrib/autograph/converters/name_scopes_test.py
index 444d0bcd46..a329b0db70 100644
--- a/tensorflow/contrib/autograph/converters/name_scopes_test.py
+++ b/tensorflow/contrib/autograph/converters/name_scopes_test.py
@@ -31,17 +31,13 @@ class FunctionNameScopeTransformer(converter_testing.TestCase):
def test_fn(l):
"""This should stay here."""
- a = 5
+ a = 1
l += a
return l
- node = self.parse_and_analyze(test_fn, {})
- node = name_scopes.transform(node, self.ctx)
-
- with self.compiled(node, ops.name_scope) as result:
+ with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
result_op = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', result_op.op.name)
-
self.assertEqual('This should stay here.', result.test_fn.__doc__)
def test_long_docstring(self):
@@ -54,13 +50,12 @@ class FunctionNameScopeTransformer(converter_testing.TestCase):
Returns:
l
"""
- return l
-
- node = self.parse_and_analyze(test_fn, {})
- node = name_scopes.transform(node, self.ctx)
+ return l + 1
- with self.compiled(node, ops.name_scope) as result:
- self.assertIn('Multi-line', result.test_fn.__doc__)
+ with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
+ result_op = result.test_fn(constant_op.constant(1))
+ self.assertIn('test_fn/', result_op.op.name)
+ self.assertIn('Multi-line docstring.', result.test_fn.__doc__)
self.assertIn('Returns:', result.test_fn.__doc__)
def test_nested_functions(self):
@@ -68,21 +63,16 @@ class FunctionNameScopeTransformer(converter_testing.TestCase):
def test_fn(l):
def inner_fn(i):
- return i ** 2
-
- l += 4
- return inner_fn(l)
+ return i + 1
- node = self.parse_and_analyze(test_fn, {})
- node = name_scopes.transform(node, self.ctx)
+ l += 1
+ return l, inner_fn(l)
- with self.compiled(node, ops.name_scope) as result:
- result_op = result.test_fn(constant_op.constant(1))
- first_result_input_name = result_op.op.inputs[0].name
- second_result_input_name = result_op.op.inputs[1].name
- self.assertIn('test_fn/', first_result_input_name)
- self.assertNotIn('inner_fn', first_result_input_name)
- self.assertIn('test_fn/inner_fn/', second_result_input_name)
+ with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
+ first, second = result.test_fn(constant_op.constant(1))
+ self.assertIn('test_fn/', first.op.name)
+ self.assertNotIn('inner_fn', first.op.name)
+ self.assertIn('test_fn/inner_fn/', second.op.name)
def test_method(self):
@@ -91,48 +81,20 @@ class FunctionNameScopeTransformer(converter_testing.TestCase):
def test_fn(self, l):
def inner_fn(i):
- return i ** 2
-
- l += 4
- return inner_fn(l)
+ return i + 1
- # Note that 'TestClass' was needed in the namespace here.
- node = self.parse_and_analyze(
- TestClass, {'TestClass': TestClass}, owner_type=TestClass)
- node = name_scopes.transform(node, self.ctx)
+ l += 1
+ return l, inner_fn(l)
- with self.compiled(node, ops.name_scope) as result:
- result_op = result.TestClass().test_fn(constant_op.constant(1))
- first_result_input_name = result_op.op.inputs[0].name
- second_result_input_name = result_op.op.inputs[1].name
- self.assertIn('TestClass/test_fn/', first_result_input_name)
- self.assertNotIn('inner_fn', first_result_input_name)
- self.assertIn('TestClass/test_fn/inner_fn/', second_result_input_name)
+ ns = {'TestClass': TestClass}
+ node, ctx = self.prepare(TestClass, ns, owner_type=TestClass)
+ node = name_scopes.transform(node, ctx)
- def test_operator(self):
-
- class TestClass(object):
-
- def __call__(self, l):
-
- def inner_fn(i):
- return i ** 2
-
- l += 4
- return inner_fn(l)
-
- # Note that 'TestClass' was needed in the namespace here.
- node = self.parse_and_analyze(
- TestClass.__call__, {'TestClass': TestClass}, owner_type=TestClass)
- node = name_scopes.transform(node, self.ctx)
-
- with self.compiled(node, ops.name_scope) as result:
- result_op = result.__call__(TestClass(), constant_op.constant(1))
- first_result_input_name = result_op.op.inputs[0].name
- second_result_input_name = result_op.op.inputs[1].name
- self.assertIn('call__/', first_result_input_name)
- self.assertNotIn('inner_fn', first_result_input_name)
- self.assertIn('call__/inner_fn/', second_result_input_name)
+ with self.compiled(node, {}, ops.name_scope) as result:
+ first, second = result.TestClass().test_fn(constant_op.constant(1))
+ self.assertIn('TestClass/test_fn/', first.op.name)
+ self.assertNotIn('inner_fn', first.op.name)
+ self.assertIn('TestClass/test_fn/inner_fn/', second.op.name)
if __name__ == '__main__':