aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-09-14 18:22:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-14 18:26:48 -0700
commit08589aa0c4447b21dd73183cf5cfafff326324dc (patch)
treed1cb488f337d9a7cd5138fd07564549febcd07ca /tensorflow/python/eager
parent0d4cb43a540f08cb73c00fac662c961e4154ac32 (diff)
Make accessed variable ordering deterministic again when constructing defuns
PiperOrigin-RevId: 213074939
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index f3fb48fd3b..e2874e25b6 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -888,13 +888,14 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
# Variables in `func_args`, `func_kwds` should be explicit inputs
# to the function, not captured inputs.
- variables = set(this_tape.watched_variables())
+ tape_variables = this_tape.watched_variables()
+ arg_variables = set()
inputs = []
for arg in nest.flatten(func_args) + nest.flatten(func_kwds):
if isinstance(arg, resource_variable_ops.ResourceVariable):
try:
resource_placeholder = func_graph.captures.pop(arg.handle)
- variables.remove(arg)
+ arg_variables.add(arg)
except KeyError:
# This case occurs if a Variable among the inputs is not actually
# used by the function; we still add an explicit input for it
@@ -904,6 +905,7 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
inputs.append(resource_placeholder)
elif isinstance(arg, ops.Tensor):
inputs.append(arg)
+ variables = [v for v in tape_variables if v not in arg_variables]
func_graph.inputs = inputs + list(func_graph.captures.values())
func_graph.structured_outputs = func_outputs
@@ -917,7 +919,6 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
# Instead of storing non-distributed component variables, we
# store their distributed containers so we can retrieve the correct
# component variables at call-time.
- variables = list(variables)
strategy = distribution_strategy_context.get_distribution_strategy()
for i, variable in enumerate(variables):
# If variable is not distributed value_container returns itself.