aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-27 06:50:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 06:54:41 -0700
commit234229b014cb0cfe4bf8e9466db79d596085faba (patch)
tree96ae0a68d991da9598ac000ec988ac2d08a701a5 /tensorflow/python/util
parent77e2686a2958eb76e0164828d5d536b86c72464b (diff)
Update logic used in get_variable to populate custom_getter's kwargs.
The new implementation ensures that the 'constraints' kwarg is propagated by customer getters whose signature includes a keyworded, variable length argument dictionary, as well as those explicitly including the 'constraints' argument. PiperOrigin-RevId: 214767296
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/function_utils.py23
-rw-r--r--tensorflow/python/util/function_utils_test.py87
2 files changed, 110 insertions, 0 deletions
diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py
index 4e9b07e20a..a56dfbff8e 100644
--- a/tensorflow/python/util/function_utils.py
+++ b/tensorflow/python/util/function_utils.py
@@ -59,6 +59,29 @@ def fn_args(fn):
return tuple(args)
+def has_kwargs(fn):
+ """Returns whether the passed callable has **kwargs in its signature.
+
+ Args:
+ fn: Function, or function-like object (e.g., result of `functools.partial`).
+
+ Returns:
+ `bool`: if `fn` has **kwargs in its signature.
+
+ Raises:
+ `TypeError`: If fn is not a Function, or function-like object.
+ """
+ if isinstance(fn, functools.partial):
+ fn = fn.func
+ elif _is_callable_object(fn):
+ fn = fn.__call__
+ elif not callable(fn):
+ raise TypeError(
+ 'fn should be a function-like object, but is of type {}.'.format(
+ type(fn)))
+ return tf_inspect.getfullargspec(fn).varkw is not None
+
+
def get_func_name(func):
"""Returns name of passed callable."""
_, func = tf_decorator.unwrap(func)
diff --git a/tensorflow/python/util/function_utils_test.py b/tensorflow/python/util/function_utils_test.py
index 1588328c26..ce768637f5 100644
--- a/tensorflow/python/util/function_utils_test.py
+++ b/tensorflow/python/util/function_utils_test.py
@@ -135,6 +135,93 @@ class FnArgsTest(test.TestCase):
self.assertEqual(3, double_wrapped_fn(a=3))
+class HasKwargsTest(test.TestCase):
+
+ def test_simple_function(self):
+
+ fn_has_kwargs = lambda **x: x
+ self.assertTrue(function_utils.has_kwargs(fn_has_kwargs))
+
+ fn_has_no_kwargs = lambda x: x
+ self.assertFalse(function_utils.has_kwargs(fn_has_no_kwargs))
+
+ def test_callable(self):
+
+ class FooHasKwargs(object):
+
+ def __call__(self, **x):
+ del x
+ self.assertTrue(function_utils.has_kwargs(FooHasKwargs()))
+
+ class FooHasNoKwargs(object):
+
+ def __call__(self, x):
+ del x
+ self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs()))
+
+ def test_bounded_method(self):
+
+ class FooHasKwargs(object):
+
+ def fn(self, **x):
+ del x
+ self.assertTrue(function_utils.has_kwargs(FooHasKwargs().fn))
+
+ class FooHasNoKwargs(object):
+
+ def fn(self, x):
+ del x
+ self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs().fn))
+
+ def test_partial_function(self):
+ expected_test_arg = 123
+
+ def fn_has_kwargs(test_arg, **x):
+ if test_arg != expected_test_arg:
+ return ValueError('partial fn does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_kwargs, test_arg=123)
+ self.assertTrue(function_utils.has_kwargs(wrapped_fn))
+
+ def fn_has_no_kwargs(x, test_arg):
+ if test_arg != expected_test_arg:
+ return ValueError('partial fn does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_no_kwargs, test_arg=123)
+ self.assertFalse(function_utils.has_kwargs(wrapped_fn))
+
+ def test_double_partial(self):
+ expected_test_arg1 = 123
+ expected_test_arg2 = 456
+
+ def fn_has_kwargs(test_arg1, test_arg2, **x):
+ if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
+ return ValueError('partial does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_kwargs, test_arg2=456)
+ double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
+
+ self.assertTrue(function_utils.has_kwargs(double_wrapped_fn))
+
+ def fn_has_no_kwargs(x, test_arg1, test_arg2):
+ if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
+ return ValueError('partial does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_no_kwargs, test_arg2=456)
+ double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
+
+ self.assertFalse(function_utils.has_kwargs(double_wrapped_fn))
+
+ def test_raises_type_error(self):
+ with self.assertRaisesRegexp(
+ TypeError, 'fn should be a function-like object'):
+ function_utils.has_kwargs('not a function')
+
+
class GetFuncNameTest(test.TestCase):
def testWithSimpleFunction(self):