diff options
Diffstat (limited to 'tensorflow/tools/compatibility/tf_upgrade_test.py')
-rw-r--r-- | tensorflow/tools/compatibility/tf_upgrade_test.py | 48 |
1 files changed, 47 insertions, 1 deletions
diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py index 286c70f612..de4e3de73c 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_test.py @@ -59,12 +59,45 @@ class TestUpgrade(test_util.TensorFlowTestCase): _, unused_report, unused_errors, new_text = self._upgrade(text) self.assertEqual(new_text, "tf.multiply(a, tf.subtract(b, c))\n") + def testRenamePack(self): + text = "tf.pack(a)\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.stack(a)\n") + text = "tf.unpack(a)\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.unstack(a)\n") + def testReorder(self): text = "tf.concat(a, b)\ntf.split(a, b, c)\n" _, unused_report, unused_errors, new_text = self._upgrade(text) - self.assertEqual(new_text, "tf.concat(concat_dim=a, values=b)\n" + self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n" "tf.split(axis=a, num_or_size_splits=b, value=c)\n") + def testConcatReorderWithKeywordArgs(self): + text = "tf.concat(concat_dim=a, values=b)\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n") + text = "tf.concat(values=b, concat_dim=a)\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.concat(values=b, axis=a)\n") + text = "tf.concat(a, values=b)\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n") + + def testConcatReorderNested(self): + text = "tf.concat(a, tf.concat(c, d))\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual( + new_text, "tf.concat(axis=a, values=tf.concat(axis=c, values=d))\n") + + def testInitializers(self): + text = ("tf.zeros_initializer;tf.zeros_initializer ()\n" + "tf.ones_initializer;tf.ones_initializer ()\n") + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual( + new_text, "tf.zeros_initializer();tf.zeros_initializer ()\n" + "tf.ones_initializer();tf.ones_initializer ()\n") + def testKeyword(self): text = "tf.reduce_any(a, reduction_indices=[1, 2])\n" _, unused_report, unused_errors, new_text = self._upgrade(text) @@ -80,6 +113,19 @@ class TestUpgrade(test_util.TensorFlowTestCase): self.assertEqual(new_text, new_text) self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."]) + def testListComprehension(self): + def _test(input, output): + _, unused_report, errors, new_text = self._upgrade(input) + self.assertEqual(new_text, output) + _test("tf.concat(0, \t[x for x in y])\n", + "tf.concat(axis=0, \tvalues=[x for x in y])\n") + _test("tf.concat(0,[x for x in y])\n", + "tf.concat(axis=0,values=[x for x in y])\n") + _test("tf.concat(0,[\nx for x in y])\n", + "tf.concat(axis=0,values=[\nx for x in y])\n") + _test("tf.concat(0,[\n \tx for x in y])\n", + "tf.concat(axis=0,values=[\n \tx for x in y])\n") + # TODO(aselle): Explicitly not testing command line interface and process_tree # for now, since this is a one off utility. |