aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2016-12-07 16:29:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-07 16:43:36 -0800
commit6809ac80b2367a993cccc2596fd99de09b2f8fd8 (patch)
treeded0836c2e8ad912e0f06e9a4c4aea1911394037
parent539fafd32ebf2a804e720811167cd6e902190e7d (diff)
Add a decorator to enforce function be called with keyworded args (only).
Change: 141373569
-rw-r--r--tensorflow/python/BUILD9
-rw-r--r--tensorflow/python/util/keyword_args.py52
-rw-r--r--tensorflow/python/util/keyword_args_test.py51
3 files changed, 112 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 0b88404560..544e27a0b1 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -244,6 +244,15 @@ py_test(
],
)
+py_test(
+ name = "keyword_args_test",
+ srcs = ["util/keyword_args_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
cc_library(
name = "python_op_gen",
srcs = ["framework/python_op_gen.cc"],
diff --git a/tensorflow/python/util/keyword_args.py b/tensorflow/python/util/keyword_args.py
new file mode 100644
index 0000000000..56bd0a63e6
--- /dev/null
+++ b/tensorflow/python/util/keyword_args.py
@@ -0,0 +1,52 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Keyword args functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+from tensorflow.python.util import decorator_utils
+
+
+def keyword_args_only(func):
+ """Decorator for marking specific function accepting keyword args only.
+
+ This decorator raises a `ValueError` if the input `func` is called with any
+ non-keyword args. This prevents the caller from providing the arguments in
+ wrong order.
+
+ Args:
+ func: The function or method needed to be decorated.
+
+ Returns:
+ Decorated function or method.
+
+ Raises:
+ ValueError: If `func` is not callable.
+ """
+
+ decorator_utils.validate_callable(func, "keyword_args_only")
+ @functools.wraps(func)
+ def new_func(*args, **kwargs):
+ """Keyword args only wrapper."""
+ if args:
+ raise ValueError(
+ "Must use keyword args to call {}.".format(func.__name__))
+ return func(**kwargs)
+ return new_func
diff --git a/tensorflow/python/util/keyword_args_test.py b/tensorflow/python/util/keyword_args_test.py
new file mode 100644
index 0000000000..08a24b15fa
--- /dev/null
+++ b/tensorflow/python/util/keyword_args_test.py
@@ -0,0 +1,51 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keyword args tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.python.util import keyword_args
+
+
+class KeywordArgsTest(tf.test.TestCase):
+
+ def test_keyword_args_only(self):
+ def func_without_decorator(a, b):
+ return a+b
+
+ @keyword_args.keyword_args_only
+ def func_with_decorator(a, b):
+ return func_without_decorator(a, b)
+
+ self.assertEqual(3, func_without_decorator(1, 2))
+ self.assertEqual(3, func_without_decorator(a=1, b=2))
+ self.assertEqual(3, func_with_decorator(a=1, b=2))
+
+ # Providing non-keyword args should fail.
+ with self.assertRaisesRegexp(
+ ValueError, "Must use keyword args to call func_with_decorator."):
+ self.assertEqual(3, func_with_decorator(1, 2))
+
+ # Partially providing keyword args should fail.
+ with self.assertRaisesRegexp(
+ ValueError, "Must use keyword args to call func_with_decorator."):
+ self.assertEqual(3, func_with_decorator(1, b=2))
+
+
+if __name__ == "__main__":
+ tf.test.main()