diff options
author | Ruoxin Sang <rxsang@google.com> | 2018-10-09 15:02:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 15:06:41 -0700 |
commit | c1093a3757224257fed0f7a1959d0fc99d5c757f (patch) | |
tree | 96c32fa021a1ee16c29173abb382404cc872bfda | |
parent | a6fcb9d3d81e9207650eda1c899051ccbb97dec7 (diff) |
In TPUMirroredVariable, when setting _initializer_op and _initial_value attributes, set the attributes of all the contained variables. This fixes a bug that tf.train.init_from_checkpoint doesn't overwrite the initialization values correctly for TPUMirroredVariable.
PiperOrigin-RevId: 216429476
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 0dd78ba185..472cb4230c 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -475,6 +475,11 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): self._aggregation = aggregation # Needed for GradientTape self._trainable = self._primary_var.trainable + # Typically like `DistributedVariable`, a `TPUMirroredVariable`'s + # initializer is composed of the initializers of the components variables. + # However, in some cases, such as when restoring from a checkpoint, we may + # set the _initializer_op property on the entire `TPUMirroredVariable`. + self._initializer_op = None def _get(self, device=None): """Returns the value for the current device or raises a ValueError.""" @@ -704,8 +709,12 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): @property def initializer(self): - return control_flow_ops.group( - [v.initializer for v in nest.flatten(self._index)]) + if self._initializer_op: + init_op = self._initializer_op + else: + init_op = control_flow_ops.group( + [v.initializer for v in self._index.values()]) + return init_op @property def graph(self): |