diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-27 06:50:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 06:54:41 -0700 |
commit | 234229b014cb0cfe4bf8e9466db79d596085faba (patch) | |
tree | 96ae0a68d991da9598ac000ec988ac2d08a701a5 /tensorflow | |
parent | 77e2686a2958eb76e0164828d5d536b86c72464b (diff) |
Update logic used in get_variable to populate custom_getter's kwargs.
The new implementation ensures that the 'constraints' kwarg is propagated by customer getters whose signature includes a keyworded, variable length argument dictionary, as well as those explicitly including the 'constraints' argument.
PiperOrigin-RevId: 214767296
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/ops/variable_scope.py | 6 | ||||
-rw-r--r-- | tensorflow/python/util/function_utils.py | 23 | ||||
-rw-r--r-- | tensorflow/python/util/function_utils_test.py | 87 |
3 files changed, 114 insertions, 2 deletions
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index a43676cd70..562e1ad6cb 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -515,8 +515,10 @@ class _VariableStore(object): "synchronization": synchronization, "aggregation": aggregation, } - # `fn_args` can handle functions, `functools.partial`, `lambda`. - if "constraint" in function_utils.fn_args(custom_getter): + # `fn_args` and `has_kwargs` can handle functions, `functools.partial`, + # `lambda`. + if ("constraint" in function_utils.fn_args(custom_getter) or + function_utils.has_kwargs(custom_getter)): custom_getter_kwargs["constraint"] = constraint return custom_getter(**custom_getter_kwargs) else: diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py index 4e9b07e20a..a56dfbff8e 100644 --- a/tensorflow/python/util/function_utils.py +++ b/tensorflow/python/util/function_utils.py @@ -59,6 +59,29 @@ def fn_args(fn): return tuple(args) +def has_kwargs(fn): + """Returns whether the passed callable has **kwargs in its signature. + + Args: + fn: Function, or function-like object (e.g., result of `functools.partial`). + + Returns: + `bool`: if `fn` has **kwargs in its signature. + + Raises: + `TypeError`: If fn is not a Function, or function-like object. + """ + if isinstance(fn, functools.partial): + fn = fn.func + elif _is_callable_object(fn): + fn = fn.__call__ + elif not callable(fn): + raise TypeError( + 'fn should be a function-like object, but is of type {}.'.format( + type(fn))) + return tf_inspect.getfullargspec(fn).varkw is not None + + def get_func_name(func): """Returns name of passed callable.""" _, func = tf_decorator.unwrap(func) diff --git a/tensorflow/python/util/function_utils_test.py b/tensorflow/python/util/function_utils_test.py index 1588328c26..ce768637f5 100644 --- a/tensorflow/python/util/function_utils_test.py +++ b/tensorflow/python/util/function_utils_test.py @@ -135,6 +135,93 @@ class FnArgsTest(test.TestCase): self.assertEqual(3, double_wrapped_fn(a=3)) +class HasKwargsTest(test.TestCase): + + def test_simple_function(self): + + fn_has_kwargs = lambda **x: x + self.assertTrue(function_utils.has_kwargs(fn_has_kwargs)) + + fn_has_no_kwargs = lambda x: x + self.assertFalse(function_utils.has_kwargs(fn_has_no_kwargs)) + + def test_callable(self): + + class FooHasKwargs(object): + + def __call__(self, **x): + del x + self.assertTrue(function_utils.has_kwargs(FooHasKwargs())) + + class FooHasNoKwargs(object): + + def __call__(self, x): + del x + self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs())) + + def test_bounded_method(self): + + class FooHasKwargs(object): + + def fn(self, **x): + del x + self.assertTrue(function_utils.has_kwargs(FooHasKwargs().fn)) + + class FooHasNoKwargs(object): + + def fn(self, x): + del x + self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs().fn)) + + def test_partial_function(self): + expected_test_arg = 123 + + def fn_has_kwargs(test_arg, **x): + if test_arg != expected_test_arg: + return ValueError('partial fn does not work correctly') + return x + + wrapped_fn = functools.partial(fn_has_kwargs, test_arg=123) + self.assertTrue(function_utils.has_kwargs(wrapped_fn)) + + def fn_has_no_kwargs(x, test_arg): + if test_arg != expected_test_arg: + return ValueError('partial fn does not work correctly') + return x + + wrapped_fn = functools.partial(fn_has_no_kwargs, test_arg=123) + self.assertFalse(function_utils.has_kwargs(wrapped_fn)) + + def test_double_partial(self): + expected_test_arg1 = 123 + expected_test_arg2 = 456 + + def fn_has_kwargs(test_arg1, test_arg2, **x): + if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: + return ValueError('partial does not work correctly') + return x + + wrapped_fn = functools.partial(fn_has_kwargs, test_arg2=456) + double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123) + + self.assertTrue(function_utils.has_kwargs(double_wrapped_fn)) + + def fn_has_no_kwargs(x, test_arg1, test_arg2): + if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: + return ValueError('partial does not work correctly') + return x + + wrapped_fn = functools.partial(fn_has_no_kwargs, test_arg2=456) + double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123) + + self.assertFalse(function_utils.has_kwargs(double_wrapped_fn)) + + def test_raises_type_error(self): + with self.assertRaisesRegexp( + TypeError, 'fn should be a function-like object'): + function_utils.has_kwargs('not a function') + + class GetFuncNameTest(test.TestCase): def testWithSimpleFunction(self): |