aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/compatibility/tf_upgrade_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tools/compatibility/tf_upgrade_test.py')
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade_test.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py
index 3d02eacba6..66325ea2ad 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_test.py
@@ -22,6 +22,7 @@ import tempfile
import six
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test as test_lib
+from tensorflow.tools.compatibility import ast_edits
from tensorflow.tools.compatibility import tf_upgrade
@@ -36,7 +37,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
def _upgrade(self, old_file_text):
in_file = six.StringIO(old_file_text)
out_file = six.StringIO()
- upgrader = tf_upgrade.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
+ upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
count, report, errors = (
upgrader.process_opened_file("test.py", in_file,
"test_out.py", out_file))
@@ -139,7 +140,7 @@ class TestUpgradeFiles(test_util.TensorFlowTestCase):
upgraded = "tf.multiply(a, b)\n"
temp_file.write(original)
temp_file.close()
- upgrader = tf_upgrade.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
+ upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
upgrader.process_file(temp_file.name, temp_file.name)
self.assertAllEqual(open(temp_file.name).read(), upgraded)
os.unlink(temp_file.name)