diff options
author | 2018-08-16 15:50:56 -0700 | |
---|---|---|
committer | 2018-08-16 15:54:47 -0700 | |
commit | 51f2c645d97c9ee1d9151bc3c959f73305daac09 (patch) | |
tree | ef20c4d17237f87e8e96020bd13a1c934e1cbeca /tensorflow | |
parent | 985cbd7ff220795abc4a50839144c177924d469c (diff) |
Make tf.metrics work with TPU Strategy.
PiperOrigin-RevId: 209064406
Diffstat (limited to 'tensorflow')
5 files changed, 121 insertions, 113 deletions
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 2f3d6bdd3f..8163494c8e 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -68,6 +68,8 @@ def _regression_dataset_fn(): "predictions": [1., .75, .25, 0.]}).repeat() +# TODO(priyag): Add TPU Strategy to this once metrics aggregate correctly using +# TowerLocalVariables on TPUs. Submit http://cl/208914352. def all_combinations(): return combinations.combine( distribution=[combinations.default_strategy, diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 72d1c6b7dd..edd5c6d17a 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -383,12 +383,21 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs + # We capture the control_flow_context at this point, before we run `fn` + # inside a while_loop. This is useful in cases where we might need to exit + # these contexts and get back to the outer context to do some things, for + # e.g. create an op which should be evaluated only once at the end of the + # loop on the host. One such usage is in creating metrics' value op. + self._outer_control_flow_context = ( + ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access + cond = lambda i, *args: i < iterations i = constant_op.constant(0) loop_result = control_flow_ops.while_loop( cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) + del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(loop_result) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 86833ad851..68561b5bbf 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -88,13 +88,22 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs + # We capture the control_flow_context at this point, before we run `fn` + # inside a while_loop. This is useful in cases where we might need to exit + # these contexts and get back to the outer context to do some things, for + # e.g. create an op which should be evaluated only once at the end of the + # loop on the host. One such usage is in creating metrics' value op. + self._outer_control_flow_context = ( + ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access + + # TODO(priyag): Use max_iterations instead of an explicit counter. cond = lambda i, *args: i < iterations i = constant_op.constant(0) - # TODO(priyag): Use max_iterations instead of an explicit counter. loop_result = control_flow_ops.while_loop( cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) + del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(loop_result) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 50b5555ba5..77fc56de36 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -159,8 +159,18 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, initial_loop_values) + # We capture the control_flow_context at this point, before we run `fn` + # inside a while_loop and TPU replicate context. This is useful in cases + # where we might need to exit these contexts and get back to the outer + # context to do some things, for e.g. create an op which should be + # evaluated only once at the end of the loop on the host. One such usage + # is in creating metrics' value op. + self._outer_control_flow_context = ( + ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access + replicate_inputs = [[]] * self.num_towers replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) + del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) # Filter out any ops from the outputs, typically this would be the case diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 9461a01515..763877c2d2 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -301,6 +301,40 @@ def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None): return total_cm, update_op +def _aggregate_across_towers(metrics_collections, metric_value_fn, *args): + """Aggregate metric value across towers.""" + def fn(distribution, *a): + """Call `metric_value_fn` in the correct control flow context.""" + if hasattr(distribution, '_outer_control_flow_context'): + # If there was an outer context captured before this method was called, + # then we enter that context to create the metric value op. If the + # caputred context is `None`, ops.control_dependencies(None) gives the + # desired behavior. Else we use `Enter` and `Exit` to enter and exit the + # captured context. + # This special handling is needed because sometimes the metric is created + # inside a while_loop (and perhaps a TPU rewrite context). But we don't + # want the value op to be evaluated every step or on the TPU. So we + # create it outside so that it can be evaluated at the end on the host, + # once the update ops have been evaluted. + + # pylint: disable=protected-access + if distribution._outer_control_flow_context is None: + with ops.control_dependencies(None): + metric_value = metric_value_fn(distribution, *a) + else: + distribution._outer_control_flow_context.Enter() + metric_value = metric_value_fn(distribution, *a) + distribution._outer_control_flow_context.Exit() + # pylint: enable=protected-access + else: + metric_value = metric_value_fn(distribution, *a) + if metrics_collections: + ops.add_to_collections(metrics_collections, metric_value) + return metric_value + + return distribution_strategy_context.get_tower_context().merge_call(fn, *args) + + @tf_export('metrics.mean') def mean(values, weights=None, @@ -368,14 +402,10 @@ def mean(values, with ops.control_dependencies([values]): update_count_op = state_ops.assign_add(count, num_values) - def aggregate_across_towers(_, t, c): - mean_t = _safe_div(t, c, 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_t) - return mean_t + compute_mean = lambda _, t, c: _safe_div(t, c, 'value') - mean_t = distribution_strategy_context.get_tower_context().merge_call( - aggregate_across_towers, total, count) + mean_t = _aggregate_across_towers( + metrics_collections, compute_mean, total, count) update_op = _safe_div(update_total_op, update_count_op, 'update_op') if updates_collections: @@ -612,14 +642,8 @@ def _confusion_matrix_at_thresholds(labels, def _aggregate_variable(v, collections): - - def f(distribution, value): - value = distribution.read_var(value) - if collections: - ops.add_to_collections(collections, value) - return value - - return distribution_strategy_context.get_tower_context().merge_call(f, v) + f = lambda distribution, value: distribution.read_var(value) + return _aggregate_across_towers(collections, f, v) @tf_export('metrics.auc') @@ -807,15 +831,12 @@ def auc(labels, raise ValueError('Invalid summation_method: %s' % summation_method) # sum up the areas of all the trapeziums - def aggregate_auc(_, values): - auc_value = compute_auc(values['tp'], values['fn'], values['tn'], - values['fp'], 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, auc_value) - return auc_value - - auc_value = distribution_strategy_context.get_tower_context().merge_call( - aggregate_auc, values) + def compute_auc_value(_, values): + return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'], + 'value') + + auc_value = _aggregate_across_towers( + metrics_collections, compute_auc_value, values) update_op = compute_auc(update_ops['tp'], update_ops['fn'], update_ops['tn'], update_ops['fp'], 'update_op') @@ -1046,16 +1067,14 @@ def mean_per_class_accuracy(labels, update_total_op = state_ops.scatter_add(total, labels, ones) update_count_op = state_ops.scatter_add(count, labels, is_correct) - def aggregate_mean_accuracy(_, count, total): + def compute_mean_accuracy(_, count, total): per_class_accuracy = _safe_div(count, total, None) mean_accuracy_v = math_ops.reduce_mean( per_class_accuracy, name='mean_accuracy') - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_accuracy_v) return mean_accuracy_v - mean_accuracy_v = distribution_strategy_context.get_tower_context( - ).merge_call(aggregate_mean_accuracy, count, total) + mean_accuracy_v = _aggregate_across_towers( + metrics_collections, compute_mean_accuracy, count, total) update_op = _safe_div(update_count_op, update_total_op, name='update_op') if updates_collections: @@ -1128,7 +1147,7 @@ def mean_iou(labels, total_cm, update_op = _streaming_confusion_matrix(labels, predictions, num_classes, weights) - def compute_mean_iou(total_cm, name): + def compute_mean_iou(_, total_cm): """Compute the mean intersection-over-union via the confusion matrix.""" sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0)) sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1)) @@ -1152,17 +1171,12 @@ def mean_iou(labels, # If the number of valid entries is 0 (no classes) we return 0. result = array_ops.where( math_ops.greater(num_valid_entries, 0), - math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0) + math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0) return result - def mean_iou_across_towers(_, v): - mean_iou_v = compute_mean_iou(v, 'mean_iou') - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_iou_v) - return mean_iou_v - - mean_iou_v = distribution_strategy_context.get_tower_context().merge_call( - mean_iou_across_towers, total_cm) + # TODO(priyag): Use outside_compilation if in TPU context. + mean_iou_v = _aggregate_across_towers( + metrics_collections, compute_mean_iou, total_cm) if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -1371,14 +1385,10 @@ def mean_tensor(values, with ops.control_dependencies([values]): update_count_op = state_ops.assign_add(count, num_values) - def aggregate_across_towers(_, t, c): - mean_t = _safe_div(t, c, 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_t) - return mean_t + compute_mean = lambda _, t, c: _safe_div(t, c, 'value') - mean_t = distribution_strategy_context.get_tower_context().merge_call( - aggregate_across_towers, total, count) + mean_t = _aggregate_across_towers( + metrics_collections, compute_mean, total, count) update_op = _safe_div(update_total_op, update_count_op, 'update_op') if updates_collections: @@ -2004,13 +2014,10 @@ def precision(labels, math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name) def once_across_towers(_, true_p, false_p): - p = compute_precision(true_p, false_p, 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, p) - return p + return compute_precision(true_p, false_p, 'value') - p = distribution_strategy_context.get_tower_context().merge_call( - once_across_towers, true_p, false_p) + p = _aggregate_across_towers(metrics_collections, once_across_towers, + true_p, false_p) update_op = compute_precision(true_positives_update_op, false_positives_update_op, 'update_op') @@ -2088,13 +2095,10 @@ def precision_at_thresholds(labels, return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name) def precision_across_towers(_, values): - prec = compute_precision(values['tp'], values['fp'], 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, prec) - return prec + return compute_precision(values['tp'], values['fp'], 'value') - prec = distribution_strategy_context.get_tower_context().merge_call( - precision_across_towers, values) + prec = _aggregate_across_towers( + metrics_collections, precision_across_towers, values) update_op = compute_precision(update_ops['tp'], update_ops['fp'], 'update_op') @@ -2184,13 +2188,10 @@ def recall(labels, math_ops.div(true_p, true_p + false_n), 0, name) def once_across_towers(_, true_p, false_n): - rec = compute_recall(true_p, false_n, 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, rec) - return rec + return compute_recall(true_p, false_n, 'value') - rec = distribution_strategy_context.get_tower_context().merge_call( - once_across_towers, true_p, false_n) + rec = _aggregate_across_towers( + metrics_collections, once_across_towers, true_p, false_n) update_op = compute_recall(true_positives_update_op, false_negatives_update_op, 'update_op') @@ -2622,14 +2623,11 @@ def recall_at_top_k(labels, class_id=class_id, weights=weights) - def aggregate_across_towers(_, tp, fn): - metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope) - if metrics_collections: - ops.add_to_collections(metrics_collections, metric) - return metric + def compute_recall(_, tp, fn): + return math_ops.div(tp, math_ops.add(tp, fn), name=scope) - metric = distribution_strategy_context.get_tower_context().merge_call( - aggregate_across_towers, tp, fn) + metric = _aggregate_across_towers( + metrics_collections, compute_recall, tp, fn) update = math_ops.div( tp_update, math_ops.add(tp_update, fn_update), name='update') @@ -2704,13 +2702,10 @@ def recall_at_thresholds(labels, return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name) def recall_across_towers(_, values): - rec = compute_recall(values['tp'], values['fn'], 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, rec) - return rec + return compute_recall(values['tp'], values['fn'], 'value') - rec = distribution_strategy_context.get_tower_context().merge_call( - recall_across_towers, values) + rec = _aggregate_across_towers( + metrics_collections, recall_across_towers, values) update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') if updates_collections: @@ -2778,14 +2773,9 @@ def root_mean_squared_error(labels, mse, update_mse_op = mean_squared_error(labels, predictions, weights, None, None, name or 'root_mean_squared_error') - def once_across_towers(_, mse): - rmse = math_ops.sqrt(mse) - if metrics_collections: - ops.add_to_collections(metrics_collections, rmse) - return rmse - rmse = distribution_strategy_context.get_tower_context().merge_call( - once_across_towers, mse) + once_across_towers = lambda _, mse: math_ops.sqrt(mse) + rmse = _aggregate_across_towers(metrics_collections, once_across_towers, mse) update_rmse_op = math_ops.sqrt(update_mse_op) if updates_collections: @@ -2880,15 +2870,12 @@ def sensitivity_at_specificity(labels, return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon, name) - def aggregate_across_towers(_, values): - sensitivity = compute_sensitivity_at_specificity( + def sensitivity_across_towers(_, values): + return compute_sensitivity_at_specificity( values['tp'], values['tn'], values['fp'], values['fn'], 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, sensitivity) - return sensitivity - sensitivity = distribution_strategy_context.get_tower_context().merge_call( - aggregate_across_towers, values) + sensitivity = _aggregate_across_towers( + metrics_collections, sensitivity_across_towers, values) update_op = compute_sensitivity_at_specificity( update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], @@ -3157,14 +3144,11 @@ def _streaming_sparse_average_precision_at_top_k(labels, total_update = state_ops.assign_add(total_var, batch_total, name='update') # Divide total by max to get mean, for both vars and the update ops. - def aggregate_across_towers(_, total_var, max_var): - mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean') - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_average_precision) - return mean_average_precision + def precision_across_towers(_, total_var, max_var): + return _safe_scalar_div(total_var, max_var, name='mean') - mean_average_precision = distribution_strategy_context.get_tower_context( - ).merge_call(aggregate_across_towers, total_var, max_var) + mean_average_precision = _aggregate_across_towers( + metrics_collections, precision_across_towers, total_var, max_var) update = _safe_scalar_div(total_update, max_update, name=scope) if updates_collections: @@ -3443,14 +3427,11 @@ def precision_at_top_k(labels, class_id=class_id, weights=weights) - def aggregate_across_towers(_, tp, fp): - metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope) - if metrics_collections: - ops.add_to_collections(metrics_collections, metric) - return metric + def precision_across_towers(_, tp, fp): + return math_ops.div(tp, math_ops.add(tp, fp), name=scope) - metric = distribution_strategy_context.get_tower_context().merge_call( - aggregate_across_towers, tp, fp) + metric = _aggregate_across_towers( + metrics_collections, precision_across_towers, tp, fp) update = math_ops.div( tp_update, math_ops.add(tp_update, fp_update), name='update') @@ -3681,15 +3662,12 @@ def specificity_at_sensitivity(labels, return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon, name) - def aggregate_across_towers(_, values): - specificity = compute_specificity_at_sensitivity( + def specificity_across_towers(_, values): + return compute_specificity_at_sensitivity( values['tp'], values['tn'], values['fp'], values['fn'], 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, specificity) - return specificity - specificity = distribution_strategy_context.get_tower_context().merge_call( - aggregate_across_towers, values) + specificity = _aggregate_across_towers( + metrics_collections, specificity_across_towers, values) update_op = compute_specificity_at_sensitivity( update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], |