aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/metrics_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/metrics_test.py')
-rw-r--r--tensorflow/python/kernel_tests/metrics_test.py29
1 files changed, 7 insertions, 22 deletions
diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py
index 3358b78efd..e0e752147c 100644
--- a/tensorflow/python/kernel_tests/metrics_test.py
+++ b/tensorflow/python/kernel_tests/metrics_test.py
@@ -3628,7 +3628,8 @@ class MeanPerClassAccuracyTest(test.TestCase):
predictions=array_ops.ones([10, 1]),
labels=array_ops.ones([10, 1]),
num_classes=2)
- _assert_metric_variables(self, ('mean_accuracy/total_confusion_matrix:0',))
+ _assert_metric_variables(self, ('mean_accuracy/count:0',
+ 'mean_accuracy/total:0'))
def testMetricsCollections(self):
my_collection_name = '__metrics__'
@@ -3797,23 +3798,6 @@ class MeanPerClassAccuracyTest(test.TestCase):
desired_output = np.mean([1.0 / 2.0, 2.0 / 3.0, 0.])
self.assertAlmostEqual(desired_output, mean_accuracy.eval())
- def testUpdateOpEvalIsAccumulatedConfusionMatrix(self):
- predictions = array_ops.concat([
- constant_op.constant(0, shape=[5]), constant_op.constant(1, shape=[5])
- ], 0)
- labels = array_ops.concat([
- constant_op.constant(0, shape=[3]), constant_op.constant(1, shape=[7])
- ], 0)
- num_classes = 2
- with self.test_session() as sess:
- mean_accuracy, update_op = metrics.mean_per_class_accuracy(
- labels, predictions, num_classes)
- sess.run(variables.local_variables_initializer())
- confusion_matrix = update_op.eval()
- self.assertAllEqual([[3, 0], [2, 5]], confusion_matrix)
- desired_mean_accuracy = np.mean([3. / 3., 5. / 7.])
- self.assertAlmostEqual(desired_mean_accuracy, mean_accuracy.eval())
-
def testAllCorrect(self):
predictions = array_ops.zeros([40])
labels = array_ops.zeros([40])
@@ -3822,7 +3806,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
- self.assertEqual(40, update_op.eval()[0])
+ self.assertEqual(1.0, update_op.eval()[0])
self.assertEqual(1.0, mean_accuracy.eval())
def testAllWrong(self):
@@ -3833,7 +3817,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
- self.assertAllEqual([[0, 0], [40, 0]], update_op.eval())
+ self.assertAllEqual([0.0, 0.0], update_op.eval())
self.assertEqual(0., mean_accuracy.eval())
def testResultsWithSomeMissing(self):
@@ -3852,8 +3836,9 @@ class MeanPerClassAccuracyTest(test.TestCase):
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
labels, predictions, num_classes, weights=weights)
sess.run(variables.local_variables_initializer())
- self.assertAllEqual([[2, 0], [2, 4]], update_op.eval())
- desired_mean_accuracy = np.mean([2. / 2., 4. / 6.])
+ desired_accuracy = np.array([2. / 2., 4. / 6.], dtype=np.float32)
+ self.assertAllEqual(desired_accuracy, update_op.eval())
+ desired_mean_accuracy = np.mean(desired_accuracy)
self.assertAlmostEqual(desired_mean_accuracy, mean_accuracy.eval())