aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/metrics_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/metrics_test.py')
-rw-r--r--tensorflow/python/keras/metrics_test.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 4195ea18ad..5f5565d4d5 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -54,6 +54,18 @@ class KerasMetricsTest(test.TestCase):
y_pred = K.variable(np.random.random((6, 7)))
self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,))
+ # Test correctness if the shape of y_true is (num_samples,)
+ y_true = K.variable([1., 0., 0., 0.])
+ y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]])
+ print(K.eval(metric(y_true, y_pred)))
+ self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.])
+
+ # Test correctness if the shape of y_true is (num_samples, 1)
+ y_true = K.variable([[1.], [0.], [0.], [0.]])
+ y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]])
+ print(K.eval(metric(y_true, y_pred)))
+ self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.])
+
def test_sparse_categorical_accuracy_float(self):
with self.cached_session():
metric = metrics.sparse_categorical_accuracy
@@ -79,6 +91,7 @@ class KerasMetricsTest(test.TestCase):
def test_sparse_top_k_categorical_accuracy(self):
with self.cached_session():
+ # Test correctness if the shape of y_true is (num_samples, 1)
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
y_true = K.variable(np.array([[1], [0]]))
result = K.eval(
@@ -91,6 +104,19 @@ class KerasMetricsTest(test.TestCase):
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(result, 0.)
+ # Test correctness if the shape of y_true is (num_samples,)
+ y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
+ y_true = K.variable(np.array([1, 0]))
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
+ self.assertEqual(result, 1)
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
+ self.assertEqual(result, 0.5)
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
+ self.assertEqual(result, 0.)
+
def test_top_k_categorical_accuracy(self):
with self.cached_session():
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))