aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/compatibility/tf_upgrade_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tools/compatibility/tf_upgrade_test.py')
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade_test.py48
1 files changed, 47 insertions, 1 deletions
diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py
index 286c70f612..de4e3de73c 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_test.py
@@ -59,12 +59,45 @@ class TestUpgrade(test_util.TensorFlowTestCase):
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, "tf.multiply(a, tf.subtract(b, c))\n")
+ def testRenamePack(self):
+ text = "tf.pack(a)\n"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, "tf.stack(a)\n")
+ text = "tf.unpack(a)\n"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, "tf.unstack(a)\n")
+
def testReorder(self):
text = "tf.concat(a, b)\ntf.split(a, b, c)\n"
_, unused_report, unused_errors, new_text = self._upgrade(text)
- self.assertEqual(new_text, "tf.concat(concat_dim=a, values=b)\n"
+ self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n"
"tf.split(axis=a, num_or_size_splits=b, value=c)\n")
+ def testConcatReorderWithKeywordArgs(self):
+ text = "tf.concat(concat_dim=a, values=b)\n"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n")
+ text = "tf.concat(values=b, concat_dim=a)\n"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, "tf.concat(values=b, axis=a)\n")
+ text = "tf.concat(a, values=b)\n"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n")
+
+ def testConcatReorderNested(self):
+ text = "tf.concat(a, tf.concat(c, d))\n"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(
+ new_text, "tf.concat(axis=a, values=tf.concat(axis=c, values=d))\n")
+
+ def testInitializers(self):
+ text = ("tf.zeros_initializer;tf.zeros_initializer ()\n"
+ "tf.ones_initializer;tf.ones_initializer ()\n")
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(
+ new_text, "tf.zeros_initializer();tf.zeros_initializer ()\n"
+ "tf.ones_initializer();tf.ones_initializer ()\n")
+
def testKeyword(self):
text = "tf.reduce_any(a, reduction_indices=[1, 2])\n"
_, unused_report, unused_errors, new_text = self._upgrade(text)
@@ -80,6 +113,19 @@ class TestUpgrade(test_util.TensorFlowTestCase):
self.assertEqual(new_text, new_text)
self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."])
+ def testListComprehension(self):
+ def _test(input, output):
+ _, unused_report, errors, new_text = self._upgrade(input)
+ self.assertEqual(new_text, output)
+ _test("tf.concat(0, \t[x for x in y])\n",
+ "tf.concat(axis=0, \tvalues=[x for x in y])\n")
+ _test("tf.concat(0,[x for x in y])\n",
+ "tf.concat(axis=0,values=[x for x in y])\n")
+ _test("tf.concat(0,[\nx for x in y])\n",
+ "tf.concat(axis=0,values=[\nx for x in y])\n")
+ _test("tf.concat(0,[\n \tx for x in y])\n",
+ "tf.concat(axis=0,values=[\n \tx for x in y])\n")
+
# TODO(aselle): Explicitly not testing command line interface and process_tree
# for now, since this is a one off utility.