aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/platform
diff options
context:
space:
mode:
authorGravatar Yilei Yang <yileiyang@google.com>2017-11-06 12:49:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-06 12:52:50 -0800
commit2652704b576adc16b4d735f651cea1024e88b72e (patch)
tree82054cc15e3b7466b5efad5fbb8f8c935f47a9db /tensorflow/python/platform
parente120e50bf5fd883cc28b584c720d959341db502d (diff)
Replace the implementation of tf.flags with absl.flags.
Previous tf.flags implementation is based on argparse. It contains -h/--help flags, which displays all flags. absl.app's --help flag only displays flags defined in the main module. There is a --helpfull flag that displays all flags. So added --helpshort --helpfull flags. app.run now raises SystemError on unknown flags (fixes #11195). Accessing flags before flags are parsed will now raise an UnparsedFlagAccessError, instead of causing implicit flag parsing previously. PiperOrigin-RevId: 174747028
Diffstat (limited to 'tensorflow/python/platform')
-rw-r--r--tensorflow/python/platform/app.py103
-rw-r--r--tensorflow/python/platform/flags.py195
-rw-r--r--tensorflow/python/platform/flags_test.py97
3 files changed, 103 insertions, 292 deletions
diff --git a/tensorflow/python/platform/app.py b/tensorflow/python/platform/app.py
index 5ecaa1baaf..c01e1c9b1a 100644
--- a/tensorflow/python/platform/app.py
+++ b/tensorflow/python/platform/app.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import errno as _errno
import sys as _sys
from tensorflow.python.platform import flags
@@ -28,24 +29,108 @@ def _benchmark_tests_can_log_memory():
return True
+def _usage(shorthelp):
+ """Writes __main__'s docstring to stdout with some help text.
+
+ Args:
+ shorthelp: bool, if True, prints only flags from the main module,
+ rather than all flags.
+ """
+ doc = _sys.modules['__main__'].__doc__
+ if not doc:
+ doc = '\nUSAGE: %s [flags]\n' % _sys.argv[0]
+ doc = flags.text_wrap(doc, indent=' ', firstline_indent='')
+ else:
+ # Replace all '%s' with sys.argv[0], and all '%%' with '%'.
+ num_specifiers = doc.count('%') - 2 * doc.count('%%')
+ try:
+ doc %= (_sys.argv[0],) * num_specifiers
+ except (OverflowError, TypeError, ValueError):
+ # Just display the docstring as-is.
+ pass
+ if shorthelp:
+ flag_str = flags.FLAGS.main_module_help()
+ else:
+ flag_str = str(flags.FLAGS)
+ try:
+ _sys.stdout.write(doc)
+ if flag_str:
+ _sys.stdout.write('\nflags:\n')
+ _sys.stdout.write(flag_str)
+ _sys.stdout.write('\n')
+ except IOError as e:
+ # We avoid printing a huge backtrace if we get EPIPE, because
+ # "foo.par --help | less" is a frequent use case.
+ if e.errno != _errno.EPIPE:
+ raise
+
+
+class _HelpFlag(flags.BooleanFlag):
+ """Special boolean flag that displays usage and raises SystemExit."""
+ NAME = 'help'
+ SHORT_NAME = 'h'
+
+ def __init__(self):
+ super(_HelpFlag, self).__init__(
+ self.NAME, False, 'show this help', short_name=self.SHORT_NAME)
+
+ def parse(self, arg):
+ if arg:
+ _usage(shorthelp=True)
+ print()
+ print('Try --helpfull to get a list of all flags.')
+ _sys.exit(1)
+
+
+class _HelpshortFlag(_HelpFlag):
+ """--helpshort is an alias for --help."""
+ NAME = 'helpshort'
+ SHORT_NAME = None
+
+
+class _HelpfullFlag(flags.BooleanFlag):
+ """Display help for flags in main module and all dependent modules."""
+
+ def __init__(self):
+ super(_HelpfullFlag, self).__init__('helpfull', False, 'show full help')
+
+ def parse(self, arg):
+ if arg:
+ _usage(shorthelp=False)
+ _sys.exit(1)
+
+
+_define_help_flags_called = False
+
+
+def _define_help_flags():
+ global _define_help_flags_called
+ if not _define_help_flags_called:
+ flags.DEFINE_flag(_HelpFlag())
+ flags.DEFINE_flag(_HelpfullFlag())
+ flags.DEFINE_flag(_HelpshortFlag())
+ _define_help_flags_called = True
+
+
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
- f = flags.FLAGS
- # Extract the args from the optional `argv` list.
- args = argv[1:] if argv else None
+ # Define help flags.
+ _define_help_flags()
- # Parse the known flags from that list, or from the command
- # line otherwise.
- # pylint: disable=protected-access
- flags_passthrough = f._parse_flags(args=args)
- # pylint: enable=protected-access
+ # Parse flags.
+ try:
+ argv = flags.FLAGS(_sys.argv if argv is None else argv)
+ except flags.Error as error:
+ _sys.stderr.write('FATAL Flags parsing error: %s\n' % error)
+ _sys.stderr.write('Pass --helpshort or --helpfull to see help on flags.\n')
+ _sys.exit(1)
main = main or _sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
- _sys.exit(main(_sys.argv[:1] + flags_passthrough))
+ _sys.exit(main(argv))
_allowed_symbols = [
diff --git a/tensorflow/python/platform/flags.py b/tensorflow/python/platform/flags.py
index 138a0ced97..e9a36ae75d 100644
--- a/tensorflow/python/platform/flags.py
+++ b/tensorflow/python/platform/flags.py
@@ -13,199 +13,10 @@
# limitations under the License.
# ==============================================================================
-"""Implementation of the flags interface."""
+"""Import router for absl.flags. See https://github.com/abseil/abseil-py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import argparse as _argparse
-
-from tensorflow.python.platform import tf_logging as _logging
-from tensorflow.python.util.all_util import remove_undocumented
-
-_global_parser = _argparse.ArgumentParser()
-
-
-# pylint: disable=invalid-name
-
-
-class _FlagValues(object):
- """Global container and accessor for flags and their values."""
-
- def __init__(self):
- self.__dict__['__flags'] = {}
- self.__dict__['__parsed'] = False
- self.__dict__['__required_flags'] = set()
-
- def _parse_flags(self, args=None):
- result, unparsed = _global_parser.parse_known_args(args=args)
- for flag_name, val in vars(result).items():
- self.__dict__['__flags'][flag_name] = val
- self.__dict__['__parsed'] = True
- self._assert_all_required()
- return unparsed
-
- def __getattr__(self, name):
- """Retrieves the 'value' attribute of the flag --name."""
- try:
- parsed = self.__dict__['__parsed']
- except KeyError:
- # May happen during pickle.load or copy.copy
- raise AttributeError(name)
- if not parsed:
- self._parse_flags()
- if name not in self.__dict__['__flags']:
- raise AttributeError(name)
- return self.__dict__['__flags'][name]
-
- def __setattr__(self, name, value):
- """Sets the 'value' attribute of the flag --name."""
- if not self.__dict__['__parsed']:
- self._parse_flags()
- self.__dict__['__flags'][name] = value
- self._assert_required(name)
-
- def _add_required_flag(self, item):
- self.__dict__['__required_flags'].add(item)
-
- def _assert_required(self, flag_name):
- if (flag_name not in self.__dict__['__flags'] or
- self.__dict__['__flags'][flag_name] is None):
- raise AttributeError('Flag --%s must be specified.' % flag_name)
-
- def _assert_all_required(self):
- for flag_name in self.__dict__['__required_flags']:
- self._assert_required(flag_name)
-
-
-def _define_helper(flag_name, default_value, docstring, flagtype):
- """Registers 'flag_name' with 'default_value' and 'docstring'."""
- _global_parser.add_argument('--' + flag_name,
- default=default_value,
- help=docstring,
- type=flagtype)
-
-
-# Provides the global object that can be used to access flags.
-FLAGS = _FlagValues()
-
-
-def DEFINE_string(flag_name, default_value, docstring):
- """Defines a flag of type 'string'.
-
- Args:
- flag_name: The name of the flag as a string.
- default_value: The default value the flag should take as a string.
- docstring: A helpful message explaining the use of the flag.
- """
- _define_helper(flag_name, default_value, docstring, str)
-
-
-def DEFINE_integer(flag_name, default_value, docstring):
- """Defines a flag of type 'int'.
-
- Args:
- flag_name: The name of the flag as a string.
- default_value: The default value the flag should take as an int.
- docstring: A helpful message explaining the use of the flag.
- """
- _define_helper(flag_name, default_value, docstring, int)
-
-
-def DEFINE_boolean(flag_name, default_value, docstring):
- """Defines a flag of type 'boolean'.
-
- Args:
- flag_name: The name of the flag as a string.
- default_value: The default value the flag should take as a boolean.
- docstring: A helpful message explaining the use of the flag.
- """
- # Register a custom function for 'bool' so --flag=True works.
- def str2bool(v):
- return v.lower() in ('true', 't', '1')
- _global_parser.add_argument('--' + flag_name,
- nargs='?',
- const=True,
- help=docstring,
- default=default_value,
- type=str2bool)
-
- # Add negated version, stay consistent with argparse with regard to
- # dashes in flag names.
- _global_parser.add_argument('--no' + flag_name,
- action='store_false',
- dest=flag_name.replace('-', '_'))
-
-
-# The internal google library defines the following alias, so we match
-# the API for consistency.
-DEFINE_bool = DEFINE_boolean # pylint: disable=invalid-name
-
-
-def DEFINE_float(flag_name, default_value, docstring):
- """Defines a flag of type 'float'.
-
- Args:
- flag_name: The name of the flag as a string.
- default_value: The default value the flag should take as a float.
- docstring: A helpful message explaining the use of the flag.
- """
- _define_helper(flag_name, default_value, docstring, float)
-
-
-def mark_flag_as_required(flag_name):
- """Ensures that flag is not None during program execution.
-
- It is recommended to call this method like this:
-
- if __name__ == '__main__':
- tf.flags.mark_flag_as_required('your_flag_name')
- tf.app.run()
-
- Args:
- flag_name: string, name of the flag to mark as required.
-
- Raises:
- AttributeError: if flag_name is not registered as a valid flag name.
- NOTE: The exception raised will change in the future.
- """
- if _global_parser.get_default(flag_name) is not None:
- _logging.warn(
- 'Flag %s has a non-None default value; therefore, '
- 'mark_flag_as_required will pass even if flag is not specified in the '
- 'command line!' % flag_name)
- FLAGS._add_required_flag(flag_name)
-
-
-def mark_flags_as_required(flag_names):
- """Ensures that flags are not None during program execution.
-
- Recommended usage:
-
- if __name__ == '__main__':
- tf.flags.mark_flags_as_required(['flag1', 'flag2', 'flag3'])
- tf.app.run()
-
- Args:
- flag_names: a list/tuple of flag names to mark as required.
-
- Raises:
- AttributeError: If any of flag name has not already been defined as a flag.
- NOTE: The exception raised will change in the future.
- """
- for flag_name in flag_names:
- mark_flag_as_required(flag_name)
-
-
-_allowed_symbols = [
- # We rely on gflags documentation.
- 'DEFINE_bool',
- 'DEFINE_boolean',
- 'DEFINE_float',
- 'DEFINE_integer',
- 'DEFINE_string',
- 'FLAGS',
- 'mark_flag_as_required',
- 'mark_flags_as_required',
-]
-remove_undocumented(__name__, _allowed_symbols)
+# go/tf-wildcard-import
+from absl.flags import * # pylint: disable=wildcard-import
diff --git a/tensorflow/python/platform/flags_test.py b/tensorflow/python/platform/flags_test.py
index 7b08c3f8a6..23060e17d2 100644
--- a/tensorflow/python/platform/flags_test.py
+++ b/tensorflow/python/platform/flags_test.py
@@ -12,108 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for our flags implementation."""
+"""Sanity tests for tf.flags."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import copy
-import sys
import unittest
-from tensorflow.python.platform import app
-from tensorflow.python.platform import flags
-
-flags.DEFINE_string("string_foo", "default_val", "HelpString")
-flags.DEFINE_integer("int_foo", 42, "HelpString")
-flags.DEFINE_float("float_foo", 42.0, "HelpString")
+from absl import flags as absl_flags
-flags.DEFINE_boolean("bool_foo", True, "HelpString")
-flags.DEFINE_boolean("bool_negation", True, "HelpString")
-flags.DEFINE_boolean("bool-dash-negation", True, "HelpString")
-flags.DEFINE_boolean("bool_a", False, "HelpString")
-flags.DEFINE_boolean("bool_c", False, "HelpString")
-flags.DEFINE_boolean("bool_d", True, "HelpString")
-flags.DEFINE_bool("bool_e", True, "HelpString")
-flags.DEFINE_string("string_foo_required", "default_val", "HelpString")
-flags.DEFINE_string("none_string_foo_required", None, "HelpString")
-
-FLAGS = flags.FLAGS
+from tensorflow.python.platform import flags
class FlagsTest(unittest.TestCase):
- def testString(self):
- res = FLAGS.string_foo
- self.assertEqual(res, "default_val")
- FLAGS.string_foo = "bar"
- self.assertEqual("bar", FLAGS.string_foo)
-
- def testBool(self):
- res = FLAGS.bool_foo
- self.assertTrue(res)
- FLAGS.bool_foo = False
- self.assertFalse(FLAGS.bool_foo)
-
- def testBoolCommandLines(self):
- # Specified on command line with no args, sets to True,
- # even if default is False.
- self.assertEqual(True, FLAGS.bool_a)
-
- # --no before the flag forces it to False, even if the
- # default is True
- self.assertEqual(False, FLAGS.bool_negation)
-
- # --bool_flag=True sets to True
- self.assertEqual(True, FLAGS.bool_c)
-
- # --bool_flag=False sets to False
- self.assertEqual(False, FLAGS.bool_d)
-
- def testInt(self):
- res = FLAGS.int_foo
- self.assertEquals(res, 42)
- FLAGS.int_foo = -1
- self.assertEqual(-1, FLAGS.int_foo)
-
- def testFloat(self):
- res = FLAGS.float_foo
- self.assertEquals(42.0, res)
- FLAGS.float_foo = -1.0
- self.assertEqual(-1.0, FLAGS.float_foo)
-
- def test_copy(self):
- copied = copy.copy(FLAGS)
- self.assertEqual(copied.__dict__, FLAGS.__dict__)
-
- def testStringRequired(self):
- res = FLAGS.string_foo_required
- self.assertEqual(res, "default_val")
- FLAGS.string_foo_required = "bar"
- self.assertEqual("bar", FLAGS.string_foo_required)
-
- def testNoneStringRequired(self):
- res = FLAGS.none_string_foo_required
- self.assertEqual(res, "default_val")
- FLAGS.none_string_foo_required = "bar"
- self.assertEqual("bar", FLAGS.none_string_foo_required)
-
-
-def main(_):
- # unittest.main() tries to interpret the unknown flags, so use the
- # direct functions instead.
- runner = unittest.TextTestRunner()
- itersuite = unittest.TestLoader().loadTestsFromTestCase(FlagsTest)
- runner.run(itersuite)
+ def test_global_flags_object(self):
+ self.assertIs(flags.FLAGS, absl_flags.FLAGS)
if __name__ == "__main__":
- # Test command lines
- sys.argv.extend([
- "--bool_a", "--nobool_negation", "--bool_c=True", "--bool_d=False",
- "--none_string_foo_required=default_val",
- "and_argument"
- ])
- flags.mark_flag_as_required('string_foo_required')
- flags.mark_flags_as_required(['none_string_foo_required'])
- app.run()
+ unittest.main()