aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/partitioned_variables_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/partitioned_variables_test.py')
-rw-r--r--tensorflow/python/kernel_tests/partitioned_variables_test.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py
index ba9359d923..1d0c2dceba 100644
--- a/tensorflow/python/kernel_tests/partitioned_variables_test.py
+++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -34,6 +36,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import saver as saver_lib
class PartitionerCreatorsTest(test.TestCase):
@@ -622,6 +625,38 @@ class PartitionedVariablesTestCase(test.TestCase):
variables.global_variables_initializer().run()
self.assertAllClose([-0.4, -0.4], x.eval())
+ def testMetaGraphSaveLoad(self):
+ save_prefix = os.path.join(self.get_temp_dir(), "ckpt")
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph) as session:
+ partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
+ with variable_scope.variable_scope("root", partitioner=partitioner):
+ v0 = variable_scope.get_variable(
+ "v0", dtype=dtypes.float32, shape=(10, 10))
+ v0_list = v0._get_variable_list()
+ v0_part = v0._get_partitions()
+ self.assertEqual(len(v0_list), 5)
+ self.assertAllEqual(v0_part, (5, 1))
+ variables.global_variables_initializer().run()
+
+ save_graph.get_collection_ref("partvar").append(v0)
+ saver = saver_lib.Saver()
+ save_graph.finalize()
+ save_path = saver.save(sess=session, save_path=save_prefix)
+ previous_value = session.run(
+ save_graph.get_tensor_by_name(v0.name + ":0"))
+
+ restore_graph = ops.Graph()
+ with restore_graph.as_default(), self.test_session(
+ graph=restore_graph) as session:
+ saver = saver_lib.import_meta_graph(save_path + ".meta")
+ saver.restore(sess=session, save_path=save_path)
+ v0, = save_graph.get_collection_ref("partvar")
+ self.assertIsInstance(v0, variables.PartitionedVariable)
+ self.assertAllEqual(
+ previous_value,
+ session.run(restore_graph.get_tensor_by_name(v0.name + ":0")))
if __name__ == "__main__":
test.main()