diff options
author | 2018-07-26 13:11:51 -0700 | |
---|---|---|
committer | 2018-07-26 13:16:27 -0700 | |
commit | 53449eb635c4abce62c5f00d05fad1b1c8d1d9ab (patch) | |
tree | c8585d643b81992753a61a30dc126922c4d33666 /tensorflow/python/training/checkpoint_utils.py | |
parent | b5ee7b1bc75928825991957c189ded0b970a1081 (diff) |
Restore tower local variables correctly in init_from_checkpoint.
PiperOrigin-RevId: 206208637
Diffstat (limited to 'tensorflow/python/training/checkpoint_utils.py')
-rw-r--r-- | tensorflow/python/training/checkpoint_utils.py | 25 |
1 files changed, 6 insertions, 19 deletions
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index 883f4fd910..799729b017 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -311,29 +311,16 @@ def _set_checkpoint_initializer(variable, # TODO(priyag, allenl): Use `SaveableObject.restore` instead here. if resource_variable_ops.is_resource_variable(variable): init_op = variable.assign(restore_op, read_value=False) + # TODO(priyag): Remove this when using `SaveableObject.restore` instead. + if hasattr(init_op, "_index"): + init_op = distribute_lib.get_distribution_strategy().group(init_op) else: init_op = state_ops.assign(variable, restore_op) # pylint:disable=protected-access - # We need special handling for `DistributedVariable`s as they contain - # mutliple actual variables. `assign` on a `DistributedVariable` returns a - # combined `init_op` which contains initializers for all the contained - # variables. We then set each underlying variable's `_initializer_op` using - # the corresponding `init_op`. - # TODO(priyag): Use `isinstance` checks when `DistributedVariable` class - # moves out of contrib. - if any(base.__name__ == "DistributedVariable" - for base in variable.__class__.__bases__): - assert distribute_lib.get_cross_tower_context() - assert hasattr(variable, "_index") - for (d, v) in six.iteritems(variable._index): - v._initializer_op = init_op._index[d] - restore_op.set_shape(v.shape) - v._initial_value = restore_op - else: - variable._initializer_op = init_op - restore_op.set_shape(variable.shape) - variable._initial_value = restore_op + variable._initializer_op = init_op + restore_op.set_shape(variable.shape) + variable._initial_value = restore_op # pylint:enable=protected-access |