aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/compatibility
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2017-06-12 21:38:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-12 21:41:36 -0700
commit4cc0caadf6b6a1b3b589595d6a752802f3d23356 (patch)
tree5240351f45ad7d5a2b8516b729e2d4c8c621ec4f /tensorflow/tools/compatibility
parentb044a5d5e8296a487b48e65304eb3c601500a487 (diff)
Pull generic logic out of tf_upgrade script.
PiperOrigin-RevId: 158806203
Diffstat (limited to 'tensorflow/tools/compatibility')
-rw-r--r--tensorflow/tools/compatibility/BUILD5
-rw-r--r--tensorflow/tools/compatibility/ast_edits.py497
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade.py462
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade_test.py5
4 files changed, 512 insertions, 457 deletions
diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD
index 2ca71e7ab8..fb40cf0833 100644
--- a/tensorflow/tools/compatibility/BUILD
+++ b/tensorflow/tools/compatibility/BUILD
@@ -10,7 +10,10 @@ load(
py_binary(
name = "tf_upgrade",
- srcs = ["tf_upgrade.py"],
+ srcs = [
+ "ast_edits.py",
+ "tf_upgrade.py",
+ ],
srcs_version = "PY2AND3",
)
diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py
new file mode 100644
index 0000000000..e7e4c91692
--- /dev/null
+++ b/tensorflow/tools/compatibility/ast_edits.py
@@ -0,0 +1,497 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Upgrader for Python scripts according to an API change specification."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+import collections
+import os
+import shutil
+import sys
+import tempfile
+import traceback
+
+
+class APIChangeSpec(object):
+ """This class defines the transformations that need to happen.
+
+ This class must provide the following fields:
+
+ * `function_keyword_renames`: maps function names to a map of old -> new
+ argument names
+ * `function_renames`: maps function names to new function names
+ * `change_to_function`: a set of function names that have changed (for
+ notifications)
+ * `function_reorders`: maps functions whose argument order has changed to the
+ list of arguments in the new order
+ * `function_handle`: maps function names to custom handlers for the function
+
+ For an example, see `TFAPIChangeSpec`.
+ """
+
+
+class _FileEditTuple(collections.namedtuple(
+ "_FileEditTuple", ["comment", "line", "start", "old", "new"])):
+ """Each edit that is recorded by a _FileEditRecorder.
+
+ Fields:
+ comment: A description of the edit and why it was made.
+ line: The line number in the file where the edit occurs (1-indexed).
+ start: The line number in the file where the edit occurs (0-indexed).
+ old: text string to remove (this must match what was in file).
+ new: text string to add in place of `old`.
+ """
+
+ __slots__ = ()
+
+
+class _FileEditRecorder(object):
+ """Record changes that need to be done to the file."""
+
+ def __init__(self, filename):
+ # all edits are lists of chars
+ self._filename = filename
+
+ self._line_to_edit = collections.defaultdict(list)
+ self._errors = []
+
+ def process(self, text):
+ """Process a list of strings, each corresponding to the recorded changes.
+
+ Args:
+ text: A list of lines of text (assumed to contain newlines)
+ Returns:
+ A tuple of the modified text and a textual description of what is done.
+ Raises:
+ ValueError: if substitution source location does not have expected text.
+ """
+
+ change_report = ""
+
+ # Iterate of each line
+ for line, edits in self._line_to_edit.items():
+ offset = 0
+ # sort by column so that edits are processed in order in order to make
+ # indexing adjustments cumulative for changes that change the string
+ # length
+ edits.sort(key=lambda x: x.start)
+
+ # Extract each line to a list of characters, because mutable lists
+ # are editable, unlike immutable strings.
+ char_array = list(text[line - 1])
+
+ # Record a description of the change
+ change_report += "%r Line %d\n" % (self._filename, line)
+ change_report += "-" * 80 + "\n\n"
+ for e in edits:
+ change_report += "%s\n" % e.comment
+ change_report += "\n Old: %s" % (text[line - 1])
+
+ # Make underscore buffers for underlining where in the line the edit was
+ change_list = [" "] * len(text[line - 1])
+ change_list_new = [" "] * len(text[line - 1])
+
+ # Iterate for each edit
+ for e in edits:
+ # Create effective start, end by accounting for change in length due
+ # to previous edits
+ start_eff = e.start + offset
+ end_eff = start_eff + len(e.old)
+
+ # Make sure the edit is changing what it should be changing
+ old_actual = "".join(char_array[start_eff:end_eff])
+ if old_actual != e.old:
+ raise ValueError("Expected text %r but got %r" %
+ ("".join(e.old), "".join(old_actual)))
+ # Make the edit
+ char_array[start_eff:end_eff] = list(e.new)
+
+ # Create the underline highlighting of the before and after
+ change_list[e.start:e.start + len(e.old)] = "~" * len(e.old)
+ change_list_new[start_eff:end_eff] = "~" * len(e.new)
+
+ # Keep track of how to generate effective ranges
+ offset += len(e.new) - len(e.old)
+
+ # Finish the report comment
+ change_report += " %s\n" % "".join(change_list)
+ text[line - 1] = "".join(char_array)
+ change_report += " New: %s" % (text[line - 1])
+ change_report += " %s\n\n" % "".join(change_list_new)
+ return "".join(text), change_report, self._errors
+
+ def add(self, comment, line, start, old, new, error=None):
+ """Add a new change that is needed.
+
+ Args:
+ comment: A description of what was changed
+ line: Line number (1 indexed)
+ start: Column offset (0 indexed)
+ old: old text
+ new: new text
+ error: this "edit" is something that cannot be fixed automatically
+ Returns:
+ None
+ """
+
+ self._line_to_edit[line].append(
+ _FileEditTuple(comment, line, start, old, new))
+ if error:
+ self._errors.append("%s:%d: %s" % (self._filename, line, error))
+
+
+class _ASTCallVisitor(ast.NodeVisitor):
+ """AST Visitor that processes function calls.
+
+ Updates function calls from old API version to new API version using a given
+ change spec.
+ """
+
+ def __init__(self, filename, lines, api_change_spec):
+ self._filename = filename
+ self._file_edit = _FileEditRecorder(filename)
+ self._lines = lines
+ self._api_change_spec = api_change_spec
+
+ def process(self, lines):
+ return self._file_edit.process(lines)
+
+ def generic_visit(self, node):
+ ast.NodeVisitor.generic_visit(self, node)
+
+ def _rename_functions(self, node, full_name):
+ function_renames = self._api_change_spec.function_renames
+ try:
+ new_name = function_renames[full_name]
+ self._file_edit.add("Renamed function %r to %r" % (full_name,
+ new_name),
+ node.lineno, node.col_offset, full_name, new_name)
+ except KeyError:
+ pass
+
+ def _get_attribute_full_path(self, node):
+ """Traverse an attribute to generate a full name e.g. tf.foo.bar.
+
+ Args:
+ node: A Node of type Attribute.
+
+ Returns:
+ a '.'-delimited full-name or None if the tree was not a simple form.
+ i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
+ """
+ curr = node
+ items = []
+ while not isinstance(curr, ast.Name):
+ if not isinstance(curr, ast.Attribute):
+ return None
+ items.append(curr.attr)
+ curr = curr.value
+ items.append(curr.id)
+ return ".".join(reversed(items))
+
+ def _find_true_position(self, node):
+ """Return correct line number and column offset for a given node.
+
+ This is necessary mainly because ListComp's location reporting reports
+ the next token after the list comprehension list opening.
+
+ Args:
+ node: Node for which we wish to know the lineno and col_offset
+ """
+ import re
+ find_open = re.compile("^\s*(\\[).*$")
+ find_string_chars = re.compile("['\"]")
+
+ if isinstance(node, ast.ListComp):
+ # Strangely, ast.ListComp returns the col_offset of the first token
+ # after the '[' token which appears to be a bug. Workaround by
+ # explicitly finding the real start of the list comprehension.
+ line = node.lineno
+ col = node.col_offset
+ # loop over lines
+ while 1:
+ # Reverse the text to and regular expression search for whitespace
+ text = self._lines[line-1]
+ reversed_preceding_text = text[:col][::-1]
+ # First find if a [ can be found with only whitespace between it and
+ # col.
+ m = find_open.match(reversed_preceding_text)
+ if m:
+ new_col_offset = col - m.start(1) - 1
+ return line, new_col_offset
+ else:
+ if (reversed_preceding_text=="" or
+ reversed_preceding_text.isspace()):
+ line = line - 1
+ prev_line = self._lines[line - 1]
+ # TODO(aselle):
+ # this is poor comment detection, but it is good enough for
+ # cases where the comment does not contain string literal starting/
+ # ending characters. If ast gave us start and end locations of the
+ # ast nodes rather than just start, we could use string literal
+ # node ranges to filter out spurious #'s that appear in string
+ # literals.
+ comment_start = prev_line.find("#")
+ if comment_start == -1:
+ col = len(prev_line) -1
+ elif find_string_chars.search(prev_line[comment_start:]) is None:
+ col = comment_start
+ else:
+ return None, None
+ else:
+ return None, None
+ # Most other nodes return proper locations (with notably does not), but
+ # it is not possible to use that in an argument.
+ return node.lineno, node.col_offset
+
+
+ def visit_Call(self, node): # pylint: disable=invalid-name
+ """Handle visiting a call node in the AST.
+
+ Args:
+ node: Current Node
+ """
+
+
+ # Find a simple attribute name path e.g. "tf.foo.bar"
+ full_name = self._get_attribute_full_path(node.func)
+
+ # Make sure the func is marked as being part of a call
+ node.func.is_function_for_call = True
+
+ if full_name:
+ # Call special handlers
+ function_handles = self._api_change_spec.function_handle
+ if full_name in function_handles:
+ function_handles[full_name](self._file_edit, node)
+
+ # Examine any non-keyword argument and make it into a keyword argument
+ # if reordering required.
+ function_reorders = self._api_change_spec.function_reorders
+ function_keyword_renames = (
+ self._api_change_spec.function_keyword_renames)
+
+ if full_name in function_reorders:
+ reordered = function_reorders[full_name]
+ for idx, arg in enumerate(node.args):
+ lineno, col_offset = self._find_true_position(arg)
+ if lineno is None or col_offset is None:
+ self._file_edit.add(
+ "Failed to add keyword %r to reordered function %r"
+ % (reordered[idx], full_name), arg.lineno, arg.col_offset,
+ "", "",
+ error="A necessary keyword argument failed to be inserted.")
+ else:
+ keyword_arg = reordered[idx]
+ if (full_name in function_keyword_renames and
+ keyword_arg in function_keyword_renames[full_name]):
+ keyword_arg = function_keyword_renames[full_name][keyword_arg]
+ self._file_edit.add("Added keyword %r to reordered function %r"
+ % (reordered[idx], full_name), lineno,
+ col_offset, "", keyword_arg + "=")
+
+ # Examine each keyword argument and convert it to the final renamed form
+ renamed_keywords = ({} if full_name not in function_keyword_renames else
+ function_keyword_renames[full_name])
+ for keyword in node.keywords:
+ argkey = keyword.arg
+ argval = keyword.value
+
+ if argkey in renamed_keywords:
+ argval_lineno, argval_col_offset = self._find_true_position(argval)
+ if argval_lineno is not None and argval_col_offset is not None:
+ # TODO(aselle): We should scan backward to find the start of the
+ # keyword key. Unfortunately ast does not give you the location of
+ # keyword keys, so we are forced to infer it from the keyword arg
+ # value.
+ key_start = argval_col_offset - len(argkey) - 1
+ key_end = key_start + len(argkey) + 1
+ if (self._lines[argval_lineno - 1][key_start:key_end] ==
+ argkey + "="):
+ self._file_edit.add("Renamed keyword argument from %r to %r" %
+ (argkey, renamed_keywords[argkey]),
+ argval_lineno,
+ argval_col_offset - len(argkey) - 1,
+ argkey + "=", renamed_keywords[argkey] + "=")
+ continue
+ self._file_edit.add(
+ "Failed to rename keyword argument from %r to %r" %
+ (argkey, renamed_keywords[argkey]),
+ argval.lineno,
+ argval.col_offset - len(argkey) - 1,
+ "", "",
+ error="Failed to find keyword lexographically. Fix manually.")
+
+ ast.NodeVisitor.generic_visit(self, node)
+
+ def visit_Attribute(self, node): # pylint: disable=invalid-name
+ """Handle bare Attributes i.e. [tf.foo, tf.bar].
+
+ Args:
+ node: Node that is of type ast.Attribute
+ """
+ full_name = self._get_attribute_full_path(node)
+ if full_name:
+ self._rename_functions(node, full_name)
+ if full_name in self._api_change_spec.change_to_function:
+ if not hasattr(node, "is_function_for_call"):
+ new_text = full_name + "()"
+ self._file_edit.add("Changed %r to %r"%(full_name, new_text),
+ node.lineno, node.col_offset, full_name, new_text)
+
+ ast.NodeVisitor.generic_visit(self, node)
+
+
+class ASTCodeUpgrader(object):
+ """Handles upgrading a set of Python files using a given API change spec."""
+
+ def __init__(self, api_change_spec):
+ if not isinstance(api_change_spec, APIChangeSpec):
+ raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" %
+ type(api_change_spec))
+ self._api_change_spec = api_change_spec
+
+ def process_file(self, in_filename, out_filename):
+ """Process the given python file for incompatible changes.
+
+ Args:
+ in_filename: filename to parse
+ out_filename: output file to write to
+ Returns:
+ A tuple representing number of files processed, log of actions, errors
+ """
+
+ # Write to a temporary file, just in case we are doing an implace modify.
+ with open(in_filename, "r") as in_file, \
+ tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
+ ret = self.process_opened_file(
+ in_filename, in_file, out_filename, temp_file)
+
+ shutil.move(temp_file.name, out_filename)
+ return ret
+
+ # Broad exceptions are required here because ast throws whatever it wants.
+ # pylint: disable=broad-except
+ def process_opened_file(self, in_filename, in_file, out_filename, out_file):
+ """Process the given python file for incompatible changes.
+
+ This function is split out to facilitate StringIO testing from
+ tf_upgrade_test.py.
+
+ Args:
+ in_filename: filename to parse
+ in_file: opened file (or StringIO)
+ out_filename: output file to write to
+ out_file: opened file (or StringIO)
+ Returns:
+ A tuple representing number of files processed, log of actions, errors
+ """
+ process_errors = []
+ text = "-" * 80 + "\n"
+ text += "Processing file %r\n outputting to %r\n" % (in_filename,
+ out_filename)
+ text += "-" * 80 + "\n\n"
+
+ parsed_ast = None
+ lines = in_file.readlines()
+ try:
+ parsed_ast = ast.parse("".join(lines))
+ except Exception:
+ text += "Failed to parse %r\n\n" % in_filename
+ text += traceback.format_exc()
+ if parsed_ast:
+ visitor = _ASTCallVisitor(in_filename, lines, self._api_change_spec)
+ visitor.visit(parsed_ast)
+ out_text, new_text, process_errors = visitor.process(lines)
+ text += new_text
+ if out_file:
+ out_file.write(out_text)
+ text += "\n"
+ return 1, text, process_errors
+ # pylint: enable=broad-except
+
+ def process_tree(self, root_directory, output_root_directory,
+ copy_other_files):
+ """Processes upgrades on an entire tree of python files in place.
+
+ Note that only Python files. If you have custom code in other languages,
+ you will need to manually upgrade those.
+
+ Args:
+ root_directory: Directory to walk and process.
+ output_root_directory: Directory to use as base.
+ copy_other_files: Copy files that are not touched by this converter.
+
+ Returns:
+ A tuple of files processed, the report string ofr all files, and errors
+ """
+
+ # make sure output directory doesn't exist
+ if output_root_directory and os.path.exists(output_root_directory):
+ print("Output directory %r must not already exist." % (
+ output_root_directory))
+ sys.exit(1)
+
+ # make sure output directory does not overlap with root_directory
+ norm_root = os.path.split(os.path.normpath(root_directory))
+ norm_output = os.path.split(os.path.normpath(output_root_directory))
+ if norm_root == norm_output:
+ print("Output directory %r same as input directory %r" % (
+ root_directory, output_root_directory))
+ sys.exit(1)
+
+ # Collect list of files to process (we do this to correctly handle if the
+ # user puts the output directory in some sub directory of the input dir)
+ files_to_process = []
+ files_to_copy = []
+ for dir_name, _, file_list in os.walk(root_directory):
+ py_files = [f for f in file_list if f.endswith(".py")]
+ copy_files = [f for f in file_list if not f.endswith(".py")]
+ for filename in py_files:
+ fullpath = os.path.join(dir_name, filename)
+ fullpath_output = os.path.join(
+ output_root_directory, os.path.relpath(fullpath, root_directory))
+ files_to_process.append((fullpath, fullpath_output))
+ if copy_other_files:
+ for filename in copy_files:
+ fullpath = os.path.join(dir_name, filename)
+ fullpath_output = os.path.join(
+ output_root_directory, os.path.relpath(fullpath, root_directory))
+ files_to_copy.append((fullpath, fullpath_output))
+
+ file_count = 0
+ tree_errors = []
+ report = ""
+ report += ("=" * 80) + "\n"
+ report += "Input tree: %r\n" % root_directory
+ report += ("=" * 80) + "\n"
+
+ for input_path, output_path in files_to_process:
+ output_directory = os.path.dirname(output_path)
+ if not os.path.isdir(output_directory):
+ os.makedirs(output_directory)
+ file_count += 1
+ _, l_report, l_errors = self.process_file(input_path, output_path)
+ tree_errors += l_errors
+ report += l_report
+ for input_path, output_path in files_to_copy:
+ output_directory = os.path.dirname(output_path)
+ if not os.path.isdir(output_directory):
+ os.makedirs(output_directory)
+ shutil.copy(input_path, output_path)
+ return file_count, report, tree_errors
diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py
index 9a4a8ff71d..72fe4a48cd 100644
--- a/tensorflow/tools/compatibility/tf_upgrade.py
+++ b/tensorflow/tools/compatibility/tf_upgrade.py
@@ -17,17 +17,13 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
import argparse
-import ast
-import collections
-import os
-import shutil
-import sys
-import tempfile
-import traceback
+
+from tensorflow.tools.compatibility import ast_edits
-class APIChangeSpec(object):
+class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"""List of maps that describe what changed in the API."""
def __init__(self):
@@ -179,7 +175,9 @@ class APIChangeSpec(object):
}
# Specially handled functions.
- self.function_handle = {"tf.reverse": self._reverse_handler}
+ self.function_handle = {
+ "tf.reverse": self._reverse_handler
+ }
@staticmethod
def _reverse_handler(file_edit_recorder, node):
@@ -196,450 +194,6 @@ class APIChangeSpec(object):
error="tf.reverse requires manual check.")
-class FileEditTuple(collections.namedtuple(
- "FileEditTuple", ["comment", "line", "start", "old", "new"])):
- """Each edit that is recorded by a FileEditRecorder.
-
- Fields:
- comment: A description of the edit and why it was made.
- line: The line number in the file where the edit occurs (1-indexed).
- start: The line number in the file where the edit occurs (0-indexed).
- old: text string to remove (this must match what was in file).
- new: text string to add in place of `old`.
- """
-
- __slots__ = ()
-
-
-class FileEditRecorder(object):
- """Record changes that need to be done to the file."""
-
- def __init__(self, filename):
- # all edits are lists of chars
- self._filename = filename
-
- self._line_to_edit = collections.defaultdict(list)
- self._errors = []
-
- def process(self, text):
- """Process a list of strings, each corresponding to the recorded changes.
-
- Args:
- text: A list of lines of text (assumed to contain newlines)
- Returns:
- A tuple of the modified text and a textual description of what is done.
- Raises:
- ValueError: if substitution source location does not have expected text.
- """
-
- change_report = ""
-
- # Iterate of each line
- for line, edits in self._line_to_edit.items():
- offset = 0
- # sort by column so that edits are processed in order in order to make
- # indexing adjustments cumulative for changes that change the string
- # length
- edits.sort(key=lambda x: x.start)
-
- # Extract each line to a list of characters, because mutable lists
- # are editable, unlike immutable strings.
- char_array = list(text[line - 1])
-
- # Record a description of the change
- change_report += "%r Line %d\n" % (self._filename, line)
- change_report += "-" * 80 + "\n\n"
- for e in edits:
- change_report += "%s\n" % e.comment
- change_report += "\n Old: %s" % (text[line - 1])
-
- # Make underscore buffers for underlining where in the line the edit was
- change_list = [" "] * len(text[line - 1])
- change_list_new = [" "] * len(text[line - 1])
-
- # Iterate for each edit
- for e in edits:
- # Create effective start, end by accounting for change in length due
- # to previous edits
- start_eff = e.start + offset
- end_eff = start_eff + len(e.old)
-
- # Make sure the edit is changing what it should be changing
- old_actual = "".join(char_array[start_eff:end_eff])
- if old_actual != e.old:
- raise ValueError("Expected text %r but got %r" %
- ("".join(e.old), "".join(old_actual)))
- # Make the edit
- char_array[start_eff:end_eff] = list(e.new)
-
- # Create the underline highlighting of the before and after
- change_list[e.start:e.start + len(e.old)] = "~" * len(e.old)
- change_list_new[start_eff:end_eff] = "~" * len(e.new)
-
- # Keep track of how to generate effective ranges
- offset += len(e.new) - len(e.old)
-
- # Finish the report comment
- change_report += " %s\n" % "".join(change_list)
- text[line - 1] = "".join(char_array)
- change_report += " New: %s" % (text[line - 1])
- change_report += " %s\n\n" % "".join(change_list_new)
- return "".join(text), change_report, self._errors
-
- def add(self, comment, line, start, old, new, error=None):
- """Add a new change that is needed.
-
- Args:
- comment: A description of what was changed
- line: Line number (1 indexed)
- start: Column offset (0 indexed)
- old: old text
- new: new text
- error: this "edit" is something that cannot be fixed automatically
- Returns:
- None
- """
-
- self._line_to_edit[line].append(
- FileEditTuple(comment, line, start, old, new))
- if error:
- self._errors.append("%s:%d: %s" % (self._filename, line, error))
-
-
-class TensorFlowCallVisitor(ast.NodeVisitor):
- """AST Visitor that finds TensorFlow Function calls.
-
- Updates function calls from old API version to new API version.
- """
-
- def __init__(self, filename, lines):
- self._filename = filename
- self._file_edit = FileEditRecorder(filename)
- self._lines = lines
- self._api_change_spec = APIChangeSpec()
-
- def process(self, lines):
- return self._file_edit.process(lines)
-
- def generic_visit(self, node):
- ast.NodeVisitor.generic_visit(self, node)
-
- def _rename_functions(self, node, full_name):
- function_renames = self._api_change_spec.function_renames
- try:
- new_name = function_renames[full_name]
- self._file_edit.add("Renamed function %r to %r" % (full_name,
- new_name),
- node.lineno, node.col_offset, full_name, new_name)
- except KeyError:
- pass
-
- def _get_attribute_full_path(self, node):
- """Traverse an attribute to generate a full name e.g. tf.foo.bar.
-
- Args:
- node: A Node of type Attribute.
-
- Returns:
- a '.'-delimited full-name or None if the tree was not a simple form.
- i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
- """
- curr = node
- items = []
- while not isinstance(curr, ast.Name):
- if not isinstance(curr, ast.Attribute):
- return None
- items.append(curr.attr)
- curr = curr.value
- items.append(curr.id)
- return ".".join(reversed(items))
-
- def _find_true_position(self, node):
- """Return correct line number and column offset for a given node.
-
- This is necessary mainly because ListComp's location reporting reports
- the next token after the list comprehension list opening.
-
- Args:
- node: Node for which we wish to know the lineno and col_offset
- """
- import re
- find_open = re.compile("^\s*(\\[).*$")
- find_string_chars = re.compile("['\"]")
-
- if isinstance(node, ast.ListComp):
- # Strangely, ast.ListComp returns the col_offset of the first token
- # after the '[' token which appears to be a bug. Workaround by
- # explicitly finding the real start of the list comprehension.
- line = node.lineno
- col = node.col_offset
- # loop over lines
- while 1:
- # Reverse the text to and regular expression search for whitespace
- text = self._lines[line-1]
- reversed_preceding_text = text[:col][::-1]
- # First find if a [ can be found with only whitespace between it and
- # col.
- m = find_open.match(reversed_preceding_text)
- if m:
- new_col_offset = col - m.start(1) - 1
- return line, new_col_offset
- else:
- if (reversed_preceding_text=="" or
- reversed_preceding_text.isspace()):
- line = line - 1
- prev_line = self._lines[line - 1]
- # TODO(aselle):
- # this is poor comment detection, but it is good enough for
- # cases where the comment does not contain string literal starting/
- # ending characters. If ast gave us start and end locations of the
- # ast nodes rather than just start, we could use string literal
- # node ranges to filter out spurious #'s that appear in string
- # literals.
- comment_start = prev_line.find("#")
- if comment_start == -1:
- col = len(prev_line) -1
- elif find_string_chars.search(prev_line[comment_start:]) is None:
- col = comment_start
- else:
- return None, None
- else:
- return None, None
- # Most other nodes return proper locations (with notably does not), but
- # it is not possible to use that in an argument.
- return node.lineno, node.col_offset
-
-
- def visit_Call(self, node): # pylint: disable=invalid-name
- """Handle visiting a call node in the AST.
-
- Args:
- node: Current Node
- """
-
-
- # Find a simple attribute name path e.g. "tf.foo.bar"
- full_name = self._get_attribute_full_path(node.func)
-
- # Make sure the func is marked as being part of a call
- node.func.is_function_for_call = True
-
- if full_name and full_name.startswith("tf."):
- # Call special handlers
- function_handles = self._api_change_spec.function_handle
- if full_name in function_handles:
- function_handles[full_name](self._file_edit, node)
-
- # Examine any non-keyword argument and make it into a keyword argument
- # if reordering required.
- function_reorders = self._api_change_spec.function_reorders
- function_keyword_renames = (
- self._api_change_spec.function_keyword_renames)
-
- if full_name in function_reorders:
- reordered = function_reorders[full_name]
- for idx, arg in enumerate(node.args):
- lineno, col_offset = self._find_true_position(arg)
- if lineno is None or col_offset is None:
- self._file_edit.add(
- "Failed to add keyword %r to reordered function %r"
- % (reordered[idx], full_name), arg.lineno, arg.col_offset,
- "", "",
- error="A necessary keyword argument failed to be inserted.")
- else:
- keyword_arg = reordered[idx]
- if (full_name in function_keyword_renames and
- keyword_arg in function_keyword_renames[full_name]):
- keyword_arg = function_keyword_renames[full_name][keyword_arg]
- self._file_edit.add("Added keyword %r to reordered function %r"
- % (reordered[idx], full_name), lineno,
- col_offset, "", keyword_arg + "=")
-
- # Examine each keyword argument and convert it to the final renamed form
- renamed_keywords = ({} if full_name not in function_keyword_renames else
- function_keyword_renames[full_name])
- for keyword in node.keywords:
- argkey = keyword.arg
- argval = keyword.value
-
- if argkey in renamed_keywords:
- argval_lineno, argval_col_offset = self._find_true_position(argval)
- if (argval_lineno is not None and argval_col_offset is not None):
- # TODO(aselle): We should scan backward to find the start of the
- # keyword key. Unfortunately ast does not give you the location of
- # keyword keys, so we are forced to infer it from the keyword arg
- # value.
- key_start = argval_col_offset - len(argkey) - 1
- key_end = key_start + len(argkey) + 1
- if self._lines[argval_lineno - 1][key_start:key_end] == argkey + "=":
- self._file_edit.add("Renamed keyword argument from %r to %r" %
- (argkey, renamed_keywords[argkey]),
- argval_lineno,
- argval_col_offset - len(argkey) - 1,
- argkey + "=", renamed_keywords[argkey] + "=")
- continue
- self._file_edit.add(
- "Failed to rename keyword argument from %r to %r" %
- (argkey, renamed_keywords[argkey]),
- argval.lineno,
- argval.col_offset - len(argkey) - 1,
- "", "",
- error="Failed to find keyword lexographically. Fix manually.")
-
- ast.NodeVisitor.generic_visit(self, node)
-
- def visit_Attribute(self, node): # pylint: disable=invalid-name
- """Handle bare Attributes i.e. [tf.foo, tf.bar].
-
- Args:
- node: Node that is of type ast.Attribute
- """
- full_name = self._get_attribute_full_path(node)
- if full_name and full_name.startswith("tf."):
- self._rename_functions(node, full_name)
- if full_name in self._api_change_spec.change_to_function:
- if not hasattr(node, "is_function_for_call"):
- new_text = full_name + "()"
- self._file_edit.add("Changed %r to %r"%(full_name, new_text),
- node.lineno, node.col_offset, full_name, new_text)
-
- ast.NodeVisitor.generic_visit(self, node)
-
-
-class TensorFlowCodeUpgrader(object):
- """Class that handles upgrading a set of Python files to TensorFlow 1.0."""
-
- def __init__(self):
- pass
-
- def process_file(self, in_filename, out_filename):
- """Process the given python file for incompatible changes.
-
- Args:
- in_filename: filename to parse
- out_filename: output file to write to
- Returns:
- A tuple representing number of files processed, log of actions, errors
- """
-
- # Write to a temporary file, just in case we are doing an implace modify.
- with open(in_filename, "r") as in_file, \
- tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
- ret = self.process_opened_file(
- in_filename, in_file, out_filename, temp_file)
-
- shutil.move(temp_file.name, out_filename)
- return ret
-
- # Broad exceptions are required here because ast throws whatever it wants.
- # pylint: disable=broad-except
- def process_opened_file(self, in_filename, in_file, out_filename, out_file):
- """Process the given python file for incompatible changes.
-
- This function is split out to facilitate StringIO testing from
- tf_upgrade_test.py.
-
- Args:
- in_filename: filename to parse
- in_file: opened file (or StringIO)
- out_filename: output file to write to
- out_file: opened file (or StringIO)
- Returns:
- A tuple representing number of files processed, log of actions, errors
- """
- process_errors = []
- text = "-" * 80 + "\n"
- text += "Processing file %r\n outputting to %r\n" % (in_filename,
- out_filename)
- text += "-" * 80 + "\n\n"
-
- parsed_ast = None
- lines = in_file.readlines()
- try:
- parsed_ast = ast.parse("".join(lines))
- except Exception:
- text += "Failed to parse %r\n\n" % in_filename
- text += traceback.format_exc()
- if parsed_ast:
- visitor = TensorFlowCallVisitor(in_filename, lines)
- visitor.visit(parsed_ast)
- out_text, new_text, process_errors = visitor.process(lines)
- text += new_text
- if out_file:
- out_file.write(out_text)
- text += "\n"
- return 1, text, process_errors
- # pylint: enable=broad-except
-
- def process_tree(self, root_directory, output_root_directory, copy_other_files):
- """Processes upgrades on an entire tree of python files in place.
-
- Note that only Python files. If you have custom code in other languages,
- you will need to manually upgrade those.
-
- Args:
- root_directory: Directory to walk and process.
- output_root_directory: Directory to use as base
- Returns:
- A tuple of files processed, the report string ofr all files, and errors
- """
-
- # make sure output directory doesn't exist
- if output_root_directory and os.path.exists(output_root_directory):
- print("Output directory %r must not already exist." % (
- output_root_directory))
- sys.exit(1)
-
- # make sure output directory does not overlap with root_directory
- norm_root = os.path.split(os.path.normpath(root_directory))
- norm_output = os.path.split(os.path.normpath(output_root_directory))
- if norm_root == norm_output:
- print("Output directory %r same as input directory %r" % (
- root_directory, output_root_directory))
- sys.exit(1)
-
- # Collect list of files to process (we do this to correctly handle if the
- # user puts the output directory in some sub directory of the input dir)
- files_to_process = []
- files_to_copy = []
- for dir_name, _, file_list in os.walk(root_directory):
- py_files = [f for f in file_list if f.endswith(".py")]
- copy_files = [f for f in file_list if not f.endswith(".py")]
- for filename in py_files:
- fullpath = os.path.join(dir_name, filename)
- fullpath_output = os.path.join(
- output_root_directory, os.path.relpath(fullpath, root_directory))
- files_to_process.append((fullpath, fullpath_output))
- if copy_other_files:
- for filename in copy_files:
- fullpath = os.path.join(dir_name, filename)
- fullpath_output = os.path.join(
- output_root_directory, os.path.relpath(fullpath, root_directory))
- files_to_copy.append((fullpath, fullpath_output))
-
- file_count = 0
- tree_errors = []
- report = ""
- report += ("=" * 80) + "\n"
- report += "Input tree: %r\n" % root_directory
- report += ("=" * 80) + "\n"
-
- for input_path, output_path in files_to_process:
- output_directory = os.path.dirname(output_path)
- if not os.path.isdir(output_directory):
- os.makedirs(output_directory)
- file_count += 1
- _, l_report, l_errors = self.process_file(input_path, output_path)
- tree_errors += l_errors
- report += l_report
- for input_path, output_path in files_to_copy:
- output_directory = os.path.dirname(output_path)
- if not os.path.isdir(output_directory):
- os.makedirs(output_directory)
- shutil.copy(input_path, output_path)
- return file_count, report, tree_errors
-
-
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
@@ -684,7 +238,7 @@ Simple usage:
default="report.txt")
args = parser.parse_args()
- upgrade = TensorFlowCodeUpgrader()
+ upgrade = ast_edits.ASTCodeUpgrader(TFAPIChangeSpec())
report_text = None
report_filename = args.report_filename
files_processed = 0
diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py
index de4e3de73c..ac838a2791 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_test.py
@@ -22,6 +22,7 @@ import tempfile
import six
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test as test_lib
+from tensorflow.tools.compatibility import ast_edits
from tensorflow.tools.compatibility import tf_upgrade
@@ -36,7 +37,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
def _upgrade(self, old_file_text):
in_file = six.StringIO(old_file_text)
out_file = six.StringIO()
- upgrader = tf_upgrade.TensorFlowCodeUpgrader()
+ upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
count, report, errors = (
upgrader.process_opened_file("test.py", in_file,
"test_out.py", out_file))
@@ -139,7 +140,7 @@ class TestUpgradeFiles(test_util.TensorFlowTestCase):
upgraded = "tf.multiply(a, b)\n"
temp_file.write(original)
temp_file.close()
- upgrader = tf_upgrade.TensorFlowCodeUpgrader()
+ upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
upgrader.process_file(temp_file.name, temp_file.name)
self.assertAllEqual(open(temp_file.name).read(), upgraded)
os.unlink(temp_file.name)