aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/platform
diff options
context:
space:
mode:
authorGravatar Yilei Yang <yileiyang@google.com>2017-11-28 12:26:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-28 12:30:17 -0800
commitb911049edfbb4a4eb07b3b46ed144da6cd33f9c1 (patch)
tree1675dc765c5879ee162533d5601b5f91d9b026dd /tensorflow/python/platform
parentba87a8030aa30f24c354cf705e79734658bb0a8b (diff)
Continue to allow old argument names specified in tf.flags.DEFINE functions.
There are more DEFINE functions in absl.flags, they only accept the absl names. PiperOrigin-RevId: 177199982
Diffstat (limited to 'tensorflow/python/platform')
-rw-r--r--tensorflow/python/platform/flags.py48
-rw-r--r--tensorflow/python/platform/flags_test.py41
2 files changed, 88 insertions, 1 deletions
diff --git a/tensorflow/python/platform/flags.py b/tensorflow/python/platform/flags.py
index e9a36ae75d..abd6f3d855 100644
--- a/tensorflow/python/platform/flags.py
+++ b/tensorflow/python/platform/flags.py
@@ -18,5 +18,53 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import logging as _logging
+
# go/tf-wildcard-import
from absl.flags import * # pylint: disable=wildcard-import
+import six as _six
+
+from tensorflow.python.util import tf_decorator
+
+
+# Since we wrap absl.flags DEFINE functions, we need to declare this module
+# does not affect key flags.
+disclaim_key_flags() # pylint: disable=undefined-variable
+
+
+_RENAMED_ARGUMENTS = {
+ 'flag_name': 'name',
+ 'default_value': 'default',
+ 'docstring': 'help',
+}
+
+
+def _wrap_define_function(original_function):
+ """Wraps absl.flags's define functions so tf.flags accepts old names."""
+
+ def wrapper(*args, **kwargs):
+ """Wrapper function that turns old keyword names to new ones."""
+ has_old_names = False
+ for old_name, new_name in _six.iteritems(_RENAMED_ARGUMENTS):
+ if old_name in kwargs:
+ has_old_names = True
+ value = kwargs.pop(old_name)
+ kwargs[new_name] = value
+ if has_old_names:
+ _logging.warning(
+ 'Use of the keyword argument names (flag_name, default_value, '
+ 'docstring) is deprecated, please use (name, default, help) instead.')
+ return original_function(*args, **kwargs)
+
+ return tf_decorator.make_decorator(original_function, wrapper)
+
+
+# pylint: disable=invalid-name,used-before-assignment
+# absl.flags APIs use `default` as the name of the default value argument.
+# Allow the following functions continue to accept `default_value`.
+DEFINE_string = _wrap_define_function(DEFINE_string)
+DEFINE_boolean = _wrap_define_function(DEFINE_boolean)
+DEFINE_bool = DEFINE_boolean
+DEFINE_float = _wrap_define_function(DEFINE_float)
+DEFINE_integer = _wrap_define_function(DEFINE_integer)
+# pylint: enable=invalid-name,used-before-assignment
diff --git a/tensorflow/python/platform/flags_test.py b/tensorflow/python/platform/flags_test.py
index 23060e17d2..e8200142dd 100644
--- a/tensorflow/python/platform/flags_test.py
+++ b/tensorflow/python/platform/flags_test.py
@@ -24,11 +24,50 @@ from absl import flags as absl_flags
from tensorflow.python.platform import flags
+flags.DEFINE_string(
+ flag_name='old_string', default_value='default', docstring='docstring')
+flags.DEFINE_string(
+ name='new_string', default='default', help='docstring')
+flags.DEFINE_integer(
+ flag_name='old_integer', default_value=1, docstring='docstring')
+flags.DEFINE_integer(
+ name='new_integer', default=1, help='docstring')
+flags.DEFINE_float(
+ flag_name='old_float', default_value=1.5, docstring='docstring')
+flags.DEFINE_float(
+ name='new_float', default=1.5, help='docstring')
+flags.DEFINE_bool(
+ flag_name='old_bool', default_value=True, docstring='docstring')
+flags.DEFINE_bool(
+ name='new_bool', default=True, help='docstring')
+flags.DEFINE_boolean(
+ flag_name='old_boolean', default_value=False, docstring='docstring')
+flags.DEFINE_boolean(
+ name='new_boolean', default=False, help='docstring')
+
+
class FlagsTest(unittest.TestCase):
def test_global_flags_object(self):
self.assertIs(flags.FLAGS, absl_flags.FLAGS)
+ def test_keyword_arguments(self):
+ test_cases = (
+ ('old_string', 'default'),
+ ('new_string', 'default'),
+ ('old_integer', 1),
+ ('new_integer', 1),
+ ('old_float', 1.5),
+ ('new_float', 1.5),
+ ('old_bool', True),
+ ('new_bool', True),
+ ('old_boolean', False),
+ ('new_boolean', False),
+ )
+ for flag_name, default_value in test_cases:
+ self.assertEqual(default_value, absl_flags.FLAGS[flag_name].default)
+ self.assertEqual('docstring', absl_flags.FLAGS[flag_name].help)
+
-if __name__ == "__main__":
+if __name__ == '__main__':
unittest.main()