diff options
author | 2018-09-14 18:22:52 -0700 | |
---|---|---|
committer | 2018-09-14 18:26:48 -0700 | |
commit | 08589aa0c4447b21dd73183cf5cfafff326324dc (patch) | |
tree | d1cb488f337d9a7cd5138fd07564549febcd07ca /tensorflow/python/eager | |
parent | 0d4cb43a540f08cb73c00fac662c961e4154ac32 (diff) |
Make accessed variable ordering deterministic again when constructing defuns
PiperOrigin-RevId: 213074939
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/function.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index f3fb48fd3b..e2874e25b6 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -888,13 +888,14 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): # Variables in `func_args`, `func_kwds` should be explicit inputs # to the function, not captured inputs. - variables = set(this_tape.watched_variables()) + tape_variables = this_tape.watched_variables() + arg_variables = set() inputs = [] for arg in nest.flatten(func_args) + nest.flatten(func_kwds): if isinstance(arg, resource_variable_ops.ResourceVariable): try: resource_placeholder = func_graph.captures.pop(arg.handle) - variables.remove(arg) + arg_variables.add(arg) except KeyError: # This case occurs if a Variable among the inputs is not actually # used by the function; we still add an explicit input for it @@ -904,6 +905,7 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): inputs.append(resource_placeholder) elif isinstance(arg, ops.Tensor): inputs.append(arg) + variables = [v for v in tape_variables if v not in arg_variables] func_graph.inputs = inputs + list(func_graph.captures.values()) func_graph.structured_outputs = func_outputs @@ -917,7 +919,6 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): # Instead of storing non-distributed component variables, we # store their distributed containers so we can retrieve the correct # component variables at call-time. - variables = list(variables) strategy = distribution_strategy_context.get_distribution_strategy() for i, variable in enumerate(variables): # If variable is not distributed value_container returns itself. |