diff options
Diffstat (limited to 'tensorflow/python/ops/metrics_impl.py')
-rw-r--r-- | tensorflow/python/ops/metrics_impl.py | 41 |
1 files changed, 21 insertions, 20 deletions
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 3aedeb6acd..9461a01515 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -34,7 +34,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export @@ -57,7 +57,8 @@ def metric_variable(shape, dtype, validate_shape=True, name=None): Furthermore, the final answer should be computed once instead of in every replica/tower. Both of these are accomplished by running the computation of the final result value inside - `tf.contrib.distribute.get_tower_context().merge_call(fn)`. + `tf.contrib.distribution_strategy_context.get_tower_context( + ).merge_call(fn)`. Inside the `merge_call()`, ops are only added to the graph once and access to a tower-local variable in a computation returns the sum across all replicas/towers. @@ -373,7 +374,7 @@ def mean(values, ops.add_to_collections(metrics_collections, mean_t) return mean_t - mean_t = distribute_lib.get_tower_context().merge_call( + mean_t = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, total, count) update_op = _safe_div(update_total_op, update_count_op, 'update_op') @@ -618,7 +619,7 @@ def _aggregate_variable(v, collections): ops.add_to_collections(collections, value) return value - return distribute_lib.get_tower_context().merge_call(f, v) + return distribution_strategy_context.get_tower_context().merge_call(f, v) @tf_export('metrics.auc') @@ -813,7 +814,7 @@ def auc(labels, ops.add_to_collections(metrics_collections, auc_value) return auc_value - auc_value = distribute_lib.get_tower_context().merge_call( + auc_value = distribution_strategy_context.get_tower_context().merge_call( aggregate_auc, values) update_op = compute_auc(update_ops['tp'], update_ops['fn'], update_ops['tn'], update_ops['fp'], 'update_op') @@ -1053,8 +1054,8 @@ def mean_per_class_accuracy(labels, ops.add_to_collections(metrics_collections, mean_accuracy_v) return mean_accuracy_v - mean_accuracy_v = distribute_lib.get_tower_context().merge_call( - aggregate_mean_accuracy, count, total) + mean_accuracy_v = distribution_strategy_context.get_tower_context( + ).merge_call(aggregate_mean_accuracy, count, total) update_op = _safe_div(update_count_op, update_total_op, name='update_op') if updates_collections: @@ -1160,7 +1161,7 @@ def mean_iou(labels, ops.add_to_collections(metrics_collections, mean_iou_v) return mean_iou_v - mean_iou_v = distribute_lib.get_tower_context().merge_call( + mean_iou_v = distribution_strategy_context.get_tower_context().merge_call( mean_iou_across_towers, total_cm) if updates_collections: @@ -1376,7 +1377,7 @@ def mean_tensor(values, ops.add_to_collections(metrics_collections, mean_t) return mean_t - mean_t = distribute_lib.get_tower_context().merge_call( + mean_t = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, total, count) update_op = _safe_div(update_total_op, update_count_op, 'update_op') @@ -2008,7 +2009,7 @@ def precision(labels, ops.add_to_collections(metrics_collections, p) return p - p = distribute_lib.get_tower_context().merge_call( + p = distribution_strategy_context.get_tower_context().merge_call( once_across_towers, true_p, false_p) update_op = compute_precision(true_positives_update_op, @@ -2092,7 +2093,7 @@ def precision_at_thresholds(labels, ops.add_to_collections(metrics_collections, prec) return prec - prec = distribute_lib.get_tower_context().merge_call( + prec = distribution_strategy_context.get_tower_context().merge_call( precision_across_towers, values) update_op = compute_precision(update_ops['tp'], update_ops['fp'], @@ -2188,7 +2189,7 @@ def recall(labels, ops.add_to_collections(metrics_collections, rec) return rec - rec = distribute_lib.get_tower_context().merge_call( + rec = distribution_strategy_context.get_tower_context().merge_call( once_across_towers, true_p, false_n) update_op = compute_recall(true_positives_update_op, @@ -2627,7 +2628,7 @@ def recall_at_top_k(labels, ops.add_to_collections(metrics_collections, metric) return metric - metric = distribute_lib.get_tower_context().merge_call( + metric = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, tp, fn) update = math_ops.div( @@ -2708,7 +2709,7 @@ def recall_at_thresholds(labels, ops.add_to_collections(metrics_collections, rec) return rec - rec = distribute_lib.get_tower_context().merge_call( + rec = distribution_strategy_context.get_tower_context().merge_call( recall_across_towers, values) update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') @@ -2783,7 +2784,7 @@ def root_mean_squared_error(labels, ops.add_to_collections(metrics_collections, rmse) return rmse - rmse = distribute_lib.get_tower_context().merge_call( + rmse = distribution_strategy_context.get_tower_context().merge_call( once_across_towers, mse) update_rmse_op = math_ops.sqrt(update_mse_op) @@ -2886,7 +2887,7 @@ def sensitivity_at_specificity(labels, ops.add_to_collections(metrics_collections, sensitivity) return sensitivity - sensitivity = distribute_lib.get_tower_context().merge_call( + sensitivity = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, values) update_op = compute_sensitivity_at_specificity( @@ -3162,8 +3163,8 @@ def _streaming_sparse_average_precision_at_top_k(labels, ops.add_to_collections(metrics_collections, mean_average_precision) return mean_average_precision - mean_average_precision = distribute_lib.get_tower_context().merge_call( - aggregate_across_towers, total_var, max_var) + mean_average_precision = distribution_strategy_context.get_tower_context( + ).merge_call(aggregate_across_towers, total_var, max_var) update = _safe_scalar_div(total_update, max_update, name=scope) if updates_collections: @@ -3448,7 +3449,7 @@ def precision_at_top_k(labels, ops.add_to_collections(metrics_collections, metric) return metric - metric = distribute_lib.get_tower_context().merge_call( + metric = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, tp, fp) update = math_ops.div( @@ -3687,7 +3688,7 @@ def specificity_at_sensitivity(labels, ops.add_to_collections(metrics_collections, specificity) return specificity - specificity = distribute_lib.get_tower_context().merge_call( + specificity = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, values) update_op = compute_specificity_at_sensitivity( |