aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Zongheng Yang <zongheng@google.com>2016-07-12 20:13:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-12 21:17:40 -0700
commit10ea197604b6760652773d2525f850bfd3238061 (patch)
tree388d97b6ca0450af104245e0b44cdcb0f6a076b8
parentb7c375735732641b376ac4bdcc84bd74f1e747e4 (diff)
Partitioned vars: test for "save in N slices, restore in M slices" scenario.
Where N != M. Previously, the Python tests were missing, so this functionality may not have been as widely exposed to users as we'd like. Change: 127276419
-rw-r--r--tensorflow/python/training/saver_test.py67
1 files changed, 67 insertions, 0 deletions
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 3f7150539d..3934f74feb 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -376,6 +376,73 @@ class SaveRestoreShardedTest(tf.test.TestCase):
sd = save.as_saver_def()
self.assertTrue(sd.sharded)
+ def testPartitionedVariables(self):
+ var_full_shape = [10, 3]
+ # Allows save/restore mechanism to work w/ different slicings.
+ var_name = "my_var"
+ saved_path = os.path.join(_TestDir("partitioned_variables"), "ckpt")
+
+ def _Save(slices):
+ with self.test_session(graph=tf.Graph()) as sess:
+ # Calls .eval() to return the ndarray that makes up the full variable.
+ rnd = tf.random_uniform(var_full_shape).eval()
+
+ if slices:
+ vs = tf.create_partitioned_variables(var_full_shape,
+ slices,
+ rnd,
+ name=var_name)
+ else:
+ vs = [tf.Variable(rnd, name=var_name)]
+
+ tf.initialize_all_variables().run()
+ saver = tf.train.Saver(vs)
+ actual_path = saver.save(sess, saved_path)
+ self.assertEqual(saved_path, actual_path)
+
+ return rnd
+
+ def _Restore(slices):
+ with self.test_session(graph=tf.Graph()) as sess:
+ if slices:
+ new_vs = tf.create_partitioned_variables(
+ var_full_shape,
+ slices,
+ tf.zeros(var_full_shape), # != original contents.
+ name=var_name)
+ else:
+ new_vs = [tf.Variable(
+ tf.zeros(shape=var_full_shape), # != original contents.
+ name=var_name)]
+
+ tf.initialize_all_variables().run()
+ saver = tf.train.Saver(new_vs)
+ saver.restore(sess, saved_path)
+
+ if slices and slices[0] != 1:
+ return tf.concat(0, new_vs).eval()
+ elif slices and slices[1] != 1:
+ return tf.concat(1, new_vs).eval()
+ else: # Non-sliced.
+ return new_vs[0].eval()
+
+ # Saves 10 horizontal parts of a partitioned variable.
+ # Restores into a full variable, non-sliced.
+ saved_full = _Save(slices=[10, 1])
+ restored_full = _Restore(slices=None)
+ self.assertAllEqual(saved_full, restored_full)
+
+ # Restores into a different number/orientation of slices.
+ restored_full = _Restore(slices=[2, 1]) # 2 horizon parts.
+ self.assertAllEqual(saved_full, restored_full)
+ restored_full = _Restore(slices=[1, 3]) # 3 vertical parts.
+ self.assertAllEqual(saved_full, restored_full)
+
+ # Now, saves a full variable and restores in slices.
+ saved_full = _Save(slices=None)
+ restored_full = _Restore(slices=[1, 3])
+ self.assertAllEqual(saved_full, restored_full)
+
class MaxToKeepTest(tf.test.TestCase):