aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpoint_utils.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-29 12:42:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-29 12:50:05 -0800
commit0a34211774c8c45c8f290e6c51335b99873dcbb9 (patch)
treeef56acfeaf3287bde34dcd5af401b5612e424562 /tensorflow/python/training/checkpoint_utils.py
parentb78134d0d5ea7f17468bea9276c35fab4a9cb388 (diff)
Remove the trailing '/' in the tensor name when loading checkpoints
PiperOrigin-RevId: 183709590
Diffstat (limited to 'tensorflow/python/training/checkpoint_utils.py')
-rw-r--r--tensorflow/python/training/checkpoint_utils.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index b5d3e78797..63235a1454 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -242,6 +242,9 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
full_tensor_name = full_tensor_name[1:]
if tensor_name_in_ckpt != "/":
full_tensor_name = tensor_name_in_ckpt + full_tensor_name
+ # Remove trailing '/', if any, in the full_tensor_name
+ if full_tensor_name.endswith("/"):
+ full_tensor_name = full_tensor_name[:-1]
if full_tensor_name not in variable_map:
raise ValueError(
"Tensor %s (%s in %s) is not found in %s checkpoint" % (