diff options
Diffstat (limited to 'tensorflow/python/training/checkpoint_utils.py')
-rw-r--r-- | tensorflow/python/training/checkpoint_utils.py | 52 |
1 files changed, 44 insertions, 8 deletions
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index 5b372e82b3..883f4fd910 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver from tensorflow.python.util.tf_export import tf_export @@ -179,6 +180,16 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): tf.errors.OpError: If missing checkpoints or tensors in checkpoints. ValueError: If missing variables in current graph. """ + if distribute_lib.get_cross_tower_context(): + _init_from_checkpoint(None, ckpt_dir_or_file, assignment_map) + else: + distribute_lib.get_tower_context().merge_call( + _init_from_checkpoint, ckpt_dir_or_file, assignment_map) + + +def _init_from_checkpoint(_, ckpt_dir_or_file, assignment_map): + """See `init_from_checkpoint` for documentation.""" + ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file) reader = load_checkpoint(ckpt_dir_or_file) variable_map = reader.get_variable_to_shape_map() @@ -187,10 +198,9 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): var = None # Check if this is Variable object or list of Variable objects (in case of # partitioned variables). - is_var = lambda x: isinstance(x, variables.Variable) - if is_var(current_var_or_name) or ( + if _is_variable(current_var_or_name) or ( isinstance(current_var_or_name, list) - and all(is_var(v) for v in current_var_or_name)): + and all(_is_variable(v) for v in current_var_or_name)): var = current_var_or_name else: store_vars = vs._get_default_variable_store()._vars # pylint:disable=protected-access @@ -205,7 +215,7 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): raise ValueError("Tensor %s is not found in %s checkpoint %s" % ( tensor_name_in_ckpt, ckpt_dir_or_file, variable_map )) - if is_var(var): + if _is_variable(var): # Additional at-call-time checks. if not var.get_shape().is_compatible_with( variable_map[tensor_name_in_ckpt]): @@ -297,13 +307,34 @@ def _set_checkpoint_initializer(variable, with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2( ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] - if isinstance(variable, resource_variable_ops.ResourceVariable): + + # TODO(priyag, allenl): Use `SaveableObject.restore` instead here. + if resource_variable_ops.is_resource_variable(variable): init_op = variable.assign(restore_op, read_value=False) else: init_op = state_ops.assign(variable, restore_op) - variable._initializer_op = init_op # pylint:disable=protected-access - restore_op.set_shape(variable.shape) - variable._initial_value = restore_op # pylint:disable=protected-access + + # pylint:disable=protected-access + # We need special handling for `DistributedVariable`s as they contain + # mutliple actual variables. `assign` on a `DistributedVariable` returns a + # combined `init_op` which contains initializers for all the contained + # variables. We then set each underlying variable's `_initializer_op` using + # the corresponding `init_op`. + # TODO(priyag): Use `isinstance` checks when `DistributedVariable` class + # moves out of contrib. + if any(base.__name__ == "DistributedVariable" + for base in variable.__class__.__bases__): + assert distribute_lib.get_cross_tower_context() + assert hasattr(variable, "_index") + for (d, v) in six.iteritems(variable._index): + v._initializer_op = init_op._index[d] + restore_op.set_shape(v.shape) + v._initial_value = restore_op + else: + variable._initializer_op = init_op + restore_op.set_shape(variable.shape) + variable._initial_value = restore_op + # pylint:enable=protected-access def _set_variable_or_list_initializer(variable_or_list, ckpt_file, @@ -337,6 +368,11 @@ def _set_variable_or_list_initializer(variable_or_list, ckpt_file, _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "") +def _is_variable(x): + return (isinstance(x, variables.Variable) or + resource_variable_ops.is_resource_variable(x)) + + def _collect_partitioned_variable(name, all_vars): """Returns list of `tf.Variable` that comprise the partitioned variable.""" if name + "/part_0" in all_vars: |