diff options
Diffstat (limited to 'tensorflow/python/keras/metrics_test.py')
-rw-r--r-- | tensorflow/python/keras/metrics_test.py | 285 |
1 files changed, 249 insertions, 36 deletions
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index 15e793f5fc..d583379708 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -18,67 +18,72 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np -from tensorflow.python import keras +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import layers +from tensorflow.python.keras import metrics +from tensorflow.python.keras.engine.training import Model +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import util as checkpointable_utils class KerasMetricsTest(test.TestCase): def test_metrics(self): with self.test_session(): - y_a = keras.backend.variable(np.random.random((6, 7))) - y_b = keras.backend.variable(np.random.random((6, 7))) - for metric in [keras.metrics.binary_accuracy, - keras.metrics.categorical_accuracy]: + y_a = K.variable(np.random.random((6, 7))) + y_b = K.variable(np.random.random((6, 7))) + for metric in [metrics.binary_accuracy, metrics.categorical_accuracy]: output = metric(y_a, y_b) - self.assertEqual(keras.backend.eval(output).shape, (6,)) + self.assertEqual(K.eval(output).shape, (6,)) def test_sparse_categorical_accuracy(self): with self.test_session(): - metric = keras.metrics.sparse_categorical_accuracy - y_a = keras.backend.variable(np.random.randint(0, 7, (6,))) - y_b = keras.backend.variable(np.random.random((6, 7))) - self.assertEqual(keras.backend.eval(metric(y_a, y_b)).shape, (6,)) + metric = metrics.sparse_categorical_accuracy + y_a = K.variable(np.random.randint(0, 7, (6,))) + y_b = K.variable(np.random.random((6, 7))) + self.assertEqual(K.eval(metric(y_a, y_b)).shape, (6,)) def test_sparse_top_k_categorical_accuracy(self): with self.test_session(): - y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1], - [0.1, 0.2, 0.7]])) - y_true = keras.backend.variable(np.array([[1], [0]])) - result = keras.backend.eval( - keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3)) + y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) + y_true = K.variable(np.array([[1], [0]])) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3)) self.assertEqual(result, 1) - result = keras.backend.eval( - keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2)) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2)) self.assertEqual(result, 0.5) - result = keras.backend.eval( - keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1)) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1)) self.assertEqual(result, 0.) def test_top_k_categorical_accuracy(self): with self.test_session(): - y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1], - [0.1, 0.2, 0.7]])) - y_true = keras.backend.variable(np.array([[0, 1, 0], [1, 0, 0]])) - result = keras.backend.eval( - keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3)) + y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) + y_true = K.variable(np.array([[0, 1, 0], [1, 0, 0]])) + result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3)) self.assertEqual(result, 1) - result = keras.backend.eval( - keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=2)) + result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=2)) self.assertEqual(result, 0.5) - result = keras.backend.eval( - keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=1)) + result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=1)) self.assertEqual(result, 0.) def test_stateful_metrics(self): with self.test_session(): np.random.seed(1334) - class BinaryTruePositives(keras.layers.Layer): + class BinaryTruePositives(layers.Layer): """Stateful Metric to count the total true positives over all batches. Assumes predictions and targets of shape `(samples, 1)`. @@ -91,11 +96,11 @@ class KerasMetricsTest(test.TestCase): def __init__(self, name='true_positives', **kwargs): super(BinaryTruePositives, self).__init__(name=name, **kwargs) - self.true_positives = keras.backend.variable(value=0, dtype='int32') + self.true_positives = K.variable(value=0, dtype='int32') self.stateful = True def reset_states(self): - keras.backend.set_value(self.true_positives, 0) + K.set_value(self.true_positives, 0) def __call__(self, y_true, y_pred): """Computes the number of true positives in a batch. @@ -120,14 +125,14 @@ class KerasMetricsTest(test.TestCase): return current_true_pos + true_pos metric_fn = BinaryTruePositives() - config = keras.metrics.serialize(metric_fn) - metric_fn = keras.metrics.deserialize( + config = metrics.serialize(metric_fn) + metric_fn = metrics.deserialize( config, custom_objects={'BinaryTruePositives': BinaryTruePositives}) # Test on simple model - inputs = keras.Input(shape=(2,)) - outputs = keras.layers.Dense(1, activation='sigmoid')(inputs) - model = keras.Model(inputs, outputs) + inputs = layers.Input(shape=(2,)) + outputs = layers.Dense(1, activation='sigmoid')(inputs) + model = Model(inputs, outputs) model.compile(optimizer='sgd', loss='binary_crossentropy', metrics=['acc', metric_fn]) @@ -184,6 +189,214 @@ class KerasMetricsTest(test.TestCase): self.assertAllClose( val_outs[2], history.history['val_true_positives'][-1], atol=1e-5) + @test_util.run_in_graph_and_eager_modes + def test_mean(self): + m = metrics.Mean(name='my_mean') + + # check config + self.assertEqual(m.name, 'my_mean') + self.assertTrue(m.stateful) + self.assertEqual(m.dtype, dtypes.float32) + self.assertEqual(len(m.variables), 2) + self.evaluate(variables.global_variables_initializer()) + + # check initial state + self.assertEqual(self.evaluate(m.total), 0) + self.assertEqual(self.evaluate(m.count), 0) + + # check __call__() + self.assertEqual(self.evaluate(m(100)), 100) + self.assertEqual(self.evaluate(m.total), 100) + self.assertEqual(self.evaluate(m.count), 1) + + # check update_state() and result() + state accumulation + tensor input + update_op = m.update_state(ops.convert_n_to_tensor([1, 5])) + self.evaluate(update_op) + self.assertAlmostEqual(self.evaluate(m.result()), 106 / 3, 2) + self.assertEqual(self.evaluate(m.total), 106) # 100 + 1 + 5 + self.assertEqual(self.evaluate(m.count), 3) + + # check reset_states() + m.reset_states() + self.assertEqual(self.evaluate(m.total), 0) + self.assertEqual(self.evaluate(m.count), 0) + + @test_util.run_in_graph_and_eager_modes + def test_mean_with_sample_weight(self): + m = metrics.Mean(dtype=dtypes.float64) + self.assertEqual(m.dtype, dtypes.float64) + self.evaluate(variables.global_variables_initializer()) + + # check scalar weight + result_t = m(100, sample_weight=0.5) + self.assertEqual(self.evaluate(result_t), 50 / 0.5) + self.assertEqual(self.evaluate(m.total), 50) + self.assertEqual(self.evaluate(m.count), 0.5) + + # check weights not scalar and weights rank matches values rank + result_t = m([1, 5], sample_weight=[1, 0.2]) + result = self.evaluate(result_t) + self.assertAlmostEqual(result, 52 / 1.7, 2) + self.assertAlmostEqual(self.evaluate(m.total), 52, 2) # 50 + 1 + 5 * 0.2 + self.assertAlmostEqual(self.evaluate(m.count), 1.7, 2) # 0.5 + 1.2 + + # check weights broadcast + result_t = m([1, 2], sample_weight=0.5) + self.assertAlmostEqual(self.evaluate(result_t), 53.5 / 2.7, 2) + self.assertAlmostEqual(self.evaluate(m.total), 53.5, 2) # 52 + 0.5 + 1 + self.assertAlmostEqual(self.evaluate(m.count), 2.7, 2) # 1.7 + 0.5 + 0.5 + + # check weights squeeze + result_t = m([1, 5], sample_weight=[[1], [0.2]]) + self.assertAlmostEqual(self.evaluate(result_t), 55.5 / 3.9, 2) + self.assertAlmostEqual(self.evaluate(m.total), 55.5, 2) # 53.5 + 1 + 1 + self.assertAlmostEqual(self.evaluate(m.count), 3.9, 2) # 2.7 + 1.2 + + # check weights expand + result_t = m([[1], [5]], sample_weight=[1, 0.2]) + self.assertAlmostEqual(self.evaluate(result_t), 57.5 / 5.1, 2) + self.assertAlmostEqual(self.evaluate(m.total), 57.5, 2) # 55.5 + 1 + 1 + self.assertAlmostEqual(self.evaluate(m.count), 5.1, 2) # 3.9 + 1.2 + + def test_mean_graph_with_placeholder(self): + with context.graph_mode(), self.test_session() as sess: + m = metrics.Mean() + v = array_ops.placeholder(dtypes.float32) + w = array_ops.placeholder(dtypes.float32) + sess.run(variables.global_variables_initializer()) + + # check __call__() + result_t = m(v, sample_weight=w) + result = sess.run(result_t, feed_dict=({v: 100, w: 0.5})) + self.assertEqual(sess.run(m.total), 50) + self.assertEqual(sess.run(m.count), 0.5) + self.assertEqual(result, 50 / 0.5) + + # check update_state() and result() + result = sess.run(result_t, feed_dict=({v: [1, 5], w: [1, 0.2]})) + self.assertAlmostEqual(sess.run(m.total), 52, 2) # 50 + 1 + 5 * 0.2 + self.assertAlmostEqual(sess.run(m.count), 1.7, 2) # 0.5 + 1.2 + self.assertAlmostEqual(result, 52 / 1.7, 2) + + @test_util.run_in_graph_and_eager_modes + def test_save_restore(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + m = metrics.Mean() + checkpoint = checkpointable_utils.Checkpoint(mean=m) + self.evaluate(variables.global_variables_initializer()) + + # update state + self.evaluate(m(100.)) + self.evaluate(m(200.)) + + # save checkpoint and then add an update + save_path = checkpoint.save(checkpoint_prefix) + self.evaluate(m(1000.)) + + # restore to the same checkpoint mean object + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.evaluate(m(300.)) + self.assertEqual(200., self.evaluate(m.result())) + + # restore to a different checkpoint mean object + restore_mean = metrics.Mean() + restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean) + status = restore_checkpoint.restore(save_path) + restore_update = restore_mean(300.) + status.assert_consumed().run_restore_ops() + self.evaluate(restore_update) + self.assertEqual(200., self.evaluate(restore_mean.result())) + self.assertEqual(3, self.evaluate(restore_mean.count)) + + @test_util.run_in_graph_and_eager_modes + def test_binary_accuracy(self): + acc_obj = metrics.BinaryAccuracy(name='my acc') + + # check config + self.assertEqual(acc_obj.name, 'my acc') + self.assertTrue(acc_obj.stateful) + self.assertEqual(len(acc_obj.variables), 2) + self.assertEqual(acc_obj.dtype, dtypes.float32) + self.evaluate(variables.global_variables_initializer()) + + # verify that correct value is returned + update_op = acc_obj.update_state([[1], [0]], [[1], [0]]) + self.evaluate(update_op) + result = self.evaluate(acc_obj.result()) + self.assertEqual(result, 1) # 2/2 + + # check y_pred squeeze + update_op = acc_obj.update_state([[1], [1]], [[[1]], [[0]]]) + self.evaluate(update_op) + result = self.evaluate(acc_obj.result()) + self.assertAlmostEqual(result, 0.75, 2) # 3/4 + + # check y_true squeeze + result_t = acc_obj([[[1]], [[1]]], [[1], [0]]) + result = self.evaluate(result_t) + self.assertAlmostEqual(result, 0.67, 2) # 4/6 + + # check with sample_weight + result_t = acc_obj([[1], [1]], [[1], [0]], [[0.5], [0.2]]) + result = self.evaluate(result_t) + self.assertAlmostEqual(result, 0.67, 2) # 4.5/6.7 + + # check incompatible shapes + with self.assertRaisesRegexp(ValueError, + r'Shapes \(1,\) and \(2,\) are incompatible'): + acc_obj.update_state([1, 1], [1]) + + @test_util.run_in_graph_and_eager_modes + def test_binary_accuracy_threshold(self): + acc_obj = metrics.BinaryAccuracy(threshold=0.7) + self.evaluate(variables.global_variables_initializer()) + result_t = acc_obj([[1], [1], [0], [0]], [[0.9], [0.6], [0.4], [0.8]]) + result = self.evaluate(result_t) + self.assertAlmostEqual(result, 0.5, 2) + + @test_util.run_in_graph_and_eager_modes + def test_invalid_result(self): + + class InvalidResult(metrics.Metric): + + def __init__(self, name='invalid-result', dtype=dtypes.float64): + super(InvalidResult, self).__init__(name=name, dtype=dtype) + + def update_state(self, *args, **kwargs): + pass + + def result(self): + return 1 + + invalid_result_obj = InvalidResult() + with self.assertRaisesRegexp( + TypeError, + 'Metric invalid-result\'s result must be a Tensor or Operation, given:' + ): + invalid_result_obj.result() + + @test_util.run_in_graph_and_eager_modes + def test_invalid_update(self): + + class InvalidUpdate(metrics.Metric): + + def __init__(self, name='invalid-update', dtype=dtypes.float64): + super(InvalidUpdate, self).__init__(name=name, dtype=dtype) + + def update_state(self, *args, **kwargs): + return [1] + + def result(self): + pass + + invalid_update_obj = InvalidUpdate() + with self.assertRaisesRegexp( + TypeError, + 'Metric invalid-update\'s update must be a Tensor or Operation, given:' + ): + invalid_update_obj.update_state() + if __name__ == '__main__': test.main() |