diff options
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 0dd78ba185..472cb4230c 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -475,6 +475,11 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): self._aggregation = aggregation # Needed for GradientTape self._trainable = self._primary_var.trainable + # Typically like `DistributedVariable`, a `TPUMirroredVariable`'s + # initializer is composed of the initializers of the components variables. + # However, in some cases, such as when restoring from a checkpoint, we may + # set the _initializer_op property on the entire `TPUMirroredVariable`. + self._initializer_op = None def _get(self, device=None): """Returns the value for the current device or raises a ValueError.""" @@ -704,8 +709,12 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): @property def initializer(self): - return control_flow_ops.group( - [v.initializer for v in nest.flatten(self._index)]) + if self._initializer_op: + init_op = self._initializer_op + else: + init_op = control_flow_ops.group( + [v.initializer for v in self._index.values()]) + return init_op @property def graph(self): |