aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/distribute/python/values.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 0dd78ba185..472cb4230c 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -475,6 +475,11 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase):
self._aggregation = aggregation
# Needed for GradientTape
self._trainable = self._primary_var.trainable
+ # Typically like `DistributedVariable`, a `TPUMirroredVariable`'s
+ # initializer is composed of the initializers of the components variables.
+ # However, in some cases, such as when restoring from a checkpoint, we may
+ # set the _initializer_op property on the entire `TPUMirroredVariable`.
+ self._initializer_op = None
def _get(self, device=None):
"""Returns the value for the current device or raises a ValueError."""
@@ -704,8 +709,12 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase):
@property
def initializer(self):
- return control_flow_ops.group(
- [v.initializer for v in nest.flatten(self._index)])
+ if self._initializer_op:
+ init_op = self._initializer_op
+ else:
+ init_op = control_flow_ops.group(
+ [v.initializer for v in self._index.values()])
+ return init_op
@property
def graph(self):