diff options
Diffstat (limited to 'tensorflow/python/keras/metrics_test.py')
-rw-r--r-- | tensorflow/python/keras/metrics_test.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index 4195ea18ad..5f5565d4d5 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -54,6 +54,18 @@ class KerasMetricsTest(test.TestCase): y_pred = K.variable(np.random.random((6, 7))) self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,)) + # Test correctness if the shape of y_true is (num_samples,) + y_true = K.variable([1., 0., 0., 0.]) + y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]]) + print(K.eval(metric(y_true, y_pred))) + self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.]) + + # Test correctness if the shape of y_true is (num_samples, 1) + y_true = K.variable([[1.], [0.], [0.], [0.]]) + y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]]) + print(K.eval(metric(y_true, y_pred))) + self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.]) + def test_sparse_categorical_accuracy_float(self): with self.cached_session(): metric = metrics.sparse_categorical_accuracy @@ -79,6 +91,7 @@ class KerasMetricsTest(test.TestCase): def test_sparse_top_k_categorical_accuracy(self): with self.cached_session(): + # Test correctness if the shape of y_true is (num_samples, 1) y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) y_true = K.variable(np.array([[1], [0]])) result = K.eval( @@ -91,6 +104,19 @@ class KerasMetricsTest(test.TestCase): metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1)) self.assertEqual(result, 0.) + # Test correctness if the shape of y_true is (num_samples,) + y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) + y_true = K.variable(np.array([1, 0])) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3)) + self.assertEqual(result, 1) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2)) + self.assertEqual(result, 0.5) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1)) + self.assertEqual(result, 0.) + def test_top_k_categorical_accuracy(self): with self.cached_session(): y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) |