diff options
author | 2018-05-18 16:33:19 -0700 | |
---|---|---|
committer | 2018-05-18 16:36:57 -0700 | |
commit | 40f53c774e914b9166a5bc8476e290da4a121c82 (patch) | |
tree | c21c2af99a3e0cf7aedf5615ffb0e50d5085ac04 | |
parent | f4cb5978667ccf6396e4a779e3a482766959e5dd (diff) |
Automated g4 rollback of changelist 197070234
PiperOrigin-RevId: 197218170
-rw-r--r-- | tensorflow/contrib/distribute/python/BUILD | 19 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/metrics_v1_test.py | 438 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/framework/test_util.py | 8 | ||||
-rw-r--r-- | tensorflow/python/ops/metrics_impl.py | 296 |
5 files changed, 102 insertions, 660 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index aeeaa0b400..64a77bbed1 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -547,22 +547,3 @@ cuda_py_test( "no_pip", ], ) - -cuda_py_test( - name = "metrics_v1_test", - srcs = ["metrics_v1_test.py"], - additional_deps = [ - ":combinations", - "@absl_py//absl/testing:parameterized", - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:test", - ], - tags = [ - "multi_and_single_gpu", - "no_pip", - ], -) diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py deleted file mode 100644 index 6c6bf14309..0000000000 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ /dev/null @@ -1,438 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for V1 metrics.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl.testing import parameterized - -from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.distribute.python import combinations -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.eager import test -from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import metrics -from tensorflow.python.ops import variables - - -def _labeled_dataset_fn(): - # First four batches of x: labels, predictions -> (labels == predictions) - # 0: 0, 0 -> True; 1: 1, 1 -> True; 2: 2, 2 -> True; 3: 3, 0 -> False - # 4: 4, 1 -> False; 5: 0, 2 -> False; 6: 1, 0 -> False; 7: 2, 1 -> False - # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False - # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True - return dataset_ops.Dataset.range(1000).map( - lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4) - - -def _boolean_dataset_fn(): - # First four batches of labels, predictions: {TP, FP, TN, FN} - # with a threshold of 0.5: - # T, T -> TP; F, T -> FP; T, F -> FN - # F, F -> TN; T, T -> TP; F, T -> FP - # T, F -> FN; F, F -> TN; T, T -> TP - # F, T -> FP; T, F -> FN; F, F -> TN - return dataset_ops.Dataset.from_tensor_slices({ - "labels": [True, False, True, False], - "predictions": [True, True, False, False]}).repeat().batch(3) - - -def _threshold_dataset_fn(): - # First four batches of labels, predictions: {TP, FP, TN, FN} - # with a threshold of 0.5: - # True, 1.0 -> TP; False, .75 -> FP; True, .25 -> FN - # False, 0.0 -> TN; True, 1.0 -> TP; False, .75 -> FP - # True, .25 -> FN; False, 0.0 -> TN; True, 1.0 -> TP - # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN - return dataset_ops.Dataset.from_tensor_slices({ - "labels": [True, False, True, False], - "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3) - - -def _regression_dataset_fn(): - return dataset_ops.Dataset.from_tensor_slices({ - "labels": [1., .5, 1., 0.], - "predictions": [1., .75, .25, 0.]}).repeat() - - -def all_combinations(): - return combinations.combine( - distribution=[combinations.default_strategy, - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], - mode=["graph"]) - - -# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k, -# metrics.precision_at_k -class MetricsV1Test(test.TestCase, parameterized.TestCase): - - def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): - with ops.Graph().as_default(), distribution.scope(): - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() - value, update = distribution.call_for_each_tower( - metric_fn, iterator.get_next()) - update = distribution.group(update) - self.evaluate(variables.local_variables_initializer()) - # TODO(josh11b): Once we switch to using a global batch size for input, - # replace "distribution.num_towers" with "1". - batches_per_update = distribution.num_towers - - # Update variables using the first `num_towers` batches. - self.evaluate(update) - self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value), - 0.001, msg="After first update") - - # Update variables using the second `num_towers` batches. - self.evaluate(update) - self.assertAllClose(expected_fn(2 * batches_per_update), - self.evaluate(value), - 0.001, - msg="After second update") - - if batches_per_update == 1: # Consume 4 input batches - self.evaluate(update) - self.assertAllClose(expected_fn(3 * batches_per_update), - self.evaluate(value), - 0.001, - msg="After third update") - self.evaluate(update) - self.assertAllClose(expected_fn(4 * batches_per_update), - self.evaluate(value), - 0.001, - msg="After fourth update") - - @combinations.generate(all_combinations()) - def testMean(self, distribution): - def _dataset_fn(): - return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4) - - def _expected_fn(num_batches): - # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc. - return num_batches * 2 - 0.5 - - self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn) - - @combinations.generate(all_combinations()) - def testAccuracy(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.accuracy(labels, predictions) - - def _expected_fn(num_batches): - return [3./4, 3./8, 3./12, 4./16][num_batches - 1] - - self._test_metric( - distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testMeanPerClassAccuracy(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.mean_per_class_accuracy( - labels, predictions, num_classes=5) - - def _expected_fn(num_batches): - mean = lambda x: sum(x) / len(x) - return [mean([1., 1., 1., 0., 0.]), - mean([0.5, 0.5, 0.5, 0., 0.]), - mean([1./3, 1./3, 0.5, 0., 0.]), - mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1] - - self._test_metric( - distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testMeanIOU(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.mean_iou( - labels, predictions, num_classes=5) - - def _expected_fn(num_batches): - mean = lambda x: sum(x) / len(x) - return [mean([1./2, 1./1, 1./1, 0.]), # no class 4 in first batch - mean([1./4, 1./4, 1./3, 0., 0.]), - mean([1./6, 1./6, 1./5, 0., 0.]), - mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1] - - self._test_metric( - distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testMeanTensor(self, distribution): - def _dataset_fn(): - dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) - # Want to produce a fixed, known shape, so drop remainder when batching. - dataset = dataset.apply(batching.batch_and_drop_remainder(4)) - return dataset - - def _expected_fn(num_batches): - # Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2 - # Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1 - # Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches - # Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1 - first = 2. * num_batches - 2. - return [first, first + 1., first + 2., first + 3.] - - self._test_metric( - distribution, _dataset_fn, metrics.mean_tensor, _expected_fn) - - @combinations.generate(all_combinations()) - def testAUCROC(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC", - summation_method="careful_interpolation") - - def _expected_fn(num_batches): - return [0.5, 7./9, 0.8, 0.75][num_batches - 1] - - self._test_metric( - distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testAUCPR(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.auc(labels, predictions, num_thresholds=8, curve="PR", - summation_method="careful_interpolation") - - def _expected_fn(num_batches): - return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1] - - self._test_metric( - distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testFalseNegatives(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.false_negatives(labels, predictions) - - def _expected_fn(num_batches): - return [1., 1., 2., 3.][num_batches - 1] - - self._test_metric( - distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testFalseNegativesAtThresholds(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.false_negatives_at_thresholds(labels, predictions, [.5]) - - def _expected_fn(num_batches): - return [[1.], [1.], [2.], [3.]][num_batches - 1] - - self._test_metric( - distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testTrueNegatives(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.true_negatives(labels, predictions) - - def _expected_fn(num_batches): - return [0., 1., 2., 3.][num_batches - 1] - - self._test_metric( - distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testTrueNegativesAtThresholds(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.true_negatives_at_thresholds(labels, predictions, [.5]) - - def _expected_fn(num_batches): - return [[0.], [1.], [2.], [3.]][num_batches - 1] - - self._test_metric( - distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testFalsePositives(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.false_positives(labels, predictions) - - def _expected_fn(num_batches): - return [1., 2., 2., 3.][num_batches - 1] - - self._test_metric( - distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testFalsePositivesAtThresholds(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.false_positives_at_thresholds(labels, predictions, [.5]) - - def _expected_fn(num_batches): - return [[1.], [2.], [2.], [3.]][num_batches - 1] - - self._test_metric( - distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testTruePositives(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.true_positives(labels, predictions) - - def _expected_fn(num_batches): - return [1., 2., 3., 3.][num_batches - 1] - - self._test_metric( - distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testTruePositivesAtThresholds(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.true_positives_at_thresholds(labels, predictions, [.5]) - - def _expected_fn(num_batches): - return [[1.], [2.], [3.], [3.]][num_batches - 1] - - self._test_metric( - distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testPrecision(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.precision(labels, predictions) - - def _expected_fn(num_batches): - return [0.5, 0.5, 0.6, 0.5][num_batches - 1] - - self._test_metric( - distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testPrecisionAtThreshold(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.precision_at_thresholds(labels, predictions, [0.5]) - - def _expected_fn(num_batches): - return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1] - - self._test_metric( - distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testRecall(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.recall(labels, predictions) - - def _expected_fn(num_batches): - return [0.5, 2./3, 0.6, 0.5][num_batches - 1] - - self._test_metric( - distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testRecallAtThreshold(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.recall_at_thresholds(labels, predictions, [0.5]) - - def _expected_fn(num_batches): - return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1] - - self._test_metric( - distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testMeanSquaredError(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.mean_squared_error(labels, predictions) - - def _expected_fn(num_batches): - return [0., 1./32, 0.208333, 0.15625][num_batches - 1] - - self._test_metric( - distribution, _regression_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testRootMeanSquaredError(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.root_mean_squared_error(labels, predictions) - - def _expected_fn(num_batches): - return [0., 0.176777, 0.456435, 0.395285][num_batches - 1] - - self._test_metric( - distribution, _regression_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testSensitivityAtSpecificity(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.sensitivity_at_specificity(labels, predictions, 0.8) - - def _expected_fn(num_batches): - return [0.5, 2./3, 0.6, 0.5][num_batches - 1] - - self._test_metric( - distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - - @combinations.generate(all_combinations()) - def testSpecificityAtSensitivity(self, distribution): - def _metric_fn(x): - labels = x["labels"] - predictions = x["predictions"] - return metrics.specificity_at_sensitivity(labels, predictions, 0.95) - - def _expected_fn(num_batches): - return [0., 1./3, 0.5, 0.5][num_batches - 1] - - self._test_metric( - distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index cb722485f8..f714d1fb21 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2461,7 +2461,6 @@ py_library( ":check_ops", ":confusion_matrix", ":control_flow_ops", - ":distribute", ":framework", ":framework_for_generated_wrappers", ":math_ops", diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index bf382a2cbf..97cd22e47a 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1328,11 +1328,11 @@ class TensorFlowTestCase(googletest.TestCase): b, rtol=rtol, atol=atol, - msg=("Mismatched value: a%s is different from b%s. %s" % - (path_str, path_str, msg))) + msg="Mismatched value: a%s is different from b%s." % (path_str, + path_str)) except TypeError as e: - msg = ("Error: a%s has %s, but b%s has %s. %s" % - (path_str, type(a), path_str, type(b), msg)) + msg = "Error: a%s has %s, but b%s has %s" % (path_str, type(a), + path_str, type(b)) e.args = ((e.args[0] + " : " + msg,) + e.args[1:]) raise diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 244e28d306..47eea6ef6b 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -34,54 +34,21 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export def metric_variable(shape, dtype, validate_shape=True, name=None): - """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections. - - If running in a `DistributionStrategy` context, the variable will be - "tower local". This means: - - * The returned object will be a container with separate variables - per replica/tower of the model. - - * When writing to the variable, e.g. using `assign_add` in a metric - update, the update will be applied to the variable local to the - replica/tower. - - * To get a metric's result value, we need to sum the variable values - across the replicas/towers before computing the final answer. - Furthermore, the final answer should be computed once instead of - in every replica/tower. Both of these are accomplished by - running the computation of the final result value inside - `tf.contrib.distribute.get_tower_context().merge_call(fn)`. - Inside the `merge_call()`, ops are only added to the graph once - and access to a tower-local variable in a computation returns - the sum across all replicas/towers. - - Args: - shape: Shape of the created variable. - dtype: Type of the created variable. - validate_shape: (Optional) Whether shape validation is enabled for - the created variable. - name: (Optional) String name of the created variable. - - Returns: - A (non-trainable) variable initialized to zero, or if inside a - `DistributionStrategy` scope a tower-local variable container. - """ - with distribute_lib.get_tower_context().tower_local_var_scope('sum'): - # Note that "tower local" implies trainable=False. - return variable_scope.variable( - lambda: array_ops.zeros(shape, dtype), - collections=[ - ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES - ], - validate_shape=validate_shape, - name=name) + """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections.""" + + return variable_scope.variable( + lambda: array_ops.zeros(shape, dtype), + trainable=False, + collections=[ + ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES + ], + validate_shape=validate_shape, + name=name) def _remove_squeezable_dimensions(predictions, labels, weights): @@ -366,16 +333,12 @@ def mean(values, with ops.control_dependencies([values]): update_count_op = state_ops.assign_add(count, num_values) - def aggregate_across_towers(_, t, c): - mean_t = _safe_div(t, c, 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_t) - return mean_t - - mean_t = distribute_lib.get_tower_context().merge_call( - aggregate_across_towers, total, count) + mean_t = _safe_div(total, count, 'value') update_op = _safe_div(update_total_op, update_count_op, 'update_op') + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_t) + if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -609,17 +572,6 @@ def _confusion_matrix_at_thresholds(labels, return values, update_ops -def _aggregate_variable(v, collections): - - def f(distribution, value): - value = distribution.fetch(value) - if collections: - ops.add_to_collections(collections, value) - return value - - return distribute_lib.get_tower_context().merge_call(f, v) - - @tf_export('metrics.auc') def auc(labels, predictions, @@ -805,18 +757,14 @@ def auc(labels, raise ValueError('Invalid summation_method: %s' % summation_method) # sum up the areas of all the trapeziums - def aggregate_auc(_, values): - auc_value = compute_auc(values['tp'], values['fn'], values['tn'], - values['fp'], 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, auc_value) - return auc_value - - auc_value = distribute_lib.get_tower_context().merge_call( - aggregate_auc, values) + auc_value = compute_auc(values['tp'], values['fn'], values['tn'], + values['fp'], 'value') update_op = compute_auc(update_ops['tp'], update_ops['fn'], update_ops['tn'], update_ops['fp'], 'update_op') + if metrics_collections: + ops.add_to_collections(metrics_collections, auc_value) + if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -1044,18 +992,15 @@ def mean_per_class_accuracy(labels, update_total_op = state_ops.scatter_add(total, labels, ones) update_count_op = state_ops.scatter_add(count, labels, is_correct) - def aggregate_mean_accuracy(_, count, total): - per_class_accuracy = _safe_div(count, total, None) - mean_accuracy_v = math_ops.reduce_mean( - per_class_accuracy, name='mean_accuracy') - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_accuracy_v) - return mean_accuracy_v - - mean_accuracy_v = distribute_lib.get_tower_context().merge_call( - aggregate_mean_accuracy, count, total) + per_class_accuracy = _safe_div(count, total, None) + mean_accuracy_v = math_ops.reduce_mean( + per_class_accuracy, name='mean_accuracy') update_op = _safe_div(update_count_op, update_total_op, name='update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_accuracy_v) + if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -1126,7 +1071,7 @@ def mean_iou(labels, total_cm, update_op = _streaming_confusion_matrix(labels, predictions, num_classes, weights) - def compute_mean_iou(total_cm, name): + def compute_mean_iou(name): """Compute the mean intersection-over-union via the confusion matrix.""" sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0)) sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1)) @@ -1153,14 +1098,10 @@ def mean_iou(labels, math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0) return result - def mean_iou_across_towers(_, v): - mean_iou_v = compute_mean_iou(v, 'mean_iou') - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_iou_v) - return mean_iou_v + mean_iou_v = compute_mean_iou('mean_iou') - mean_iou_v = distribute_lib.get_tower_context().merge_call( - mean_iou_across_towers, total_cm) + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_iou_v) if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -1369,16 +1310,12 @@ def mean_tensor(values, with ops.control_dependencies([values]): update_count_op = state_ops.assign_add(count, num_values) - def aggregate_across_towers(_, t, c): - mean_t = _safe_div(t, c, 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_t) - return mean_t + mean_t = _safe_div(total, count, 'value') + update_op = _safe_div(update_total_op, update_count_op, 'update_op') - mean_t = distribute_lib.get_tower_context().merge_call( - aggregate_across_towers, total, count) + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_t) - update_op = _safe_div(update_total_op, update_count_op, 'update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -1476,9 +1413,12 @@ def _count_condition(values, weights = math_ops.to_float(weights) values = math_ops.multiply(values, weights) - value_tensor = _aggregate_variable(count, metrics_collections) - + value_tensor = array_ops.identity(count) update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) + + if metrics_collections: + ops.add_to_collections(metrics_collections, value_tensor) + if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -1585,12 +1525,13 @@ def false_negatives_at_thresholds(labels, values, update_ops = _confusion_matrix_at_thresholds( labels, predictions, thresholds, weights=weights, includes=('fn',)) - fn_value = _aggregate_variable(values['fn'], metrics_collections) + if metrics_collections: + ops.add_to_collections(metrics_collections, values['fn']) if updates_collections: ops.add_to_collections(updates_collections, update_ops['fn']) - return fn_value, update_ops['fn'] + return values['fn'], update_ops['fn'] @tf_export('metrics.false_positives') @@ -1694,12 +1635,13 @@ def false_positives_at_thresholds(labels, values, update_ops = _confusion_matrix_at_thresholds( labels, predictions, thresholds, weights=weights, includes=('fp',)) - fp_value = _aggregate_variable(values['fp'], metrics_collections) + if metrics_collections: + ops.add_to_collections(metrics_collections, values['fp']) if updates_collections: ops.add_to_collections(updates_collections, update_ops['fp']) - return fp_value, update_ops['fp'] + return values['fp'], update_ops['fp'] @tf_export('metrics.true_negatives') @@ -1803,12 +1745,13 @@ def true_negatives_at_thresholds(labels, values, update_ops = _confusion_matrix_at_thresholds( labels, predictions, thresholds, weights=weights, includes=('tn',)) - tn_value = _aggregate_variable(values['tn'], metrics_collections) + if metrics_collections: + ops.add_to_collections(metrics_collections, values['tn']) if updates_collections: ops.add_to_collections(updates_collections, update_ops['tn']) - return tn_value, update_ops['tn'] + return values['tn'], update_ops['tn'] @tf_export('metrics.true_positives') @@ -1912,12 +1855,13 @@ def true_positives_at_thresholds(labels, values, update_ops = _confusion_matrix_at_thresholds( labels, predictions, thresholds, weights=weights, includes=('tp',)) - tp_value = _aggregate_variable(values['tp'], metrics_collections) + if metrics_collections: + ops.add_to_collections(metrics_collections, values['tp']) if updates_collections: ops.add_to_collections(updates_collections, update_ops['tp']) - return tp_value, update_ops['tp'] + return values['tp'], update_ops['tp'] @tf_export('metrics.precision') @@ -2001,17 +1945,13 @@ def precision(labels, return array_ops.where( math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name) - def once_across_towers(_, true_p, false_p): - p = compute_precision(true_p, false_p, 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, p) - return p - - p = distribute_lib.get_tower_context().merge_call( - once_across_towers, true_p, false_p) - + p = compute_precision(true_p, false_p, 'value') update_op = compute_precision(true_positives_update_op, false_positives_update_op, 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, p) + if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -2085,17 +2025,13 @@ def precision_at_thresholds(labels, def compute_precision(tp, fp, name): return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name) - def precision_across_towers(_, values): - prec = compute_precision(values['tp'], values['fp'], 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, prec) - return prec - - prec = distribute_lib.get_tower_context().merge_call( - precision_across_towers, values) - + prec = compute_precision(values['tp'], values['fp'], 'value') update_op = compute_precision(update_ops['tp'], update_ops['fp'], 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, prec) + if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -2114,7 +2050,7 @@ def recall(labels, The `recall` function creates two local variables, `true_positives` and `false_negatives`, that are used to compute the recall. This value is ultimately returned as `recall`, an idempotent operation that simply divides - `true_positives` by the sum of `true_positives` and `false_negatives`. + `true_positives` by the sum of `true_positives` and `false_negatives`. For estimation of the metric over a stream of data, the function creates an `update_op` that updates these variables and returns the `recall`. `update_op` @@ -2181,17 +2117,13 @@ def recall(labels, math_ops.greater(true_p + false_n, 0), math_ops.div(true_p, true_p + false_n), 0, name) - def once_across_towers(_, true_p, false_n): - rec = compute_recall(true_p, false_n, 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, rec) - return rec - - rec = distribute_lib.get_tower_context().merge_call( - once_across_towers, true_p, false_n) - + rec = compute_recall(true_p, false_n, 'value') update_op = compute_recall(true_positives_update_op, false_negatives_update_op, 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, rec) + if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -2620,17 +2552,11 @@ def recall_at_top_k(labels, class_id=class_id, weights=weights) - def aggregate_across_towers(_, tp, fn): - metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope) - if metrics_collections: - ops.add_to_collections(metrics_collections, metric) - return metric - - metric = distribute_lib.get_tower_context().merge_call( - aggregate_across_towers, tp, fn) - + metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope) update = math_ops.div( tp_update, math_ops.add(tp_update, fn_update), name='update') + if metrics_collections: + ops.add_to_collections(metrics_collections, metric) if updates_collections: ops.add_to_collections(updates_collections, update) return metric, update @@ -2701,16 +2627,12 @@ def recall_at_thresholds(labels, def compute_recall(tp, fn, name): return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name) - def recall_across_towers(_, values): - rec = compute_recall(values['tp'], values['fn'], 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, rec) - return rec + rec = compute_recall(values['tp'], values['fn'], 'value') + update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') - rec = distribute_lib.get_tower_context().merge_call( - recall_across_towers, values) + if metrics_collections: + ops.add_to_collections(metrics_collections, rec) - update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -2776,16 +2698,13 @@ def root_mean_squared_error(labels, mse, update_mse_op = mean_squared_error(labels, predictions, weights, None, None, name or 'root_mean_squared_error') - def once_across_towers(_, mse): - rmse = math_ops.sqrt(mse) - if metrics_collections: - ops.add_to_collections(metrics_collections, rmse) - return rmse - - rmse = distribute_lib.get_tower_context().merge_call( - once_across_towers, mse) + rmse = math_ops.sqrt(mse) update_rmse_op = math_ops.sqrt(update_mse_op) + + if metrics_collections: + ops.add_to_collections(metrics_collections, rmse) + if updates_collections: ops.add_to_collections(updates_collections, update_rmse_op) @@ -2878,19 +2797,15 @@ def sensitivity_at_specificity(labels, return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon, name) - def aggregate_across_towers(_, values): - sensitivity = compute_sensitivity_at_specificity( - values['tp'], values['tn'], values['fp'], values['fn'], 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, sensitivity) - return sensitivity - - sensitivity = distribute_lib.get_tower_context().merge_call( - aggregate_across_towers, values) - + sensitivity = compute_sensitivity_at_specificity( + values['tp'], values['tn'], values['fp'], values['fn'], 'value') update_op = compute_sensitivity_at_specificity( update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, sensitivity) + if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -3155,16 +3070,11 @@ def _streaming_sparse_average_precision_at_top_k(labels, total_update = state_ops.assign_add(total_var, batch_total, name='update') # Divide total by max to get mean, for both vars and the update ops. - def aggregate_across_towers(_, total_var, max_var): - mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean') - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_average_precision) - return mean_average_precision - - mean_average_precision = distribute_lib.get_tower_context().merge_call( - aggregate_across_towers, total_var, max_var) - + mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean') update = _safe_scalar_div(total_update, max_update, name=scope) + + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_average_precision) if updates_collections: ops.add_to_collections(updates_collections, update) @@ -3441,17 +3351,11 @@ def precision_at_top_k(labels, class_id=class_id, weights=weights) - def aggregate_across_towers(_, tp, fp): - metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope) - if metrics_collections: - ops.add_to_collections(metrics_collections, metric) - return metric - - metric = distribute_lib.get_tower_context().merge_call( - aggregate_across_towers, tp, fp) - + metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope) update = math_ops.div( tp_update, math_ops.add(tp_update, fp_update), name='update') + if metrics_collections: + ops.add_to_collections(metrics_collections, metric) if updates_collections: ops.add_to_collections(updates_collections, update) return metric, update @@ -3679,19 +3583,15 @@ def specificity_at_sensitivity(labels, return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon, name) - def aggregate_across_towers(_, values): - specificity = compute_specificity_at_sensitivity( - values['tp'], values['tn'], values['fp'], values['fn'], 'value') - if metrics_collections: - ops.add_to_collections(metrics_collections, specificity) - return specificity - - specificity = distribute_lib.get_tower_context().merge_call( - aggregate_across_towers, values) - + specificity = compute_specificity_at_sensitivity( + values['tp'], values['tn'], values['fp'], values['fn'], 'value') update_op = compute_specificity_at_sensitivity( update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, specificity) + if updates_collections: ops.add_to_collections(updates_collections, update_op) |