aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/metrics/python/ops/metric_ops_test.py')
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py54
1 files changed, 53 insertions, 1 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index f24bec7f11..6e038481e3 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -5856,7 +5856,7 @@ class StreamingMeanIOUTest(test.TestCase):
sess.run(variables.local_variables_initializer())
for _ in range(5):
sess.run(update_op)
- desired_output = np.mean([1.0 / 3.0, 2.0 / 4.0, 0.])
+ desired_output = np.mean([1.0 / 3.0, 2.0 / 4.0])
self.assertAlmostEqual(desired_output, miou.eval())
def testUpdateOpEvalIsAccumulatedConfusionMatrix(self):
@@ -5938,6 +5938,58 @@ class StreamingMeanIOUTest(test.TestCase):
desired_miou = np.mean([2. / 4., 4. / 6.])
self.assertAlmostEqual(desired_miou, miou.eval())
+ def testMissingClassInLabels(self):
+ labels = constant_op.constant([
+ [[0, 0, 1, 1, 0, 0],
+ [1, 0, 0, 0, 0, 1]],
+ [[1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0]]])
+ predictions = constant_op.constant([
+ [[0, 0, 2, 1, 1, 0],
+ [0, 1, 2, 2, 0, 1]],
+ [[0, 0, 2, 1, 1, 1],
+ [1, 1, 2, 0, 0, 0]]])
+ num_classes = 3
+ with self.test_session() as sess:
+ miou, update_op = metrics.streaming_mean_iou(
+ predictions, labels, num_classes)
+ sess.run(variables.local_variables_initializer())
+ self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op.eval())
+ self.assertAlmostEqual(
+ 1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 / (0 + 5 + 0)),
+ miou.eval())
+
+ def testMissingClassOverallSmall(self):
+ labels = constant_op.constant([0])
+ predictions = constant_op.constant([0])
+ num_classes = 2
+ with self.test_session() as sess:
+ miou, update_op = metrics.streaming_mean_iou(
+ predictions, labels, num_classes)
+ sess.run(variables.local_variables_initializer())
+ self.assertAllEqual([[1, 0], [0, 0]], update_op.eval())
+ self.assertAlmostEqual(1, miou.eval())
+
+ def testMissingClassOverallLarge(self):
+ labels = constant_op.constant([
+ [[0, 0, 1, 1, 0, 0],
+ [1, 0, 0, 0, 0, 1]],
+ [[1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0]]])
+ predictions = constant_op.constant([
+ [[0, 0, 1, 1, 0, 0],
+ [1, 1, 0, 0, 1, 1]],
+ [[0, 0, 0, 1, 1, 1],
+ [1, 1, 1, 0, 0, 0]]])
+ num_classes = 3
+ with self.test_session() as sess:
+ miou, update_op = metrics.streaming_mean_iou(
+ predictions, labels, num_classes)
+ sess.run(variables.local_variables_initializer())
+ self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op.eval())
+ self.assertAlmostEqual(
+ 1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)), miou.eval())
+
class StreamingConcatTest(test.TestCase):