diff options
author | 2016-06-22 19:27:14 -0800 | |
---|---|---|
committer | 2016-06-22 20:33:29 -0700 | |
commit | cf7230b8e6bc5c8ae80d64e8f9eda03cae437da1 (patch) | |
tree | 51502e327d7c67bae9978621db41bff8816c000c /tensorflow/contrib/metrics/python/ops | |
parent | 8ec747f1cd9ceb2c585addc5529c2946c4a09c54 (diff) |
Modify metric_ops.remove_squeezable_dimensions to handle statically unknown ranks.
Add tests to assert all update_op tensors returned from metric_ops produce the proper metric value, and fix streaming_auc to do so.
Change: 125639012
Diffstat (limited to 'tensorflow/contrib/metrics/python/ops')
4 files changed, 201 insertions, 179 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py b/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py index d310e55170..73c515d76b 100644 --- a/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py +++ b/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.metrics.python.ops import metric_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -25,11 +26,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops -"""Confusion matrix related metrics.""" - - def confusion_matrix(predictions, labels, num_classes=None, name=None): - """Computes the confusion matrix from predictions and labels + """Computes the confusion matrix from predictions and labels. Calculate the Confusion Matrix for a pair of prediction and label 1-D int arrays. @@ -68,9 +66,10 @@ def confusion_matrix(predictions, labels, num_classes=None, name=None): """ with ops.op_scope([predictions, labels, num_classes], name, 'confusion_matrix') as name: - predictions = ops.convert_to_tensor( - predictions, name='predictions', dtype=dtypes.int64) - labels = ops.convert_to_tensor(labels, name='labels', dtype=dtypes.int64) + predictions, labels = metric_ops.remove_squeezable_dimensions( + ops.convert_to_tensor( + predictions, name='predictions', dtype=dtypes.int64), + ops.convert_to_tensor(labels, name='labels', dtype=dtypes.int64)) if num_classes is None: num_classes = math_ops.maximum(math_ops.reduce_max(predictions), diff --git a/tensorflow/contrib/metrics/python/ops/histogram_ops.py b/tensorflow/contrib/metrics/python/ops/histogram_ops.py index 60c26149ea..11b9602a2e 100644 --- a/tensorflow/contrib/metrics/python/ops/histogram_ops.py +++ b/tensorflow/contrib/metrics/python/ops/histogram_ops.py @@ -23,6 +23,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.metrics.python.ops import metric_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -78,6 +79,8 @@ def auc_using_histogram(boolean_labels, collections = [ops.GraphKeys.LOCAL_VARIABLES] with variable_scope.variable_op_scope( [boolean_labels, scores, score_range], name, 'auc_using_histogram'): + scores, boolean_labels = metric_ops.remove_squeezable_dimensions( + scores, boolean_labels) score_range = ops.convert_to_tensor(score_range, name='score_range') boolean_labels, scores = _check_labels_and_scores( boolean_labels, scores, check_shape) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index e9ed4a1799..92ee12dff2 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1,3 +1,4 @@ +# pylint: disable=g-bad-file-header # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import sparse_ops @@ -75,14 +77,48 @@ def _create_local(name, shape=None, collections=None): collections=collections) -def _remove_squeezable_dimensions(predictions, labels): - predictions_rank = predictions.get_shape().ndims - labels_rank = labels.get_shape().ndims - if not (labels_rank is None or predictions_rank is None): - if labels_rank == (predictions_rank + 1): +def remove_squeezable_dimensions(predictions, labels): + """Squeeze last dim if ranks of `predictions` and `labels` differ by 1. + + This will use static shape if available. Otherwise, it will add graph + operations, which could result in a performance hit. + + Args: + predictions: Predicted values, a `Tensor` of arbitrary dimensions. + labels: Label values, a `Tensor` whose dimensions match `predictions`. + + Returns: + Tuple of `predictions` and `labels`, possibly with last dim squeezed. + """ + predictions = ops.convert_to_tensor(predictions) + labels = ops.convert_to_tensor(labels) + predictions_shape = predictions.get_shape() + predictions_rank = predictions_shape.ndims + labels_shape = labels.get_shape() + labels_rank = labels_shape.ndims + if (labels_rank is not None) and (predictions_rank is not None): + # Use static rank. + rank_diff = predictions_rank - labels_rank + if rank_diff == -1: labels = array_ops.squeeze(labels, [-1]) - elif predictions_rank == (labels_rank + 1): + elif rank_diff == 1: predictions = array_ops.squeeze(predictions, [-1]) + return predictions, labels + + # Use dynamic rank. + rank_diff = array_ops.rank(predictions) - array_ops.rank(labels) + if (predictions_rank is None) or ( + predictions_shape.dims[-1].is_compatible_with(1)): + predictions = control_flow_ops.cond( + math_ops.equal(1, rank_diff), + lambda: array_ops.squeeze(predictions, [-1]), + lambda: predictions) + if (labels_rank is None) or ( + labels_shape.dims[-1].is_compatible_with(1)): + labels = control_flow_ops.cond( + math_ops.equal(-1, rank_diff), + lambda: array_ops.squeeze(labels, [-1]), + lambda: labels) return predictions, labels @@ -357,7 +393,7 @@ def streaming_accuracy(predictions, labels, weights=None, if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) is_correct = math_ops.to_float(math_ops.equal(predictions, labels)) return streaming_mean(is_correct, weights, metrics_collections, @@ -412,7 +448,7 @@ def streaming_precision(predictions, labels, ignore_mask=None, with variable_scope.variable_op_scope( [predictions, labels], name, 'precision'): - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) true_positives, true_positives_update_op = _streaming_true_positives( @@ -489,7 +525,7 @@ def streaming_recall(predictions, labels, ignore_mask=None, or tuple. """ with variable_scope.variable_op_scope([predictions, labels], name, 'recall'): - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) true_positives, true_positives_update_op = _streaming_true_positives( @@ -567,7 +603,7 @@ def _tp_fn_tn_fp(predictions, labels, thresholds, ignore_mask=None): or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) num_thresholds = len(thresholds) @@ -712,33 +748,24 @@ def streaming_auc(predictions, labels, ignore_mask=None, num_thresholds=200, for i in range(num_thresholds-2)] thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] - (true_positives, false_negatives, true_negatives, false_positives, - true_positives_compute_op, false_negatives_compute_op, - true_negatives_compute_op, false_positives_compute_op) = _tp_fn_tn_fp( - predictions, labels, thresholds, ignore_mask) + (tp, fn, tn, fp, tp_update_op, fn_update_op, tn_update_op, + fp_update_op) = _tp_fn_tn_fp(predictions, labels, thresholds, ignore_mask) - epsilon = 1.0e-6 - assert array_ops.squeeze( - false_positives).get_shape().as_list()[0] == num_thresholds # Add epsilons to avoid dividing by 0. - false_positive_rate = math_ops.div( - false_positives, - false_positives + true_negatives + epsilon) - recall = math_ops.div(true_positives + epsilon, - true_positives + false_negatives + epsilon) + epsilon = 1.0e-6 + assert array_ops.squeeze(fp).get_shape().as_list()[0] == num_thresholds - def compute_auc(name): + def compute_auc(tp, fn, tn, fp, name): + fp_rate = math_ops.div(fp, fp + tn + epsilon) + recall = math_ops.div(tp + epsilon, tp + fn + epsilon) return math_ops.reduce_sum(math_ops.mul( - false_positive_rate[:num_thresholds - 1] - false_positive_rate[1:], + fp_rate[:num_thresholds - 1] - fp_rate[1:], (recall[:num_thresholds - 1] + recall[1:]) / 2.), name=name) # sum up the areas of all the trapeziums - auc = compute_auc('value') - with ops.control_dependencies([true_positives_compute_op, - false_negatives_compute_op, - true_negatives_compute_op, - false_positives_compute_op]): - update_op = compute_auc('update_op') + auc = compute_auc(tp, fn, tn, fp, 'value') + update_op = compute_auc( + tp_update_op, fn_update_op, tn_update_op, fp_update_op, 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, auc) @@ -1367,7 +1394,7 @@ def streaming_mean_absolute_error(predictions, labels, weights=None, `predictions` or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) absolute_errors = math_ops.abs(predictions - labels) return streaming_mean(absolute_errors, weights, metrics_collections, @@ -1419,10 +1446,10 @@ def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, `predictions` or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - predictions, normalizer = _remove_squeezable_dimensions( + predictions, normalizer = remove_squeezable_dimensions( predictions, normalizer) predictions.get_shape().assert_is_compatible_with(normalizer.get_shape()) relative_errors = math_ops.select( @@ -1477,7 +1504,7 @@ def streaming_mean_squared_error(predictions, labels, weights=None, `predictions` or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) squared_error = math_ops.square(labels - predictions) return streaming_mean(squared_error, weights, metrics_collections, @@ -1528,7 +1555,7 @@ def streaming_root_mean_squared_error(predictions, labels, weights=None, `predictions` or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) value_tensor, update_op = streaming_mean_squared_error( predictions, labels, weights, None, None, @@ -1593,7 +1620,7 @@ def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, ignore_mask is of the wrong size or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels = _remove_squeezable_dimensions(predictions, labels) + predictions, labels = remove_squeezable_dimensions(predictions, labels) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) radial_diffs = math_ops.mul(predictions, labels) radial_diffs = math_ops.reduce_sum(radial_diffs, diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 93352940f9..ea2adc2fbf 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -1,3 +1,4 @@ +# pylint: disable=g-bad-file-header # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,7 +24,7 @@ import math import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf - +from tensorflow.contrib.metrics.python.ops import metric_ops NAN = float('nan') @@ -136,6 +137,108 @@ def _binary_3d_label_to_sparse(labels): tf.constant(v.shape, tf.int64)) +class RemoveSqueezableDimensionsTest(tf.test.TestCase): + + def testRemoveSqueezableDimensions(self): + self._testRemoveSqueezableDimensions( + predictions_have_static_shape=False, predictions_have_extra_dim=False, + labels_have_static_shape=False, labels_have_extra_dim=False) + + def testRemoveSqueezableDimensions_extraLabelDim(self): + self._testRemoveSqueezableDimensions( + predictions_have_static_shape=False, predictions_have_extra_dim=False, + labels_have_static_shape=False, labels_have_extra_dim=True) + + def testRemoveSqueezableDimensions_staticLabel(self): + self._testRemoveSqueezableDimensions( + predictions_have_static_shape=False, predictions_have_extra_dim=False, + labels_have_static_shape=True, labels_have_extra_dim=False) + + def testRemoveSqueezableDimensions_staticLabel_extraLabelDim(self): + self._testRemoveSqueezableDimensions( + predictions_have_static_shape=False, predictions_have_extra_dim=False, + labels_have_static_shape=True, labels_have_extra_dim=True) + + def testRemoveSqueezableDimensions_extraPredictionDim(self): + self._testRemoveSqueezableDimensions( + predictions_have_static_shape=False, predictions_have_extra_dim=True, + labels_have_static_shape=False, labels_have_extra_dim=False) + + def testRemoveSqueezableDimensions_extraPredictionDim_staticLabel(self): + self._testRemoveSqueezableDimensions( + predictions_have_static_shape=False, predictions_have_extra_dim=True, + labels_have_static_shape=True, labels_have_extra_dim=False) + + def testRemoveSqueezableDimensions_staticPrediction(self): + self._testRemoveSqueezableDimensions( + predictions_have_static_shape=True, predictions_have_extra_dim=False, + labels_have_static_shape=False, labels_have_extra_dim=False) + + def testRemoveSqueezableDimensions_staticPrediction_extraLabelDim(self): + self._testRemoveSqueezableDimensions( + predictions_have_static_shape=True, predictions_have_extra_dim=False, + labels_have_static_shape=False, labels_have_extra_dim=True) + + def testRemoveSqueezableDimensions_static(self): + self._testRemoveSqueezableDimensions( + predictions_have_static_shape=True, predictions_have_extra_dim=False, + labels_have_static_shape=True, labels_have_extra_dim=False) + + def testRemoveSqueezableDimensions_static_extraLabelDim(self): + self._testRemoveSqueezableDimensions( + predictions_have_static_shape=True, predictions_have_extra_dim=False, + labels_have_static_shape=True, labels_have_extra_dim=True) + + def testRemoveSqueezableDimensions_staticPrediction_extraPredictionDim(self): + self._testRemoveSqueezableDimensions( + predictions_have_static_shape=True, predictions_have_extra_dim=True, + labels_have_static_shape=False, labels_have_extra_dim=False) + + def testRemoveSqueezableDimensions_static_extraPredictionDim(self): + self._testRemoveSqueezableDimensions( + predictions_have_static_shape=True, predictions_have_extra_dim=True, + labels_have_static_shape=True, labels_have_extra_dim=False) + + # TODO(ptucker): Replace this with parameterized test. + def _testRemoveSqueezableDimensions( + self, + predictions_have_static_shape, + predictions_have_extra_dim, + labels_have_static_shape, + labels_have_extra_dim): + assert not (predictions_have_extra_dim and labels_have_extra_dim) + predictions_value = (0, 1, 1, 0, 0, 1, 0) + labels_value = (0, 0, 1, 1, 0, 0, 0) + + input_predictions_value = ( + [[p] for p in predictions_value] if predictions_have_extra_dim else + predictions_value) + input_labels_value = ( + [[l] for l in labels_value] if labels_have_extra_dim else labels_value) + + with tf.Graph().as_default() as g: + feed_dict = {} + if predictions_have_static_shape: + predictions = tf.constant(input_predictions_value, dtype=tf.int32) + else: + predictions = tf.placeholder(dtype=tf.int32, name='predictions') + feed_dict[predictions] = input_predictions_value + if labels_have_static_shape: + labels = tf.constant(input_labels_value, dtype=tf.int32) + else: + labels = tf.placeholder(dtype=tf.int32, name='labels') + feed_dict[labels] = input_labels_value + + squeezed_predictions, squeezed_labels = ( + metric_ops.remove_squeezable_dimensions(predictions, labels)) + with self.test_session(g): + tf.initialize_local_variables().run() + self.assertAllClose( + predictions_value, squeezed_predictions.eval(feed_dict=feed_dict)) + self.assertAllClose( + labels_value, squeezed_labels.eval(feed_dict=feed_dict)) + + class StreamingMeanTest(tf.test.TestCase): def setUp(self): @@ -291,8 +394,9 @@ class StreamingAccuracyTest(tf.test.TestCase): predictions, labels) sess.run(tf.initialize_local_variables()) - for _ in range(4): + for _ in xrange(3): sess.run(update_op) + self.assertEqual(0.5, sess.run(update_op)) self.assertEqual(0.5, accuracy.eval()) def testEffectivelyEquivalentSizes(self): @@ -336,8 +440,9 @@ class StreamingAccuracyTest(tf.test.TestCase): predictions, labels, weights) sess.run(tf.initialize_local_variables()) - for _ in range(4): + for _ in xrange(3): sess.run(update_op) + self.assertEqual(1.0, sess.run(update_op)) self.assertEqual(1.0, accuracy.eval()) @@ -381,19 +486,6 @@ class StreamingPrecisionTest(tf.test.TestCase): for _ in range(10): self.assertEqual(initial_precision, precision.eval()) - def testEffectivelyEquivalentShapes(self): - inputs = np.random.randint(0, 2, size=(100, 1)) - - predictions = tf.constant(inputs, shape=(100, 1)) - labels = tf.constant(inputs, shape=(100,)) - precision, update_op = tf.contrib.metrics.streaming_precision( - predictions, labels) - - with self.test_session() as sess: - sess.run(tf.initialize_local_variables()) - self.assertAlmostEqual(1, sess.run(update_op)) - self.assertAlmostEqual(1, precision.eval()) - def testAllCorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -483,18 +575,6 @@ class StreamingRecallTest(tf.test.TestCase): for _ in range(10): self.assertEqual(initial_recall, recall.eval()) - def testEffectivelyEquivalentShapes(self): - np_inputs = np.random.randint(0, 2, size=(100, 1)) - - predictions = tf.constant(np_inputs, shape=(100,)) - labels = tf.constant(np_inputs, shape=(100, 1)) - recall, update_op = tf.contrib.metrics.streaming_recall(predictions, labels) - - with self.test_session() as sess: - sess.run(tf.initialize_local_variables()) - sess.run(update_op) - self.assertEqual(1, recall.eval()) - def testAllCorrect(self): np_inputs = np.random.randint(0, 2, size=(100, 1)) @@ -580,19 +660,6 @@ class StreamingAUCTest(tf.test.TestCase): for _ in range(10): self.assertAlmostEqual(initial_auc, auc.eval(), 5) - def testEffectivelyEquivalentShapes(self): - inputs = np.random.randint(0, 2, size=(100, 1)) - - with self.test_session() as sess: - predictions = tf.constant(inputs, dtype=tf.float32, shape=(100,)) - labels = tf.constant(inputs, shape=(100, 1)) - auc, update_op = tf.contrib.metrics.streaming_auc(predictions, labels) - - sess.run(tf.initialize_local_variables()) - sess.run(update_op) - - self.assertEqual(1, auc.eval()) - def testAllCorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -602,7 +669,7 @@ class StreamingAUCTest(tf.test.TestCase): auc, update_op = tf.contrib.metrics.streaming_auc(predictions, labels) sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertEqual(1, sess.run(update_op)) self.assertEqual(1, auc.eval()) @@ -613,7 +680,7 @@ class StreamingAUCTest(tf.test.TestCase): auc, update_op = tf.contrib.metrics.streaming_auc(predictions, labels) sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertAlmostEqual(0.5, sess.run(update_op)) self.assertAlmostEqual(0.5, auc.eval()) @@ -626,7 +693,7 @@ class StreamingAUCTest(tf.test.TestCase): auc, update_op = tf.contrib.metrics.streaming_auc(predictions, labels) sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertAlmostEqual(0, sess.run(update_op)) self.assertAlmostEqual(0, auc.eval()) @@ -637,7 +704,7 @@ class StreamingAUCTest(tf.test.TestCase): auc, update_op = tf.contrib.metrics.streaming_auc(predictions, labels) sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertAlmostEqual(1, sess.run(update_op), 6) self.assertAlmostEqual(1, auc.eval(), 6) @@ -769,24 +836,6 @@ class StreamingPrecisionRecallThresholdsTest(tf.test.TestCase): self.assertAllClose(initial_prec, prec.eval()) self.assertAllClose(initial_rec, rec.eval()) - def testEffectivelyEquivalentShapes(self): - inputs = np.random.randint(0, 2, size=(100, 1)) - - with self.test_session() as sess: - predictions = tf.constant(inputs, dtype=tf.float32, shape=(100,)) - labels = tf.constant(inputs, shape=(100, 1)) - thresholds = [0.5] - prec, prec_op = tf.contrib.metrics.streaming_precision_at_thresholds( - predictions, labels, thresholds) - rec, rec_op = tf.contrib.metrics.streaming_recall_at_thresholds( - predictions, labels, thresholds) - - sess.run(tf.initialize_local_variables()) - sess.run([prec_op, rec_op]) - - self.assertEqual(1, prec.eval()) - self.assertEqual(1, rec.eval()) - def testAllCorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -1015,7 +1064,7 @@ class StreamingRecallAtKTest(tf.test.TestCase): with self.test_session() as sess: sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertEqual(0.25, sess.run(update_op)) self.assertEqual(0.25, recall.eval()) def testSingleUpdateAllPresentKIs2(self): @@ -1028,7 +1077,7 @@ class StreamingRecallAtKTest(tf.test.TestCase): with self.test_session() as sess: sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertEqual(0.5, sess.run(update_op)) self.assertEqual(0.5, recall.eval()) def testSingleUpdateAllPresentKIs3(self): @@ -1041,7 +1090,7 @@ class StreamingRecallAtKTest(tf.test.TestCase): with self.test_session() as sess: sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertEqual(1.0, sess.run(update_op)) self.assertEqual(1.0, recall.eval()) def testSingleUpdateSomeMissingKIs2(self): @@ -1056,7 +1105,7 @@ class StreamingRecallAtKTest(tf.test.TestCase): with self.test_session() as sess: sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertEqual(1.0, sess.run(update_op)) self.assertEqual(1.0, recall.eval()) @@ -1701,17 +1750,6 @@ class StreamingMeanAbsoluteErrorTest(tf.test.TestCase): updates_collections=[my_collection_name]) self.assertListEqual(tf.get_collection(my_collection_name), [update_op]) - def testEffectivelyEquivalentShapes(self): - predictions = tf.ones((10, 3, 1)) - labels = tf.ones((10, 3,)) - error, update_op = tf.contrib.metrics.streaming_mean_absolute_error( - predictions, labels) - - with self.test_session() as sess: - sess.run(tf.initialize_local_variables()) - self.assertEqual(0.0, update_op.eval()) - self.assertEqual(0.0, error.eval()) - def testValueTensorIsIdempotent(self): predictions = tf.random_normal((10, 3), seed=1) labels = tf.random_normal((10, 3), seed=2) @@ -1740,7 +1778,7 @@ class StreamingMeanAbsoluteErrorTest(tf.test.TestCase): with self.test_session() as sess: sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertEqual(3, sess.run(update_op)) self.assertEqual(3, error.eval()) @@ -1786,18 +1824,6 @@ class StreamingMeanRelativeErrorTest(tf.test.TestCase): for _ in range(10): self.assertEqual(initial_error, error.eval()) - def testEffectivelyEquivalentShapes(self): - predictions = tf.ones((10, 3, 1)) - labels = tf.ones((10, 3,)) - normalizer = tf.ones((10, 3, 1)) - error, update_op = tf.contrib.metrics.streaming_mean_relative_error( - predictions, labels, normalizer) - - with self.test_session() as sess: - sess.run(tf.initialize_local_variables()) - self.assertEqual(0.0, update_op.eval()) - self.assertEqual(0.0, error.eval()) - def testSingleUpdateNormalizedByLabels(self): np_predictions = np.asarray([2, 4, 6, 8], dtype=np.float32) np_labels = np.asarray([1, 3, 2, 3], dtype=np.float32) @@ -1813,7 +1839,7 @@ class StreamingMeanRelativeErrorTest(tf.test.TestCase): with self.test_session() as sess: sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertEqual(expected_error, sess.run(update_op)) self.assertEqual(expected_error, error.eval()) def testSingleUpdateNormalizedByZeros(self): @@ -1827,7 +1853,7 @@ class StreamingMeanRelativeErrorTest(tf.test.TestCase): with self.test_session() as sess: sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertEqual(0.0, sess.run(update_op)) self.assertEqual(0.0, error.eval()) @@ -1852,17 +1878,6 @@ class StreamingMeanSquaredErrorTest(tf.test.TestCase): updates_collections=[my_collection_name]) self.assertListEqual(tf.get_collection(my_collection_name), [update_op]) - def testEffectivelyEquivalentShapes(self): - predictions = tf.ones((10, 3, 1)) - labels = tf.ones((10, 3,)) - error, update_op = tf.contrib.metrics.streaming_mean_squared_error( - predictions, labels) - - with self.test_session() as sess: - sess.run(tf.initialize_local_variables()) - self.assertEqual(0.0, update_op.eval()) - self.assertEqual(0.0, error.eval()) - def testValueTensorIsIdempotent(self): predictions = tf.random_normal((10, 3), seed=1) labels = tf.random_normal((10, 3), seed=2) @@ -1890,7 +1905,7 @@ class StreamingMeanSquaredErrorTest(tf.test.TestCase): with self.test_session() as sess: sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertEqual(0, sess.run(update_op)) self.assertEqual(0, error.eval()) def testSingleUpdateWithError(self): @@ -1902,7 +1917,7 @@ class StreamingMeanSquaredErrorTest(tf.test.TestCase): with self.test_session() as sess: sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertEqual(6, sess.run(update_op)) self.assertEqual(6, error.eval()) def testSingleUpdateWithErrorAndWeights(self): @@ -1915,7 +1930,7 @@ class StreamingMeanSquaredErrorTest(tf.test.TestCase): with self.test_session() as sess: sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertEqual(13, sess.run(update_op)) self.assertEqual(13, error.eval()) def testMultipleBatchesOfSizeOne(self): @@ -1937,7 +1952,7 @@ class StreamingMeanSquaredErrorTest(tf.test.TestCase): sess.run(tf.initialize_local_variables()) sess.run(update_op) - sess.run(update_op) + self.assertAlmostEqual(208 / 6.0, sess.run(update_op), 5) self.assertAlmostEqual(208 / 6.0, error.eval(), 5) @@ -2028,17 +2043,6 @@ class StreamingRootMeanSquaredErrorTest(tf.test.TestCase): updates_collections=[my_collection_name]) self.assertListEqual(tf.get_collection(my_collection_name), [update_op]) - def testEffectivelyEquivalentShapes(self): - predictions = tf.ones((10, 3,)) - labels = tf.ones((10, 3, 1)) - error, update_op = tf.contrib.metrics.streaming_root_mean_squared_error( - predictions, labels) - - with self.test_session() as sess: - sess.run(tf.initialize_local_variables()) - self.assertEqual(0.0, update_op.eval()) - self.assertEqual(0.0, error.eval()) - def testValueTensorIsIdempotent(self): predictions = tf.random_normal((10, 3), seed=1) labels = tf.random_normal((10, 3), seed=2) @@ -2066,7 +2070,7 @@ class StreamingRootMeanSquaredErrorTest(tf.test.TestCase): predictions, labels) sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertEqual(0, sess.run(update_op)) self.assertEqual(0, rmse.eval()) @@ -2092,7 +2096,7 @@ class StreamingRootMeanSquaredErrorTest(tf.test.TestCase): predictions, labels, weights) sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertAlmostEqual(math.sqrt(13), sess.run(update_op)) self.assertAlmostEqual(math.sqrt(13), rmse.eval(), 5) @@ -2120,17 +2124,6 @@ class StreamingMeanCosineDistanceTest(tf.test.TestCase): updates_collections=[my_collection_name]) self.assertListEqual(tf.get_collection(my_collection_name), [update_op]) - def testEffectivelyEquivalentShapes(self): - predictions = tf.nn.l2_normalize(tf.ones((10, 3,)), dim=1) - labels = tf.nn.l2_normalize(tf.ones((10, 3, 1)), dim=1) - error, update_op = tf.contrib.metrics.streaming_mean_cosine_distance( - predictions, labels, dim=1) - - with self.test_session() as sess: - sess.run(tf.initialize_local_variables()) - self.assertAlmostEqual(0.0, update_op.eval(), 5) - self.assertAlmostEqual(0.0, error.eval(), 5) - def testValueTensorIsIdempotent(self): predictions = tf.random_normal((10, 3), seed=1) labels = tf.random_normal((10, 3), seed=2) @@ -2162,7 +2155,7 @@ class StreamingMeanCosineDistanceTest(tf.test.TestCase): with self.test_session() as sess: sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertEqual(0, sess.run(update_op)) self.assertEqual(0, error.eval()) def testSingleUpdateWithError1(self): @@ -2181,7 +2174,7 @@ class StreamingMeanCosineDistanceTest(tf.test.TestCase): with self.test_session() as sess: sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertAlmostEqual(1, sess.run(update_op), 5) self.assertAlmostEqual(1, error.eval(), 5) def testSingleUpdateWithError2(self): @@ -2201,7 +2194,7 @@ class StreamingMeanCosineDistanceTest(tf.test.TestCase): with self.test_session() as sess: sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertAlmostEqual(1.0, sess.run(update_op), 5) self.assertAlmostEqual(1.0, error.eval(), 5) def testSingleUpdateWithErrorAndMissing1(self): @@ -2221,7 +2214,7 @@ class StreamingMeanCosineDistanceTest(tf.test.TestCase): with self.test_session() as sess: sess.run(tf.initialize_local_variables()) - sess.run(update_op) + self.assertEqual(0, sess.run(update_op)) self.assertEqual(0, error.eval()) def testSingleUpdateWithErrorAndMissing2(self): |