aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework/python/ops/variables.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/framework/python/ops/variables.py')
-rw-r--r--tensorflow/contrib/framework/python/ops/variables.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py
index a9d47ac9b9..0754c3e0e3 100644
--- a/tensorflow/contrib/framework/python/ops/variables.py
+++ b/tensorflow/contrib/framework/python/ops/variables.py
@@ -25,6 +25,7 @@ import re
from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope
from tensorflow.contrib.framework.python.ops import gen_variable_ops
from tensorflow.contrib.util import loader
+from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
@@ -684,7 +685,8 @@ def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False,
'Variable %s missing in checkpoint %s', var, model_path)
var_list = available_vars
if var_list:
- saver = tf_saver.Saver(var_list, reshape=reshape_variables)
+ saver = tf_saver.Saver(var_list, reshape=reshape_variables,
+ write_version=saver_pb2.SaverDef.V1)
def callback(session):
saver.restore(session, model_path)
return callback