aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/compatibility
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-02-08 09:25:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-08 09:50:05 -0800
commit639b4e71f532761a4840b1cdbaea55ad0917c75b (patch)
tree5116415b1d9ff82f054dd4feeadd81cb833d6435 /tensorflow/tools/compatibility
parent15ff7b702788c0cf75bb8d5ce090f06490098cf7 (diff)
Merge changes from github.
Change: 146918929
Diffstat (limited to 'tensorflow/tools/compatibility')
-rw-r--r--tensorflow/tools/compatibility/README.md15
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade.py128
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade_test.py48
3 files changed, 177 insertions, 14 deletions
diff --git a/tensorflow/tools/compatibility/README.md b/tensorflow/tools/compatibility/README.md
index 3b66e73f9a..77e27531a9 100644
--- a/tensorflow/tools/compatibility/README.md
+++ b/tensorflow/tools/compatibility/README.md
@@ -36,6 +36,9 @@ particular, functions that have had reordered arguments like `tf.concat`,
`tf.split` will cause the script to incorrectly add keyword arguments that
mismap arguments.
+- This script wouldn't actually reorder arguments. Instead, the script will add
+keyword arguments to functions that had their arguments reordered.
+
- This script is not able to upgrade all functions. One notable example is
`tf.reverse()` which has been changed to take a list of indices rather than
a tensor of bools. If the script detects this, it will report this to stdout
@@ -43,6 +46,12 @@ a tensor of bools. If the script detects this, it will report this to stdout
`tf.reverse(a, [False, True, True])` you will need to manually change it to
`tf.reverse(a, [1, 2])`.
-
-
-
+- There are some syntaxes that are not handleable with this script as this
+script was designed to use only standard python packages. If the script fails
+with "A necessary keyword argument failed to be inserted." or
+"Failed to find keyword lexicographically. Fix manually.", you can try
+[@machrisaa's fork of this script](https://github.com/machrisaa/tf0to1).
+[@machrisaa](https://github.com/machrisaa) has used the
+[RedBaron Python refactoring engine](https://redbaron.readthedocs.io/en/latest/)
+which is able to localize syntactic elements more reliably than the built-in
+`ast` module this script is based upon.
diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py
index 374be0475a..bcff10f21d 100644
--- a/tensorflow/tools/compatibility/tf_upgrade.py
+++ b/tensorflow/tools/compatibility/tf_upgrade.py
@@ -95,11 +95,15 @@ class APIChangeSpec(object):
"tf.split": {
"split_dim": "axis",
"num_split": "num_or_size_splits"
- }
+ },
+ "tf.concat": {
+ "concat_dim": "axis"
+ },
}
# Mapping from function to the new name of the function
self.function_renames = {
+ "tf.inv": "tf.reciprocal",
"tf.contrib.deprecated.scalar_summary": "tf.summary.scalar",
"tf.contrib.deprecated.histogram_summary": "tf.summary.histogram",
"tf.listdiff": "tf.setdiff1d",
@@ -142,6 +146,13 @@ class APIChangeSpec(object):
"tf.select": "tf.where",
"tf.complex_abs": "tf.abs",
"tf.batch_matmul": "tf.matmul",
+ "tf.pack": "tf.stack",
+ "tf.unpack": "tf.unstack",
+ }
+
+ self.change_to_function = {
+ "tf.ones_initializer",
+ "tf.zeros_initializer",
}
# Functions that were reordered should be changed to the new keyword args
@@ -149,6 +160,7 @@ class APIChangeSpec(object):
# positional arguments yourself, this could do the wrong thing.
self.function_reorders = {
"tf.split": ["axis", "num_or_size_splits", "value", "name"],
+ "tf.sparse_split": ["axis", "num_or_size_splits", "value", "name"],
"tf.concat": ["concat_dim", "values", "name"],
"tf.svd": ["tensor", "compute_uv", "full_matrices", "name"],
"tf.nn.softmax_cross_entropy_with_logits": [
@@ -335,6 +347,62 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
items.append(curr.id)
return ".".join(reversed(items))
+ def _find_true_position(self, node):
+ """Return correct line number and column offset for a given node.
+
+ This is necessary mainly because ListComp's location reporting reports
+ the next token after the list comprehension list opening.
+
+ Args:
+ node: Node for which we wish to know the lineno and col_offset
+ """
+ import re
+ find_open = re.compile("^\s*(\\[).*$")
+ find_string_chars = re.compile("['\"]")
+
+ if isinstance(node, ast.ListComp):
+ # Strangely, ast.ListComp returns the col_offset of the first token
+ # after the '[' token which appears to be a bug. Workaround by
+ # explicitly finding the real start of the list comprehension.
+ line = node.lineno
+ col = node.col_offset
+ # loop over lines
+ while 1:
+ # Reverse the text to and regular expression search for whitespace
+ text = self._lines[line-1]
+ reversed_preceding_text = text[:col][::-1]
+ # First find if a [ can be found with only whitespace between it and
+ # col.
+ m = find_open.match(reversed_preceding_text)
+ if m:
+ new_col_offset = col - m.start(1) - 1
+ return line, new_col_offset
+ else:
+ if (reversed_preceding_text=="" or
+ reversed_preceding_text.isspace()):
+ line = line - 1
+ prev_line = self._lines[line - 1]
+ # TODO(aselle):
+ # this is poor comment detection, but it is good enough for
+ # cases where the comment does not contain string literal starting/
+ # ending characters. If ast gave us start and end locations of the
+ # ast nodes rather than just start, we could use string literal
+ # node ranges to filter out spurious #'s that appear in string
+ # literals.
+ comment_start = prev_line.find("#")
+ if comment_start == -1:
+ col = len(prev_line) -1
+ elif find_string_chars.search(prev_line[comment_start:]) is None:
+ col = comment_start
+ else:
+ return None, None
+ else:
+ return None, None
+ # Most other nodes return proper locations (with notably does not), but
+ # it is not possible to use that in an argument.
+ return node.lineno, node.col_offset
+
+
def visit_Call(self, node): # pylint: disable=invalid-name
"""Handle visiting a call node in the AST.
@@ -342,11 +410,13 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
node: Current Node
"""
- ast.NodeVisitor.generic_visit(self, node)
# Find a simple attribute name path e.g. "tf.foo.bar"
full_name = self._get_attribute_full_path(node.func)
+ # Make sure the func is marked as being part of a call
+ node.func.is_function_for_call = True
+
if full_name and full_name.startswith("tf."):
# Call special handlers
function_handles = self._api_change_spec.function_handle
@@ -356,27 +426,60 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
# Examine any non-keyword argument and make it into a keyword argument
# if reordering required.
function_reorders = self._api_change_spec.function_reorders
+ function_keyword_renames = (
+ self._api_change_spec.function_keyword_renames)
+
if full_name in function_reorders:
reordered = function_reorders[full_name]
for idx, arg in enumerate(node.args):
- self._file_edit.add("Added keyword %r to reordered function %r"
- % (reordered[idx], full_name), arg.lineno,
- arg.col_offset, "", reordered[idx] + "=")
+ lineno, col_offset = self._find_true_position(arg)
+ if lineno is None or col_offset is None:
+ self._file_edit.add(
+ "Failed to add keyword %r to reordered function %r"
+ % (reordered[idx], full_name), arg.lineno, arg.col_offset,
+ "", "",
+ error="A necessary keyword argument failed to be inserted.")
+ else:
+ keyword_arg = reordered[idx]
+ if (full_name in function_keyword_renames and
+ keyword_arg in function_keyword_renames[full_name]):
+ keyword_arg = function_keyword_renames[full_name][keyword_arg]
+ self._file_edit.add("Added keyword %r to reordered function %r"
+ % (reordered[idx], full_name), lineno,
+ col_offset, "", keyword_arg + "=")
# Examine each keyword argument and convert it to the final renamed form
- function_keyword_renames = (
- self._api_change_spec.function_keyword_renames)
renamed_keywords = ({} if full_name not in function_keyword_renames else
function_keyword_renames[full_name])
for keyword in node.keywords:
argkey = keyword.arg
argval = keyword.value
+
if argkey in renamed_keywords:
- self._file_edit.add("Renamed keyword argument from %r to %r" %
+ argval_lineno, argval_col_offset = self._find_true_position(argval)
+ if (argval_lineno is not None and argval_col_offset is not None):
+ # TODO(aselle): We should scan backward to find the start of the
+ # keyword key. Unfortunately ast does not give you the location of
+ # keyword keys, so we are forced to infer it from the keyword arg
+ # value.
+ key_start = argval_col_offset - len(argkey) - 1
+ key_end = key_start + len(argkey) + 1
+ if self._lines[argval_lineno - 1][key_start:key_end] == argkey + "=":
+ self._file_edit.add("Renamed keyword argument from %r to %r" %
(argkey, renamed_keywords[argkey]),
- argval.lineno,
- argval.col_offset - len(argkey) - 1,
+ argval_lineno,
+ argval_col_offset - len(argkey) - 1,
argkey + "=", renamed_keywords[argkey] + "=")
+ continue
+ self._file_edit.add(
+ "Failed to rename keyword argument from %r to %r" %
+ (argkey, renamed_keywords[argkey]),
+ argval.lineno,
+ argval.col_offset - len(argkey) - 1,
+ "", "",
+ error="Failed to find keyword lexographically. Fix manually.")
+
+ ast.NodeVisitor.generic_visit(self, node)
def visit_Attribute(self, node): # pylint: disable=invalid-name
"""Handle bare Attributes i.e. [tf.foo, tf.bar].
@@ -387,6 +490,11 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
full_name = self._get_attribute_full_path(node)
if full_name and full_name.startswith("tf."):
self._rename_functions(node, full_name)
+ if full_name in self._api_change_spec.change_to_function:
+ if not hasattr(node, "is_function_for_call"):
+ new_text = full_name + "()"
+ self._file_edit.add("Changed %r to %r"%(full_name, new_text),
+ node.lineno, node.col_offset, full_name, new_text)
ast.NodeVisitor.generic_visit(self, node)
diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py
index 286c70f612..de4e3de73c 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_test.py
@@ -59,12 +59,45 @@ class TestUpgrade(test_util.TensorFlowTestCase):
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, "tf.multiply(a, tf.subtract(b, c))\n")
+ def testRenamePack(self):
+ text = "tf.pack(a)\n"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, "tf.stack(a)\n")
+ text = "tf.unpack(a)\n"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, "tf.unstack(a)\n")
+
def testReorder(self):
text = "tf.concat(a, b)\ntf.split(a, b, c)\n"
_, unused_report, unused_errors, new_text = self._upgrade(text)
- self.assertEqual(new_text, "tf.concat(concat_dim=a, values=b)\n"
+ self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n"
"tf.split(axis=a, num_or_size_splits=b, value=c)\n")
+ def testConcatReorderWithKeywordArgs(self):
+ text = "tf.concat(concat_dim=a, values=b)\n"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n")
+ text = "tf.concat(values=b, concat_dim=a)\n"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, "tf.concat(values=b, axis=a)\n")
+ text = "tf.concat(a, values=b)\n"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n")
+
+ def testConcatReorderNested(self):
+ text = "tf.concat(a, tf.concat(c, d))\n"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(
+ new_text, "tf.concat(axis=a, values=tf.concat(axis=c, values=d))\n")
+
+ def testInitializers(self):
+ text = ("tf.zeros_initializer;tf.zeros_initializer ()\n"
+ "tf.ones_initializer;tf.ones_initializer ()\n")
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(
+ new_text, "tf.zeros_initializer();tf.zeros_initializer ()\n"
+ "tf.ones_initializer();tf.ones_initializer ()\n")
+
def testKeyword(self):
text = "tf.reduce_any(a, reduction_indices=[1, 2])\n"
_, unused_report, unused_errors, new_text = self._upgrade(text)
@@ -80,6 +113,19 @@ class TestUpgrade(test_util.TensorFlowTestCase):
self.assertEqual(new_text, new_text)
self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."])
+ def testListComprehension(self):
+ def _test(input, output):
+ _, unused_report, errors, new_text = self._upgrade(input)
+ self.assertEqual(new_text, output)
+ _test("tf.concat(0, \t[x for x in y])\n",
+ "tf.concat(axis=0, \tvalues=[x for x in y])\n")
+ _test("tf.concat(0,[x for x in y])\n",
+ "tf.concat(axis=0,values=[x for x in y])\n")
+ _test("tf.concat(0,[\nx for x in y])\n",
+ "tf.concat(axis=0,values=[\nx for x in y])\n")
+ _test("tf.concat(0,[\n \tx for x in y])\n",
+ "tf.concat(axis=0,values=[\n \tx for x in y])\n")
+
# TODO(aselle): Explicitly not testing command line interface and process_tree
# for now, since this is a one off utility.