aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-22 19:27:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-22 20:33:29 -0700
commitcf7230b8e6bc5c8ae80d64e8f9eda03cae437da1 (patch)
tree51502e327d7c67bae9978621db41bff8816c000c
parent8ec747f1cd9ceb2c585addc5529c2946c4a09c54 (diff)
Modify metric_ops.remove_squeezable_dimensions to handle statically unknown ranks.
Add tests to assert all update_op tensors returned from metric_ops produce the proper metric value, and fix streaming_auc to do so. Change: 125639012
-rw-r--r--tensorflow/contrib/metrics/python/kernel_tests/confusion_matrix_ops_test.py82
-rw-r--r--tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py13
-rw-r--r--tensorflow/contrib/metrics/python/ops/histogram_ops.py3
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py99
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py265
5 files changed, 242 insertions, 220 deletions
diff --git a/tensorflow/contrib/metrics/python/kernel_tests/confusion_matrix_ops_test.py b/tensorflow/contrib/metrics/python/kernel_tests/confusion_matrix_ops_test.py
index 638010806a..5812b6547d 100644
--- a/tensorflow/contrib/metrics/python/kernel_tests/confusion_matrix_ops_test.py
+++ b/tensorflow/contrib/metrics/python/kernel_tests/confusion_matrix_ops_test.py
@@ -34,42 +34,42 @@ class ConfusionMatrixTest(tf.test.TestCase):
labels = np.arange(5, dtype=dtype)
truth = np.asarray(
- [[1, 0, 0, 0, 0],
- [0, 1, 0, 0, 0],
- [0, 0, 1, 0, 0],
- [0, 0, 0, 1, 0],
- [0, 0, 0, 0, 1]],
- dtype=dtype)
+ [[1, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 0],
+ [0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 1]],
+ dtype=dtype)
self._testConfMatrix(
- predictions=predictions,
- labels=labels,
- truth=truth)
+ predictions=predictions,
+ labels=labels,
+ truth=truth)
- def testInt32Basic(self, dtype=np.int32):
- self._testBasic(dtype)
+ def testInt32Basic(self):
+ self._testBasic(dtype=np.int32)
- def testInt64Basic(self, dtype=np.int64):
- self._testBasic(dtype)
+ def testInt64Basic(self):
+ self._testBasic(dtype=np.int64)
def _testDiffentLabelsInPredictionAndTarget(self, dtype):
predictions = np.asarray([1, 2, 3], dtype=dtype)
labels = np.asarray([4, 5, 6], dtype=dtype)
truth = np.asarray(
- [[0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 1, 0, 0],
- [0, 0, 0, 0, 0, 1, 0],
- [0, 0, 0, 0, 0, 0, 1],
- [0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 0]],
- dtype=dtype)
+ [[0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0]],
+ dtype=dtype)
self._testConfMatrix(
- predictions=predictions,
- labels=labels,
- truth=truth)
+ predictions=predictions,
+ labels=labels,
+ truth=truth)
def testInt32DifferentLabels(self, dtype=np.int32):
self._testDiffentLabelsInPredictionAndTarget(dtype)
@@ -82,19 +82,19 @@ class ConfusionMatrixTest(tf.test.TestCase):
labels = np.asarray([1, 1, 2, 3, 5, 1, 3, 6, 3, 1], dtype=dtype)
truth = np.asarray(
- [[0, 0, 0, 0, 0, 0, 0],
- [0, 2, 0, 1, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 1],
- [0, 0, 0, 2, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 1, 0],
- [0, 1, 0, 0, 0, 0, 0]],
- dtype=dtype)
+ [[0, 0, 0, 0, 0, 0, 0],
+ [0, 2, 0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 1],
+ [0, 0, 0, 2, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 0],
+ [0, 1, 0, 0, 0, 0, 0]],
+ dtype=dtype)
self._testConfMatrix(
- predictions=predictions,
- labels=labels,
- truth=truth)
+ predictions=predictions,
+ labels=labels,
+ truth=truth)
def testInt32MultipleLabels(self, dtype=np.int32):
self._testMultipleLabels(dtype)
@@ -106,22 +106,22 @@ class ConfusionMatrixTest(tf.test.TestCase):
predictions = np.asarray([[1, 2, 3]])
labels = np.asarray([1, 2, 3])
self.assertRaisesRegexp(
- ValueError, "are not compatible",
+ ValueError, "an not squeeze dim",
tf.contrib.metrics.confusion_matrix, predictions, labels)
predictions = np.asarray([1, 2, 3])
labels = np.asarray([[1, 2, 3]])
self.assertRaisesRegexp(
- ValueError, "are not compatible",
+ ValueError, "an not squeeze dim",
tf.contrib.metrics.confusion_matrix, predictions, labels)
def testInputDifferentSize(self):
predictions = np.asarray([1, 2, 3])
labels = np.asarray([1, 2])
self.assertRaisesRegexp(
- ValueError,
- "are not compatible",
- tf.contrib.metrics.confusion_matrix, predictions, labels)
+ ValueError, "are not compatible",
+ tf.contrib.metrics.confusion_matrix, predictions, labels)
+
-if __name__ == '__main__':
+if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py b/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py
index d310e55170..73c515d76b 100644
--- a/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.metrics.python.ops import metric_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -25,11 +26,8 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
-"""Confusion matrix related metrics."""
-
-
def confusion_matrix(predictions, labels, num_classes=None, name=None):
- """Computes the confusion matrix from predictions and labels
+ """Computes the confusion matrix from predictions and labels.
Calculate the Confusion Matrix for a pair of prediction and
label 1-D int arrays.
@@ -68,9 +66,10 @@ def confusion_matrix(predictions, labels, num_classes=None, name=None):
"""
with ops.op_scope([predictions, labels, num_classes], name,
'confusion_matrix') as name:
- predictions = ops.convert_to_tensor(
- predictions, name='predictions', dtype=dtypes.int64)
- labels = ops.convert_to_tensor(labels, name='labels', dtype=dtypes.int64)
+ predictions, labels = metric_ops.remove_squeezable_dimensions(
+ ops.convert_to_tensor(
+ predictions, name='predictions', dtype=dtypes.int64),
+ ops.convert_to_tensor(labels, name='labels', dtype=dtypes.int64))
if num_classes is None:
num_classes = math_ops.maximum(math_ops.reduce_max(predictions),
diff --git a/tensorflow/contrib/metrics/python/ops/histogram_ops.py b/tensorflow/contrib/metrics/python/ops/histogram_ops.py
index 60c26149ea..11b9602a2e 100644
--- a/tensorflow/contrib/metrics/python/ops/histogram_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/histogram_ops.py
@@ -23,6 +23,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.metrics.python.ops import metric_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -78,6 +79,8 @@ def auc_using_histogram(boolean_labels,
collections = [ops.GraphKeys.LOCAL_VARIABLES]
with variable_scope.variable_op_scope(
[boolean_labels, scores, score_range], name, 'auc_using_histogram'):
+ scores, boolean_labels = metric_ops.remove_squeezable_dimensions(
+ scores, boolean_labels)
score_range = ops.convert_to_tensor(score_range, name='score_range')
boolean_labels, scores = _check_labels_and_scores(
boolean_labels, scores, check_shape)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index e9ed4a1799..92ee12dff2 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -1,3 +1,4 @@
+# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -28,6 +29,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import sparse_ops
@@ -75,14 +77,48 @@ def _create_local(name, shape=None, collections=None):
collections=collections)
-def _remove_squeezable_dimensions(predictions, labels):
- predictions_rank = predictions.get_shape().ndims
- labels_rank = labels.get_shape().ndims
- if not (labels_rank is None or predictions_rank is None):
- if labels_rank == (predictions_rank + 1):
+def remove_squeezable_dimensions(predictions, labels):
+ """Squeeze last dim if ranks of `predictions` and `labels` differ by 1.
+
+ This will use static shape if available. Otherwise, it will add graph
+ operations, which could result in a performance hit.
+
+ Args:
+ predictions: Predicted values, a `Tensor` of arbitrary dimensions.
+ labels: Label values, a `Tensor` whose dimensions match `predictions`.
+
+ Returns:
+ Tuple of `predictions` and `labels`, possibly with last dim squeezed.
+ """
+ predictions = ops.convert_to_tensor(predictions)
+ labels = ops.convert_to_tensor(labels)
+ predictions_shape = predictions.get_shape()
+ predictions_rank = predictions_shape.ndims
+ labels_shape = labels.get_shape()
+ labels_rank = labels_shape.ndims
+ if (labels_rank is not None) and (predictions_rank is not None):
+ # Use static rank.
+ rank_diff = predictions_rank - labels_rank
+ if rank_diff == -1:
labels = array_ops.squeeze(labels, [-1])
- elif predictions_rank == (labels_rank + 1):
+ elif rank_diff == 1:
predictions = array_ops.squeeze(predictions, [-1])
+ return predictions, labels
+
+ # Use dynamic rank.
+ rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
+ if (predictions_rank is None) or (
+ predictions_shape.dims[-1].is_compatible_with(1)):
+ predictions = control_flow_ops.cond(
+ math_ops.equal(1, rank_diff),
+ lambda: array_ops.squeeze(predictions, [-1]),
+ lambda: predictions)
+ if (labels_rank is None) or (
+ labels_shape.dims[-1].is_compatible_with(1)):
+ labels = control_flow_ops.cond(
+ math_ops.equal(-1, rank_diff),
+ lambda: array_ops.squeeze(labels, [-1]),
+ lambda: labels)
return predictions, labels
@@ -357,7 +393,7 @@ def streaming_accuracy(predictions, labels, weights=None,
if either `metrics_collections` or `updates_collections` are not
a list or tuple.
"""
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
return streaming_mean(is_correct, weights, metrics_collections,
@@ -412,7 +448,7 @@ def streaming_precision(predictions, labels, ignore_mask=None,
with variable_scope.variable_op_scope(
[predictions, labels], name, 'precision'):
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
true_positives, true_positives_update_op = _streaming_true_positives(
@@ -489,7 +525,7 @@ def streaming_recall(predictions, labels, ignore_mask=None,
or tuple.
"""
with variable_scope.variable_op_scope([predictions, labels], name, 'recall'):
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
true_positives, true_positives_update_op = _streaming_true_positives(
@@ -567,7 +603,7 @@ def _tp_fn_tn_fp(predictions, labels, thresholds, ignore_mask=None):
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
"""
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
num_thresholds = len(thresholds)
@@ -712,33 +748,24 @@ def streaming_auc(predictions, labels, ignore_mask=None, num_thresholds=200,
for i in range(num_thresholds-2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
- (true_positives, false_negatives, true_negatives, false_positives,
- true_positives_compute_op, false_negatives_compute_op,
- true_negatives_compute_op, false_positives_compute_op) = _tp_fn_tn_fp(
- predictions, labels, thresholds, ignore_mask)
+ (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op,
+ fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, ignore_mask)
- epsilon = 1.0e-6
- assert array_ops.squeeze(
- false_positives).get_shape().as_list()[0] == num_thresholds
# Add epsilons to avoid dividing by 0.
- false_positive_rate = math_ops.div(
- false_positives,
- false_positives + true_negatives + epsilon)
- recall = math_ops.div(true_positives + epsilon,
- true_positives + false_negatives + epsilon)
+ epsilon = 1.0e-6
+ assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds
- def compute_auc(name):
+ def compute_auc(tp, fn, tn, fp, name):
+ fp_rate = math_ops.div(fp, fp + tn + epsilon)
+ recall = math_ops.div(tp + epsilon, tp + fn + epsilon)
return math_ops.reduce_sum(math_ops.mul(
- false_positive_rate[:num_thresholds - 1] - false_positive_rate[1:],
+ fp_rate[:num_thresholds - 1] - fp_rate[1:],
(recall[:num_thresholds - 1] + recall[1:]) / 2.), name=name)
# sum up the areas of all the trapeziums
- auc = compute_auc('value')
- with ops.control_dependencies([true_positives_compute_op,
- false_negatives_compute_op,
- true_negatives_compute_op,
- false_positives_compute_op]):
- update_op = compute_auc('update_op')
+ auc = compute_auc(tp, fn, tn, fp, 'value')
+ update_op = compute_auc(
+ tp_update_op, fn_update_op, tn_update_op, fp_update_op, 'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, auc)
@@ -1367,7 +1394,7 @@ def streaming_mean_absolute_error(predictions, labels, weights=None,
`predictions` or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
absolute_errors = math_ops.abs(predictions - labels)
return streaming_mean(absolute_errors, weights, metrics_collections,
@@ -1419,10 +1446,10 @@ def streaming_mean_relative_error(predictions, labels, normalizer, weights=None,
`predictions` or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
- predictions, normalizer = _remove_squeezable_dimensions(
+ predictions, normalizer = remove_squeezable_dimensions(
predictions, normalizer)
predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
relative_errors = math_ops.select(
@@ -1477,7 +1504,7 @@ def streaming_mean_squared_error(predictions, labels, weights=None,
`predictions` or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
squared_error = math_ops.square(labels - predictions)
return streaming_mean(squared_error, weights, metrics_collections,
@@ -1528,7 +1555,7 @@ def streaming_root_mean_squared_error(predictions, labels, weights=None,
`predictions` or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
value_tensor, update_op = streaming_mean_squared_error(
predictions, labels, weights, None, None,
@@ -1593,7 +1620,7 @@ def streaming_mean_cosine_distance(predictions, labels, dim, weights=None,
ignore_mask is of the wrong size or if either `metrics_collections` or
`updates_collections` are not a list or tuple.
"""
- predictions, labels = _remove_squeezable_dimensions(predictions, labels)
+ predictions, labels = remove_squeezable_dimensions(predictions, labels)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
radial_diffs = math_ops.mul(predictions, labels)
radial_diffs = math_ops.reduce_sum(radial_diffs,
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 93352940f9..ea2adc2fbf 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -1,3 +1,4 @@
+# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -23,7 +24,7 @@ import math
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
-
+from tensorflow.contrib.metrics.python.ops import metric_ops
NAN = float('nan')
@@ -136,6 +137,108 @@ def _binary_3d_label_to_sparse(labels):
tf.constant(v.shape, tf.int64))
+class RemoveSqueezableDimensionsTest(tf.test.TestCase):
+
+ def testRemoveSqueezableDimensions(self):
+ self._testRemoveSqueezableDimensions(
+ predictions_have_static_shape=False, predictions_have_extra_dim=False,
+ labels_have_static_shape=False, labels_have_extra_dim=False)
+
+ def testRemoveSqueezableDimensions_extraLabelDim(self):
+ self._testRemoveSqueezableDimensions(
+ predictions_have_static_shape=False, predictions_have_extra_dim=False,
+ labels_have_static_shape=False, labels_have_extra_dim=True)
+
+ def testRemoveSqueezableDimensions_staticLabel(self):
+ self._testRemoveSqueezableDimensions(
+ predictions_have_static_shape=False, predictions_have_extra_dim=False,
+ labels_have_static_shape=True, labels_have_extra_dim=False)
+
+ def testRemoveSqueezableDimensions_staticLabel_extraLabelDim(self):
+ self._testRemoveSqueezableDimensions(
+ predictions_have_static_shape=False, predictions_have_extra_dim=False,
+ labels_have_static_shape=True, labels_have_extra_dim=True)
+
+ def testRemoveSqueezableDimensions_extraPredictionDim(self):
+ self._testRemoveSqueezableDimensions(
+ predictions_have_static_shape=False, predictions_have_extra_dim=True,
+ labels_have_static_shape=False, labels_have_extra_dim=False)
+
+ def testRemoveSqueezableDimensions_extraPredictionDim_staticLabel(self):
+ self._testRemoveSqueezableDimensions(
+ predictions_have_static_shape=False, predictions_have_extra_dim=True,
+ labels_have_static_shape=True, labels_have_extra_dim=False)
+
+ def testRemoveSqueezableDimensions_staticPrediction(self):
+ self._testRemoveSqueezableDimensions(
+ predictions_have_static_shape=True, predictions_have_extra_dim=False,
+ labels_have_static_shape=False, labels_have_extra_dim=False)
+
+ def testRemoveSqueezableDimensions_staticPrediction_extraLabelDim(self):
+ self._testRemoveSqueezableDimensions(
+ predictions_have_static_shape=True, predictions_have_extra_dim=False,
+ labels_have_static_shape=False, labels_have_extra_dim=True)
+
+ def testRemoveSqueezableDimensions_static(self):
+ self._testRemoveSqueezableDimensions(
+ predictions_have_static_shape=True, predictions_have_extra_dim=False,
+ labels_have_static_shape=True, labels_have_extra_dim=False)
+
+ def testRemoveSqueezableDimensions_static_extraLabelDim(self):
+ self._testRemoveSqueezableDimensions(
+ predictions_have_static_shape=True, predictions_have_extra_dim=False,
+ labels_have_static_shape=True, labels_have_extra_dim=True)
+
+ def testRemoveSqueezableDimensions_staticPrediction_extraPredictionDim(self):
+ self._testRemoveSqueezableDimensions(
+ predictions_have_static_shape=True, predictions_have_extra_dim=True,
+ labels_have_static_shape=False, labels_have_extra_dim=False)
+
+ def testRemoveSqueezableDimensions_static_extraPredictionDim(self):
+ self._testRemoveSqueezableDimensions(
+ predictions_have_static_shape=True, predictions_have_extra_dim=True,
+ labels_have_static_shape=True, labels_have_extra_dim=False)
+
+ # TODO(ptucker): Replace this with parameterized test.
+ def _testRemoveSqueezableDimensions(
+ self,
+ predictions_have_static_shape,
+ predictions_have_extra_dim,
+ labels_have_static_shape,
+ labels_have_extra_dim):
+ assert not (predictions_have_extra_dim and labels_have_extra_dim)
+ predictions_value = (0, 1, 1, 0, 0, 1, 0)
+ labels_value = (0, 0, 1, 1, 0, 0, 0)
+
+ input_predictions_value = (
+ [[p] for p in predictions_value] if predictions_have_extra_dim else
+ predictions_value)
+ input_labels_value = (
+ [[l] for l in labels_value] if labels_have_extra_dim else labels_value)
+
+ with tf.Graph().as_default() as g:
+ feed_dict = {}
+ if predictions_have_static_shape:
+ predictions = tf.constant(input_predictions_value, dtype=tf.int32)
+ else:
+ predictions = tf.placeholder(dtype=tf.int32, name='predictions')
+ feed_dict[predictions] = input_predictions_value
+ if labels_have_static_shape:
+ labels = tf.constant(input_labels_value, dtype=tf.int32)
+ else:
+ labels = tf.placeholder(dtype=tf.int32, name='labels')
+ feed_dict[labels] = input_labels_value
+
+ squeezed_predictions, squeezed_labels = (
+ metric_ops.remove_squeezable_dimensions(predictions, labels))
+ with self.test_session(g):
+ tf.initialize_local_variables().run()
+ self.assertAllClose(
+ predictions_value, squeezed_predictions.eval(feed_dict=feed_dict))
+ self.assertAllClose(
+ labels_value, squeezed_labels.eval(feed_dict=feed_dict))
+
+
class StreamingMeanTest(tf.test.TestCase):
def setUp(self):
@@ -291,8 +394,9 @@ class StreamingAccuracyTest(tf.test.TestCase):
predictions, labels)
sess.run(tf.initialize_local_variables())
- for _ in range(4):
+ for _ in xrange(3):
sess.run(update_op)
+ self.assertEqual(0.5, sess.run(update_op))
self.assertEqual(0.5, accuracy.eval())
def testEffectivelyEquivalentSizes(self):
@@ -336,8 +440,9 @@ class StreamingAccuracyTest(tf.test.TestCase):
predictions, labels, weights)
sess.run(tf.initialize_local_variables())
- for _ in range(4):
+ for _ in xrange(3):
sess.run(update_op)
+ self.assertEqual(1.0, sess.run(update_op))
self.assertEqual(1.0, accuracy.eval())
@@ -381,19 +486,6 @@ class StreamingPrecisionTest(tf.test.TestCase):
for _ in range(10):
self.assertEqual(initial_precision, precision.eval())
- def testEffectivelyEquivalentShapes(self):
- inputs = np.random.randint(0, 2, size=(100, 1))
-
- predictions = tf.constant(inputs, shape=(100, 1))
- labels = tf.constant(inputs, shape=(100,))
- precision, update_op = tf.contrib.metrics.streaming_precision(
- predictions, labels)
-
- with self.test_session() as sess:
- sess.run(tf.initialize_local_variables())
- self.assertAlmostEqual(1, sess.run(update_op))
- self.assertAlmostEqual(1, precision.eval())
-
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
@@ -483,18 +575,6 @@ class StreamingRecallTest(tf.test.TestCase):
for _ in range(10):
self.assertEqual(initial_recall, recall.eval())
- def testEffectivelyEquivalentShapes(self):
- np_inputs = np.random.randint(0, 2, size=(100, 1))
-
- predictions = tf.constant(np_inputs, shape=(100,))
- labels = tf.constant(np_inputs, shape=(100, 1))
- recall, update_op = tf.contrib.metrics.streaming_recall(predictions, labels)
-
- with self.test_session() as sess:
- sess.run(tf.initialize_local_variables())
- sess.run(update_op)
- self.assertEqual(1, recall.eval())
-
def testAllCorrect(self):
np_inputs = np.random.randint(0, 2, size=(100, 1))
@@ -580,19 +660,6 @@ class StreamingAUCTest(tf.test.TestCase):
for _ in range(10):
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
- def testEffectivelyEquivalentShapes(self):
- inputs = np.random.randint(0, 2, size=(100, 1))
-
- with self.test_session() as sess:
- predictions = tf.constant(inputs, dtype=tf.float32, shape=(100,))
- labels = tf.constant(inputs, shape=(100, 1))
- auc, update_op = tf.contrib.metrics.streaming_auc(predictions, labels)
-
- sess.run(tf.initialize_local_variables())
- sess.run(update_op)
-
- self.assertEqual(1, auc.eval())
-
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
@@ -602,7 +669,7 @@ class StreamingAUCTest(tf.test.TestCase):
auc, update_op = tf.contrib.metrics.streaming_auc(predictions, labels)
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, auc.eval())
@@ -613,7 +680,7 @@ class StreamingAUCTest(tf.test.TestCase):
auc, update_op = tf.contrib.metrics.streaming_auc(predictions, labels)
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertAlmostEqual(0.5, sess.run(update_op))
self.assertAlmostEqual(0.5, auc.eval())
@@ -626,7 +693,7 @@ class StreamingAUCTest(tf.test.TestCase):
auc, update_op = tf.contrib.metrics.streaming_auc(predictions, labels)
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertAlmostEqual(0, sess.run(update_op))
self.assertAlmostEqual(0, auc.eval())
@@ -637,7 +704,7 @@ class StreamingAUCTest(tf.test.TestCase):
auc, update_op = tf.contrib.metrics.streaming_auc(predictions, labels)
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertAlmostEqual(1, sess.run(update_op), 6)
self.assertAlmostEqual(1, auc.eval(), 6)
@@ -769,24 +836,6 @@ class StreamingPrecisionRecallThresholdsTest(tf.test.TestCase):
self.assertAllClose(initial_prec, prec.eval())
self.assertAllClose(initial_rec, rec.eval())
- def testEffectivelyEquivalentShapes(self):
- inputs = np.random.randint(0, 2, size=(100, 1))
-
- with self.test_session() as sess:
- predictions = tf.constant(inputs, dtype=tf.float32, shape=(100,))
- labels = tf.constant(inputs, shape=(100, 1))
- thresholds = [0.5]
- prec, prec_op = tf.contrib.metrics.streaming_precision_at_thresholds(
- predictions, labels, thresholds)
- rec, rec_op = tf.contrib.metrics.streaming_recall_at_thresholds(
- predictions, labels, thresholds)
-
- sess.run(tf.initialize_local_variables())
- sess.run([prec_op, rec_op])
-
- self.assertEqual(1, prec.eval())
- self.assertEqual(1, rec.eval())
-
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
@@ -1015,7 +1064,7 @@ class StreamingRecallAtKTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertEqual(0.25, sess.run(update_op))
self.assertEqual(0.25, recall.eval())
def testSingleUpdateAllPresentKIs2(self):
@@ -1028,7 +1077,7 @@ class StreamingRecallAtKTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertEqual(0.5, sess.run(update_op))
self.assertEqual(0.5, recall.eval())
def testSingleUpdateAllPresentKIs3(self):
@@ -1041,7 +1090,7 @@ class StreamingRecallAtKTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertEqual(1.0, sess.run(update_op))
self.assertEqual(1.0, recall.eval())
def testSingleUpdateSomeMissingKIs2(self):
@@ -1056,7 +1105,7 @@ class StreamingRecallAtKTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertEqual(1.0, sess.run(update_op))
self.assertEqual(1.0, recall.eval())
@@ -1701,17 +1750,6 @@ class StreamingMeanAbsoluteErrorTest(tf.test.TestCase):
updates_collections=[my_collection_name])
self.assertListEqual(tf.get_collection(my_collection_name), [update_op])
- def testEffectivelyEquivalentShapes(self):
- predictions = tf.ones((10, 3, 1))
- labels = tf.ones((10, 3,))
- error, update_op = tf.contrib.metrics.streaming_mean_absolute_error(
- predictions, labels)
-
- with self.test_session() as sess:
- sess.run(tf.initialize_local_variables())
- self.assertEqual(0.0, update_op.eval())
- self.assertEqual(0.0, error.eval())
-
def testValueTensorIsIdempotent(self):
predictions = tf.random_normal((10, 3), seed=1)
labels = tf.random_normal((10, 3), seed=2)
@@ -1740,7 +1778,7 @@ class StreamingMeanAbsoluteErrorTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertEqual(3, sess.run(update_op))
self.assertEqual(3, error.eval())
@@ -1786,18 +1824,6 @@ class StreamingMeanRelativeErrorTest(tf.test.TestCase):
for _ in range(10):
self.assertEqual(initial_error, error.eval())
- def testEffectivelyEquivalentShapes(self):
- predictions = tf.ones((10, 3, 1))
- labels = tf.ones((10, 3,))
- normalizer = tf.ones((10, 3, 1))
- error, update_op = tf.contrib.metrics.streaming_mean_relative_error(
- predictions, labels, normalizer)
-
- with self.test_session() as sess:
- sess.run(tf.initialize_local_variables())
- self.assertEqual(0.0, update_op.eval())
- self.assertEqual(0.0, error.eval())
-
def testSingleUpdateNormalizedByLabels(self):
np_predictions = np.asarray([2, 4, 6, 8], dtype=np.float32)
np_labels = np.asarray([1, 3, 2, 3], dtype=np.float32)
@@ -1813,7 +1839,7 @@ class StreamingMeanRelativeErrorTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertEqual(expected_error, sess.run(update_op))
self.assertEqual(expected_error, error.eval())
def testSingleUpdateNormalizedByZeros(self):
@@ -1827,7 +1853,7 @@ class StreamingMeanRelativeErrorTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertEqual(0.0, sess.run(update_op))
self.assertEqual(0.0, error.eval())
@@ -1852,17 +1878,6 @@ class StreamingMeanSquaredErrorTest(tf.test.TestCase):
updates_collections=[my_collection_name])
self.assertListEqual(tf.get_collection(my_collection_name), [update_op])
- def testEffectivelyEquivalentShapes(self):
- predictions = tf.ones((10, 3, 1))
- labels = tf.ones((10, 3,))
- error, update_op = tf.contrib.metrics.streaming_mean_squared_error(
- predictions, labels)
-
- with self.test_session() as sess:
- sess.run(tf.initialize_local_variables())
- self.assertEqual(0.0, update_op.eval())
- self.assertEqual(0.0, error.eval())
-
def testValueTensorIsIdempotent(self):
predictions = tf.random_normal((10, 3), seed=1)
labels = tf.random_normal((10, 3), seed=2)
@@ -1890,7 +1905,7 @@ class StreamingMeanSquaredErrorTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
def testSingleUpdateWithError(self):
@@ -1902,7 +1917,7 @@ class StreamingMeanSquaredErrorTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertEqual(6, sess.run(update_op))
self.assertEqual(6, error.eval())
def testSingleUpdateWithErrorAndWeights(self):
@@ -1915,7 +1930,7 @@ class StreamingMeanSquaredErrorTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertEqual(13, sess.run(update_op))
self.assertEqual(13, error.eval())
def testMultipleBatchesOfSizeOne(self):
@@ -1937,7 +1952,7 @@ class StreamingMeanSquaredErrorTest(tf.test.TestCase):
sess.run(tf.initialize_local_variables())
sess.run(update_op)
- sess.run(update_op)
+ self.assertAlmostEqual(208 / 6.0, sess.run(update_op), 5)
self.assertAlmostEqual(208 / 6.0, error.eval(), 5)
@@ -2028,17 +2043,6 @@ class StreamingRootMeanSquaredErrorTest(tf.test.TestCase):
updates_collections=[my_collection_name])
self.assertListEqual(tf.get_collection(my_collection_name), [update_op])
- def testEffectivelyEquivalentShapes(self):
- predictions = tf.ones((10, 3,))
- labels = tf.ones((10, 3, 1))
- error, update_op = tf.contrib.metrics.streaming_root_mean_squared_error(
- predictions, labels)
-
- with self.test_session() as sess:
- sess.run(tf.initialize_local_variables())
- self.assertEqual(0.0, update_op.eval())
- self.assertEqual(0.0, error.eval())
-
def testValueTensorIsIdempotent(self):
predictions = tf.random_normal((10, 3), seed=1)
labels = tf.random_normal((10, 3), seed=2)
@@ -2066,7 +2070,7 @@ class StreamingRootMeanSquaredErrorTest(tf.test.TestCase):
predictions, labels)
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, rmse.eval())
@@ -2092,7 +2096,7 @@ class StreamingRootMeanSquaredErrorTest(tf.test.TestCase):
predictions, labels, weights)
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertAlmostEqual(math.sqrt(13), sess.run(update_op))
self.assertAlmostEqual(math.sqrt(13), rmse.eval(), 5)
@@ -2120,17 +2124,6 @@ class StreamingMeanCosineDistanceTest(tf.test.TestCase):
updates_collections=[my_collection_name])
self.assertListEqual(tf.get_collection(my_collection_name), [update_op])
- def testEffectivelyEquivalentShapes(self):
- predictions = tf.nn.l2_normalize(tf.ones((10, 3,)), dim=1)
- labels = tf.nn.l2_normalize(tf.ones((10, 3, 1)), dim=1)
- error, update_op = tf.contrib.metrics.streaming_mean_cosine_distance(
- predictions, labels, dim=1)
-
- with self.test_session() as sess:
- sess.run(tf.initialize_local_variables())
- self.assertAlmostEqual(0.0, update_op.eval(), 5)
- self.assertAlmostEqual(0.0, error.eval(), 5)
-
def testValueTensorIsIdempotent(self):
predictions = tf.random_normal((10, 3), seed=1)
labels = tf.random_normal((10, 3), seed=2)
@@ -2162,7 +2155,7 @@ class StreamingMeanCosineDistanceTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
def testSingleUpdateWithError1(self):
@@ -2181,7 +2174,7 @@ class StreamingMeanCosineDistanceTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertAlmostEqual(1, sess.run(update_op), 5)
self.assertAlmostEqual(1, error.eval(), 5)
def testSingleUpdateWithError2(self):
@@ -2201,7 +2194,7 @@ class StreamingMeanCosineDistanceTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertAlmostEqual(1.0, sess.run(update_op), 5)
self.assertAlmostEqual(1.0, error.eval(), 5)
def testSingleUpdateWithErrorAndMissing1(self):
@@ -2221,7 +2214,7 @@ class StreamingMeanCosineDistanceTest(tf.test.TestCase):
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
- sess.run(update_op)
+ self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
def testSingleUpdateWithErrorAndMissing2(self):