aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-08-16 15:50:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 15:54:47 -0700
commit51f2c645d97c9ee1d9151bc3c959f73305daac09 (patch)
treeef20c4d17237f87e8e96020bd13a1c934e1cbeca /tensorflow
parent985cbd7ff220795abc4a50839144c177924d469c (diff)
Make tf.metrics work with TPU Strategy.
PiperOrigin-RevId: 209064406
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py9
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py11
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py10
-rw-r--r--tensorflow/python/ops/metrics_impl.py202
5 files changed, 121 insertions, 113 deletions
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
index 2f3d6bdd3f..8163494c8e 100644
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -68,6 +68,8 @@ def _regression_dataset_fn():
"predictions": [1., .75, .25, 0.]}).repeat()
+# TODO(priyag): Add TPU Strategy to this once metrics aggregate correctly using
+# TowerLocalVariables on TPUs. Submit http://cl/208914352.
def all_combinations():
return combinations.combine(
distribution=[combinations.default_strategy,
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 72d1c6b7dd..edd5c6d17a 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -383,12 +383,21 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
with ops.control_dependencies([fn_result]):
return [i + 1] + flat_last_step_outputs
+ # We capture the control_flow_context at this point, before we run `fn`
+ # inside a while_loop. This is useful in cases where we might need to exit
+ # these contexts and get back to the outer context to do some things, for
+ # e.g. create an op which should be evaluated only once at the end of the
+ # loop on the host. One such usage is in creating metrics' value op.
+ self._outer_control_flow_context = (
+ ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
+
cond = lambda i, *args: i < iterations
i = constant_op.constant(0)
loop_result = control_flow_ops.while_loop(
cond, body, [i] + initial_loop_values, name="",
parallel_iterations=1, back_prop=False, swap_memory=False,
return_same_structure=True)
+ del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(loop_result)
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 86833ad851..68561b5bbf 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -88,13 +88,22 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
with ops.control_dependencies([fn_result]):
return [i + 1] + flat_last_step_outputs
+ # We capture the control_flow_context at this point, before we run `fn`
+ # inside a while_loop. This is useful in cases where we might need to exit
+ # these contexts and get back to the outer context to do some things, for
+ # e.g. create an op which should be evaluated only once at the end of the
+ # loop on the host. One such usage is in creating metrics' value op.
+ self._outer_control_flow_context = (
+ ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
+
+ # TODO(priyag): Use max_iterations instead of an explicit counter.
cond = lambda i, *args: i < iterations
i = constant_op.constant(0)
- # TODO(priyag): Use max_iterations instead of an explicit counter.
loop_result = control_flow_ops.while_loop(
cond, body, [i] + initial_loop_values, name="",
parallel_iterations=1, back_prop=False, swap_memory=False,
return_same_structure=True)
+ del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(loop_result)
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 50b5555ba5..77fc56de36 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -159,8 +159,18 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def iterate_on_tpu():
return training_loop.repeat(iterations, run_fn, initial_loop_values)
+ # We capture the control_flow_context at this point, before we run `fn`
+ # inside a while_loop and TPU replicate context. This is useful in cases
+ # where we might need to exit these contexts and get back to the outer
+ # context to do some things, for e.g. create an op which should be
+ # evaluated only once at the end of the loop on the host. One such usage
+ # is in creating metrics' value op.
+ self._outer_control_flow_context = (
+ ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
+
replicate_inputs = [[]] * self.num_towers
replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
+ del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)
# Filter out any ops from the outputs, typically this would be the case
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 9461a01515..763877c2d2 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -301,6 +301,40 @@ def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
return total_cm, update_op
+def _aggregate_across_towers(metrics_collections, metric_value_fn, *args):
+ """Aggregate metric value across towers."""
+ def fn(distribution, *a):
+ """Call `metric_value_fn` in the correct control flow context."""
+ if hasattr(distribution, '_outer_control_flow_context'):
+ # If there was an outer context captured before this method was called,
+ # then we enter that context to create the metric value op. If the
+ # caputred context is `None`, ops.control_dependencies(None) gives the
+ # desired behavior. Else we use `Enter` and `Exit` to enter and exit the
+ # captured context.
+ # This special handling is needed because sometimes the metric is created
+ # inside a while_loop (and perhaps a TPU rewrite context). But we don't
+ # want the value op to be evaluated every step or on the TPU. So we
+ # create it outside so that it can be evaluated at the end on the host,
+ # once the update ops have been evaluted.
+
+ # pylint: disable=protected-access
+ if distribution._outer_control_flow_context is None:
+ with ops.control_dependencies(None):
+ metric_value = metric_value_fn(distribution, *a)
+ else:
+ distribution._outer_control_flow_context.Enter()
+ metric_value = metric_value_fn(distribution, *a)
+ distribution._outer_control_flow_context.Exit()
+ # pylint: enable=protected-access
+ else:
+ metric_value = metric_value_fn(distribution, *a)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, metric_value)
+ return metric_value
+
+ return distribution_strategy_context.get_tower_context().merge_call(fn, *args)
+
+
@tf_export('metrics.mean')
def mean(values,
weights=None,
@@ -368,14 +402,10 @@ def mean(values,
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- def aggregate_across_towers(_, t, c):
- mean_t = _safe_div(t, c, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_t)
- return mean_t
+ compute_mean = lambda _, t, c: _safe_div(t, c, 'value')
- mean_t = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, total, count)
+ mean_t = _aggregate_across_towers(
+ metrics_collections, compute_mean, total, count)
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
if updates_collections:
@@ -612,14 +642,8 @@ def _confusion_matrix_at_thresholds(labels,
def _aggregate_variable(v, collections):
-
- def f(distribution, value):
- value = distribution.read_var(value)
- if collections:
- ops.add_to_collections(collections, value)
- return value
-
- return distribution_strategy_context.get_tower_context().merge_call(f, v)
+ f = lambda distribution, value: distribution.read_var(value)
+ return _aggregate_across_towers(collections, f, v)
@tf_export('metrics.auc')
@@ -807,15 +831,12 @@ def auc(labels,
raise ValueError('Invalid summation_method: %s' % summation_method)
# sum up the areas of all the trapeziums
- def aggregate_auc(_, values):
- auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
- values['fp'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, auc_value)
- return auc_value
-
- auc_value = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_auc, values)
+ def compute_auc_value(_, values):
+ return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'],
+ 'value')
+
+ auc_value = _aggregate_across_towers(
+ metrics_collections, compute_auc_value, values)
update_op = compute_auc(update_ops['tp'], update_ops['fn'],
update_ops['tn'], update_ops['fp'], 'update_op')
@@ -1046,16 +1067,14 @@ def mean_per_class_accuracy(labels,
update_total_op = state_ops.scatter_add(total, labels, ones)
update_count_op = state_ops.scatter_add(count, labels, is_correct)
- def aggregate_mean_accuracy(_, count, total):
+ def compute_mean_accuracy(_, count, total):
per_class_accuracy = _safe_div(count, total, None)
mean_accuracy_v = math_ops.reduce_mean(
per_class_accuracy, name='mean_accuracy')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_accuracy_v)
return mean_accuracy_v
- mean_accuracy_v = distribution_strategy_context.get_tower_context(
- ).merge_call(aggregate_mean_accuracy, count, total)
+ mean_accuracy_v = _aggregate_across_towers(
+ metrics_collections, compute_mean_accuracy, count, total)
update_op = _safe_div(update_count_op, update_total_op, name='update_op')
if updates_collections:
@@ -1128,7 +1147,7 @@ def mean_iou(labels,
total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
num_classes, weights)
- def compute_mean_iou(total_cm, name):
+ def compute_mean_iou(_, total_cm):
"""Compute the mean intersection-over-union via the confusion matrix."""
sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
@@ -1152,17 +1171,12 @@ def mean_iou(labels,
# If the number of valid entries is 0 (no classes) we return 0.
result = array_ops.where(
math_ops.greater(num_valid_entries, 0),
- math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0)
+ math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0)
return result
- def mean_iou_across_towers(_, v):
- mean_iou_v = compute_mean_iou(v, 'mean_iou')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_iou_v)
- return mean_iou_v
-
- mean_iou_v = distribution_strategy_context.get_tower_context().merge_call(
- mean_iou_across_towers, total_cm)
+ # TODO(priyag): Use outside_compilation if in TPU context.
+ mean_iou_v = _aggregate_across_towers(
+ metrics_collections, compute_mean_iou, total_cm)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -1371,14 +1385,10 @@ def mean_tensor(values,
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- def aggregate_across_towers(_, t, c):
- mean_t = _safe_div(t, c, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_t)
- return mean_t
+ compute_mean = lambda _, t, c: _safe_div(t, c, 'value')
- mean_t = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, total, count)
+ mean_t = _aggregate_across_towers(
+ metrics_collections, compute_mean, total, count)
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
if updates_collections:
@@ -2004,13 +2014,10 @@ def precision(labels,
math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
def once_across_towers(_, true_p, false_p):
- p = compute_precision(true_p, false_p, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, p)
- return p
+ return compute_precision(true_p, false_p, 'value')
- p = distribution_strategy_context.get_tower_context().merge_call(
- once_across_towers, true_p, false_p)
+ p = _aggregate_across_towers(metrics_collections, once_across_towers,
+ true_p, false_p)
update_op = compute_precision(true_positives_update_op,
false_positives_update_op, 'update_op')
@@ -2088,13 +2095,10 @@ def precision_at_thresholds(labels,
return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
def precision_across_towers(_, values):
- prec = compute_precision(values['tp'], values['fp'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, prec)
- return prec
+ return compute_precision(values['tp'], values['fp'], 'value')
- prec = distribution_strategy_context.get_tower_context().merge_call(
- precision_across_towers, values)
+ prec = _aggregate_across_towers(
+ metrics_collections, precision_across_towers, values)
update_op = compute_precision(update_ops['tp'], update_ops['fp'],
'update_op')
@@ -2184,13 +2188,10 @@ def recall(labels,
math_ops.div(true_p, true_p + false_n), 0, name)
def once_across_towers(_, true_p, false_n):
- rec = compute_recall(true_p, false_n, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rec)
- return rec
+ return compute_recall(true_p, false_n, 'value')
- rec = distribution_strategy_context.get_tower_context().merge_call(
- once_across_towers, true_p, false_n)
+ rec = _aggregate_across_towers(
+ metrics_collections, once_across_towers, true_p, false_n)
update_op = compute_recall(true_positives_update_op,
false_negatives_update_op, 'update_op')
@@ -2622,14 +2623,11 @@ def recall_at_top_k(labels,
class_id=class_id,
weights=weights)
- def aggregate_across_towers(_, tp, fn):
- metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
- return metric
+ def compute_recall(_, tp, fn):
+ return math_ops.div(tp, math_ops.add(tp, fn), name=scope)
- metric = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, tp, fn)
+ metric = _aggregate_across_towers(
+ metrics_collections, compute_recall, tp, fn)
update = math_ops.div(
tp_update, math_ops.add(tp_update, fn_update), name='update')
@@ -2704,13 +2702,10 @@ def recall_at_thresholds(labels,
return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
def recall_across_towers(_, values):
- rec = compute_recall(values['tp'], values['fn'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rec)
- return rec
+ return compute_recall(values['tp'], values['fn'], 'value')
- rec = distribution_strategy_context.get_tower_context().merge_call(
- recall_across_towers, values)
+ rec = _aggregate_across_towers(
+ metrics_collections, recall_across_towers, values)
update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
if updates_collections:
@@ -2778,14 +2773,9 @@ def root_mean_squared_error(labels,
mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
None, name or
'root_mean_squared_error')
- def once_across_towers(_, mse):
- rmse = math_ops.sqrt(mse)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rmse)
- return rmse
- rmse = distribution_strategy_context.get_tower_context().merge_call(
- once_across_towers, mse)
+ once_across_towers = lambda _, mse: math_ops.sqrt(mse)
+ rmse = _aggregate_across_towers(metrics_collections, once_across_towers, mse)
update_rmse_op = math_ops.sqrt(update_mse_op)
if updates_collections:
@@ -2880,15 +2870,12 @@ def sensitivity_at_specificity(labels,
return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
name)
- def aggregate_across_towers(_, values):
- sensitivity = compute_sensitivity_at_specificity(
+ def sensitivity_across_towers(_, values):
+ return compute_sensitivity_at_specificity(
values['tp'], values['tn'], values['fp'], values['fn'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, sensitivity)
- return sensitivity
- sensitivity = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, values)
+ sensitivity = _aggregate_across_towers(
+ metrics_collections, sensitivity_across_towers, values)
update_op = compute_sensitivity_at_specificity(
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
@@ -3157,14 +3144,11 @@ def _streaming_sparse_average_precision_at_top_k(labels,
total_update = state_ops.assign_add(total_var, batch_total, name='update')
# Divide total by max to get mean, for both vars and the update ops.
- def aggregate_across_towers(_, total_var, max_var):
- mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_average_precision)
- return mean_average_precision
+ def precision_across_towers(_, total_var, max_var):
+ return _safe_scalar_div(total_var, max_var, name='mean')
- mean_average_precision = distribution_strategy_context.get_tower_context(
- ).merge_call(aggregate_across_towers, total_var, max_var)
+ mean_average_precision = _aggregate_across_towers(
+ metrics_collections, precision_across_towers, total_var, max_var)
update = _safe_scalar_div(total_update, max_update, name=scope)
if updates_collections:
@@ -3443,14 +3427,11 @@ def precision_at_top_k(labels,
class_id=class_id,
weights=weights)
- def aggregate_across_towers(_, tp, fp):
- metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
- return metric
+ def precision_across_towers(_, tp, fp):
+ return math_ops.div(tp, math_ops.add(tp, fp), name=scope)
- metric = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, tp, fp)
+ metric = _aggregate_across_towers(
+ metrics_collections, precision_across_towers, tp, fp)
update = math_ops.div(
tp_update, math_ops.add(tp_update, fp_update), name='update')
@@ -3681,15 +3662,12 @@ def specificity_at_sensitivity(labels,
return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
name)
- def aggregate_across_towers(_, values):
- specificity = compute_specificity_at_sensitivity(
+ def specificity_across_towers(_, values):
+ return compute_specificity_at_sensitivity(
values['tp'], values['tn'], values['fp'], values['fn'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, specificity)
- return specificity
- specificity = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, values)
+ specificity = _aggregate_across_towers(
+ metrics_collections, specificity_across_towers, values)
update_op = compute_specificity_at_sensitivity(
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],