diff options
author | Yilei Yang <yileiyang@google.com> | 2017-11-28 12:26:12 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-28 12:30:17 -0800 |
commit | b911049edfbb4a4eb07b3b46ed144da6cd33f9c1 (patch) | |
tree | 1675dc765c5879ee162533d5601b5f91d9b026dd /tensorflow/python/platform | |
parent | ba87a8030aa30f24c354cf705e79734658bb0a8b (diff) |
Continue to allow old argument names specified in tf.flags.DEFINE functions.
There are more DEFINE functions in absl.flags, they only accept the absl names.
PiperOrigin-RevId: 177199982
Diffstat (limited to 'tensorflow/python/platform')
-rw-r--r-- | tensorflow/python/platform/flags.py | 48 | ||||
-rw-r--r-- | tensorflow/python/platform/flags_test.py | 41 |
2 files changed, 88 insertions, 1 deletions
diff --git a/tensorflow/python/platform/flags.py b/tensorflow/python/platform/flags.py index e9a36ae75d..abd6f3d855 100644 --- a/tensorflow/python/platform/flags.py +++ b/tensorflow/python/platform/flags.py @@ -18,5 +18,53 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import logging as _logging + # go/tf-wildcard-import from absl.flags import * # pylint: disable=wildcard-import +import six as _six + +from tensorflow.python.util import tf_decorator + + +# Since we wrap absl.flags DEFINE functions, we need to declare this module +# does not affect key flags. +disclaim_key_flags() # pylint: disable=undefined-variable + + +_RENAMED_ARGUMENTS = { + 'flag_name': 'name', + 'default_value': 'default', + 'docstring': 'help', +} + + +def _wrap_define_function(original_function): + """Wraps absl.flags's define functions so tf.flags accepts old names.""" + + def wrapper(*args, **kwargs): + """Wrapper function that turns old keyword names to new ones.""" + has_old_names = False + for old_name, new_name in _six.iteritems(_RENAMED_ARGUMENTS): + if old_name in kwargs: + has_old_names = True + value = kwargs.pop(old_name) + kwargs[new_name] = value + if has_old_names: + _logging.warning( + 'Use of the keyword argument names (flag_name, default_value, ' + 'docstring) is deprecated, please use (name, default, help) instead.') + return original_function(*args, **kwargs) + + return tf_decorator.make_decorator(original_function, wrapper) + + +# pylint: disable=invalid-name,used-before-assignment +# absl.flags APIs use `default` as the name of the default value argument. +# Allow the following functions continue to accept `default_value`. +DEFINE_string = _wrap_define_function(DEFINE_string) +DEFINE_boolean = _wrap_define_function(DEFINE_boolean) +DEFINE_bool = DEFINE_boolean +DEFINE_float = _wrap_define_function(DEFINE_float) +DEFINE_integer = _wrap_define_function(DEFINE_integer) +# pylint: enable=invalid-name,used-before-assignment diff --git a/tensorflow/python/platform/flags_test.py b/tensorflow/python/platform/flags_test.py index 23060e17d2..e8200142dd 100644 --- a/tensorflow/python/platform/flags_test.py +++ b/tensorflow/python/platform/flags_test.py @@ -24,11 +24,50 @@ from absl import flags as absl_flags from tensorflow.python.platform import flags +flags.DEFINE_string( + flag_name='old_string', default_value='default', docstring='docstring') +flags.DEFINE_string( + name='new_string', default='default', help='docstring') +flags.DEFINE_integer( + flag_name='old_integer', default_value=1, docstring='docstring') +flags.DEFINE_integer( + name='new_integer', default=1, help='docstring') +flags.DEFINE_float( + flag_name='old_float', default_value=1.5, docstring='docstring') +flags.DEFINE_float( + name='new_float', default=1.5, help='docstring') +flags.DEFINE_bool( + flag_name='old_bool', default_value=True, docstring='docstring') +flags.DEFINE_bool( + name='new_bool', default=True, help='docstring') +flags.DEFINE_boolean( + flag_name='old_boolean', default_value=False, docstring='docstring') +flags.DEFINE_boolean( + name='new_boolean', default=False, help='docstring') + + class FlagsTest(unittest.TestCase): def test_global_flags_object(self): self.assertIs(flags.FLAGS, absl_flags.FLAGS) + def test_keyword_arguments(self): + test_cases = ( + ('old_string', 'default'), + ('new_string', 'default'), + ('old_integer', 1), + ('new_integer', 1), + ('old_float', 1.5), + ('new_float', 1.5), + ('old_bool', True), + ('new_bool', True), + ('old_boolean', False), + ('new_boolean', False), + ) + for flag_name, default_value in test_cases: + self.assertEqual(default_value, absl_flags.FLAGS[flag_name].default) + self.assertEqual('docstring', absl_flags.FLAGS[flag_name].help) + -if __name__ == "__main__": +if __name__ == '__main__': unittest.main() |