aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpoint_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpoint_utils.py')
-rw-r--r--tensorflow/python/training/checkpoint_utils.py52
1 files changed, 44 insertions, 8 deletions
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index 5b372e82b3..883f4fd910 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import tf_export
@@ -179,6 +180,16 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
ValueError: If missing variables in current graph.
"""
+ if distribute_lib.get_cross_tower_context():
+ _init_from_checkpoint(None, ckpt_dir_or_file, assignment_map)
+ else:
+ distribute_lib.get_tower_context().merge_call(
+ _init_from_checkpoint, ckpt_dir_or_file, assignment_map)
+
+
+def _init_from_checkpoint(_, ckpt_dir_or_file, assignment_map):
+ """See `init_from_checkpoint` for documentation."""
+
ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file)
reader = load_checkpoint(ckpt_dir_or_file)
variable_map = reader.get_variable_to_shape_map()
@@ -187,10 +198,9 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
var = None
# Check if this is Variable object or list of Variable objects (in case of
# partitioned variables).
- is_var = lambda x: isinstance(x, variables.Variable)
- if is_var(current_var_or_name) or (
+ if _is_variable(current_var_or_name) or (
isinstance(current_var_or_name, list)
- and all(is_var(v) for v in current_var_or_name)):
+ and all(_is_variable(v) for v in current_var_or_name)):
var = current_var_or_name
else:
store_vars = vs._get_default_variable_store()._vars # pylint:disable=protected-access
@@ -205,7 +215,7 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
raise ValueError("Tensor %s is not found in %s checkpoint %s" % (
tensor_name_in_ckpt, ckpt_dir_or_file, variable_map
))
- if is_var(var):
+ if _is_variable(var):
# Additional at-call-time checks.
if not var.get_shape().is_compatible_with(
variable_map[tensor_name_in_ckpt]):
@@ -297,13 +307,34 @@ def _set_checkpoint_initializer(variable,
with ops.device(variable.device), ops.device("/cpu:0"):
restore_op = io_ops.restore_v2(
ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
- if isinstance(variable, resource_variable_ops.ResourceVariable):
+
+ # TODO(priyag, allenl): Use `SaveableObject.restore` instead here.
+ if resource_variable_ops.is_resource_variable(variable):
init_op = variable.assign(restore_op, read_value=False)
else:
init_op = state_ops.assign(variable, restore_op)
- variable._initializer_op = init_op # pylint:disable=protected-access
- restore_op.set_shape(variable.shape)
- variable._initial_value = restore_op # pylint:disable=protected-access
+
+ # pylint:disable=protected-access
+ # We need special handling for `DistributedVariable`s as they contain
+ # mutliple actual variables. `assign` on a `DistributedVariable` returns a
+ # combined `init_op` which contains initializers for all the contained
+ # variables. We then set each underlying variable's `_initializer_op` using
+ # the corresponding `init_op`.
+ # TODO(priyag): Use `isinstance` checks when `DistributedVariable` class
+ # moves out of contrib.
+ if any(base.__name__ == "DistributedVariable"
+ for base in variable.__class__.__bases__):
+ assert distribute_lib.get_cross_tower_context()
+ assert hasattr(variable, "_index")
+ for (d, v) in six.iteritems(variable._index):
+ v._initializer_op = init_op._index[d]
+ restore_op.set_shape(v.shape)
+ v._initial_value = restore_op
+ else:
+ variable._initializer_op = init_op
+ restore_op.set_shape(variable.shape)
+ variable._initial_value = restore_op
+ # pylint:enable=protected-access
def _set_variable_or_list_initializer(variable_or_list, ckpt_file,
@@ -337,6 +368,11 @@ def _set_variable_or_list_initializer(variable_or_list, ckpt_file,
_set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "")
+def _is_variable(x):
+ return (isinstance(x, variables.Variable) or
+ resource_variable_ops.is_resource_variable(x))
+
+
def _collect_partitioned_variable(name, all_vars):
"""Returns list of `tf.Variable` that comprise the partitioned variable."""
if name + "/part_0" in all_vars: