aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-28 10:51:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-28 10:56:51 -0800
commitc81a8ae591cf43b6d10b887dfb22a780af3beec0 (patch)
tree0e765e40a607f26f832ce1301618cf42bba9cccd /tensorflow/python
parent4cb754e0513262e6d89eacc90eb3673f2b405234 (diff)
Make sure that additional ops added by Savers to read ResourceVariables are added to the graph in a deterministic way.
For ResourceVariables (op "VarHandleOp"), ops.internal_convert_to_tensor will add new ops such as "Read_8/ReadVariableOp". If op_list is cast to a set, as before this change, then adding these new ops made graph construction non-deterministic. PiperOrigin-RevId: 177185279
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/training/saver.py5
-rw-r--r--tensorflow/python/training/saver_test.py12
2 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 5bddde1698..bd47736d4b 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -523,7 +523,10 @@ class BaseSaverBuilder(object):
if not isinstance(op_list, (list, tuple, set)):
raise TypeError("Variables to save should be passed in a dict or a "
"list: %s" % op_list)
- op_list = set(op_list)
+ # When ResourceVariables are converted to Tensors, read ops are added to the
+ # graph. Sorting the op_list ensures that the resulting graph is always
+ # constructed in a deterministic way:
+ op_list = sorted(op_list, key=lambda x: x.name)
names_to_saveables = {}
# pylint: disable=protected-access
for var in op_list:
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 744b17dd22..98ac197204 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -164,6 +164,18 @@ class SaverTest(test.TestCase):
def testResourceBasic(self):
self.basicSaveRestore(resource_variable_ops.ResourceVariable)
+ def testResourceVariableReadOpsAddedDeterministically(self):
+ graph_defs = []
+ num_graphs = 10
+ for _ in range(num_graphs):
+ with ops_lib.Graph().as_default() as g:
+ for i in range(20):
+ resource_variable_ops.ResourceVariable(i, name="var%s" % i)
+ saver_module.Saver()
+ graph_defs.append(g.as_graph_def())
+ for i in range(num_graphs - 1):
+ self.assertEqual(graph_defs[i], graph_defs[i + 1])
+
def testEagerBasic(self):
with context.eager_mode():
ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt")