aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/saver_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/saver_test.py')
-rw-r--r--tensorflow/python/training/saver_test.py72
1 files changed, 47 insertions, 25 deletions
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index f235300eb5..ecce8ae6bd 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -24,10 +24,8 @@ import math
import os
import random
import shutil
-import sys
import tempfile
import time
-import traceback
import numpy as np
import six
@@ -176,6 +174,24 @@ class SaverTest(test.TestCase):
def testResourceBasic(self):
self.basicSaveRestore(resource_variable_ops.ResourceVariable)
+ def testResourceColocation(self):
+ partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2)
+ with ops_lib.device("/job:ps/device:GPU:0"):
+ v = variable_scope.get_variable("v0",
+ shape=[10, 2],
+ partitioner=partitioner,
+ use_resource=True)
+ saver_module.Saver({"v0": v}).build()
+ save_op = None
+ for op in ops_lib.get_default_graph().get_operations():
+ if op.type == "SaveV2":
+ save_op = op
+ break
+ assert save_op is not None
+ for save_inp in save_op.inputs[3:]:
+ # Input to SaveV2 op is placed on CPU of the same device as the Variable.
+ self.assertEqual("/job:ps/device:CPU:0", save_inp.device)
+
def testResourceVariableReadOpsAddedDeterministically(self):
graph_defs = []
num_graphs = 10
@@ -369,8 +385,8 @@ class SaverTest(test.TestCase):
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
with self.test_session() as sess:
save = saver_module.Saver({"v0": v0}, write_version=ver)
- with self.assertRaisesRegexp(errors.NotFoundError,
- "Failed to find any matching files for"):
+ with self.assertRaisesRegexp(
+ ValueError, "The passed save_path is not a valid checkpoint:"):
save.restore(sess, "invalid path")
def testInt64(self):
@@ -3139,27 +3155,33 @@ class CheckpointableCompatibilityTests(test.TestCase):
errors.NotFoundError, "Key b not found in checkpoint"):
b_saver.restore(sess=sess, save_path=save_path)
- def testCheckpointNotFoundErrorRaised(self):
- # Restore does some tricky exception handling to figure out if it should
- # load an object-based checkpoint. Tests that the exception handling isn't
- # too broad.
- a = resource_variable_ops.ResourceVariable(1., name="a")
- saver = saver_module.Saver([a])
- with self.test_session() as sess:
- with self.assertRaisesRegexp(
- errors.NotFoundError,
- "Failed to find any matching files for path_which_does_not_exist"):
- saver.restore(sess=sess, save_path="path_which_does_not_exist")
- try:
- saver.restore(sess=sess, save_path="path_which_does_not_exist")
- except errors.NotFoundError:
- # Make sure we don't have a confusing "During handling of the above
- # exception" block in Python 3.
- # pylint: disable=no-value-for-parameter
- exception_string = "\n".join(
- traceback.format_exception(*sys.exc_info()))
- # pylint: enable=no-value-for-parameter
- self.assertNotIn("NewCheckpointReader", exception_string)
+ with self.assertRaises(errors.NotFoundError) as cs:
+ b_saver.restore(sess=sess, save_path=save_path)
+
+ # Make sure we don't have a confusing "During handling of the above
+ # exception" block in Python 3.
+ self.assertNotIn("NewCheckpointReader", cs.exception.message)
+
+ def testGraphChangedForRestoreErrorRaised(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+ with ops_lib.Graph().as_default() as g:
+ a = variables.Variable(1., name="a")
+ a_saver = saver_module.Saver([a])
+
+ with self.test_session(graph=g) as sess:
+ sess.run(a.initializer)
+ save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
+
+ with ops_lib.Graph().as_default() as g:
+ a = variables.Variable([1.], name="a")
+ a_saver = saver_module.Saver([a])
+ with self.test_session(graph=g) as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "a mismatch between the current graph and the graph"):
+ a_saver.restore(sess=sess, save_path=save_path)
def testLoadFromObjectBasedGraph(self):
checkpoint_directory = self.get_temp_dir()