1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for compiler module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import textwrap
import gast
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
class CompilerTest(test.TestCase):
def test_parser_compile_idempotent(self):
def test_fn(x):
a = True
b = ''
if a:
b = x + 1
return b
self.assertEqual(
textwrap.dedent(tf_inspect.getsource(test_fn)),
tf_inspect.getsource(
compiler.ast_to_object(
parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
def test_ast_to_source(self):
node = gast.If(
test=gast.Num(1),
body=[
gast.Assign(
targets=[gast.Name('a', gast.Store(), None)],
value=gast.Name('b', gast.Load(), None))
],
orelse=[
gast.Assign(
targets=[gast.Name('a', gast.Store(), None)],
value=gast.Str('c'))
])
source = compiler.ast_to_source(node, indentation=' ')
self.assertEqual(
textwrap.dedent("""
if 1:
a = b
else:
a = 'c'
""").strip(), source.strip())
def test_ast_to_object(self):
node = gast.FunctionDef(
name='f',
args=gast.arguments(
args=[gast.Name('a', gast.Param(), None)],
vararg=None,
kwonlyargs=[],
kwarg=None,
defaults=[],
kw_defaults=[]),
body=[
gast.Return(
gast.BinOp(
op=gast.Add(),
left=gast.Name('a', gast.Load(), None),
right=gast.Num(1)))
],
decorator_list=[],
returns=None)
module, source = compiler.ast_to_object(node)
expected_source = """
def f(a):
return a + 1
"""
self.assertEqual(
textwrap.dedent(expected_source).strip(),
source.strip())
self.assertEqual(2, module.f(1))
with open(module.__file__, 'r') as temp_output:
self.assertEqual(
textwrap.dedent(expected_source).strip(),
temp_output.read().strip())
if __name__ == '__main__':
test.main()
|