aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/metrics_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/metrics_impl.py')
-rw-r--r--tensorflow/python/ops/metrics_impl.py19
1 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 5eab12c41d..3aedeb6acd 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -73,15 +73,16 @@ def metric_variable(shape, dtype, validate_shape=True, name=None):
A (non-trainable) variable initialized to zero, or if inside a
`DistributionStrategy` scope a tower-local variable container.
"""
- with distribute_lib.get_tower_context().tower_local_var_scope('sum'):
- # Note that "tower local" implies trainable=False.
- return variable_scope.variable(
- lambda: array_ops.zeros(shape, dtype),
- collections=[
- ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
- ],
- validate_shape=validate_shape,
- name=name)
+ # Note that synchronization "ON_READ" implies trainable=False.
+ return variable_scope.variable(
+ lambda: array_ops.zeros(shape, dtype),
+ collections=[
+ ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
+ ],
+ validate_shape=validate_shape,
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM,
+ name=name)
def _remove_squeezable_dimensions(predictions, labels, weights):