From c1093a3757224257fed0f7a1959d0fc99d5c757f Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Tue, 9 Oct 2018 15:02:51 -0700 Subject: In TPUMirroredVariable, when setting _initializer_op and _initial_value attributes, set the attributes of all the contained variables. This fixes a bug that tf.train.init_from_checkpoint doesn't overwrite the initialization values correctly for TPUMirroredVariable. PiperOrigin-RevId: 216429476 --- tensorflow/contrib/distribute/python/values.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 0dd78ba185..472cb4230c 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -475,6 +475,11 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): self._aggregation = aggregation # Needed for GradientTape self._trainable = self._primary_var.trainable + # Typically like `DistributedVariable`, a `TPUMirroredVariable`'s + # initializer is composed of the initializers of the components variables. + # However, in some cases, such as when restoring from a checkpoint, we may + # set the _initializer_op property on the entire `TPUMirroredVariable`. + self._initializer_op = None def _get(self, device=None): """Returns the value for the current device or raises a ValueError.""" @@ -704,8 +709,12 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): @property def initializer(self): - return control_flow_ops.group( - [v.initializer for v in nest.flatten(self._index)]) + if self._initializer_op: + init_op = self._initializer_op + else: + init_op = control_flow_ops.group( + [v.initializer for v in self._index.values()]) + return init_op @property def graph(self): -- cgit v1.2.3