aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/compatibility
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2017-01-12 15:49:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-12 16:04:50 -0800
commite2fa82466e55cd52f24f7daee1bd1d45a58b5282 (patch)
tree3bce9226edac46f22284b5ca1c0e14e4a29bfe7f /tensorflow/tools/compatibility
parent3b5a37e14c49dd1a0128eb5b02a423b9a7956321 (diff)
Improve TensorFlow upgrade script
- Handle more functions: tf.svd tf.batch_matmul tf.nn.softmax_cross_entropy_with_logits, tf.nn.sparse_softmax_cross_entropy_with_logits, tf.nn.sigmoid_cross_entropy_with_logits": [ - Handle in-place file modification correctly (and add test). - Handle raw attribute lookups i.e. lists of functions `foo = [tf.mul]` can be upgraded to `foo = [tf.multiply]` Change: 144381716
Diffstat (limited to 'tensorflow/tools/compatibility')
-rw-r--r--tensorflow/tools/compatibility/testdata/test_file_v0_11.py27
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade.py175
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade_test.py18
3 files changed, 144 insertions, 76 deletions
diff --git a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
index 37d914c648..01f37d8768 100644
--- a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
+++ b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
@@ -163,6 +163,33 @@ class TestUpgrade(test_util.TensorFlowTestCase):
# # TODO(aselle): (tf.batch_*)
# ]
+ def testBatchAndSvd(self):
+ with self.test_session():
+ mat = [[1., 2.], [2., 3.]]
+ batched_mat = tf.expand_dims(mat, [0])
+ result = tf.matmul(mat, mat).eval()
+ result_batched = tf.batch_matmul(batched_mat, batched_mat).eval()
+ self.assertAllEqual(result_batched, np.expand_dims(result, 0))
+ self.assertAllEqual(
+ tf.svd(mat, False, True).eval(),
+ tf.svd(mat, compute_uv=False, full_matrices=True).eval())
+
+ def testCrossEntropy(self):
+ # TODO(aselle): Test sparse_softmax_...
+ with self.test_session():
+ labels = [.8, .5, .2, .1]
+ logits = [.9, .1, .3, .1]
+ self.assertAllEqual(
+ tf.nn.softmax_cross_entropy_with_logits(
+ logits, labels).eval(),
+ tf.nn.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits).eval())
+ self.assertAllEqual(
+ tf.nn.sigmoid_cross_entropy_with_logits(
+ logits, labels).eval(),
+ tf.nn.sigmoid_cross_entropy_with_logits(
+ labels=labels, logits=logits).eval())
+
def testVariables(self):
with self.test_session() as s:
diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py
index 223f8cd5f5..374be0475a 100644
--- a/tensorflow/tools/compatibility/tf_upgrade.py
+++ b/tensorflow/tools/compatibility/tf_upgrade.py
@@ -21,14 +21,11 @@ import argparse
import ast
import collections
import os
+import shutil
import sys
+import tempfile
import traceback
-# TODO(aselle): Add SVD, Concat
-# TODO(aselle): summary merge all (can we detect this?)
-# TODO(aselle): batch_matmul
-# TODO(wicke): tf.nn.{softmax,sparse_softmax,sigmoid}_cross_entropy_with_logits?
-
class APIChangeSpec(object):
"""List of maps that describe what changed in the API."""
@@ -143,7 +140,8 @@ class APIChangeSpec(object):
"tf.batch_fft3d": "tf.fft3d",
"tf.batch_ifft3d": "tf.ifft3d",
"tf.select": "tf.where",
- "tf.complex_abs": "tf.abs"
+ "tf.complex_abs": "tf.abs",
+ "tf.batch_matmul": "tf.matmul",
}
# Functions that were reordered should be changed to the new keyword args
@@ -151,7 +149,14 @@ class APIChangeSpec(object):
# positional arguments yourself, this could do the wrong thing.
self.function_reorders = {
"tf.split": ["axis", "num_or_size_splits", "value", "name"],
- "tf.concat": ["concat_dim", "values", "name"]
+ "tf.concat": ["concat_dim", "values", "name"],
+ "tf.svd": ["tensor", "compute_uv", "full_matrices", "name"],
+ "tf.nn.softmax_cross_entropy_with_logits": [
+ "logits", "labels", "dim", "name"],
+ "tf.nn.sparse_softmax_cross_entropy_with_logits": [
+ "logits", "labels", "name"],
+ "tf.nn.sigmoid_cross_entropy_with_logits": [
+ "logits", "labels", "name"]
}
# Specially handled functions.
@@ -223,7 +228,7 @@ class FileEditRecorder(object):
char_array = list(text[line - 1])
# Record a description of the change
- change_report += "%s Line %d\n" % (self._filename, line)
+ change_report += "%r Line %d\n" % (self._filename, line)
change_report += "-" * 80 + "\n\n"
for e in edits:
change_report += "%s\n" % e.comment
@@ -243,7 +248,7 @@ class FileEditRecorder(object):
# Make sure the edit is changing what it should be changing
old_actual = "".join(char_array[start_eff:end_eff])
if old_actual != e.old:
- raise ValueError("Expected text '%s' but got '%s'" %
+ raise ValueError("Expected text %r but got %r" %
("".join(e.old), "".join(old_actual)))
# Make the edit
char_array[start_eff:end_eff] = list(e.new)
@@ -278,7 +283,7 @@ class FileEditRecorder(object):
self._line_to_edit[line].append(
FileEditTuple(comment, line, start, old, new))
- if error is not None:
+ if error:
self._errors.append("%s:%d: %s" % (self._filename, line, error))
@@ -302,11 +307,33 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
def _rename_functions(self, node, full_name):
function_renames = self._api_change_spec.function_renames
- if full_name in function_renames:
+ try:
new_name = function_renames[full_name]
- self._file_edit.add("Renamed function `%s` to `%s`" % (full_name,
- new_name),
+ self._file_edit.add("Renamed function %r to %r" % (full_name,
+ new_name),
node.lineno, node.col_offset, full_name, new_name)
+ except KeyError:
+ pass
+
+ def _get_attribute_full_path(self, node):
+ """Traverse an attribute to generate a full name e.g. tf.foo.bar.
+
+ Args:
+ node: A Node of type Attribute.
+
+ Returns:
+ a '.'-delimited full-name or None if the tree was not a simple form.
+ i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
+ """
+ curr = node
+ items = []
+ while not isinstance(curr, ast.Name):
+ if not isinstance(curr, ast.Attribute):
+ return None
+ items.append(curr.attr)
+ curr = curr.value
+ items.append(curr.id)
+ return ".".join(reversed(items))
def visit_Call(self, node): # pylint: disable=invalid-name
"""Handle visiting a call node in the AST.
@@ -315,59 +342,51 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
node: Current Node
"""
- # Find call string (this is not perfectly accurate,
- # but should cover tf.x*)
- curr = node.func
- items = []
- valid = True
- while not isinstance(curr, ast.Name):
- if isinstance(curr, ast.Attribute):
- items.append(curr.attr)
- else:
- # We cannot just return, because we need to keep walking.
- # TODO(aselle): Would it be cleaner to use an exception here with else?
- valid = False
- break
- curr = curr.value
- if valid:
- items.append(curr.id)
-
- if valid:
- # Conversion logic
- full_name = ".".join(items[::-1])
- if full_name.startswith("tf."):
- # Call special handlers
- function_handles = self._api_change_spec.function_handle
- if full_name in function_handles:
- function_handles[full_name](self._file_edit, node)
-
- # Check for renames
- self._rename_functions(node, full_name)
-
- # Examine any non-keyword argument and make it into a keyword argument
- # if reordering required.
- function_reorders = self._api_change_spec.function_reorders
- if full_name in function_reorders:
- reordered = function_reorders[full_name]
- for idx, arg in enumerate(node.args):
- self._file_edit.add("Added keyword `%s` to reordered function `%s`"
- % (reordered[idx], full_name), arg.lineno,
- arg.col_offset, "", reordered[idx] + "=")
-
- # Examine each keyword argument and convert it to the final renamed form
- function_keyword_renames = (
- self._api_change_spec.function_keyword_renames)
- renamed_keywords = ({} if full_name not in function_keyword_renames else
- function_keyword_renames[full_name])
- for keyword in node.keywords:
- argkey = keyword.arg
- argval = keyword.value
- if argkey in renamed_keywords:
- self._file_edit.add("Renamed keyword argument from `%s` to `%s`" %
- (argkey, renamed_keywords[argkey]),
- argval.lineno,
- argval.col_offset - len(argkey) - 1,
- argkey + "=", renamed_keywords[argkey] + "=")
+ ast.NodeVisitor.generic_visit(self, node)
+
+ # Find a simple attribute name path e.g. "tf.foo.bar"
+ full_name = self._get_attribute_full_path(node.func)
+
+ if full_name and full_name.startswith("tf."):
+ # Call special handlers
+ function_handles = self._api_change_spec.function_handle
+ if full_name in function_handles:
+ function_handles[full_name](self._file_edit, node)
+
+ # Examine any non-keyword argument and make it into a keyword argument
+ # if reordering required.
+ function_reorders = self._api_change_spec.function_reorders
+ if full_name in function_reorders:
+ reordered = function_reorders[full_name]
+ for idx, arg in enumerate(node.args):
+ self._file_edit.add("Added keyword %r to reordered function %r"
+ % (reordered[idx], full_name), arg.lineno,
+ arg.col_offset, "", reordered[idx] + "=")
+
+ # Examine each keyword argument and convert it to the final renamed form
+ function_keyword_renames = (
+ self._api_change_spec.function_keyword_renames)
+ renamed_keywords = ({} if full_name not in function_keyword_renames else
+ function_keyword_renames[full_name])
+ for keyword in node.keywords:
+ argkey = keyword.arg
+ argval = keyword.value
+ if argkey in renamed_keywords:
+ self._file_edit.add("Renamed keyword argument from %r to %r" %
+ (argkey, renamed_keywords[argkey]),
+ argval.lineno,
+ argval.col_offset - len(argkey) - 1,
+ argkey + "=", renamed_keywords[argkey] + "=")
+
+ def visit_Attribute(self, node): # pylint: disable=invalid-name
+ """Handle bare Attributes i.e. [tf.foo, tf.bar].
+
+ Args:
+ node: Node that is of type ast.Attribute
+ """
+ full_name = self._get_attribute_full_path(node)
+ if full_name and full_name.startswith("tf."):
+ self._rename_functions(node, full_name)
ast.NodeVisitor.generic_visit(self, node)
@@ -387,11 +406,15 @@ class TensorFlowCodeUpgrader(object):
Returns:
A tuple representing number of files processed, log of actions, errors
"""
- in_file = open(in_filename, "r")
- out_file = open(out_filename, "w") if out_filename else None
- return self.process_opened_file(
- in_filename, in_file, out_filename, out_file)
+ # Write to a temporary file, just in case we are doing an implace modify.
+ with open(in_filename, "r") as in_file, \
+ tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
+ ret = self.process_opened_file(
+ in_filename, in_file, out_filename, temp_file)
+
+ shutil.move(temp_file.name, out_filename)
+ return ret
# Broad exceptions are required here because ast throws whatever it wants.
# pylint: disable=broad-except
@@ -411,7 +434,7 @@ class TensorFlowCodeUpgrader(object):
"""
process_errors = []
text = "-" * 80 + "\n"
- text += "Processing file %s\n outputting to %s\n" % (in_filename,
+ text += "Processing file %r\n outputting to %r\n" % (in_filename,
out_filename)
text += "-" * 80 + "\n\n"
@@ -420,7 +443,7 @@ class TensorFlowCodeUpgrader(object):
try:
parsed_ast = ast.parse("".join(lines))
except Exception:
- text += "Failed to parse %s\n\n" % in_filename
+ text += "Failed to parse %r\n\n" % in_filename
text += traceback.format_exc()
if parsed_ast:
visitor = TensorFlowCallVisitor(in_filename, lines)
@@ -448,7 +471,7 @@ class TensorFlowCodeUpgrader(object):
# make sure output directory doesn't exist
if output_root_directory and os.path.exists(output_root_directory):
- print("Output directory '%s' must not already exist." % (
+ print("Output directory %r must not already exist." % (
output_root_directory))
sys.exit(1)
@@ -456,7 +479,7 @@ class TensorFlowCodeUpgrader(object):
norm_root = os.path.split(os.path.normpath(root_directory))
norm_output = os.path.split(os.path.normpath(output_root_directory))
if norm_root == norm_output:
- print("Output directory '%s' same as input directory '%s"'' % (
+ print("Output directory %r same as input directory %r" % (
root_directory, output_root_directory))
sys.exit(1)
@@ -475,7 +498,7 @@ class TensorFlowCodeUpgrader(object):
tree_errors = []
report = ""
report += ("=" * 80) + "\n"
- report += "Input tree: %s\n" % root_directory
+ report += "Input tree: %r\n" % root_directory
report += ("=" * 80) + "\n"
for input_path, output_path in files_to_process:
@@ -547,4 +570,4 @@ Simple usage:
print("Detected %d errors that require attention" % len(errors))
print("-" * 80)
print("\n".join(errors))
- print("\nMake sure to read the detailed log %s\n" % report_filename)
+ print("\nMake sure to read the detailed log %r\n" % report_filename)
diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py
index 7548b38b91..286c70f612 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_test.py
@@ -17,6 +17,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+import tempfile
import six
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test as test_lib
@@ -81,5 +83,21 @@ class TestUpgrade(test_util.TensorFlowTestCase):
# TODO(aselle): Explicitly not testing command line interface and process_tree
# for now, since this is a one off utility.
+
+class TestUpgradeFiles(test_util.TensorFlowTestCase):
+
+ def testInplace(self):
+ """Check to make sure we don't have a file system race."""
+ temp_file = tempfile.NamedTemporaryFile("w", delete=False)
+ original = "tf.mul(a, b)\n"
+ upgraded = "tf.multiply(a, b)\n"
+ temp_file.write(original)
+ temp_file.close()
+ upgrader = tf_upgrade.TensorFlowCodeUpgrader()
+ upgrader.process_file(temp_file.name, temp_file.name)
+ self.assertAllEqual(open(temp_file.name).read(), upgraded)
+ os.unlink(temp_file.name)
+
+
if __name__ == "__main__":
test_lib.main()