diff options
author | 2017-06-09 12:26:31 -0700 | |
---|---|---|
committer | 2017-06-09 12:31:52 -0700 | |
commit | 59c4cfee63b1b5d40b337cf3a14b2c8cd9d4c1b1 (patch) | |
tree | 78285a763e95f63ca902103749cdf2db353fb23f | |
parent | 92cfd0105db9bb659e2b85548cbea24452f6aa60 (diff) |
Fix true_negatives and false_negatives to handle unsqueezed dimensions (as is already done for true_positives and false_positives).
PiperOrigin-RevId: 158548096
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops_test.py | 148 | ||||
-rw-r--r-- | tensorflow/python/ops/metrics_impl.py | 7 |
3 files changed, 96 insertions, 66 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index c2211961df..6c773d9a7f 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -208,9 +208,10 @@ def streaming_true_negatives(predictions, labels, weights=None, with variable_scope.variable_scope( name, 'true_negatives', (predictions, labels, weights)): - predictions = math_ops.cast(predictions, dtype=dtypes.bool) - labels = math_ops.cast(labels, dtype=dtypes.bool) - predictions.get_shape().assert_is_compatible_with(labels.get_shape()) + predictions, labels, weights = _remove_squeezable_dimensions( + predictions=math_ops.cast(predictions, dtype=dtypes.bool), + labels=math_ops.cast(labels, dtype=dtypes.bool), + weights=weights) is_true_negative = math_ops.logical_and(math_ops.equal(labels, False), math_ops.equal(predictions, False)) return _count_condition(is_true_negative, weights, metrics_collections, diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index f93b1945a6..6496cecfbd 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -663,22 +663,29 @@ class StreamingTruePositivesTest(test.TestCase): _assert_local_variables(self, ('true_positives/count:0',)) def testUnweighted(self): - for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): - predictions = math_ops.cast(constant_op.constant( - ((1, 0, 1, 0), - (0, 1, 1, 1), - (0, 0, 0, 0))), dtype=dtype) - labels = math_ops.cast(constant_op.constant( - ((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))), dtype=dtype) - tp, tp_update_op = metrics.streaming_true_positives(predictions, labels) + for expand_predictions in [True, False]: + for expand_labels in [True, False]: + for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): + predictions = math_ops.cast(constant_op.constant( + ((1, 0, 1, 0), + (0, 1, 1, 1), + (0, 0, 0, 0))), dtype=dtype) + if expand_predictions: + predictions = array_ops.expand_dims(predictions, 2) + labels = math_ops.cast(constant_op.constant( + ((0, 1, 1, 0), + (1, 0, 0, 0), + (0, 0, 0, 0))), dtype=dtype) + if expand_labels: + labels = array_ops.expand_dims(labels, 2) + tp, tp_update_op = metrics.streaming_true_positives(predictions, + labels) - with self.test_session() as sess: - sess.run(variables.local_variables_initializer()) - self.assertEqual(0, tp.eval()) - self.assertEqual(1, tp_update_op.eval()) - self.assertEqual(1, tp.eval()) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertEqual(0, tp.eval()) + self.assertEqual(1, tp_update_op.eval()) + self.assertEqual(1, tp.eval()) def testWeighted(self): for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): @@ -712,22 +719,29 @@ class StreamingFalseNegativesTest(test.TestCase): _assert_local_variables(self, ('false_negatives/count:0',)) def testUnweighted(self): - for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): - predictions = math_ops.cast(constant_op.constant( - ((1, 0, 1, 0), - (0, 1, 1, 1), - (0, 0, 0, 0))), dtype=dtype) - labels = math_ops.cast(constant_op.constant( - ((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))), dtype=dtype) - fn, fn_update_op = metrics.streaming_false_negatives(predictions, labels) + for expand_predictions in [True, False]: + for expand_labels in [True, False]: + for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): + predictions = math_ops.cast(constant_op.constant( + ((1, 0, 1, 0), + (0, 1, 1, 1), + (0, 0, 0, 0))), dtype=dtype) + if expand_predictions: + predictions = array_ops.expand_dims(predictions, 2) + labels = math_ops.cast(constant_op.constant( + ((0, 1, 1, 0), + (1, 0, 0, 0), + (0, 0, 0, 0))), dtype=dtype) + if expand_labels: + labels = array_ops.expand_dims(labels, 2) + fn, fn_update_op = metrics.streaming_false_negatives(predictions, + labels) - with self.test_session() as sess: - sess.run(variables.local_variables_initializer()) - self.assertEqual(0, fn.eval()) - self.assertEqual(2, fn_update_op.eval()) - self.assertEqual(2, fn.eval()) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertEqual(0, fn.eval()) + self.assertEqual(2, fn_update_op.eval()) + self.assertEqual(2, fn.eval()) def testWeighted(self): for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): @@ -761,22 +775,29 @@ class StreamingFalsePositivesTest(test.TestCase): _assert_local_variables(self, ('false_positives/count:0',)) def testUnweighted(self): - for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): - predictions = math_ops.cast(constant_op.constant( - ((1, 0, 1, 0), - (0, 1, 1, 1), - (0, 0, 0, 0))), dtype=dtype) - labels = math_ops.cast(constant_op.constant( - ((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))), dtype=dtype) - fp, fp_update_op = metrics.streaming_false_positives(predictions, labels) + for expand_predictions in [True, False]: + for expand_labels in [True, False]: + for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): + predictions = math_ops.cast(constant_op.constant( + ((1, 0, 1, 0), + (0, 1, 1, 1), + (0, 0, 0, 0))), dtype=dtype) + if expand_predictions: + predictions = array_ops.expand_dims(predictions, 2) + labels = math_ops.cast(constant_op.constant( + ((0, 1, 1, 0), + (1, 0, 0, 0), + (0, 0, 0, 0))), dtype=dtype) + if expand_labels: + labels = array_ops.expand_dims(labels, 2) + fp, fp_update_op = metrics.streaming_false_positives(predictions, + labels) - with self.test_session() as sess: - sess.run(variables.local_variables_initializer()) - self.assertEqual(0, fp.eval()) - self.assertEqual(4, fp_update_op.eval()) - self.assertEqual(4, fp.eval()) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertEqual(0, fp.eval()) + self.assertEqual(4, fp_update_op.eval()) + self.assertEqual(4, fp.eval()) def testWeighted(self): for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): @@ -814,22 +835,29 @@ class StreamingTrueNegativesTest(test.TestCase): _assert_local_variables(self, ('true_negatives/count:0',)) def testUnweighted(self): - for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): - predictions = math_ops.cast(constant_op.constant( - ((1, 0, 1, 0), - (0, 1, 1, 1), - (0, 0, 0, 0))), dtype=dtype) - labels = math_ops.cast(constant_op.constant( - ((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))), dtype=dtype) - tn, tn_update_op = metrics.streaming_true_negatives(predictions, labels) + for expand_predictions in [True, False]: + for expand_labels in [True, False]: + for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): + predictions = math_ops.cast(constant_op.constant( + ((1, 0, 1, 0), + (0, 1, 1, 1), + (0, 0, 0, 0))), dtype=dtype) + if expand_predictions: + predictions = array_ops.expand_dims(predictions, 2) + labels = math_ops.cast(constant_op.constant( + ((0, 1, 1, 0), + (1, 0, 0, 0), + (0, 0, 0, 0))), dtype=dtype) + if expand_labels: + labels = array_ops.expand_dims(labels, 2) + tn, tn_update_op = metrics.streaming_true_negatives(predictions, + labels) - with self.test_session() as sess: - sess.run(variables.local_variables_initializer()) - self.assertEqual(0, tn.eval()) - self.assertEqual(5, tn_update_op.eval()) - self.assertEqual(5, tn.eval()) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertEqual(0, tn.eval()) + self.assertEqual(5, tn_update_op.eval()) + self.assertEqual(5, tn.eval()) def testWeighted(self): for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 46cc1f4dda..fab4c5cb0f 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -1502,9 +1502,10 @@ def false_negatives(labels, predictions, weights=None, with variable_scope.variable_scope( name, 'false_negatives', (predictions, labels, weights)): - labels = math_ops.cast(labels, dtype=dtypes.bool) - predictions = math_ops.cast(predictions, dtype=dtypes.bool) - predictions.get_shape().assert_is_compatible_with(labels.get_shape()) + predictions, labels, weights = _remove_squeezable_dimensions( + predictions=math_ops.cast(predictions, dtype=dtypes.bool), + labels=math_ops.cast(labels, dtype=dtypes.bool), + weights=weights) is_false_negative = math_ops.logical_and(math_ops.equal(labels, True), math_ops.equal(predictions, False)) return _count_condition(is_false_negative, weights, metrics_collections, |