aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/compatibility/tf_upgrade.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tools/compatibility/tf_upgrade.py')
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade.py128
1 files changed, 118 insertions, 10 deletions
diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py
index 374be0475a..bcff10f21d 100644
--- a/tensorflow/tools/compatibility/tf_upgrade.py
+++ b/tensorflow/tools/compatibility/tf_upgrade.py
@@ -95,11 +95,15 @@ class APIChangeSpec(object):
"tf.split": {
"split_dim": "axis",
"num_split": "num_or_size_splits"
- }
+ },
+ "tf.concat": {
+ "concat_dim": "axis"
+ },
}
# Mapping from function to the new name of the function
self.function_renames = {
+ "tf.inv": "tf.reciprocal",
"tf.contrib.deprecated.scalar_summary": "tf.summary.scalar",
"tf.contrib.deprecated.histogram_summary": "tf.summary.histogram",
"tf.listdiff": "tf.setdiff1d",
@@ -142,6 +146,13 @@ class APIChangeSpec(object):
"tf.select": "tf.where",
"tf.complex_abs": "tf.abs",
"tf.batch_matmul": "tf.matmul",
+ "tf.pack": "tf.stack",
+ "tf.unpack": "tf.unstack",
+ }
+
+ self.change_to_function = {
+ "tf.ones_initializer",
+ "tf.zeros_initializer",
}
# Functions that were reordered should be changed to the new keyword args
@@ -149,6 +160,7 @@ class APIChangeSpec(object):
# positional arguments yourself, this could do the wrong thing.
self.function_reorders = {
"tf.split": ["axis", "num_or_size_splits", "value", "name"],
+ "tf.sparse_split": ["axis", "num_or_size_splits", "value", "name"],
"tf.concat": ["concat_dim", "values", "name"],
"tf.svd": ["tensor", "compute_uv", "full_matrices", "name"],
"tf.nn.softmax_cross_entropy_with_logits": [
@@ -335,6 +347,62 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
items.append(curr.id)
return ".".join(reversed(items))
+ def _find_true_position(self, node):
+ """Return correct line number and column offset for a given node.
+
+ This is necessary mainly because ListComp's location reporting reports
+ the next token after the list comprehension list opening.
+
+ Args:
+ node: Node for which we wish to know the lineno and col_offset
+ """
+ import re
+ find_open = re.compile("^\s*(\\[).*$")
+ find_string_chars = re.compile("['\"]")
+
+ if isinstance(node, ast.ListComp):
+ # Strangely, ast.ListComp returns the col_offset of the first token
+ # after the '[' token which appears to be a bug. Workaround by
+ # explicitly finding the real start of the list comprehension.
+ line = node.lineno
+ col = node.col_offset
+ # loop over lines
+ while 1:
+ # Reverse the text to and regular expression search for whitespace
+ text = self._lines[line-1]
+ reversed_preceding_text = text[:col][::-1]
+ # First find if a [ can be found with only whitespace between it and
+ # col.
+ m = find_open.match(reversed_preceding_text)
+ if m:
+ new_col_offset = col - m.start(1) - 1
+ return line, new_col_offset
+ else:
+ if (reversed_preceding_text=="" or
+ reversed_preceding_text.isspace()):
+ line = line - 1
+ prev_line = self._lines[line - 1]
+ # TODO(aselle):
+ # this is poor comment detection, but it is good enough for
+ # cases where the comment does not contain string literal starting/
+ # ending characters. If ast gave us start and end locations of the
+ # ast nodes rather than just start, we could use string literal
+ # node ranges to filter out spurious #'s that appear in string
+ # literals.
+ comment_start = prev_line.find("#")
+ if comment_start == -1:
+ col = len(prev_line) -1
+ elif find_string_chars.search(prev_line[comment_start:]) is None:
+ col = comment_start
+ else:
+ return None, None
+ else:
+ return None, None
+ # Most other nodes return proper locations (with notably does not), but
+ # it is not possible to use that in an argument.
+ return node.lineno, node.col_offset
+
+
def visit_Call(self, node): # pylint: disable=invalid-name
"""Handle visiting a call node in the AST.
@@ -342,11 +410,13 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
node: Current Node
"""
- ast.NodeVisitor.generic_visit(self, node)
# Find a simple attribute name path e.g. "tf.foo.bar"
full_name = self._get_attribute_full_path(node.func)
+ # Make sure the func is marked as being part of a call
+ node.func.is_function_for_call = True
+
if full_name and full_name.startswith("tf."):
# Call special handlers
function_handles = self._api_change_spec.function_handle
@@ -356,27 +426,60 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
# Examine any non-keyword argument and make it into a keyword argument
# if reordering required.
function_reorders = self._api_change_spec.function_reorders
+ function_keyword_renames = (
+ self._api_change_spec.function_keyword_renames)
+
if full_name in function_reorders:
reordered = function_reorders[full_name]
for idx, arg in enumerate(node.args):
- self._file_edit.add("Added keyword %r to reordered function %r"
- % (reordered[idx], full_name), arg.lineno,
- arg.col_offset, "", reordered[idx] + "=")
+ lineno, col_offset = self._find_true_position(arg)
+ if lineno is None or col_offset is None:
+ self._file_edit.add(
+ "Failed to add keyword %r to reordered function %r"
+ % (reordered[idx], full_name), arg.lineno, arg.col_offset,
+ "", "",
+ error="A necessary keyword argument failed to be inserted.")
+ else:
+ keyword_arg = reordered[idx]
+ if (full_name in function_keyword_renames and
+ keyword_arg in function_keyword_renames[full_name]):
+ keyword_arg = function_keyword_renames[full_name][keyword_arg]
+ self._file_edit.add("Added keyword %r to reordered function %r"
+ % (reordered[idx], full_name), lineno,
+ col_offset, "", keyword_arg + "=")
# Examine each keyword argument and convert it to the final renamed form
- function_keyword_renames = (
- self._api_change_spec.function_keyword_renames)
renamed_keywords = ({} if full_name not in function_keyword_renames else
function_keyword_renames[full_name])
for keyword in node.keywords:
argkey = keyword.arg
argval = keyword.value
+
if argkey in renamed_keywords:
- self._file_edit.add("Renamed keyword argument from %r to %r" %
+ argval_lineno, argval_col_offset = self._find_true_position(argval)
+ if (argval_lineno is not None and argval_col_offset is not None):
+ # TODO(aselle): We should scan backward to find the start of the
+ # keyword key. Unfortunately ast does not give you the location of
+ # keyword keys, so we are forced to infer it from the keyword arg
+ # value.
+ key_start = argval_col_offset - len(argkey) - 1
+ key_end = key_start + len(argkey) + 1
+ if self._lines[argval_lineno - 1][key_start:key_end] == argkey + "=":
+ self._file_edit.add("Renamed keyword argument from %r to %r" %
(argkey, renamed_keywords[argkey]),
- argval.lineno,
- argval.col_offset - len(argkey) - 1,
+ argval_lineno,
+ argval_col_offset - len(argkey) - 1,
argkey + "=", renamed_keywords[argkey] + "=")
+ continue
+ self._file_edit.add(
+ "Failed to rename keyword argument from %r to %r" %
+ (argkey, renamed_keywords[argkey]),
+ argval.lineno,
+ argval.col_offset - len(argkey) - 1,
+ "", "",
+ error="Failed to find keyword lexographically. Fix manually.")
+
+ ast.NodeVisitor.generic_visit(self, node)
def visit_Attribute(self, node): # pylint: disable=invalid-name
"""Handle bare Attributes i.e. [tf.foo, tf.bar].
@@ -387,6 +490,11 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
full_name = self._get_attribute_full_path(node)
if full_name and full_name.startswith("tf."):
self._rename_functions(node, full_name)
+ if full_name in self._api_change_spec.change_to_function:
+ if not hasattr(node, "is_function_for_call"):
+ new_text = full_name + "()"
+ self._file_edit.add("Changed %r to %r"%(full_name, new_text),
+ node.lineno, node.col_offset, full_name, new_text)
ast.NodeVisitor.generic_visit(self, node)