aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework/python/framework/checkpoint_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/framework/python/framework/checkpoint_utils.py')
-rw-r--r--tensorflow/contrib/framework/python/framework/checkpoint_utils.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py
index 4cd3efafa0..5d078236ac 100644
--- a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py
+++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py
@@ -280,10 +280,11 @@ def init_from_checkpoint(checkpoint_dir, assignment_map):
for var_name in scope_variables:
# Lookup name with specified prefix and suffix from current variable.
# If tensor_name given is '/' (root), don't use it for full name.
+ full_tensor_name = var_name[len(scopes):]
+ if current_var_or_name != "/":
+ full_tensor_name = full_tensor_name[1:]
if tensor_name_in_ckpt != "/":
- full_tensor_name = tensor_name_in_ckpt + var_name[len(scopes) + 1:]
- else:
- full_tensor_name = var_name[len(scopes) + 1:]
+ full_tensor_name = tensor_name_in_ckpt + full_tensor_name
if full_tensor_name not in variable_map:
raise ValueError(
"Tensor %s (%s in %s) is not found in %s checkpoint" % (