diff options
author | 2016-07-12 20:13:45 -0800 | |
---|---|---|
committer | 2016-07-12 21:17:40 -0700 | |
commit | 10ea197604b6760652773d2525f850bfd3238061 (patch) | |
tree | 388d97b6ca0450af104245e0b44cdcb0f6a076b8 | |
parent | b7c375735732641b376ac4bdcc84bd74f1e747e4 (diff) |
Partitioned vars: test for "save in N slices, restore in M slices" scenario.
Where N != M.
Previously, the Python tests were missing, so this functionality may not have been as
widely exposed to users as we'd like.
Change: 127276419
-rw-r--r-- | tensorflow/python/training/saver_test.py | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 3f7150539d..3934f74feb 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -376,6 +376,73 @@ class SaveRestoreShardedTest(tf.test.TestCase): sd = save.as_saver_def() self.assertTrue(sd.sharded) + def testPartitionedVariables(self): + var_full_shape = [10, 3] + # Allows save/restore mechanism to work w/ different slicings. + var_name = "my_var" + saved_path = os.path.join(_TestDir("partitioned_variables"), "ckpt") + + def _Save(slices): + with self.test_session(graph=tf.Graph()) as sess: + # Calls .eval() to return the ndarray that makes up the full variable. + rnd = tf.random_uniform(var_full_shape).eval() + + if slices: + vs = tf.create_partitioned_variables(var_full_shape, + slices, + rnd, + name=var_name) + else: + vs = [tf.Variable(rnd, name=var_name)] + + tf.initialize_all_variables().run() + saver = tf.train.Saver(vs) + actual_path = saver.save(sess, saved_path) + self.assertEqual(saved_path, actual_path) + + return rnd + + def _Restore(slices): + with self.test_session(graph=tf.Graph()) as sess: + if slices: + new_vs = tf.create_partitioned_variables( + var_full_shape, + slices, + tf.zeros(var_full_shape), # != original contents. + name=var_name) + else: + new_vs = [tf.Variable( + tf.zeros(shape=var_full_shape), # != original contents. + name=var_name)] + + tf.initialize_all_variables().run() + saver = tf.train.Saver(new_vs) + saver.restore(sess, saved_path) + + if slices and slices[0] != 1: + return tf.concat(0, new_vs).eval() + elif slices and slices[1] != 1: + return tf.concat(1, new_vs).eval() + else: # Non-sliced. + return new_vs[0].eval() + + # Saves 10 horizontal parts of a partitioned variable. + # Restores into a full variable, non-sliced. + saved_full = _Save(slices=[10, 1]) + restored_full = _Restore(slices=None) + self.assertAllEqual(saved_full, restored_full) + + # Restores into a different number/orientation of slices. + restored_full = _Restore(slices=[2, 1]) # 2 horizon parts. + self.assertAllEqual(saved_full, restored_full) + restored_full = _Restore(slices=[1, 3]) # 3 vertical parts. + self.assertAllEqual(saved_full, restored_full) + + # Now, saves a full variable and restores in slices. + saved_full = _Save(slices=None) + restored_full = _Restore(slices=[1, 3]) + self.assertAllEqual(saved_full, restored_full) + class MaxToKeepTest(tf.test.TestCase): |