diff options
Diffstat (limited to 'tensorflow/contrib/framework/python/framework/checkpoint_utils.py')
-rw-r--r-- | tensorflow/contrib/framework/python/framework/checkpoint_utils.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py index 4cd3efafa0..5d078236ac 100644 --- a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py +++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py @@ -280,10 +280,11 @@ def init_from_checkpoint(checkpoint_dir, assignment_map): for var_name in scope_variables: # Lookup name with specified prefix and suffix from current variable. # If tensor_name given is '/' (root), don't use it for full name. + full_tensor_name = var_name[len(scopes):] + if current_var_or_name != "/": + full_tensor_name = full_tensor_name[1:] if tensor_name_in_ckpt != "/": - full_tensor_name = tensor_name_in_ckpt + var_name[len(scopes) + 1:] - else: - full_tensor_name = var_name[len(scopes) + 1:] + full_tensor_name = tensor_name_in_ckpt + full_tensor_name if full_tensor_name not in variable_map: raise ValueError( "Tensor %s (%s in %s) is not found in %s checkpoint" % ( |