diff options
Diffstat (limited to 'tensorflow/python/platform/flags_test.py')
-rw-r--r-- | tensorflow/python/platform/flags_test.py | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/python/platform/flags_test.py b/tensorflow/python/platform/flags_test.py index 8b990975dd..7b08c3f8a6 100644 --- a/tensorflow/python/platform/flags_test.py +++ b/tensorflow/python/platform/flags_test.py @@ -35,6 +35,8 @@ flags.DEFINE_boolean("bool_a", False, "HelpString") flags.DEFINE_boolean("bool_c", False, "HelpString") flags.DEFINE_boolean("bool_d", True, "HelpString") flags.DEFINE_bool("bool_e", True, "HelpString") +flags.DEFINE_string("string_foo_required", "default_val", "HelpString") +flags.DEFINE_string("none_string_foo_required", None, "HelpString") FLAGS = flags.FLAGS @@ -84,6 +86,18 @@ class FlagsTest(unittest.TestCase): copied = copy.copy(FLAGS) self.assertEqual(copied.__dict__, FLAGS.__dict__) + def testStringRequired(self): + res = FLAGS.string_foo_required + self.assertEqual(res, "default_val") + FLAGS.string_foo_required = "bar" + self.assertEqual("bar", FLAGS.string_foo_required) + + def testNoneStringRequired(self): + res = FLAGS.none_string_foo_required + self.assertEqual(res, "default_val") + FLAGS.none_string_foo_required = "bar" + self.assertEqual("bar", FLAGS.none_string_foo_required) + def main(_): # unittest.main() tries to interpret the unknown flags, so use the @@ -97,7 +111,9 @@ if __name__ == "__main__": # Test command lines sys.argv.extend([ "--bool_a", "--nobool_negation", "--bool_c=True", "--bool_d=False", + "--none_string_foo_required=default_val", "and_argument" ]) - + flags.mark_flag_as_required('string_foo_required') + flags.mark_flags_as_required(['none_string_foo_required']) app.run() |