aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2017-11-07 08:53:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-07 08:57:41 -0800
commit955b506d24a63e77772ecda28af9cab1ceffb9e7 (patch)
tree629d0b6016ce403091c0b82d86b074af34525155
parentda1b1d28faca9aa65e832b9bbada8d509ea2df7d (diff)
variables_to_restore: Differentiate python variables by string name rather than object.
variables_to_restore ensured that duplicate variables weren't added to the return map by comparing python variable object. Normally there is only one Variable object for each underlying variable, so this wasn't a problem. But when one initializes a graph by importing a GraphDef, duplicate python Variable objects are created for each occurrence of a variable in a collection (say, global variables and moving average variables). This change fixes variables_to_restore to work with an imported graph def by not comparing Variable objects. PiperOrigin-RevId: 174861804
-rw-r--r--tensorflow/python/training/moving_averages.py7
-rw-r--r--tensorflow/python/training/moving_averages_test.py27
2 files changed, 31 insertions, 3 deletions
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index eb07343850..e34c759e89 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -498,8 +498,9 @@ class ExponentialMovingAverage(object):
# Collect all the variables with moving average,
for v in moving_avg_variables:
name_map[self.average_name(v)] = v
- # Make sure we restore variables without moving average as well.
- for v in list(set(variables.global_variables()) - moving_avg_variables):
- if v.op.name not in name_map:
+ # Make sure we restore variables without moving averages as well.
+ moving_avg_variable_names = set([v.name for v in moving_avg_variables])
+ for v in list(set(variables.global_variables())):
+ if v.name not in moving_avg_variable_names and v.op.name not in name_map:
name_map[v.op.name] = v
return name_map
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 63604cf19d..6efdeb2866 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import moving_averages
+from tensorflow.python.training import saver as saver_lib
class MovingAveragesTest(test.TestCase):
@@ -392,6 +393,32 @@ class ExponentialMovingAverageTest(test.TestCase):
self.assertEqual([b"loc:@v1"], ema.average(v1).op.colocation_groups())
self.assertDeviceEqual("/job:default", ema.average(tensor2).device)
+ def _ExportAndImportGraph(self, graph):
+ """Export and import graph into a new graph."""
+ meta_graph = saver_lib.export_meta_graph(
+ graph=graph, collection_list=graph.get_all_collection_keys())
+ graph_copy = ops.Graph()
+ with graph_copy.as_default():
+ _ = saver_lib.import_meta_graph(meta_graph)
+ return graph_copy
+
+ def testImportedGraphVariablesToRestore(self):
+ g = ops.Graph()
+ with g.as_default():
+ variables.Variable(10.0, name="v")
+ # Export and import the graph into a new graph.
+ g_copy = self._ExportAndImportGraph(g)
+ with g_copy.as_default():
+ ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg")
+ vars_to_restore = ema.variables_to_restore()
+ # There should only be one variable in vars_to_restore. This is important
+ # to check because when importing from a GraphDef, TF makes duplicate
+ # python Variable objects referring to the same underlying variable. We
+ # need to be sure that two variables referring to the same variable don't
+ # both get added to vars_to_restore.
+ self.assertEqual(len(vars_to_restore), 1)
+ self.assertTrue("v/foo_avg" in vars_to_restore)
+
if __name__ == "__main__":
test.main()