aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/metrics_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/metrics_impl.py')
-rw-r--r--tensorflow/python/ops/metrics_impl.py41
1 files changed, 21 insertions, 20 deletions
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 3aedeb6acd..9461a01515 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -34,7 +34,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -57,7 +57,8 @@ def metric_variable(shape, dtype, validate_shape=True, name=None):
Furthermore, the final answer should be computed once instead of
in every replica/tower. Both of these are accomplished by
running the computation of the final result value inside
- `tf.contrib.distribute.get_tower_context().merge_call(fn)`.
+ `tf.contrib.distribution_strategy_context.get_tower_context(
+ ).merge_call(fn)`.
Inside the `merge_call()`, ops are only added to the graph once
and access to a tower-local variable in a computation returns
the sum across all replicas/towers.
@@ -373,7 +374,7 @@ def mean(values,
ops.add_to_collections(metrics_collections, mean_t)
return mean_t
- mean_t = distribute_lib.get_tower_context().merge_call(
+ mean_t = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, total, count)
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
@@ -618,7 +619,7 @@ def _aggregate_variable(v, collections):
ops.add_to_collections(collections, value)
return value
- return distribute_lib.get_tower_context().merge_call(f, v)
+ return distribution_strategy_context.get_tower_context().merge_call(f, v)
@tf_export('metrics.auc')
@@ -813,7 +814,7 @@ def auc(labels,
ops.add_to_collections(metrics_collections, auc_value)
return auc_value
- auc_value = distribute_lib.get_tower_context().merge_call(
+ auc_value = distribution_strategy_context.get_tower_context().merge_call(
aggregate_auc, values)
update_op = compute_auc(update_ops['tp'], update_ops['fn'],
update_ops['tn'], update_ops['fp'], 'update_op')
@@ -1053,8 +1054,8 @@ def mean_per_class_accuracy(labels,
ops.add_to_collections(metrics_collections, mean_accuracy_v)
return mean_accuracy_v
- mean_accuracy_v = distribute_lib.get_tower_context().merge_call(
- aggregate_mean_accuracy, count, total)
+ mean_accuracy_v = distribution_strategy_context.get_tower_context(
+ ).merge_call(aggregate_mean_accuracy, count, total)
update_op = _safe_div(update_count_op, update_total_op, name='update_op')
if updates_collections:
@@ -1160,7 +1161,7 @@ def mean_iou(labels,
ops.add_to_collections(metrics_collections, mean_iou_v)
return mean_iou_v
- mean_iou_v = distribute_lib.get_tower_context().merge_call(
+ mean_iou_v = distribution_strategy_context.get_tower_context().merge_call(
mean_iou_across_towers, total_cm)
if updates_collections:
@@ -1376,7 +1377,7 @@ def mean_tensor(values,
ops.add_to_collections(metrics_collections, mean_t)
return mean_t
- mean_t = distribute_lib.get_tower_context().merge_call(
+ mean_t = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, total, count)
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
@@ -2008,7 +2009,7 @@ def precision(labels,
ops.add_to_collections(metrics_collections, p)
return p
- p = distribute_lib.get_tower_context().merge_call(
+ p = distribution_strategy_context.get_tower_context().merge_call(
once_across_towers, true_p, false_p)
update_op = compute_precision(true_positives_update_op,
@@ -2092,7 +2093,7 @@ def precision_at_thresholds(labels,
ops.add_to_collections(metrics_collections, prec)
return prec
- prec = distribute_lib.get_tower_context().merge_call(
+ prec = distribution_strategy_context.get_tower_context().merge_call(
precision_across_towers, values)
update_op = compute_precision(update_ops['tp'], update_ops['fp'],
@@ -2188,7 +2189,7 @@ def recall(labels,
ops.add_to_collections(metrics_collections, rec)
return rec
- rec = distribute_lib.get_tower_context().merge_call(
+ rec = distribution_strategy_context.get_tower_context().merge_call(
once_across_towers, true_p, false_n)
update_op = compute_recall(true_positives_update_op,
@@ -2627,7 +2628,7 @@ def recall_at_top_k(labels,
ops.add_to_collections(metrics_collections, metric)
return metric
- metric = distribute_lib.get_tower_context().merge_call(
+ metric = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, tp, fn)
update = math_ops.div(
@@ -2708,7 +2709,7 @@ def recall_at_thresholds(labels,
ops.add_to_collections(metrics_collections, rec)
return rec
- rec = distribute_lib.get_tower_context().merge_call(
+ rec = distribution_strategy_context.get_tower_context().merge_call(
recall_across_towers, values)
update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
@@ -2783,7 +2784,7 @@ def root_mean_squared_error(labels,
ops.add_to_collections(metrics_collections, rmse)
return rmse
- rmse = distribute_lib.get_tower_context().merge_call(
+ rmse = distribution_strategy_context.get_tower_context().merge_call(
once_across_towers, mse)
update_rmse_op = math_ops.sqrt(update_mse_op)
@@ -2886,7 +2887,7 @@ def sensitivity_at_specificity(labels,
ops.add_to_collections(metrics_collections, sensitivity)
return sensitivity
- sensitivity = distribute_lib.get_tower_context().merge_call(
+ sensitivity = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, values)
update_op = compute_sensitivity_at_specificity(
@@ -3162,8 +3163,8 @@ def _streaming_sparse_average_precision_at_top_k(labels,
ops.add_to_collections(metrics_collections, mean_average_precision)
return mean_average_precision
- mean_average_precision = distribute_lib.get_tower_context().merge_call(
- aggregate_across_towers, total_var, max_var)
+ mean_average_precision = distribution_strategy_context.get_tower_context(
+ ).merge_call(aggregate_across_towers, total_var, max_var)
update = _safe_scalar_div(total_update, max_update, name=scope)
if updates_collections:
@@ -3448,7 +3449,7 @@ def precision_at_top_k(labels,
ops.add_to_collections(metrics_collections, metric)
return metric
- metric = distribute_lib.get_tower_context().merge_call(
+ metric = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, tp, fp)
update = math_ops.div(
@@ -3687,7 +3688,7 @@ def specificity_at_sensitivity(labels,
ops.add_to_collections(metrics_collections, specificity)
return specificity
- specificity = distribute_lib.get_tower_context().merge_call(
+ specificity = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, values)
update_op = compute_specificity_at_sensitivity(