diff options
Diffstat (limited to 'tensorflow/python/autograph/pyct/pretty_printer.py')
-rw-r--r-- | tensorflow/python/autograph/pyct/pretty_printer.py | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/pretty_printer.py b/tensorflow/python/autograph/pyct/pretty_printer.py new file mode 100644 index 0000000000..bacc1e4a77 --- /dev/null +++ b/tensorflow/python/autograph/pyct/pretty_printer.py @@ -0,0 +1,113 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Print an AST tree in a form more readable than ast.dump.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast +import termcolor + + +class PrettyPrinter(gast.NodeVisitor): + """Print AST nodes.""" + + def __init__(self, color): + self.indent_lvl = 0 + self.result = '' + self.color = color + + def _color(self, string, color, attrs=None): + if self.color: + return termcolor.colored(string, color, attrs=attrs) + return string + + def _type(self, node): + return self._color(node.__class__.__name__, None, ['bold']) + + def _field(self, name): + return self._color(name, 'blue') + + def _value(self, name): + return self._color(name, 'magenta') + + def _warning(self, name): + return self._color(name, 'red') + + def _indent(self): + return self._color('| ' * self.indent_lvl, None, ['dark']) + + def _print(self, s): + self.result += s + self.result += '\n' + + def generic_visit(self, node, name=None): + if node._fields: + cont = ':' + else: + cont = '()' + + if name: + self._print('%s%s=%s%s' % (self._indent(), self._field(name), + self._type(node), cont)) + else: + self._print('%s%s%s' % (self._indent(), self._type(node), cont)) + + self.indent_lvl += 1 + for f in node._fields: + if not hasattr(node, f): + self._print('%s%s' % (self._indent(), self._warning('%s=<unset>' % f))) + continue + v = getattr(node, f) + if isinstance(v, list): + if v: + self._print('%s%s=[' % (self._indent(), self._field(f))) + self.indent_lvl += 1 + for n in v: + self.generic_visit(n) + self.indent_lvl -= 1 + self._print('%s]' % (self._indent())) + else: + self._print('%s%s=[]' % (self._indent(), self._field(f))) + elif isinstance(v, tuple): + if v: + self._print('%s%s=(' % (self._indent(), self._field(f))) + self.indent_lvl += 1 + for n in v: + self.generic_visit(n) + self.indent_lvl -= 1 + self._print('%s)' % (self._indent())) + else: + self._print('%s%s=()' % (self._indent(), self._field(f))) + elif isinstance(v, gast.AST): + self.generic_visit(v, f) + elif isinstance(v, str): + self._print('%s%s=%s' % (self._indent(), self._field(f), + self._value('"%s"' % v))) + else: + self._print('%s%s=%s' % (self._indent(), self._field(f), + self._value(v))) + self.indent_lvl -= 1 + + +def fmt(node, color=True): + printer = PrettyPrinter(color) + if isinstance(node, (list, tuple)): + for n in node: + printer.visit(n) + else: + printer.visit(node) + return printer.result |