aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 10:38:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 10:38:58 -0700
commitb5b75662b0a82493f474434b3861006a304eebe2 (patch)
tree5ef5da55f3916806445607697a80d3ba920914d1 /tensorflow/python
parent8622f05a62948d8966be8962a6a33e0a8b5a116d (diff)
parent039ddaa6c0af4be4291383564db5a964d0035c1d (diff)
Merge pull request #22392 from yanboliang:metrics
PiperOrigin-RevId: 215760505
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/keras/metrics.py14
-rw-r--r--tensorflow/python/keras/metrics_test.py26
2 files changed, 34 insertions, 6 deletions
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index f4e8419eb0..d217244e2f 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -651,7 +651,9 @@ def categorical_accuracy(y_true, y_pred):
@tf_export('keras.metrics.sparse_categorical_accuracy')
def sparse_categorical_accuracy(y_true, y_pred):
- y_true = math_ops.reduce_max(y_true, axis=-1)
+ # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
+ if (len(K.int_shape(y_true)) == len(K.int_shape(y_pred))):
+ y_true = array_ops.squeeze(y_true, [-1])
y_pred = math_ops.argmax(y_pred, axis=-1)
# If the expected labels are float, we need to cast the int returned by
@@ -670,11 +672,11 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5):
@tf_export('keras.metrics.sparse_top_k_categorical_accuracy')
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
- return K.mean(
- nn.in_top_k(y_pred,
- math_ops.cast(math_ops.reduce_max(y_true, axis=-1), 'int32'),
- k),
- axis=-1)
+ # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
+ if (len(K.int_shape(y_true)) == len(K.int_shape(y_pred))):
+ y_true = array_ops.squeeze(y_true, [-1])
+
+ return K.mean(nn.in_top_k(y_pred, math_ops.cast(y_true, 'int32'), k), axis=-1)
# Aliases
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 4195ea18ad..5f5565d4d5 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -54,6 +54,18 @@ class KerasMetricsTest(test.TestCase):
y_pred = K.variable(np.random.random((6, 7)))
self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,))
+ # Test correctness if the shape of y_true is (num_samples,)
+ y_true = K.variable([1., 0., 0., 0.])
+ y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]])
+ print(K.eval(metric(y_true, y_pred)))
+ self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.])
+
+ # Test correctness if the shape of y_true is (num_samples, 1)
+ y_true = K.variable([[1.], [0.], [0.], [0.]])
+ y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]])
+ print(K.eval(metric(y_true, y_pred)))
+ self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.])
+
def test_sparse_categorical_accuracy_float(self):
with self.cached_session():
metric = metrics.sparse_categorical_accuracy
@@ -79,6 +91,7 @@ class KerasMetricsTest(test.TestCase):
def test_sparse_top_k_categorical_accuracy(self):
with self.cached_session():
+ # Test correctness if the shape of y_true is (num_samples, 1)
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
y_true = K.variable(np.array([[1], [0]]))
result = K.eval(
@@ -91,6 +104,19 @@ class KerasMetricsTest(test.TestCase):
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(result, 0.)
+ # Test correctness if the shape of y_true is (num_samples,)
+ y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
+ y_true = K.variable(np.array([1, 0]))
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
+ self.assertEqual(result, 1)
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
+ self.assertEqual(result, 0.5)
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
+ self.assertEqual(result, 0.)
+
def test_top_k_categorical_accuracy(self):
with self.cached_session():
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))