aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-05-21 15:43:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-21 15:45:42 -0700
commitb28938c3672db9c23f84298160658787d0ccf69d (patch)
tree6d459c42999488532ca822b698ef86fde3e151ee
parent564fcf1e48224f518fa5b62bbc5e80f84270d0fb (diff)
Remove object-based checkpointing probes from Python 3 tf.train.Saver "name not found" stack traces
PiperOrigin-RevId: 197473101
-rw-r--r--tensorflow/python/training/saver.py6
-rw-r--r--tensorflow/python/training/saver_test.py12
2 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 294adbb74b..fc89f88063 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -1754,13 +1754,17 @@ class Saver(object):
exception_type, exception_value, exception_traceback = sys.exc_info()
# The checkpoint would not be loaded successfully as is. Try to parse it
# as an object-based checkpoint.
+ should_reraise = False
try:
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
object_graph_string = reader.get_tensor(
checkpointable.OBJECT_GRAPH_PROTO_KEY)
except errors.NotFoundError:
# This is not an object-based checkpoint, or the checkpoint doesn't
- # exist. Re-raise the original exception.
+ # exist. Re-raise the original exception, but do it outside the except
+ # block so the object graph lookup isn't included in the stack trace.
+ should_reraise = True
+ if should_reraise:
six.reraise(exception_type, exception_value, exception_traceback)
del exception_traceback # avoid reference cycles
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index e7f7addf81..f1991093e0 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -24,8 +24,10 @@ import math
import os
import random
import shutil
+import sys
import tempfile
import time
+import traceback
import numpy as np
import six
@@ -3093,6 +3095,16 @@ class CheckpointableCompatibilityTests(test.TestCase):
errors.NotFoundError,
"Failed to find any matching files for path_which_does_not_exist"):
saver.restore(sess=sess, save_path="path_which_does_not_exist")
+ try:
+ saver.restore(sess=sess, save_path="path_which_does_not_exist")
+ except errors.NotFoundError:
+ # Make sure we don't have a confusing "During handling of the above
+ # exception" block in Python 3.
+ # pylint: disable=no-value-for-parameter
+ exception_string = "\n".join(
+ traceback.format_exception(*sys.exc_info()))
+ # pylint: enable=no-value-for-parameter
+ self.assertNotIn("NewCheckpointReader", exception_string)
def testLoadFromObjectBasedGraph(self):
checkpoint_directory = self.get_temp_dir()