diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 10:38:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 10:38:58 -0700 |
commit | b5b75662b0a82493f474434b3861006a304eebe2 (patch) | |
tree | 5ef5da55f3916806445607697a80d3ba920914d1 /tensorflow/python/keras | |
parent | 8622f05a62948d8966be8962a6a33e0a8b5a116d (diff) | |
parent | 039ddaa6c0af4be4291383564db5a964d0035c1d (diff) |
Merge pull request #22392 from yanboliang:metrics
PiperOrigin-RevId: 215760505
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r-- | tensorflow/python/keras/metrics.py | 14 | ||||
-rw-r--r-- | tensorflow/python/keras/metrics_test.py | 26 |
2 files changed, 34 insertions, 6 deletions
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index f4e8419eb0..d217244e2f 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -651,7 +651,9 @@ def categorical_accuracy(y_true, y_pred): @tf_export('keras.metrics.sparse_categorical_accuracy') def sparse_categorical_accuracy(y_true, y_pred): - y_true = math_ops.reduce_max(y_true, axis=-1) + # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) + if (len(K.int_shape(y_true)) == len(K.int_shape(y_pred))): + y_true = array_ops.squeeze(y_true, [-1]) y_pred = math_ops.argmax(y_pred, axis=-1) # If the expected labels are float, we need to cast the int returned by @@ -670,11 +672,11 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5): @tf_export('keras.metrics.sparse_top_k_categorical_accuracy') def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): - return K.mean( - nn.in_top_k(y_pred, - math_ops.cast(math_ops.reduce_max(y_true, axis=-1), 'int32'), - k), - axis=-1) + # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) + if (len(K.int_shape(y_true)) == len(K.int_shape(y_pred))): + y_true = array_ops.squeeze(y_true, [-1]) + + return K.mean(nn.in_top_k(y_pred, math_ops.cast(y_true, 'int32'), k), axis=-1) # Aliases diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index 4195ea18ad..5f5565d4d5 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -54,6 +54,18 @@ class KerasMetricsTest(test.TestCase): y_pred = K.variable(np.random.random((6, 7))) self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,)) + # Test correctness if the shape of y_true is (num_samples,) + y_true = K.variable([1., 0., 0., 0.]) + y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]]) + print(K.eval(metric(y_true, y_pred))) + self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.]) + + # Test correctness if the shape of y_true is (num_samples, 1) + y_true = K.variable([[1.], [0.], [0.], [0.]]) + y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]]) + print(K.eval(metric(y_true, y_pred))) + self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.]) + def test_sparse_categorical_accuracy_float(self): with self.cached_session(): metric = metrics.sparse_categorical_accuracy @@ -79,6 +91,7 @@ class KerasMetricsTest(test.TestCase): def test_sparse_top_k_categorical_accuracy(self): with self.cached_session(): + # Test correctness if the shape of y_true is (num_samples, 1) y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) y_true = K.variable(np.array([[1], [0]])) result = K.eval( @@ -91,6 +104,19 @@ class KerasMetricsTest(test.TestCase): metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1)) self.assertEqual(result, 0.) + # Test correctness if the shape of y_true is (num_samples,) + y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) + y_true = K.variable(np.array([1, 0])) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3)) + self.assertEqual(result, 1) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2)) + self.assertEqual(result, 0.5) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1)) + self.assertEqual(result, 0.) + def test_top_k_categorical_accuracy(self): with self.cached_session(): y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) |