From 10ea197604b6760652773d2525f850bfd3238061 Mon Sep 17 00:00:00 2001 From: Zongheng Yang Date: Tue, 12 Jul 2016 20:13:45 -0800 Subject: Partitioned vars: test for "save in N slices, restore in M slices" scenario. Where N != M. Previously, the Python tests were missing, so this functionality may not have been as widely exposed to users as we'd like. Change: 127276419 --- tensorflow/python/training/saver_test.py | 67 ++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 3f7150539d..3934f74feb 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -376,6 +376,73 @@ class SaveRestoreShardedTest(tf.test.TestCase): sd = save.as_saver_def() self.assertTrue(sd.sharded) + def testPartitionedVariables(self): + var_full_shape = [10, 3] + # Allows save/restore mechanism to work w/ different slicings. + var_name = "my_var" + saved_path = os.path.join(_TestDir("partitioned_variables"), "ckpt") + + def _Save(slices): + with self.test_session(graph=tf.Graph()) as sess: + # Calls .eval() to return the ndarray that makes up the full variable. + rnd = tf.random_uniform(var_full_shape).eval() + + if slices: + vs = tf.create_partitioned_variables(var_full_shape, + slices, + rnd, + name=var_name) + else: + vs = [tf.Variable(rnd, name=var_name)] + + tf.initialize_all_variables().run() + saver = tf.train.Saver(vs) + actual_path = saver.save(sess, saved_path) + self.assertEqual(saved_path, actual_path) + + return rnd + + def _Restore(slices): + with self.test_session(graph=tf.Graph()) as sess: + if slices: + new_vs = tf.create_partitioned_variables( + var_full_shape, + slices, + tf.zeros(var_full_shape), # != original contents. + name=var_name) + else: + new_vs = [tf.Variable( + tf.zeros(shape=var_full_shape), # != original contents. + name=var_name)] + + tf.initialize_all_variables().run() + saver = tf.train.Saver(new_vs) + saver.restore(sess, saved_path) + + if slices and slices[0] != 1: + return tf.concat(0, new_vs).eval() + elif slices and slices[1] != 1: + return tf.concat(1, new_vs).eval() + else: # Non-sliced. + return new_vs[0].eval() + + # Saves 10 horizontal parts of a partitioned variable. + # Restores into a full variable, non-sliced. + saved_full = _Save(slices=[10, 1]) + restored_full = _Restore(slices=None) + self.assertAllEqual(saved_full, restored_full) + + # Restores into a different number/orientation of slices. + restored_full = _Restore(slices=[2, 1]) # 2 horizon parts. + self.assertAllEqual(saved_full, restored_full) + restored_full = _Restore(slices=[1, 3]) # 3 vertical parts. + self.assertAllEqual(saved_full, restored_full) + + # Now, saves a full variable and restores in slices. + saved_full = _Save(slices=None) + restored_full = _Restore(slices=[1, 3]) + self.assertAllEqual(saved_full, restored_full) + class MaxToKeepTest(tf.test.TestCase): -- cgit v1.2.3