diff options
author | Yilei Yang <yileiyang@google.com> | 2017-11-06 12:49:13 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-06 12:52:50 -0800 |
commit | 2652704b576adc16b4d735f651cea1024e88b72e (patch) | |
tree | 82054cc15e3b7466b5efad5fbb8f8c935f47a9db /tensorflow/python/platform | |
parent | e120e50bf5fd883cc28b584c720d959341db502d (diff) |
Replace the implementation of tf.flags with absl.flags.
Previous tf.flags implementation is based on argparse. It contains -h/--help flags, which displays all flags.
absl.app's --help flag only displays flags defined in the main module. There is a --helpfull flag that displays all flags.
So added --helpshort --helpfull flags.
app.run now raises SystemError on unknown flags (fixes #11195).
Accessing flags before flags are parsed will now raise an UnparsedFlagAccessError, instead of causing implicit flag parsing previously.
PiperOrigin-RevId: 174747028
Diffstat (limited to 'tensorflow/python/platform')
-rw-r--r-- | tensorflow/python/platform/app.py | 103 | ||||
-rw-r--r-- | tensorflow/python/platform/flags.py | 195 | ||||
-rw-r--r-- | tensorflow/python/platform/flags_test.py | 97 |
3 files changed, 103 insertions, 292 deletions
diff --git a/tensorflow/python/platform/app.py b/tensorflow/python/platform/app.py index 5ecaa1baaf..c01e1c9b1a 100644 --- a/tensorflow/python/platform/app.py +++ b/tensorflow/python/platform/app.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import errno as _errno import sys as _sys from tensorflow.python.platform import flags @@ -28,24 +29,108 @@ def _benchmark_tests_can_log_memory(): return True +def _usage(shorthelp): + """Writes __main__'s docstring to stdout with some help text. + + Args: + shorthelp: bool, if True, prints only flags from the main module, + rather than all flags. + """ + doc = _sys.modules['__main__'].__doc__ + if not doc: + doc = '\nUSAGE: %s [flags]\n' % _sys.argv[0] + doc = flags.text_wrap(doc, indent=' ', firstline_indent='') + else: + # Replace all '%s' with sys.argv[0], and all '%%' with '%'. + num_specifiers = doc.count('%') - 2 * doc.count('%%') + try: + doc %= (_sys.argv[0],) * num_specifiers + except (OverflowError, TypeError, ValueError): + # Just display the docstring as-is. + pass + if shorthelp: + flag_str = flags.FLAGS.main_module_help() + else: + flag_str = str(flags.FLAGS) + try: + _sys.stdout.write(doc) + if flag_str: + _sys.stdout.write('\nflags:\n') + _sys.stdout.write(flag_str) + _sys.stdout.write('\n') + except IOError as e: + # We avoid printing a huge backtrace if we get EPIPE, because + # "foo.par --help | less" is a frequent use case. + if e.errno != _errno.EPIPE: + raise + + +class _HelpFlag(flags.BooleanFlag): + """Special boolean flag that displays usage and raises SystemExit.""" + NAME = 'help' + SHORT_NAME = 'h' + + def __init__(self): + super(_HelpFlag, self).__init__( + self.NAME, False, 'show this help', short_name=self.SHORT_NAME) + + def parse(self, arg): + if arg: + _usage(shorthelp=True) + print() + print('Try --helpfull to get a list of all flags.') + _sys.exit(1) + + +class _HelpshortFlag(_HelpFlag): + """--helpshort is an alias for --help.""" + NAME = 'helpshort' + SHORT_NAME = None + + +class _HelpfullFlag(flags.BooleanFlag): + """Display help for flags in main module and all dependent modules.""" + + def __init__(self): + super(_HelpfullFlag, self).__init__('helpfull', False, 'show full help') + + def parse(self, arg): + if arg: + _usage(shorthelp=False) + _sys.exit(1) + + +_define_help_flags_called = False + + +def _define_help_flags(): + global _define_help_flags_called + if not _define_help_flags_called: + flags.DEFINE_flag(_HelpFlag()) + flags.DEFINE_flag(_HelpfullFlag()) + flags.DEFINE_flag(_HelpshortFlag()) + _define_help_flags_called = True + + def run(main=None, argv=None): """Runs the program with an optional 'main' function and 'argv' list.""" - f = flags.FLAGS - # Extract the args from the optional `argv` list. - args = argv[1:] if argv else None + # Define help flags. + _define_help_flags() - # Parse the known flags from that list, or from the command - # line otherwise. - # pylint: disable=protected-access - flags_passthrough = f._parse_flags(args=args) - # pylint: enable=protected-access + # Parse flags. + try: + argv = flags.FLAGS(_sys.argv if argv is None else argv) + except flags.Error as error: + _sys.stderr.write('FATAL Flags parsing error: %s\n' % error) + _sys.stderr.write('Pass --helpshort or --helpfull to see help on flags.\n') + _sys.exit(1) main = main or _sys.modules['__main__'].main # Call the main function, passing through any arguments # to the final program. - _sys.exit(main(_sys.argv[:1] + flags_passthrough)) + _sys.exit(main(argv)) _allowed_symbols = [ diff --git a/tensorflow/python/platform/flags.py b/tensorflow/python/platform/flags.py index 138a0ced97..e9a36ae75d 100644 --- a/tensorflow/python/platform/flags.py +++ b/tensorflow/python/platform/flags.py @@ -13,199 +13,10 @@ # limitations under the License. # ============================================================================== -"""Implementation of the flags interface.""" +"""Import router for absl.flags. See https://github.com/abseil/abseil-py.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import argparse as _argparse - -from tensorflow.python.platform import tf_logging as _logging -from tensorflow.python.util.all_util import remove_undocumented - -_global_parser = _argparse.ArgumentParser() - - -# pylint: disable=invalid-name - - -class _FlagValues(object): - """Global container and accessor for flags and their values.""" - - def __init__(self): - self.__dict__['__flags'] = {} - self.__dict__['__parsed'] = False - self.__dict__['__required_flags'] = set() - - def _parse_flags(self, args=None): - result, unparsed = _global_parser.parse_known_args(args=args) - for flag_name, val in vars(result).items(): - self.__dict__['__flags'][flag_name] = val - self.__dict__['__parsed'] = True - self._assert_all_required() - return unparsed - - def __getattr__(self, name): - """Retrieves the 'value' attribute of the flag --name.""" - try: - parsed = self.__dict__['__parsed'] - except KeyError: - # May happen during pickle.load or copy.copy - raise AttributeError(name) - if not parsed: - self._parse_flags() - if name not in self.__dict__['__flags']: - raise AttributeError(name) - return self.__dict__['__flags'][name] - - def __setattr__(self, name, value): - """Sets the 'value' attribute of the flag --name.""" - if not self.__dict__['__parsed']: - self._parse_flags() - self.__dict__['__flags'][name] = value - self._assert_required(name) - - def _add_required_flag(self, item): - self.__dict__['__required_flags'].add(item) - - def _assert_required(self, flag_name): - if (flag_name not in self.__dict__['__flags'] or - self.__dict__['__flags'][flag_name] is None): - raise AttributeError('Flag --%s must be specified.' % flag_name) - - def _assert_all_required(self): - for flag_name in self.__dict__['__required_flags']: - self._assert_required(flag_name) - - -def _define_helper(flag_name, default_value, docstring, flagtype): - """Registers 'flag_name' with 'default_value' and 'docstring'.""" - _global_parser.add_argument('--' + flag_name, - default=default_value, - help=docstring, - type=flagtype) - - -# Provides the global object that can be used to access flags. -FLAGS = _FlagValues() - - -def DEFINE_string(flag_name, default_value, docstring): - """Defines a flag of type 'string'. - - Args: - flag_name: The name of the flag as a string. - default_value: The default value the flag should take as a string. - docstring: A helpful message explaining the use of the flag. - """ - _define_helper(flag_name, default_value, docstring, str) - - -def DEFINE_integer(flag_name, default_value, docstring): - """Defines a flag of type 'int'. - - Args: - flag_name: The name of the flag as a string. - default_value: The default value the flag should take as an int. - docstring: A helpful message explaining the use of the flag. - """ - _define_helper(flag_name, default_value, docstring, int) - - -def DEFINE_boolean(flag_name, default_value, docstring): - """Defines a flag of type 'boolean'. - - Args: - flag_name: The name of the flag as a string. - default_value: The default value the flag should take as a boolean. - docstring: A helpful message explaining the use of the flag. - """ - # Register a custom function for 'bool' so --flag=True works. - def str2bool(v): - return v.lower() in ('true', 't', '1') - _global_parser.add_argument('--' + flag_name, - nargs='?', - const=True, - help=docstring, - default=default_value, - type=str2bool) - - # Add negated version, stay consistent with argparse with regard to - # dashes in flag names. - _global_parser.add_argument('--no' + flag_name, - action='store_false', - dest=flag_name.replace('-', '_')) - - -# The internal google library defines the following alias, so we match -# the API for consistency. -DEFINE_bool = DEFINE_boolean # pylint: disable=invalid-name - - -def DEFINE_float(flag_name, default_value, docstring): - """Defines a flag of type 'float'. - - Args: - flag_name: The name of the flag as a string. - default_value: The default value the flag should take as a float. - docstring: A helpful message explaining the use of the flag. - """ - _define_helper(flag_name, default_value, docstring, float) - - -def mark_flag_as_required(flag_name): - """Ensures that flag is not None during program execution. - - It is recommended to call this method like this: - - if __name__ == '__main__': - tf.flags.mark_flag_as_required('your_flag_name') - tf.app.run() - - Args: - flag_name: string, name of the flag to mark as required. - - Raises: - AttributeError: if flag_name is not registered as a valid flag name. - NOTE: The exception raised will change in the future. - """ - if _global_parser.get_default(flag_name) is not None: - _logging.warn( - 'Flag %s has a non-None default value; therefore, ' - 'mark_flag_as_required will pass even if flag is not specified in the ' - 'command line!' % flag_name) - FLAGS._add_required_flag(flag_name) - - -def mark_flags_as_required(flag_names): - """Ensures that flags are not None during program execution. - - Recommended usage: - - if __name__ == '__main__': - tf.flags.mark_flags_as_required(['flag1', 'flag2', 'flag3']) - tf.app.run() - - Args: - flag_names: a list/tuple of flag names to mark as required. - - Raises: - AttributeError: If any of flag name has not already been defined as a flag. - NOTE: The exception raised will change in the future. - """ - for flag_name in flag_names: - mark_flag_as_required(flag_name) - - -_allowed_symbols = [ - # We rely on gflags documentation. - 'DEFINE_bool', - 'DEFINE_boolean', - 'DEFINE_float', - 'DEFINE_integer', - 'DEFINE_string', - 'FLAGS', - 'mark_flag_as_required', - 'mark_flags_as_required', -] -remove_undocumented(__name__, _allowed_symbols) +# go/tf-wildcard-import +from absl.flags import * # pylint: disable=wildcard-import diff --git a/tensorflow/python/platform/flags_test.py b/tensorflow/python/platform/flags_test.py index 7b08c3f8a6..23060e17d2 100644 --- a/tensorflow/python/platform/flags_test.py +++ b/tensorflow/python/platform/flags_test.py @@ -12,108 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for our flags implementation.""" +"""Sanity tests for tf.flags.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy -import sys import unittest -from tensorflow.python.platform import app -from tensorflow.python.platform import flags - -flags.DEFINE_string("string_foo", "default_val", "HelpString") -flags.DEFINE_integer("int_foo", 42, "HelpString") -flags.DEFINE_float("float_foo", 42.0, "HelpString") +from absl import flags as absl_flags -flags.DEFINE_boolean("bool_foo", True, "HelpString") -flags.DEFINE_boolean("bool_negation", True, "HelpString") -flags.DEFINE_boolean("bool-dash-negation", True, "HelpString") -flags.DEFINE_boolean("bool_a", False, "HelpString") -flags.DEFINE_boolean("bool_c", False, "HelpString") -flags.DEFINE_boolean("bool_d", True, "HelpString") -flags.DEFINE_bool("bool_e", True, "HelpString") -flags.DEFINE_string("string_foo_required", "default_val", "HelpString") -flags.DEFINE_string("none_string_foo_required", None, "HelpString") - -FLAGS = flags.FLAGS +from tensorflow.python.platform import flags class FlagsTest(unittest.TestCase): - def testString(self): - res = FLAGS.string_foo - self.assertEqual(res, "default_val") - FLAGS.string_foo = "bar" - self.assertEqual("bar", FLAGS.string_foo) - - def testBool(self): - res = FLAGS.bool_foo - self.assertTrue(res) - FLAGS.bool_foo = False - self.assertFalse(FLAGS.bool_foo) - - def testBoolCommandLines(self): - # Specified on command line with no args, sets to True, - # even if default is False. - self.assertEqual(True, FLAGS.bool_a) - - # --no before the flag forces it to False, even if the - # default is True - self.assertEqual(False, FLAGS.bool_negation) - - # --bool_flag=True sets to True - self.assertEqual(True, FLAGS.bool_c) - - # --bool_flag=False sets to False - self.assertEqual(False, FLAGS.bool_d) - - def testInt(self): - res = FLAGS.int_foo - self.assertEquals(res, 42) - FLAGS.int_foo = -1 - self.assertEqual(-1, FLAGS.int_foo) - - def testFloat(self): - res = FLAGS.float_foo - self.assertEquals(42.0, res) - FLAGS.float_foo = -1.0 - self.assertEqual(-1.0, FLAGS.float_foo) - - def test_copy(self): - copied = copy.copy(FLAGS) - self.assertEqual(copied.__dict__, FLAGS.__dict__) - - def testStringRequired(self): - res = FLAGS.string_foo_required - self.assertEqual(res, "default_val") - FLAGS.string_foo_required = "bar" - self.assertEqual("bar", FLAGS.string_foo_required) - - def testNoneStringRequired(self): - res = FLAGS.none_string_foo_required - self.assertEqual(res, "default_val") - FLAGS.none_string_foo_required = "bar" - self.assertEqual("bar", FLAGS.none_string_foo_required) - - -def main(_): - # unittest.main() tries to interpret the unknown flags, so use the - # direct functions instead. - runner = unittest.TextTestRunner() - itersuite = unittest.TestLoader().loadTestsFromTestCase(FlagsTest) - runner.run(itersuite) + def test_global_flags_object(self): + self.assertIs(flags.FLAGS, absl_flags.FLAGS) if __name__ == "__main__": - # Test command lines - sys.argv.extend([ - "--bool_a", "--nobool_negation", "--bool_c=True", "--bool_d=False", - "--none_string_foo_required=default_val", - "and_argument" - ]) - flags.mark_flag_as_required('string_foo_required') - flags.mark_flags_as_required(['none_string_foo_required']) - app.run() + unittest.main() |