diff options
author | 2017-11-07 08:53:39 -0800 | |
---|---|---|
committer | 2017-11-07 08:57:41 -0800 | |
commit | 955b506d24a63e77772ecda28af9cab1ceffb9e7 (patch) | |
tree | 629d0b6016ce403091c0b82d86b074af34525155 | |
parent | da1b1d28faca9aa65e832b9bbada8d509ea2df7d (diff) |
variables_to_restore: Differentiate python variables by string name rather than object.
variables_to_restore ensured that duplicate variables weren't added to the return map by comparing python variable object. Normally there is only one Variable object for each underlying variable, so this wasn't a problem. But when one initializes a graph by importing a GraphDef, duplicate python Variable objects are created for each occurrence of a variable in a collection (say, global variables and moving average variables).
This change fixes variables_to_restore to work with an imported graph def by not comparing Variable objects.
PiperOrigin-RevId: 174861804
-rw-r--r-- | tensorflow/python/training/moving_averages.py | 7 | ||||
-rw-r--r-- | tensorflow/python/training/moving_averages_test.py | 27 |
2 files changed, 31 insertions, 3 deletions
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index eb07343850..e34c759e89 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -498,8 +498,9 @@ class ExponentialMovingAverage(object): # Collect all the variables with moving average, for v in moving_avg_variables: name_map[self.average_name(v)] = v - # Make sure we restore variables without moving average as well. - for v in list(set(variables.global_variables()) - moving_avg_variables): - if v.op.name not in name_map: + # Make sure we restore variables without moving averages as well. + moving_avg_variable_names = set([v.name for v in moving_avg_variables]) + for v in list(set(variables.global_variables())): + if v.name not in moving_avg_variable_names and v.op.name not in name_map: name_map[v.op.name] = v return name_map diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py index 63604cf19d..6efdeb2866 100644 --- a/tensorflow/python/training/moving_averages_test.py +++ b/tensorflow/python/training/moving_averages_test.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import moving_averages +from tensorflow.python.training import saver as saver_lib class MovingAveragesTest(test.TestCase): @@ -392,6 +393,32 @@ class ExponentialMovingAverageTest(test.TestCase): self.assertEqual([b"loc:@v1"], ema.average(v1).op.colocation_groups()) self.assertDeviceEqual("/job:default", ema.average(tensor2).device) + def _ExportAndImportGraph(self, graph): + """Export and import graph into a new graph.""" + meta_graph = saver_lib.export_meta_graph( + graph=graph, collection_list=graph.get_all_collection_keys()) + graph_copy = ops.Graph() + with graph_copy.as_default(): + _ = saver_lib.import_meta_graph(meta_graph) + return graph_copy + + def testImportedGraphVariablesToRestore(self): + g = ops.Graph() + with g.as_default(): + variables.Variable(10.0, name="v") + # Export and import the graph into a new graph. + g_copy = self._ExportAndImportGraph(g) + with g_copy.as_default(): + ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg") + vars_to_restore = ema.variables_to_restore() + # There should only be one variable in vars_to_restore. This is important + # to check because when importing from a GraphDef, TF makes duplicate + # python Variable objects referring to the same underlying variable. We + # need to be sure that two variables referring to the same variable don't + # both get added to vars_to_restore. + self.assertEqual(len(vars_to_restore), 1) + self.assertTrue("v/foo_avg" in vars_to_restore) + if __name__ == "__main__": test.main() |