diff options
Diffstat (limited to 'tensorflow/python/training/moving_averages.py')
-rw-r--r-- | tensorflow/python/training/moving_averages.py | 55 |
1 files changed, 28 insertions, 27 deletions
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 4b91d1e963..177a7ddfa5 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -363,10 +363,12 @@ class ExponentialMovingAverage(object): `GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to `tf.global_variables()`. - Returns an op that updates all shadow variables as described above. + Returns an op that updates all shadow variables from the current value of + their associated variables. - Note that `apply()` can be called multiple times with different lists of - variables. + Note that `apply()` can be called multiple times. When eager execution is + enabled each call to apply will update the variables once, so this needs to + be called in a loop. Args: var_list: A list of Variable or Tensor objects. The variables @@ -389,31 +391,30 @@ class ExponentialMovingAverage(object): dtypes.float64]: raise TypeError("The variables must be half, float, or double: %s" % var.name) - if var in self._averages: - raise ValueError("Moving average already computed for: %s" % var.name) - # For variables: to lower communication bandwidth across devices we keep - # the moving averages on the same device as the variables. For other - # tensors, we rely on the existing device allocation mechanism. - with ops.init_scope(): - if isinstance(var, variables.Variable): - avg = slot_creator.create_slot(var, - var.initialized_value(), - self.name, - colocate_with_primary=True) - # NOTE(mrry): We only add `tf.Variable` objects to the - # `MOVING_AVERAGE_VARIABLES` collection. - ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var) - else: - avg = slot_creator.create_zeros_slot( - var, - self.name, - colocate_with_primary=(var.op.type in ["Variable", - "VariableV2", - "VarHandleOp"])) - if self._zero_debias: - zero_debias_true.add(avg) - self._averages[var] = avg + if var not in self._averages: + # For variables: to lower communication bandwidth across devices we keep + # the moving averages on the same device as the variables. For other + # tensors, we rely on the existing device allocation mechanism. + with ops.init_scope(): + if isinstance(var, variables.Variable): + avg = slot_creator.create_slot(var, + var.initialized_value(), + self.name, + colocate_with_primary=True) + # NOTE(mrry): We only add `tf.Variable` objects to the + # `MOVING_AVERAGE_VARIABLES` collection. + ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var) + else: + avg = slot_creator.create_zeros_slot( + var, + self.name, + colocate_with_primary=(var.op.type in ["Variable", + "VariableV2", + "VarHandleOp"])) + if self._zero_debias: + zero_debias_true.add(avg) + self._averages[var] = avg with ops.name_scope(self.name) as scope: decay = ops.convert_to_tensor(self._decay, name="decay") |