diff options
author | 2018-09-27 06:50:20 -0700 | |
---|---|---|
committer | 2018-09-27 06:54:41 -0700 | |
commit | 234229b014cb0cfe4bf8e9466db79d596085faba (patch) | |
tree | 96ae0a68d991da9598ac000ec988ac2d08a701a5 /tensorflow/python/ops | |
parent | 77e2686a2958eb76e0164828d5d536b86c72464b (diff) |
Update logic used in get_variable to populate custom_getter's kwargs.
The new implementation ensures that the 'constraints' kwarg is propagated by customer getters whose signature includes a keyworded, variable length argument dictionary, as well as those explicitly including the 'constraints' argument.
PiperOrigin-RevId: 214767296
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r-- | tensorflow/python/ops/variable_scope.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index a43676cd70..562e1ad6cb 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -515,8 +515,10 @@ class _VariableStore(object): "synchronization": synchronization, "aggregation": aggregation, } - # `fn_args` can handle functions, `functools.partial`, `lambda`. - if "constraint" in function_utils.fn_args(custom_getter): + # `fn_args` and `has_kwargs` can handle functions, `functools.partial`, + # `lambda`. + if ("constraint" in function_utils.fn_args(custom_getter) or + function_utils.has_kwargs(custom_getter)): custom_getter_kwargs["constraint"] = constraint return custom_getter(**custom_getter_kwargs) else: |