diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/partitioned_variables_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/partitioned_variables_test.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py index ba9359d923..1d0c2dceba 100644 --- a/tensorflow/python/kernel_tests/partitioned_variables_test.py +++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin @@ -34,6 +36,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent +from tensorflow.python.training import saver as saver_lib class PartitionerCreatorsTest(test.TestCase): @@ -622,6 +625,38 @@ class PartitionedVariablesTestCase(test.TestCase): variables.global_variables_initializer().run() self.assertAllClose([-0.4, -0.4], x.eval()) + def testMetaGraphSaveLoad(self): + save_prefix = os.path.join(self.get_temp_dir(), "ckpt") + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session( + graph=save_graph) as session: + partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0) + with variable_scope.variable_scope("root", partitioner=partitioner): + v0 = variable_scope.get_variable( + "v0", dtype=dtypes.float32, shape=(10, 10)) + v0_list = v0._get_variable_list() + v0_part = v0._get_partitions() + self.assertEqual(len(v0_list), 5) + self.assertAllEqual(v0_part, (5, 1)) + variables.global_variables_initializer().run() + + save_graph.get_collection_ref("partvar").append(v0) + saver = saver_lib.Saver() + save_graph.finalize() + save_path = saver.save(sess=session, save_path=save_prefix) + previous_value = session.run( + save_graph.get_tensor_by_name(v0.name + ":0")) + + restore_graph = ops.Graph() + with restore_graph.as_default(), self.test_session( + graph=restore_graph) as session: + saver = saver_lib.import_meta_graph(save_path + ".meta") + saver.restore(sess=session, save_path=save_path) + v0, = save_graph.get_collection_ref("partvar") + self.assertIsInstance(v0, variables.PartitionedVariable) + self.assertAllEqual( + previous_value, + session.run(restore_graph.get_tensor_by_name(v0.name + ":0"))) if __name__ == "__main__": test.main() |