diff options
author | 2017-01-12 15:49:07 -0800 | |
---|---|---|
committer | 2017-01-12 16:04:50 -0800 | |
commit | e2fa82466e55cd52f24f7daee1bd1d45a58b5282 (patch) | |
tree | 3bce9226edac46f22284b5ca1c0e14e4a29bfe7f /tensorflow/tools/compatibility | |
parent | 3b5a37e14c49dd1a0128eb5b02a423b9a7956321 (diff) |
Improve TensorFlow upgrade script
- Handle more functions:
tf.svd
tf.batch_matmul
tf.nn.softmax_cross_entropy_with_logits,
tf.nn.sparse_softmax_cross_entropy_with_logits,
tf.nn.sigmoid_cross_entropy_with_logits": [
- Handle in-place file modification correctly (and add test).
- Handle raw attribute lookups i.e. lists of functions
`foo = [tf.mul]` can be upgraded to `foo = [tf.multiply]`
Change: 144381716
Diffstat (limited to 'tensorflow/tools/compatibility')
-rw-r--r-- | tensorflow/tools/compatibility/testdata/test_file_v0_11.py | 27 | ||||
-rw-r--r-- | tensorflow/tools/compatibility/tf_upgrade.py | 175 | ||||
-rw-r--r-- | tensorflow/tools/compatibility/tf_upgrade_test.py | 18 |
3 files changed, 144 insertions, 76 deletions
diff --git a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py index 37d914c648..01f37d8768 100644 --- a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py +++ b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py @@ -163,6 +163,33 @@ class TestUpgrade(test_util.TensorFlowTestCase): # # TODO(aselle): (tf.batch_*) # ] + def testBatchAndSvd(self): + with self.test_session(): + mat = [[1., 2.], [2., 3.]] + batched_mat = tf.expand_dims(mat, [0]) + result = tf.matmul(mat, mat).eval() + result_batched = tf.batch_matmul(batched_mat, batched_mat).eval() + self.assertAllEqual(result_batched, np.expand_dims(result, 0)) + self.assertAllEqual( + tf.svd(mat, False, True).eval(), + tf.svd(mat, compute_uv=False, full_matrices=True).eval()) + + def testCrossEntropy(self): + # TODO(aselle): Test sparse_softmax_... + with self.test_session(): + labels = [.8, .5, .2, .1] + logits = [.9, .1, .3, .1] + self.assertAllEqual( + tf.nn.softmax_cross_entropy_with_logits( + logits, labels).eval(), + tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits).eval()) + self.assertAllEqual( + tf.nn.sigmoid_cross_entropy_with_logits( + logits, labels).eval(), + tf.nn.sigmoid_cross_entropy_with_logits( + labels=labels, logits=logits).eval()) + def testVariables(self): with self.test_session() as s: diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py index 223f8cd5f5..374be0475a 100644 --- a/tensorflow/tools/compatibility/tf_upgrade.py +++ b/tensorflow/tools/compatibility/tf_upgrade.py @@ -21,14 +21,11 @@ import argparse import ast import collections import os +import shutil import sys +import tempfile import traceback -# TODO(aselle): Add SVD, Concat -# TODO(aselle): summary merge all (can we detect this?) -# TODO(aselle): batch_matmul -# TODO(wicke): tf.nn.{softmax,sparse_softmax,sigmoid}_cross_entropy_with_logits? - class APIChangeSpec(object): """List of maps that describe what changed in the API.""" @@ -143,7 +140,8 @@ class APIChangeSpec(object): "tf.batch_fft3d": "tf.fft3d", "tf.batch_ifft3d": "tf.ifft3d", "tf.select": "tf.where", - "tf.complex_abs": "tf.abs" + "tf.complex_abs": "tf.abs", + "tf.batch_matmul": "tf.matmul", } # Functions that were reordered should be changed to the new keyword args @@ -151,7 +149,14 @@ class APIChangeSpec(object): # positional arguments yourself, this could do the wrong thing. self.function_reorders = { "tf.split": ["axis", "num_or_size_splits", "value", "name"], - "tf.concat": ["concat_dim", "values", "name"] + "tf.concat": ["concat_dim", "values", "name"], + "tf.svd": ["tensor", "compute_uv", "full_matrices", "name"], + "tf.nn.softmax_cross_entropy_with_logits": [ + "logits", "labels", "dim", "name"], + "tf.nn.sparse_softmax_cross_entropy_with_logits": [ + "logits", "labels", "name"], + "tf.nn.sigmoid_cross_entropy_with_logits": [ + "logits", "labels", "name"] } # Specially handled functions. @@ -223,7 +228,7 @@ class FileEditRecorder(object): char_array = list(text[line - 1]) # Record a description of the change - change_report += "%s Line %d\n" % (self._filename, line) + change_report += "%r Line %d\n" % (self._filename, line) change_report += "-" * 80 + "\n\n" for e in edits: change_report += "%s\n" % e.comment @@ -243,7 +248,7 @@ class FileEditRecorder(object): # Make sure the edit is changing what it should be changing old_actual = "".join(char_array[start_eff:end_eff]) if old_actual != e.old: - raise ValueError("Expected text '%s' but got '%s'" % + raise ValueError("Expected text %r but got %r" % ("".join(e.old), "".join(old_actual))) # Make the edit char_array[start_eff:end_eff] = list(e.new) @@ -278,7 +283,7 @@ class FileEditRecorder(object): self._line_to_edit[line].append( FileEditTuple(comment, line, start, old, new)) - if error is not None: + if error: self._errors.append("%s:%d: %s" % (self._filename, line, error)) @@ -302,11 +307,33 @@ class TensorFlowCallVisitor(ast.NodeVisitor): def _rename_functions(self, node, full_name): function_renames = self._api_change_spec.function_renames - if full_name in function_renames: + try: new_name = function_renames[full_name] - self._file_edit.add("Renamed function `%s` to `%s`" % (full_name, - new_name), + self._file_edit.add("Renamed function %r to %r" % (full_name, + new_name), node.lineno, node.col_offset, full_name, new_name) + except KeyError: + pass + + def _get_attribute_full_path(self, node): + """Traverse an attribute to generate a full name e.g. tf.foo.bar. + + Args: + node: A Node of type Attribute. + + Returns: + a '.'-delimited full-name or None if the tree was not a simple form. + i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c". + """ + curr = node + items = [] + while not isinstance(curr, ast.Name): + if not isinstance(curr, ast.Attribute): + return None + items.append(curr.attr) + curr = curr.value + items.append(curr.id) + return ".".join(reversed(items)) def visit_Call(self, node): # pylint: disable=invalid-name """Handle visiting a call node in the AST. @@ -315,59 +342,51 @@ class TensorFlowCallVisitor(ast.NodeVisitor): node: Current Node """ - # Find call string (this is not perfectly accurate, - # but should cover tf.x*) - curr = node.func - items = [] - valid = True - while not isinstance(curr, ast.Name): - if isinstance(curr, ast.Attribute): - items.append(curr.attr) - else: - # We cannot just return, because we need to keep walking. - # TODO(aselle): Would it be cleaner to use an exception here with else? - valid = False - break - curr = curr.value - if valid: - items.append(curr.id) - - if valid: - # Conversion logic - full_name = ".".join(items[::-1]) - if full_name.startswith("tf."): - # Call special handlers - function_handles = self._api_change_spec.function_handle - if full_name in function_handles: - function_handles[full_name](self._file_edit, node) - - # Check for renames - self._rename_functions(node, full_name) - - # Examine any non-keyword argument and make it into a keyword argument - # if reordering required. - function_reorders = self._api_change_spec.function_reorders - if full_name in function_reorders: - reordered = function_reorders[full_name] - for idx, arg in enumerate(node.args): - self._file_edit.add("Added keyword `%s` to reordered function `%s`" - % (reordered[idx], full_name), arg.lineno, - arg.col_offset, "", reordered[idx] + "=") - - # Examine each keyword argument and convert it to the final renamed form - function_keyword_renames = ( - self._api_change_spec.function_keyword_renames) - renamed_keywords = ({} if full_name not in function_keyword_renames else - function_keyword_renames[full_name]) - for keyword in node.keywords: - argkey = keyword.arg - argval = keyword.value - if argkey in renamed_keywords: - self._file_edit.add("Renamed keyword argument from `%s` to `%s`" % - (argkey, renamed_keywords[argkey]), - argval.lineno, - argval.col_offset - len(argkey) - 1, - argkey + "=", renamed_keywords[argkey] + "=") + ast.NodeVisitor.generic_visit(self, node) + + # Find a simple attribute name path e.g. "tf.foo.bar" + full_name = self._get_attribute_full_path(node.func) + + if full_name and full_name.startswith("tf."): + # Call special handlers + function_handles = self._api_change_spec.function_handle + if full_name in function_handles: + function_handles[full_name](self._file_edit, node) + + # Examine any non-keyword argument and make it into a keyword argument + # if reordering required. + function_reorders = self._api_change_spec.function_reorders + if full_name in function_reorders: + reordered = function_reorders[full_name] + for idx, arg in enumerate(node.args): + self._file_edit.add("Added keyword %r to reordered function %r" + % (reordered[idx], full_name), arg.lineno, + arg.col_offset, "", reordered[idx] + "=") + + # Examine each keyword argument and convert it to the final renamed form + function_keyword_renames = ( + self._api_change_spec.function_keyword_renames) + renamed_keywords = ({} if full_name not in function_keyword_renames else + function_keyword_renames[full_name]) + for keyword in node.keywords: + argkey = keyword.arg + argval = keyword.value + if argkey in renamed_keywords: + self._file_edit.add("Renamed keyword argument from %r to %r" % + (argkey, renamed_keywords[argkey]), + argval.lineno, + argval.col_offset - len(argkey) - 1, + argkey + "=", renamed_keywords[argkey] + "=") + + def visit_Attribute(self, node): # pylint: disable=invalid-name + """Handle bare Attributes i.e. [tf.foo, tf.bar]. + + Args: + node: Node that is of type ast.Attribute + """ + full_name = self._get_attribute_full_path(node) + if full_name and full_name.startswith("tf."): + self._rename_functions(node, full_name) ast.NodeVisitor.generic_visit(self, node) @@ -387,11 +406,15 @@ class TensorFlowCodeUpgrader(object): Returns: A tuple representing number of files processed, log of actions, errors """ - in_file = open(in_filename, "r") - out_file = open(out_filename, "w") if out_filename else None - return self.process_opened_file( - in_filename, in_file, out_filename, out_file) + # Write to a temporary file, just in case we are doing an implace modify. + with open(in_filename, "r") as in_file, \ + tempfile.NamedTemporaryFile("w", delete=False) as temp_file: + ret = self.process_opened_file( + in_filename, in_file, out_filename, temp_file) + + shutil.move(temp_file.name, out_filename) + return ret # Broad exceptions are required here because ast throws whatever it wants. # pylint: disable=broad-except @@ -411,7 +434,7 @@ class TensorFlowCodeUpgrader(object): """ process_errors = [] text = "-" * 80 + "\n" - text += "Processing file %s\n outputting to %s\n" % (in_filename, + text += "Processing file %r\n outputting to %r\n" % (in_filename, out_filename) text += "-" * 80 + "\n\n" @@ -420,7 +443,7 @@ class TensorFlowCodeUpgrader(object): try: parsed_ast = ast.parse("".join(lines)) except Exception: - text += "Failed to parse %s\n\n" % in_filename + text += "Failed to parse %r\n\n" % in_filename text += traceback.format_exc() if parsed_ast: visitor = TensorFlowCallVisitor(in_filename, lines) @@ -448,7 +471,7 @@ class TensorFlowCodeUpgrader(object): # make sure output directory doesn't exist if output_root_directory and os.path.exists(output_root_directory): - print("Output directory '%s' must not already exist." % ( + print("Output directory %r must not already exist." % ( output_root_directory)) sys.exit(1) @@ -456,7 +479,7 @@ class TensorFlowCodeUpgrader(object): norm_root = os.path.split(os.path.normpath(root_directory)) norm_output = os.path.split(os.path.normpath(output_root_directory)) if norm_root == norm_output: - print("Output directory '%s' same as input directory '%s"'' % ( + print("Output directory %r same as input directory %r" % ( root_directory, output_root_directory)) sys.exit(1) @@ -475,7 +498,7 @@ class TensorFlowCodeUpgrader(object): tree_errors = [] report = "" report += ("=" * 80) + "\n" - report += "Input tree: %s\n" % root_directory + report += "Input tree: %r\n" % root_directory report += ("=" * 80) + "\n" for input_path, output_path in files_to_process: @@ -547,4 +570,4 @@ Simple usage: print("Detected %d errors that require attention" % len(errors)) print("-" * 80) print("\n".join(errors)) - print("\nMake sure to read the detailed log %s\n" % report_filename) + print("\nMake sure to read the detailed log %r\n" % report_filename) diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py index 7548b38b91..286c70f612 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import tempfile import six from tensorflow.python.framework import test_util from tensorflow.python.platform import test as test_lib @@ -81,5 +83,21 @@ class TestUpgrade(test_util.TensorFlowTestCase): # TODO(aselle): Explicitly not testing command line interface and process_tree # for now, since this is a one off utility. + +class TestUpgradeFiles(test_util.TensorFlowTestCase): + + def testInplace(self): + """Check to make sure we don't have a file system race.""" + temp_file = tempfile.NamedTemporaryFile("w", delete=False) + original = "tf.mul(a, b)\n" + upgraded = "tf.multiply(a, b)\n" + temp_file.write(original) + temp_file.close() + upgrader = tf_upgrade.TensorFlowCodeUpgrader() + upgrader.process_file(temp_file.name, temp_file.name) + self.assertAllEqual(open(temp_file.name).read(), upgraded) + os.unlink(temp_file.name) + + if __name__ == "__main__": test_lib.main() |