diff options
author | Benoit Steiner <bsteiner@google.com> | 2017-02-08 09:25:09 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-08 09:50:05 -0800 |
commit | 639b4e71f532761a4840b1cdbaea55ad0917c75b (patch) | |
tree | 5116415b1d9ff82f054dd4feeadd81cb833d6435 /tensorflow/tools/compatibility | |
parent | 15ff7b702788c0cf75bb8d5ce090f06490098cf7 (diff) |
Merge changes from github.
Change: 146918929
Diffstat (limited to 'tensorflow/tools/compatibility')
-rw-r--r-- | tensorflow/tools/compatibility/README.md | 15 | ||||
-rw-r--r-- | tensorflow/tools/compatibility/tf_upgrade.py | 128 | ||||
-rw-r--r-- | tensorflow/tools/compatibility/tf_upgrade_test.py | 48 |
3 files changed, 177 insertions, 14 deletions
diff --git a/tensorflow/tools/compatibility/README.md b/tensorflow/tools/compatibility/README.md index 3b66e73f9a..77e27531a9 100644 --- a/tensorflow/tools/compatibility/README.md +++ b/tensorflow/tools/compatibility/README.md @@ -36,6 +36,9 @@ particular, functions that have had reordered arguments like `tf.concat`, `tf.split` will cause the script to incorrectly add keyword arguments that mismap arguments. +- This script wouldn't actually reorder arguments. Instead, the script will add +keyword arguments to functions that had their arguments reordered. + - This script is not able to upgrade all functions. One notable example is `tf.reverse()` which has been changed to take a list of indices rather than a tensor of bools. If the script detects this, it will report this to stdout @@ -43,6 +46,12 @@ a tensor of bools. If the script detects this, it will report this to stdout `tf.reverse(a, [False, True, True])` you will need to manually change it to `tf.reverse(a, [1, 2])`. - - - +- There are some syntaxes that are not handleable with this script as this +script was designed to use only standard python packages. If the script fails +with "A necessary keyword argument failed to be inserted." or +"Failed to find keyword lexicographically. Fix manually.", you can try +[@machrisaa's fork of this script](https://github.com/machrisaa/tf0to1). +[@machrisaa](https://github.com/machrisaa) has used the +[RedBaron Python refactoring engine](https://redbaron.readthedocs.io/en/latest/) +which is able to localize syntactic elements more reliably than the built-in +`ast` module this script is based upon. diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py index 374be0475a..bcff10f21d 100644 --- a/tensorflow/tools/compatibility/tf_upgrade.py +++ b/tensorflow/tools/compatibility/tf_upgrade.py @@ -95,11 +95,15 @@ class APIChangeSpec(object): "tf.split": { "split_dim": "axis", "num_split": "num_or_size_splits" - } + }, + "tf.concat": { + "concat_dim": "axis" + }, } # Mapping from function to the new name of the function self.function_renames = { + "tf.inv": "tf.reciprocal", "tf.contrib.deprecated.scalar_summary": "tf.summary.scalar", "tf.contrib.deprecated.histogram_summary": "tf.summary.histogram", "tf.listdiff": "tf.setdiff1d", @@ -142,6 +146,13 @@ class APIChangeSpec(object): "tf.select": "tf.where", "tf.complex_abs": "tf.abs", "tf.batch_matmul": "tf.matmul", + "tf.pack": "tf.stack", + "tf.unpack": "tf.unstack", + } + + self.change_to_function = { + "tf.ones_initializer", + "tf.zeros_initializer", } # Functions that were reordered should be changed to the new keyword args @@ -149,6 +160,7 @@ class APIChangeSpec(object): # positional arguments yourself, this could do the wrong thing. self.function_reorders = { "tf.split": ["axis", "num_or_size_splits", "value", "name"], + "tf.sparse_split": ["axis", "num_or_size_splits", "value", "name"], "tf.concat": ["concat_dim", "values", "name"], "tf.svd": ["tensor", "compute_uv", "full_matrices", "name"], "tf.nn.softmax_cross_entropy_with_logits": [ @@ -335,6 +347,62 @@ class TensorFlowCallVisitor(ast.NodeVisitor): items.append(curr.id) return ".".join(reversed(items)) + def _find_true_position(self, node): + """Return correct line number and column offset for a given node. + + This is necessary mainly because ListComp's location reporting reports + the next token after the list comprehension list opening. + + Args: + node: Node for which we wish to know the lineno and col_offset + """ + import re + find_open = re.compile("^\s*(\\[).*$") + find_string_chars = re.compile("['\"]") + + if isinstance(node, ast.ListComp): + # Strangely, ast.ListComp returns the col_offset of the first token + # after the '[' token which appears to be a bug. Workaround by + # explicitly finding the real start of the list comprehension. + line = node.lineno + col = node.col_offset + # loop over lines + while 1: + # Reverse the text to and regular expression search for whitespace + text = self._lines[line-1] + reversed_preceding_text = text[:col][::-1] + # First find if a [ can be found with only whitespace between it and + # col. + m = find_open.match(reversed_preceding_text) + if m: + new_col_offset = col - m.start(1) - 1 + return line, new_col_offset + else: + if (reversed_preceding_text=="" or + reversed_preceding_text.isspace()): + line = line - 1 + prev_line = self._lines[line - 1] + # TODO(aselle): + # this is poor comment detection, but it is good enough for + # cases where the comment does not contain string literal starting/ + # ending characters. If ast gave us start and end locations of the + # ast nodes rather than just start, we could use string literal + # node ranges to filter out spurious #'s that appear in string + # literals. + comment_start = prev_line.find("#") + if comment_start == -1: + col = len(prev_line) -1 + elif find_string_chars.search(prev_line[comment_start:]) is None: + col = comment_start + else: + return None, None + else: + return None, None + # Most other nodes return proper locations (with notably does not), but + # it is not possible to use that in an argument. + return node.lineno, node.col_offset + + def visit_Call(self, node): # pylint: disable=invalid-name """Handle visiting a call node in the AST. @@ -342,11 +410,13 @@ class TensorFlowCallVisitor(ast.NodeVisitor): node: Current Node """ - ast.NodeVisitor.generic_visit(self, node) # Find a simple attribute name path e.g. "tf.foo.bar" full_name = self._get_attribute_full_path(node.func) + # Make sure the func is marked as being part of a call + node.func.is_function_for_call = True + if full_name and full_name.startswith("tf."): # Call special handlers function_handles = self._api_change_spec.function_handle @@ -356,27 +426,60 @@ class TensorFlowCallVisitor(ast.NodeVisitor): # Examine any non-keyword argument and make it into a keyword argument # if reordering required. function_reorders = self._api_change_spec.function_reorders + function_keyword_renames = ( + self._api_change_spec.function_keyword_renames) + if full_name in function_reorders: reordered = function_reorders[full_name] for idx, arg in enumerate(node.args): - self._file_edit.add("Added keyword %r to reordered function %r" - % (reordered[idx], full_name), arg.lineno, - arg.col_offset, "", reordered[idx] + "=") + lineno, col_offset = self._find_true_position(arg) + if lineno is None or col_offset is None: + self._file_edit.add( + "Failed to add keyword %r to reordered function %r" + % (reordered[idx], full_name), arg.lineno, arg.col_offset, + "", "", + error="A necessary keyword argument failed to be inserted.") + else: + keyword_arg = reordered[idx] + if (full_name in function_keyword_renames and + keyword_arg in function_keyword_renames[full_name]): + keyword_arg = function_keyword_renames[full_name][keyword_arg] + self._file_edit.add("Added keyword %r to reordered function %r" + % (reordered[idx], full_name), lineno, + col_offset, "", keyword_arg + "=") # Examine each keyword argument and convert it to the final renamed form - function_keyword_renames = ( - self._api_change_spec.function_keyword_renames) renamed_keywords = ({} if full_name not in function_keyword_renames else function_keyword_renames[full_name]) for keyword in node.keywords: argkey = keyword.arg argval = keyword.value + if argkey in renamed_keywords: - self._file_edit.add("Renamed keyword argument from %r to %r" % + argval_lineno, argval_col_offset = self._find_true_position(argval) + if (argval_lineno is not None and argval_col_offset is not None): + # TODO(aselle): We should scan backward to find the start of the + # keyword key. Unfortunately ast does not give you the location of + # keyword keys, so we are forced to infer it from the keyword arg + # value. + key_start = argval_col_offset - len(argkey) - 1 + key_end = key_start + len(argkey) + 1 + if self._lines[argval_lineno - 1][key_start:key_end] == argkey + "=": + self._file_edit.add("Renamed keyword argument from %r to %r" % (argkey, renamed_keywords[argkey]), - argval.lineno, - argval.col_offset - len(argkey) - 1, + argval_lineno, + argval_col_offset - len(argkey) - 1, argkey + "=", renamed_keywords[argkey] + "=") + continue + self._file_edit.add( + "Failed to rename keyword argument from %r to %r" % + (argkey, renamed_keywords[argkey]), + argval.lineno, + argval.col_offset - len(argkey) - 1, + "", "", + error="Failed to find keyword lexographically. Fix manually.") + + ast.NodeVisitor.generic_visit(self, node) def visit_Attribute(self, node): # pylint: disable=invalid-name """Handle bare Attributes i.e. [tf.foo, tf.bar]. @@ -387,6 +490,11 @@ class TensorFlowCallVisitor(ast.NodeVisitor): full_name = self._get_attribute_full_path(node) if full_name and full_name.startswith("tf."): self._rename_functions(node, full_name) + if full_name in self._api_change_spec.change_to_function: + if not hasattr(node, "is_function_for_call"): + new_text = full_name + "()" + self._file_edit.add("Changed %r to %r"%(full_name, new_text), + node.lineno, node.col_offset, full_name, new_text) ast.NodeVisitor.generic_visit(self, node) diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py index 286c70f612..de4e3de73c 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_test.py @@ -59,12 +59,45 @@ class TestUpgrade(test_util.TensorFlowTestCase): _, unused_report, unused_errors, new_text = self._upgrade(text) self.assertEqual(new_text, "tf.multiply(a, tf.subtract(b, c))\n") + def testRenamePack(self): + text = "tf.pack(a)\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.stack(a)\n") + text = "tf.unpack(a)\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.unstack(a)\n") + def testReorder(self): text = "tf.concat(a, b)\ntf.split(a, b, c)\n" _, unused_report, unused_errors, new_text = self._upgrade(text) - self.assertEqual(new_text, "tf.concat(concat_dim=a, values=b)\n" + self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n" "tf.split(axis=a, num_or_size_splits=b, value=c)\n") + def testConcatReorderWithKeywordArgs(self): + text = "tf.concat(concat_dim=a, values=b)\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n") + text = "tf.concat(values=b, concat_dim=a)\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.concat(values=b, axis=a)\n") + text = "tf.concat(a, values=b)\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n") + + def testConcatReorderNested(self): + text = "tf.concat(a, tf.concat(c, d))\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual( + new_text, "tf.concat(axis=a, values=tf.concat(axis=c, values=d))\n") + + def testInitializers(self): + text = ("tf.zeros_initializer;tf.zeros_initializer ()\n" + "tf.ones_initializer;tf.ones_initializer ()\n") + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual( + new_text, "tf.zeros_initializer();tf.zeros_initializer ()\n" + "tf.ones_initializer();tf.ones_initializer ()\n") + def testKeyword(self): text = "tf.reduce_any(a, reduction_indices=[1, 2])\n" _, unused_report, unused_errors, new_text = self._upgrade(text) @@ -80,6 +113,19 @@ class TestUpgrade(test_util.TensorFlowTestCase): self.assertEqual(new_text, new_text) self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."]) + def testListComprehension(self): + def _test(input, output): + _, unused_report, errors, new_text = self._upgrade(input) + self.assertEqual(new_text, output) + _test("tf.concat(0, \t[x for x in y])\n", + "tf.concat(axis=0, \tvalues=[x for x in y])\n") + _test("tf.concat(0,[x for x in y])\n", + "tf.concat(axis=0,values=[x for x in y])\n") + _test("tf.concat(0,[\nx for x in y])\n", + "tf.concat(axis=0,values=[\nx for x in y])\n") + _test("tf.concat(0,[\n \tx for x in y])\n", + "tf.concat(axis=0,values=[\n \tx for x in y])\n") + # TODO(aselle): Explicitly not testing command line interface and process_tree # for now, since this is a one off utility. |