diff options
author | 2018-01-29 12:42:09 -0800 | |
---|---|---|
committer | 2018-01-29 12:50:05 -0800 | |
commit | 0a34211774c8c45c8f290e6c51335b99873dcbb9 (patch) | |
tree | ef56acfeaf3287bde34dcd5af401b5612e424562 /tensorflow/python/training/checkpoint_utils.py | |
parent | b78134d0d5ea7f17468bea9276c35fab4a9cb388 (diff) |
Remove the trailing '/' in the tensor name when loading checkpoints
PiperOrigin-RevId: 183709590
Diffstat (limited to 'tensorflow/python/training/checkpoint_utils.py')
-rw-r--r-- | tensorflow/python/training/checkpoint_utils.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index b5d3e78797..63235a1454 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -242,6 +242,9 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): full_tensor_name = full_tensor_name[1:] if tensor_name_in_ckpt != "/": full_tensor_name = tensor_name_in_ckpt + full_tensor_name + # Remove trailing '/', if any, in the full_tensor_name + if full_tensor_name.endswith("/"): + full_tensor_name = full_tensor_name[:-1] if full_tensor_name not in variable_map: raise ValueError( "Tensor %s (%s in %s) is not found in %s checkpoint" % ( |