aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/platform/flags_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/platform/flags_test.py')
-rw-r--r--tensorflow/python/platform/flags_test.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/python/platform/flags_test.py b/tensorflow/python/platform/flags_test.py
index 8b990975dd..7b08c3f8a6 100644
--- a/tensorflow/python/platform/flags_test.py
+++ b/tensorflow/python/platform/flags_test.py
@@ -35,6 +35,8 @@ flags.DEFINE_boolean("bool_a", False, "HelpString")
flags.DEFINE_boolean("bool_c", False, "HelpString")
flags.DEFINE_boolean("bool_d", True, "HelpString")
flags.DEFINE_bool("bool_e", True, "HelpString")
+flags.DEFINE_string("string_foo_required", "default_val", "HelpString")
+flags.DEFINE_string("none_string_foo_required", None, "HelpString")
FLAGS = flags.FLAGS
@@ -84,6 +86,18 @@ class FlagsTest(unittest.TestCase):
copied = copy.copy(FLAGS)
self.assertEqual(copied.__dict__, FLAGS.__dict__)
+ def testStringRequired(self):
+ res = FLAGS.string_foo_required
+ self.assertEqual(res, "default_val")
+ FLAGS.string_foo_required = "bar"
+ self.assertEqual("bar", FLAGS.string_foo_required)
+
+ def testNoneStringRequired(self):
+ res = FLAGS.none_string_foo_required
+ self.assertEqual(res, "default_val")
+ FLAGS.none_string_foo_required = "bar"
+ self.assertEqual("bar", FLAGS.none_string_foo_required)
+
def main(_):
# unittest.main() tries to interpret the unknown flags, so use the
@@ -97,7 +111,9 @@ if __name__ == "__main__":
# Test command lines
sys.argv.extend([
"--bool_a", "--nobool_negation", "--bool_c=True", "--bool_d=False",
+ "--none_string_foo_required=default_val",
"and_argument"
])
-
+ flags.mark_flag_as_required('string_foo_required')
+ flags.mark_flags_as_required(['none_string_foo_required'])
app.run()