aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-09 12:26:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-09 12:31:52 -0700
commit59c4cfee63b1b5d40b337cf3a14b2c8cd9d4c1b1 (patch)
tree78285a763e95f63ca902103749cdf2db353fb23f
parent92cfd0105db9bb659e2b85548cbea24452f6aa60 (diff)
Fix true_negatives and false_negatives to handle unsqueezed dimensions (as is already done for true_positives and false_positives).
PiperOrigin-RevId: 158548096
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py7
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py148
-rw-r--r--tensorflow/python/ops/metrics_impl.py7
3 files changed, 96 insertions, 66 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index c2211961df..6c773d9a7f 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -208,9 +208,10 @@ def streaming_true_negatives(predictions, labels, weights=None,
with variable_scope.variable_scope(
name, 'true_negatives', (predictions, labels, weights)):
- predictions = math_ops.cast(predictions, dtype=dtypes.bool)
- labels = math_ops.cast(labels, dtype=dtypes.bool)
- predictions.get_shape().assert_is_compatible_with(labels.get_shape())
+ predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
+ weights=weights)
is_true_negative = math_ops.logical_and(math_ops.equal(labels, False),
math_ops.equal(predictions, False))
return _count_condition(is_true_negative, weights, metrics_collections,
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index f93b1945a6..6496cecfbd 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -663,22 +663,29 @@ class StreamingTruePositivesTest(test.TestCase):
_assert_local_variables(self, ('true_positives/count:0',))
def testUnweighted(self):
- for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
- tp, tp_update_op = metrics.streaming_true_positives(predictions, labels)
+ for expand_predictions in [True, False]:
+ for expand_labels in [True, False]:
+ for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
+ predictions = math_ops.cast(constant_op.constant(
+ ((1, 0, 1, 0),
+ (0, 1, 1, 1),
+ (0, 0, 0, 0))), dtype=dtype)
+ if expand_predictions:
+ predictions = array_ops.expand_dims(predictions, 2)
+ labels = math_ops.cast(constant_op.constant(
+ ((0, 1, 1, 0),
+ (1, 0, 0, 0),
+ (0, 0, 0, 0))), dtype=dtype)
+ if expand_labels:
+ labels = array_ops.expand_dims(labels, 2)
+ tp, tp_update_op = metrics.streaming_true_positives(predictions,
+ labels)
- with self.test_session() as sess:
- sess.run(variables.local_variables_initializer())
- self.assertEqual(0, tp.eval())
- self.assertEqual(1, tp_update_op.eval())
- self.assertEqual(1, tp.eval())
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertEqual(0, tp.eval())
+ self.assertEqual(1, tp_update_op.eval())
+ self.assertEqual(1, tp.eval())
def testWeighted(self):
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
@@ -712,22 +719,29 @@ class StreamingFalseNegativesTest(test.TestCase):
_assert_local_variables(self, ('false_negatives/count:0',))
def testUnweighted(self):
- for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
- fn, fn_update_op = metrics.streaming_false_negatives(predictions, labels)
+ for expand_predictions in [True, False]:
+ for expand_labels in [True, False]:
+ for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
+ predictions = math_ops.cast(constant_op.constant(
+ ((1, 0, 1, 0),
+ (0, 1, 1, 1),
+ (0, 0, 0, 0))), dtype=dtype)
+ if expand_predictions:
+ predictions = array_ops.expand_dims(predictions, 2)
+ labels = math_ops.cast(constant_op.constant(
+ ((0, 1, 1, 0),
+ (1, 0, 0, 0),
+ (0, 0, 0, 0))), dtype=dtype)
+ if expand_labels:
+ labels = array_ops.expand_dims(labels, 2)
+ fn, fn_update_op = metrics.streaming_false_negatives(predictions,
+ labels)
- with self.test_session() as sess:
- sess.run(variables.local_variables_initializer())
- self.assertEqual(0, fn.eval())
- self.assertEqual(2, fn_update_op.eval())
- self.assertEqual(2, fn.eval())
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertEqual(0, fn.eval())
+ self.assertEqual(2, fn_update_op.eval())
+ self.assertEqual(2, fn.eval())
def testWeighted(self):
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
@@ -761,22 +775,29 @@ class StreamingFalsePositivesTest(test.TestCase):
_assert_local_variables(self, ('false_positives/count:0',))
def testUnweighted(self):
- for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
- fp, fp_update_op = metrics.streaming_false_positives(predictions, labels)
+ for expand_predictions in [True, False]:
+ for expand_labels in [True, False]:
+ for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
+ predictions = math_ops.cast(constant_op.constant(
+ ((1, 0, 1, 0),
+ (0, 1, 1, 1),
+ (0, 0, 0, 0))), dtype=dtype)
+ if expand_predictions:
+ predictions = array_ops.expand_dims(predictions, 2)
+ labels = math_ops.cast(constant_op.constant(
+ ((0, 1, 1, 0),
+ (1, 0, 0, 0),
+ (0, 0, 0, 0))), dtype=dtype)
+ if expand_labels:
+ labels = array_ops.expand_dims(labels, 2)
+ fp, fp_update_op = metrics.streaming_false_positives(predictions,
+ labels)
- with self.test_session() as sess:
- sess.run(variables.local_variables_initializer())
- self.assertEqual(0, fp.eval())
- self.assertEqual(4, fp_update_op.eval())
- self.assertEqual(4, fp.eval())
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertEqual(0, fp.eval())
+ self.assertEqual(4, fp_update_op.eval())
+ self.assertEqual(4, fp.eval())
def testWeighted(self):
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
@@ -814,22 +835,29 @@ class StreamingTrueNegativesTest(test.TestCase):
_assert_local_variables(self, ('true_negatives/count:0',))
def testUnweighted(self):
- for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
- tn, tn_update_op = metrics.streaming_true_negatives(predictions, labels)
+ for expand_predictions in [True, False]:
+ for expand_labels in [True, False]:
+ for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
+ predictions = math_ops.cast(constant_op.constant(
+ ((1, 0, 1, 0),
+ (0, 1, 1, 1),
+ (0, 0, 0, 0))), dtype=dtype)
+ if expand_predictions:
+ predictions = array_ops.expand_dims(predictions, 2)
+ labels = math_ops.cast(constant_op.constant(
+ ((0, 1, 1, 0),
+ (1, 0, 0, 0),
+ (0, 0, 0, 0))), dtype=dtype)
+ if expand_labels:
+ labels = array_ops.expand_dims(labels, 2)
+ tn, tn_update_op = metrics.streaming_true_negatives(predictions,
+ labels)
- with self.test_session() as sess:
- sess.run(variables.local_variables_initializer())
- self.assertEqual(0, tn.eval())
- self.assertEqual(5, tn_update_op.eval())
- self.assertEqual(5, tn.eval())
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertEqual(0, tn.eval())
+ self.assertEqual(5, tn_update_op.eval())
+ self.assertEqual(5, tn.eval())
def testWeighted(self):
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 46cc1f4dda..fab4c5cb0f 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -1502,9 +1502,10 @@ def false_negatives(labels, predictions, weights=None,
with variable_scope.variable_scope(
name, 'false_negatives', (predictions, labels, weights)):
- labels = math_ops.cast(labels, dtype=dtypes.bool)
- predictions = math_ops.cast(predictions, dtype=dtypes.bool)
- predictions.get_shape().assert_is_compatible_with(labels.get_shape())
+ predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
+ weights=weights)
is_false_negative = math_ops.logical_and(math_ops.equal(labels, True),
math_ops.equal(predictions, False))
return _count_condition(is_false_negative, weights, metrics_collections,