aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpoint_utils.py
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-07-26 13:11:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 13:16:27 -0700
commit53449eb635c4abce62c5f00d05fad1b1c8d1d9ab (patch)
treec8585d643b81992753a61a30dc126922c4d33666 /tensorflow/python/training/checkpoint_utils.py
parentb5ee7b1bc75928825991957c189ded0b970a1081 (diff)
Restore tower local variables correctly in init_from_checkpoint.
PiperOrigin-RevId: 206208637
Diffstat (limited to 'tensorflow/python/training/checkpoint_utils.py')
-rw-r--r--tensorflow/python/training/checkpoint_utils.py25
1 files changed, 6 insertions, 19 deletions
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index 883f4fd910..799729b017 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -311,29 +311,16 @@ def _set_checkpoint_initializer(variable,
# TODO(priyag, allenl): Use `SaveableObject.restore` instead here.
if resource_variable_ops.is_resource_variable(variable):
init_op = variable.assign(restore_op, read_value=False)
+ # TODO(priyag): Remove this when using `SaveableObject.restore` instead.
+ if hasattr(init_op, "_index"):
+ init_op = distribute_lib.get_distribution_strategy().group(init_op)
else:
init_op = state_ops.assign(variable, restore_op)
# pylint:disable=protected-access
- # We need special handling for `DistributedVariable`s as they contain
- # mutliple actual variables. `assign` on a `DistributedVariable` returns a
- # combined `init_op` which contains initializers for all the contained
- # variables. We then set each underlying variable's `_initializer_op` using
- # the corresponding `init_op`.
- # TODO(priyag): Use `isinstance` checks when `DistributedVariable` class
- # moves out of contrib.
- if any(base.__name__ == "DistributedVariable"
- for base in variable.__class__.__bases__):
- assert distribute_lib.get_cross_tower_context()
- assert hasattr(variable, "_index")
- for (d, v) in six.iteritems(variable._index):
- v._initializer_op = init_op._index[d]
- restore_op.set_shape(v.shape)
- v._initial_value = restore_op
- else:
- variable._initializer_op = init_op
- restore_op.set_shape(variable.shape)
- variable._initial_value = restore_op
+ variable._initializer_op = init_op
+ restore_op.set_shape(variable.shape)
+ variable._initial_value = restore_op
# pylint:enable=protected-access