aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2018-05-18 16:33:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-18 16:36:57 -0700
commit40f53c774e914b9166a5bc8476e290da4a121c82 (patch)
treec21c2af99a3e0cf7aedf5615ffb0e50d5085ac04
parentf4cb5978667ccf6396e4a779e3a482766959e5dd (diff)
Automated g4 rollback of changelist 197070234
PiperOrigin-RevId: 197218170
-rw-r--r--tensorflow/contrib/distribute/python/BUILD19
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py438
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/framework/test_util.py8
-rw-r--r--tensorflow/python/ops/metrics_impl.py296
5 files changed, 102 insertions, 660 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index aeeaa0b400..64a77bbed1 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -547,22 +547,3 @@ cuda_py_test(
"no_pip",
],
)
-
-cuda_py_test(
- name = "metrics_v1_test",
- srcs = ["metrics_v1_test.py"],
- additional_deps = [
- ":combinations",
- "@absl_py//absl/testing:parameterized",
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/eager:test",
- ],
- tags = [
- "multi_and_single_gpu",
- "no_pip",
- ],
-)
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
deleted file mode 100644
index 6c6bf14309..0000000000
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ /dev/null
@@ -1,438 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for V1 metrics."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.distribute.python import combinations
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.eager import test
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import metrics
-from tensorflow.python.ops import variables
-
-
-def _labeled_dataset_fn():
- # First four batches of x: labels, predictions -> (labels == predictions)
- # 0: 0, 0 -> True; 1: 1, 1 -> True; 2: 2, 2 -> True; 3: 3, 0 -> False
- # 4: 4, 1 -> False; 5: 0, 2 -> False; 6: 1, 0 -> False; 7: 2, 1 -> False
- # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False
- # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True
- return dataset_ops.Dataset.range(1000).map(
- lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4)
-
-
-def _boolean_dataset_fn():
- # First four batches of labels, predictions: {TP, FP, TN, FN}
- # with a threshold of 0.5:
- # T, T -> TP; F, T -> FP; T, F -> FN
- # F, F -> TN; T, T -> TP; F, T -> FP
- # T, F -> FN; F, F -> TN; T, T -> TP
- # F, T -> FP; T, F -> FN; F, F -> TN
- return dataset_ops.Dataset.from_tensor_slices({
- "labels": [True, False, True, False],
- "predictions": [True, True, False, False]}).repeat().batch(3)
-
-
-def _threshold_dataset_fn():
- # First four batches of labels, predictions: {TP, FP, TN, FN}
- # with a threshold of 0.5:
- # True, 1.0 -> TP; False, .75 -> FP; True, .25 -> FN
- # False, 0.0 -> TN; True, 1.0 -> TP; False, .75 -> FP
- # True, .25 -> FN; False, 0.0 -> TN; True, 1.0 -> TP
- # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN
- return dataset_ops.Dataset.from_tensor_slices({
- "labels": [True, False, True, False],
- "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3)
-
-
-def _regression_dataset_fn():
- return dataset_ops.Dataset.from_tensor_slices({
- "labels": [1., .5, 1., 0.],
- "predictions": [1., .75, .25, 0.]}).repeat()
-
-
-def all_combinations():
- return combinations.combine(
- distribution=[combinations.default_strategy,
- combinations.one_device_strategy,
- combinations.mirrored_strategy_with_gpu_and_cpu,
- combinations.mirrored_strategy_with_two_gpus],
- mode=["graph"])
-
-
-# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k,
-# metrics.precision_at_k
-class MetricsV1Test(test.TestCase, parameterized.TestCase):
-
- def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
- with ops.Graph().as_default(), distribution.scope():
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
- value, update = distribution.call_for_each_tower(
- metric_fn, iterator.get_next())
- update = distribution.group(update)
- self.evaluate(variables.local_variables_initializer())
- # TODO(josh11b): Once we switch to using a global batch size for input,
- # replace "distribution.num_towers" with "1".
- batches_per_update = distribution.num_towers
-
- # Update variables using the first `num_towers` batches.
- self.evaluate(update)
- self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value),
- 0.001, msg="After first update")
-
- # Update variables using the second `num_towers` batches.
- self.evaluate(update)
- self.assertAllClose(expected_fn(2 * batches_per_update),
- self.evaluate(value),
- 0.001,
- msg="After second update")
-
- if batches_per_update == 1: # Consume 4 input batches
- self.evaluate(update)
- self.assertAllClose(expected_fn(3 * batches_per_update),
- self.evaluate(value),
- 0.001,
- msg="After third update")
- self.evaluate(update)
- self.assertAllClose(expected_fn(4 * batches_per_update),
- self.evaluate(value),
- 0.001,
- msg="After fourth update")
-
- @combinations.generate(all_combinations())
- def testMean(self, distribution):
- def _dataset_fn():
- return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4)
-
- def _expected_fn(num_batches):
- # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc.
- return num_batches * 2 - 0.5
-
- self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testAccuracy(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.accuracy(labels, predictions)
-
- def _expected_fn(num_batches):
- return [3./4, 3./8, 3./12, 4./16][num_batches - 1]
-
- self._test_metric(
- distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testMeanPerClassAccuracy(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.mean_per_class_accuracy(
- labels, predictions, num_classes=5)
-
- def _expected_fn(num_batches):
- mean = lambda x: sum(x) / len(x)
- return [mean([1., 1., 1., 0., 0.]),
- mean([0.5, 0.5, 0.5, 0., 0.]),
- mean([1./3, 1./3, 0.5, 0., 0.]),
- mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1]
-
- self._test_metric(
- distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testMeanIOU(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.mean_iou(
- labels, predictions, num_classes=5)
-
- def _expected_fn(num_batches):
- mean = lambda x: sum(x) / len(x)
- return [mean([1./2, 1./1, 1./1, 0.]), # no class 4 in first batch
- mean([1./4, 1./4, 1./3, 0., 0.]),
- mean([1./6, 1./6, 1./5, 0., 0.]),
- mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1]
-
- self._test_metric(
- distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testMeanTensor(self, distribution):
- def _dataset_fn():
- dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float)
- # Want to produce a fixed, known shape, so drop remainder when batching.
- dataset = dataset.apply(batching.batch_and_drop_remainder(4))
- return dataset
-
- def _expected_fn(num_batches):
- # Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2
- # Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1
- # Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches
- # Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1
- first = 2. * num_batches - 2.
- return [first, first + 1., first + 2., first + 3.]
-
- self._test_metric(
- distribution, _dataset_fn, metrics.mean_tensor, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testAUCROC(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC",
- summation_method="careful_interpolation")
-
- def _expected_fn(num_batches):
- return [0.5, 7./9, 0.8, 0.75][num_batches - 1]
-
- self._test_metric(
- distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testAUCPR(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.auc(labels, predictions, num_thresholds=8, curve="PR",
- summation_method="careful_interpolation")
-
- def _expected_fn(num_batches):
- return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1]
-
- self._test_metric(
- distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testFalseNegatives(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.false_negatives(labels, predictions)
-
- def _expected_fn(num_batches):
- return [1., 1., 2., 3.][num_batches - 1]
-
- self._test_metric(
- distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testFalseNegativesAtThresholds(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.false_negatives_at_thresholds(labels, predictions, [.5])
-
- def _expected_fn(num_batches):
- return [[1.], [1.], [2.], [3.]][num_batches - 1]
-
- self._test_metric(
- distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testTrueNegatives(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.true_negatives(labels, predictions)
-
- def _expected_fn(num_batches):
- return [0., 1., 2., 3.][num_batches - 1]
-
- self._test_metric(
- distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testTrueNegativesAtThresholds(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.true_negatives_at_thresholds(labels, predictions, [.5])
-
- def _expected_fn(num_batches):
- return [[0.], [1.], [2.], [3.]][num_batches - 1]
-
- self._test_metric(
- distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testFalsePositives(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.false_positives(labels, predictions)
-
- def _expected_fn(num_batches):
- return [1., 2., 2., 3.][num_batches - 1]
-
- self._test_metric(
- distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testFalsePositivesAtThresholds(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.false_positives_at_thresholds(labels, predictions, [.5])
-
- def _expected_fn(num_batches):
- return [[1.], [2.], [2.], [3.]][num_batches - 1]
-
- self._test_metric(
- distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testTruePositives(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.true_positives(labels, predictions)
-
- def _expected_fn(num_batches):
- return [1., 2., 3., 3.][num_batches - 1]
-
- self._test_metric(
- distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testTruePositivesAtThresholds(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.true_positives_at_thresholds(labels, predictions, [.5])
-
- def _expected_fn(num_batches):
- return [[1.], [2.], [3.], [3.]][num_batches - 1]
-
- self._test_metric(
- distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testPrecision(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.precision(labels, predictions)
-
- def _expected_fn(num_batches):
- return [0.5, 0.5, 0.6, 0.5][num_batches - 1]
-
- self._test_metric(
- distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testPrecisionAtThreshold(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.precision_at_thresholds(labels, predictions, [0.5])
-
- def _expected_fn(num_batches):
- return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1]
-
- self._test_metric(
- distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testRecall(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.recall(labels, predictions)
-
- def _expected_fn(num_batches):
- return [0.5, 2./3, 0.6, 0.5][num_batches - 1]
-
- self._test_metric(
- distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testRecallAtThreshold(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.recall_at_thresholds(labels, predictions, [0.5])
-
- def _expected_fn(num_batches):
- return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1]
-
- self._test_metric(
- distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testMeanSquaredError(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.mean_squared_error(labels, predictions)
-
- def _expected_fn(num_batches):
- return [0., 1./32, 0.208333, 0.15625][num_batches - 1]
-
- self._test_metric(
- distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testRootMeanSquaredError(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.root_mean_squared_error(labels, predictions)
-
- def _expected_fn(num_batches):
- return [0., 0.176777, 0.456435, 0.395285][num_batches - 1]
-
- self._test_metric(
- distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testSensitivityAtSpecificity(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.sensitivity_at_specificity(labels, predictions, 0.8)
-
- def _expected_fn(num_batches):
- return [0.5, 2./3, 0.6, 0.5][num_batches - 1]
-
- self._test_metric(
- distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
-
- @combinations.generate(all_combinations())
- def testSpecificityAtSensitivity(self, distribution):
- def _metric_fn(x):
- labels = x["labels"]
- predictions = x["predictions"]
- return metrics.specificity_at_sensitivity(labels, predictions, 0.95)
-
- def _expected_fn(num_batches):
- return [0., 1./3, 0.5, 0.5][num_batches - 1]
-
- self._test_metric(
- distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index cb722485f8..f714d1fb21 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2461,7 +2461,6 @@ py_library(
":check_ops",
":confusion_matrix",
":control_flow_ops",
- ":distribute",
":framework",
":framework_for_generated_wrappers",
":math_ops",
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index bf382a2cbf..97cd22e47a 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -1328,11 +1328,11 @@ class TensorFlowTestCase(googletest.TestCase):
b,
rtol=rtol,
atol=atol,
- msg=("Mismatched value: a%s is different from b%s. %s" %
- (path_str, path_str, msg)))
+ msg="Mismatched value: a%s is different from b%s." % (path_str,
+ path_str))
except TypeError as e:
- msg = ("Error: a%s has %s, but b%s has %s. %s" %
- (path_str, type(a), path_str, type(b), msg))
+ msg = "Error: a%s has %s, but b%s has %s" % (path_str, type(a),
+ path_str, type(b))
e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
raise
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 244e28d306..47eea6ef6b 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -34,54 +34,21 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
def metric_variable(shape, dtype, validate_shape=True, name=None):
- """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections.
-
- If running in a `DistributionStrategy` context, the variable will be
- "tower local". This means:
-
- * The returned object will be a container with separate variables
- per replica/tower of the model.
-
- * When writing to the variable, e.g. using `assign_add` in a metric
- update, the update will be applied to the variable local to the
- replica/tower.
-
- * To get a metric's result value, we need to sum the variable values
- across the replicas/towers before computing the final answer.
- Furthermore, the final answer should be computed once instead of
- in every replica/tower. Both of these are accomplished by
- running the computation of the final result value inside
- `tf.contrib.distribute.get_tower_context().merge_call(fn)`.
- Inside the `merge_call()`, ops are only added to the graph once
- and access to a tower-local variable in a computation returns
- the sum across all replicas/towers.
-
- Args:
- shape: Shape of the created variable.
- dtype: Type of the created variable.
- validate_shape: (Optional) Whether shape validation is enabled for
- the created variable.
- name: (Optional) String name of the created variable.
-
- Returns:
- A (non-trainable) variable initialized to zero, or if inside a
- `DistributionStrategy` scope a tower-local variable container.
- """
- with distribute_lib.get_tower_context().tower_local_var_scope('sum'):
- # Note that "tower local" implies trainable=False.
- return variable_scope.variable(
- lambda: array_ops.zeros(shape, dtype),
- collections=[
- ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
- ],
- validate_shape=validate_shape,
- name=name)
+ """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections."""
+
+ return variable_scope.variable(
+ lambda: array_ops.zeros(shape, dtype),
+ trainable=False,
+ collections=[
+ ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
+ ],
+ validate_shape=validate_shape,
+ name=name)
def _remove_squeezable_dimensions(predictions, labels, weights):
@@ -366,16 +333,12 @@ def mean(values,
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- def aggregate_across_towers(_, t, c):
- mean_t = _safe_div(t, c, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_t)
- return mean_t
-
- mean_t = distribute_lib.get_tower_context().merge_call(
- aggregate_across_towers, total, count)
+ mean_t = _safe_div(total, count, 'value')
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_t)
+
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -609,17 +572,6 @@ def _confusion_matrix_at_thresholds(labels,
return values, update_ops
-def _aggregate_variable(v, collections):
-
- def f(distribution, value):
- value = distribution.fetch(value)
- if collections:
- ops.add_to_collections(collections, value)
- return value
-
- return distribute_lib.get_tower_context().merge_call(f, v)
-
-
@tf_export('metrics.auc')
def auc(labels,
predictions,
@@ -805,18 +757,14 @@ def auc(labels,
raise ValueError('Invalid summation_method: %s' % summation_method)
# sum up the areas of all the trapeziums
- def aggregate_auc(_, values):
- auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
- values['fp'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, auc_value)
- return auc_value
-
- auc_value = distribute_lib.get_tower_context().merge_call(
- aggregate_auc, values)
+ auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
+ values['fp'], 'value')
update_op = compute_auc(update_ops['tp'], update_ops['fn'],
update_ops['tn'], update_ops['fp'], 'update_op')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, auc_value)
+
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -1044,18 +992,15 @@ def mean_per_class_accuracy(labels,
update_total_op = state_ops.scatter_add(total, labels, ones)
update_count_op = state_ops.scatter_add(count, labels, is_correct)
- def aggregate_mean_accuracy(_, count, total):
- per_class_accuracy = _safe_div(count, total, None)
- mean_accuracy_v = math_ops.reduce_mean(
- per_class_accuracy, name='mean_accuracy')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_accuracy_v)
- return mean_accuracy_v
-
- mean_accuracy_v = distribute_lib.get_tower_context().merge_call(
- aggregate_mean_accuracy, count, total)
+ per_class_accuracy = _safe_div(count, total, None)
+ mean_accuracy_v = math_ops.reduce_mean(
+ per_class_accuracy, name='mean_accuracy')
update_op = _safe_div(update_count_op, update_total_op, name='update_op')
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_accuracy_v)
+
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -1126,7 +1071,7 @@ def mean_iou(labels,
total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
num_classes, weights)
- def compute_mean_iou(total_cm, name):
+ def compute_mean_iou(name):
"""Compute the mean intersection-over-union via the confusion matrix."""
sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
@@ -1153,14 +1098,10 @@ def mean_iou(labels,
math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0)
return result
- def mean_iou_across_towers(_, v):
- mean_iou_v = compute_mean_iou(v, 'mean_iou')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_iou_v)
- return mean_iou_v
+ mean_iou_v = compute_mean_iou('mean_iou')
- mean_iou_v = distribute_lib.get_tower_context().merge_call(
- mean_iou_across_towers, total_cm)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_iou_v)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -1369,16 +1310,12 @@ def mean_tensor(values,
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- def aggregate_across_towers(_, t, c):
- mean_t = _safe_div(t, c, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_t)
- return mean_t
+ mean_t = _safe_div(total, count, 'value')
+ update_op = _safe_div(update_total_op, update_count_op, 'update_op')
- mean_t = distribute_lib.get_tower_context().merge_call(
- aggregate_across_towers, total, count)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_t)
- update_op = _safe_div(update_total_op, update_count_op, 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -1476,9 +1413,12 @@ def _count_condition(values,
weights = math_ops.to_float(weights)
values = math_ops.multiply(values, weights)
- value_tensor = _aggregate_variable(count, metrics_collections)
-
+ value_tensor = array_ops.identity(count)
update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, value_tensor)
+
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -1585,12 +1525,13 @@ def false_negatives_at_thresholds(labels,
values, update_ops = _confusion_matrix_at_thresholds(
labels, predictions, thresholds, weights=weights, includes=('fn',))
- fn_value = _aggregate_variable(values['fn'], metrics_collections)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, values['fn'])
if updates_collections:
ops.add_to_collections(updates_collections, update_ops['fn'])
- return fn_value, update_ops['fn']
+ return values['fn'], update_ops['fn']
@tf_export('metrics.false_positives')
@@ -1694,12 +1635,13 @@ def false_positives_at_thresholds(labels,
values, update_ops = _confusion_matrix_at_thresholds(
labels, predictions, thresholds, weights=weights, includes=('fp',))
- fp_value = _aggregate_variable(values['fp'], metrics_collections)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, values['fp'])
if updates_collections:
ops.add_to_collections(updates_collections, update_ops['fp'])
- return fp_value, update_ops['fp']
+ return values['fp'], update_ops['fp']
@tf_export('metrics.true_negatives')
@@ -1803,12 +1745,13 @@ def true_negatives_at_thresholds(labels,
values, update_ops = _confusion_matrix_at_thresholds(
labels, predictions, thresholds, weights=weights, includes=('tn',))
- tn_value = _aggregate_variable(values['tn'], metrics_collections)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, values['tn'])
if updates_collections:
ops.add_to_collections(updates_collections, update_ops['tn'])
- return tn_value, update_ops['tn']
+ return values['tn'], update_ops['tn']
@tf_export('metrics.true_positives')
@@ -1912,12 +1855,13 @@ def true_positives_at_thresholds(labels,
values, update_ops = _confusion_matrix_at_thresholds(
labels, predictions, thresholds, weights=weights, includes=('tp',))
- tp_value = _aggregate_variable(values['tp'], metrics_collections)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, values['tp'])
if updates_collections:
ops.add_to_collections(updates_collections, update_ops['tp'])
- return tp_value, update_ops['tp']
+ return values['tp'], update_ops['tp']
@tf_export('metrics.precision')
@@ -2001,17 +1945,13 @@ def precision(labels,
return array_ops.where(
math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
- def once_across_towers(_, true_p, false_p):
- p = compute_precision(true_p, false_p, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, p)
- return p
-
- p = distribute_lib.get_tower_context().merge_call(
- once_across_towers, true_p, false_p)
-
+ p = compute_precision(true_p, false_p, 'value')
update_op = compute_precision(true_positives_update_op,
false_positives_update_op, 'update_op')
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, p)
+
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -2085,17 +2025,13 @@ def precision_at_thresholds(labels,
def compute_precision(tp, fp, name):
return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
- def precision_across_towers(_, values):
- prec = compute_precision(values['tp'], values['fp'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, prec)
- return prec
-
- prec = distribute_lib.get_tower_context().merge_call(
- precision_across_towers, values)
-
+ prec = compute_precision(values['tp'], values['fp'], 'value')
update_op = compute_precision(update_ops['tp'], update_ops['fp'],
'update_op')
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, prec)
+
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -2114,7 +2050,7 @@ def recall(labels,
The `recall` function creates two local variables, `true_positives`
and `false_negatives`, that are used to compute the recall. This value is
ultimately returned as `recall`, an idempotent operation that simply divides
- `true_positives` by the sum of `true_positives` and `false_negatives`.
+ `true_positives` by the sum of `true_positives` and `false_negatives`.
For estimation of the metric over a stream of data, the function creates an
`update_op` that updates these variables and returns the `recall`. `update_op`
@@ -2181,17 +2117,13 @@ def recall(labels,
math_ops.greater(true_p + false_n, 0),
math_ops.div(true_p, true_p + false_n), 0, name)
- def once_across_towers(_, true_p, false_n):
- rec = compute_recall(true_p, false_n, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rec)
- return rec
-
- rec = distribute_lib.get_tower_context().merge_call(
- once_across_towers, true_p, false_n)
-
+ rec = compute_recall(true_p, false_n, 'value')
update_op = compute_recall(true_positives_update_op,
false_negatives_update_op, 'update_op')
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, rec)
+
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -2620,17 +2552,11 @@ def recall_at_top_k(labels,
class_id=class_id,
weights=weights)
- def aggregate_across_towers(_, tp, fn):
- metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
- return metric
-
- metric = distribute_lib.get_tower_context().merge_call(
- aggregate_across_towers, tp, fn)
-
+ metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
update = math_ops.div(
tp_update, math_ops.add(tp_update, fn_update), name='update')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, metric)
if updates_collections:
ops.add_to_collections(updates_collections, update)
return metric, update
@@ -2701,16 +2627,12 @@ def recall_at_thresholds(labels,
def compute_recall(tp, fn, name):
return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
- def recall_across_towers(_, values):
- rec = compute_recall(values['tp'], values['fn'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rec)
- return rec
+ rec = compute_recall(values['tp'], values['fn'], 'value')
+ update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
- rec = distribute_lib.get_tower_context().merge_call(
- recall_across_towers, values)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, rec)
- update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -2776,16 +2698,13 @@ def root_mean_squared_error(labels,
mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
None, name or
'root_mean_squared_error')
- def once_across_towers(_, mse):
- rmse = math_ops.sqrt(mse)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rmse)
- return rmse
-
- rmse = distribute_lib.get_tower_context().merge_call(
- once_across_towers, mse)
+ rmse = math_ops.sqrt(mse)
update_rmse_op = math_ops.sqrt(update_mse_op)
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, rmse)
+
if updates_collections:
ops.add_to_collections(updates_collections, update_rmse_op)
@@ -2878,19 +2797,15 @@ def sensitivity_at_specificity(labels,
return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
name)
- def aggregate_across_towers(_, values):
- sensitivity = compute_sensitivity_at_specificity(
- values['tp'], values['tn'], values['fp'], values['fn'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, sensitivity)
- return sensitivity
-
- sensitivity = distribute_lib.get_tower_context().merge_call(
- aggregate_across_towers, values)
-
+ sensitivity = compute_sensitivity_at_specificity(
+ values['tp'], values['tn'], values['fp'], values['fn'], 'value')
update_op = compute_sensitivity_at_specificity(
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
'update_op')
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, sensitivity)
+
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -3155,16 +3070,11 @@ def _streaming_sparse_average_precision_at_top_k(labels,
total_update = state_ops.assign_add(total_var, batch_total, name='update')
# Divide total by max to get mean, for both vars and the update ops.
- def aggregate_across_towers(_, total_var, max_var):
- mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_average_precision)
- return mean_average_precision
-
- mean_average_precision = distribute_lib.get_tower_context().merge_call(
- aggregate_across_towers, total_var, max_var)
-
+ mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
update = _safe_scalar_div(total_update, max_update, name=scope)
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_average_precision)
if updates_collections:
ops.add_to_collections(updates_collections, update)
@@ -3441,17 +3351,11 @@ def precision_at_top_k(labels,
class_id=class_id,
weights=weights)
- def aggregate_across_towers(_, tp, fp):
- metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
- return metric
-
- metric = distribute_lib.get_tower_context().merge_call(
- aggregate_across_towers, tp, fp)
-
+ metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
update = math_ops.div(
tp_update, math_ops.add(tp_update, fp_update), name='update')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, metric)
if updates_collections:
ops.add_to_collections(updates_collections, update)
return metric, update
@@ -3679,19 +3583,15 @@ def specificity_at_sensitivity(labels,
return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
name)
- def aggregate_across_towers(_, values):
- specificity = compute_specificity_at_sensitivity(
- values['tp'], values['tn'], values['fp'], values['fn'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, specificity)
- return specificity
-
- specificity = distribute_lib.get_tower_context().merge_call(
- aggregate_across_towers, values)
-
+ specificity = compute_specificity_at_sensitivity(
+ values['tp'], values['tn'], values['fp'], values['fn'], 'value')
update_op = compute_specificity_at_sensitivity(
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
'update_op')
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, specificity)
+
if updates_collections:
ops.add_to_collections(updates_collections, update_op)