diff options
Diffstat (limited to 'tensorflow/python/platform/flags.py')
-rw-r--r-- | tensorflow/python/platform/flags.py | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/python/platform/flags.py b/tensorflow/python/platform/flags.py index 60ec4f84c4..138a0ced97 100644 --- a/tensorflow/python/platform/flags.py +++ b/tensorflow/python/platform/flags.py @@ -20,6 +20,7 @@ from __future__ import print_function import argparse as _argparse +from tensorflow.python.platform import tf_logging as _logging from tensorflow.python.util.all_util import remove_undocumented _global_parser = _argparse.ArgumentParser() @@ -34,12 +35,14 @@ class _FlagValues(object): def __init__(self): self.__dict__['__flags'] = {} self.__dict__['__parsed'] = False + self.__dict__['__required_flags'] = set() def _parse_flags(self, args=None): result, unparsed = _global_parser.parse_known_args(args=args) for flag_name, val in vars(result).items(): self.__dict__['__flags'][flag_name] = val self.__dict__['__parsed'] = True + self._assert_all_required() return unparsed def __getattr__(self, name): @@ -60,6 +63,19 @@ class _FlagValues(object): if not self.__dict__['__parsed']: self._parse_flags() self.__dict__['__flags'][name] = value + self._assert_required(name) + + def _add_required_flag(self, item): + self.__dict__['__required_flags'].add(item) + + def _assert_required(self, flag_name): + if (flag_name not in self.__dict__['__flags'] or + self.__dict__['__flags'][flag_name] is None): + raise AttributeError('Flag --%s must be specified.' % flag_name) + + def _assert_all_required(self): + for flag_name in self.__dict__['__required_flags']: + self._assert_required(flag_name) def _define_helper(flag_name, default_value, docstring, flagtype): @@ -136,6 +152,51 @@ def DEFINE_float(flag_name, default_value, docstring): """ _define_helper(flag_name, default_value, docstring, float) + +def mark_flag_as_required(flag_name): + """Ensures that flag is not None during program execution. + + It is recommended to call this method like this: + + if __name__ == '__main__': + tf.flags.mark_flag_as_required('your_flag_name') + tf.app.run() + + Args: + flag_name: string, name of the flag to mark as required. + + Raises: + AttributeError: if flag_name is not registered as a valid flag name. + NOTE: The exception raised will change in the future. + """ + if _global_parser.get_default(flag_name) is not None: + _logging.warn( + 'Flag %s has a non-None default value; therefore, ' + 'mark_flag_as_required will pass even if flag is not specified in the ' + 'command line!' % flag_name) + FLAGS._add_required_flag(flag_name) + + +def mark_flags_as_required(flag_names): + """Ensures that flags are not None during program execution. + + Recommended usage: + + if __name__ == '__main__': + tf.flags.mark_flags_as_required(['flag1', 'flag2', 'flag3']) + tf.app.run() + + Args: + flag_names: a list/tuple of flag names to mark as required. + + Raises: + AttributeError: If any of flag name has not already been defined as a flag. + NOTE: The exception raised will change in the future. + """ + for flag_name in flag_names: + mark_flag_as_required(flag_name) + + _allowed_symbols = [ # We rely on gflags documentation. 'DEFINE_bool', @@ -144,5 +205,7 @@ _allowed_symbols = [ 'DEFINE_integer', 'DEFINE_string', 'FLAGS', + 'mark_flag_as_required', + 'mark_flags_as_required', ] remove_undocumented(__name__, _allowed_symbols) |