diff options
author | 2017-11-28 10:51:26 -0800 | |
---|---|---|
committer | 2017-11-28 10:56:51 -0800 | |
commit | c81a8ae591cf43b6d10b887dfb22a780af3beec0 (patch) | |
tree | 0e765e40a607f26f832ce1301618cf42bba9cccd | |
parent | 4cb754e0513262e6d89eacc90eb3673f2b405234 (diff) |
Make sure that additional ops added by Savers to read ResourceVariables are added to the graph in a deterministic way.
For ResourceVariables (op "VarHandleOp"), ops.internal_convert_to_tensor will add new ops such as "Read_8/ReadVariableOp". If op_list is cast to a set, as before this change, then adding these new ops made graph construction non-deterministic.
PiperOrigin-RevId: 177185279
-rw-r--r-- | tensorflow/python/training/saver.py | 5 | ||||
-rw-r--r-- | tensorflow/python/training/saver_test.py | 12 |
2 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 5bddde1698..bd47736d4b 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -523,7 +523,10 @@ class BaseSaverBuilder(object): if not isinstance(op_list, (list, tuple, set)): raise TypeError("Variables to save should be passed in a dict or a " "list: %s" % op_list) - op_list = set(op_list) + # When ResourceVariables are converted to Tensors, read ops are added to the + # graph. Sorting the op_list ensures that the resulting graph is always + # constructed in a deterministic way: + op_list = sorted(op_list, key=lambda x: x.name) names_to_saveables = {} # pylint: disable=protected-access for var in op_list: diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 744b17dd22..98ac197204 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -164,6 +164,18 @@ class SaverTest(test.TestCase): def testResourceBasic(self): self.basicSaveRestore(resource_variable_ops.ResourceVariable) + def testResourceVariableReadOpsAddedDeterministically(self): + graph_defs = [] + num_graphs = 10 + for _ in range(num_graphs): + with ops_lib.Graph().as_default() as g: + for i in range(20): + resource_variable_ops.ResourceVariable(i, name="var%s" % i) + saver_module.Saver() + graph_defs.append(g.as_graph_def()) + for i in range(num_graphs - 1): + self.assertEqual(graph_defs[i], graph_defs[i + 1]) + def testEagerBasic(self): with context.eager_mode(): ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt") |