diff options
-rw-r--r-- | tensorflow/python/platform/flags_test.py | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/tensorflow/python/platform/flags_test.py b/tensorflow/python/platform/flags_test.py index bc2d012d03..d2b7da7ad2 100644 --- a/tensorflow/python/platform/flags_test.py +++ b/tensorflow/python/platform/flags_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse import sys import unittest @@ -81,7 +82,17 @@ class FlagsTest(unittest.TestCase): self.assertEqual(-1.0, FLAGS.float_foo) -def main(_): +def main(argv): + # Test that argparse can parse flags that aren't registered + # with tf.flags. + parser = argparse.ArgumentParser() + parser.add_argument("--argparse_val", type=int, default=1000, + help="Test flag") + argparse_flags, _ = parser.parse_known_args(argv) + if argparse_flags.argparse_val != 10: + raise ValueError("argparse flag was not parsed: got %d", + argparse_flags.argparse_val) + # unittest.main() tries to interpret the unknown flags, so use the # direct functions instead. runner = unittest.TextTestRunner() @@ -93,6 +104,7 @@ if __name__ == "__main__": # Test command lines sys.argv.extend(["--bool_a", "--nobool_negation", "--bool_c=True", "--bool_d=False", - "--unknown_flag", "and_argument"]) + "--unknown_flag", "--argparse_val=10", + "and_argument"]) app.run() |