diff options
Diffstat (limited to 'tensorflow/python/ops/metrics_impl.py')
-rw-r--r-- | tensorflow/python/ops/metrics_impl.py | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 5eab12c41d..3aedeb6acd 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -73,15 +73,16 @@ def metric_variable(shape, dtype, validate_shape=True, name=None): A (non-trainable) variable initialized to zero, or if inside a `DistributionStrategy` scope a tower-local variable container. """ - with distribute_lib.get_tower_context().tower_local_var_scope('sum'): - # Note that "tower local" implies trainable=False. - return variable_scope.variable( - lambda: array_ops.zeros(shape, dtype), - collections=[ - ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES - ], - validate_shape=validate_shape, - name=name) + # Note that synchronization "ON_READ" implies trainable=False. + return variable_scope.variable( + lambda: array_ops.zeros(shape, dtype), + collections=[ + ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES + ], + validate_shape=validate_shape, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM, + name=name) def _remove_squeezable_dimensions(predictions, labels, weights): |