diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/metrics_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/metrics_test.py | 29 |
1 files changed, 7 insertions, 22 deletions
diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py index 3358b78efd..e0e752147c 100644 --- a/tensorflow/python/kernel_tests/metrics_test.py +++ b/tensorflow/python/kernel_tests/metrics_test.py @@ -3628,7 +3628,8 @@ class MeanPerClassAccuracyTest(test.TestCase): predictions=array_ops.ones([10, 1]), labels=array_ops.ones([10, 1]), num_classes=2) - _assert_metric_variables(self, ('mean_accuracy/total_confusion_matrix:0',)) + _assert_metric_variables(self, ('mean_accuracy/count:0', + 'mean_accuracy/total:0')) def testMetricsCollections(self): my_collection_name = '__metrics__' @@ -3797,23 +3798,6 @@ class MeanPerClassAccuracyTest(test.TestCase): desired_output = np.mean([1.0 / 2.0, 2.0 / 3.0, 0.]) self.assertAlmostEqual(desired_output, mean_accuracy.eval()) - def testUpdateOpEvalIsAccumulatedConfusionMatrix(self): - predictions = array_ops.concat([ - constant_op.constant(0, shape=[5]), constant_op.constant(1, shape=[5]) - ], 0) - labels = array_ops.concat([ - constant_op.constant(0, shape=[3]), constant_op.constant(1, shape=[7]) - ], 0) - num_classes = 2 - with self.test_session() as sess: - mean_accuracy, update_op = metrics.mean_per_class_accuracy( - labels, predictions, num_classes) - sess.run(variables.local_variables_initializer()) - confusion_matrix = update_op.eval() - self.assertAllEqual([[3, 0], [2, 5]], confusion_matrix) - desired_mean_accuracy = np.mean([3. / 3., 5. / 7.]) - self.assertAlmostEqual(desired_mean_accuracy, mean_accuracy.eval()) - def testAllCorrect(self): predictions = array_ops.zeros([40]) labels = array_ops.zeros([40]) @@ -3822,7 +3806,7 @@ class MeanPerClassAccuracyTest(test.TestCase): mean_accuracy, update_op = metrics.mean_per_class_accuracy( labels, predictions, num_classes) sess.run(variables.local_variables_initializer()) - self.assertEqual(40, update_op.eval()[0]) + self.assertEqual(1.0, update_op.eval()[0]) self.assertEqual(1.0, mean_accuracy.eval()) def testAllWrong(self): @@ -3833,7 +3817,7 @@ class MeanPerClassAccuracyTest(test.TestCase): mean_accuracy, update_op = metrics.mean_per_class_accuracy( labels, predictions, num_classes) sess.run(variables.local_variables_initializer()) - self.assertAllEqual([[0, 0], [40, 0]], update_op.eval()) + self.assertAllEqual([0.0, 0.0], update_op.eval()) self.assertEqual(0., mean_accuracy.eval()) def testResultsWithSomeMissing(self): @@ -3852,8 +3836,9 @@ class MeanPerClassAccuracyTest(test.TestCase): mean_accuracy, update_op = metrics.mean_per_class_accuracy( labels, predictions, num_classes, weights=weights) sess.run(variables.local_variables_initializer()) - self.assertAllEqual([[2, 0], [2, 4]], update_op.eval()) - desired_mean_accuracy = np.mean([2. / 2., 4. / 6.]) + desired_accuracy = np.array([2. / 2., 4. / 6.], dtype=np.float32) + self.assertAllEqual(desired_accuracy, update_op.eval()) + desired_mean_accuracy = np.mean(desired_accuracy) self.assertAlmostEqual(desired_mean_accuracy, mean_accuracy.eval()) |