diff options
Diffstat (limited to 'tensorflow/python/keras/callbacks_test.py')
-rw-r--r-- | tensorflow/python/keras/callbacks_test.py | 118 |
1 files changed, 111 insertions, 7 deletions
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 467bc4cdc4..bb85347033 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -313,6 +313,42 @@ class KerasCallbacksTest(test.TestCase): hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) assert len(hist.epoch) >= patience + def test_EarlyStopping_final_weights_when_restoring_model_weights(self): + + class DummyModel(object): + + def __init__(self): + self.stop_training = False + self.weights = -1 + + def get_weights(self): + return self.weights + + def set_weights(self, weights): + self.weights = weights + + def set_weight_to_epoch(self, epoch): + self.weights = epoch + + early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', + patience=2, + restore_best_weights=True) + early_stop.model = DummyModel() + losses = [0.2, 0.15, 0.1, 0.11, 0.12] + # The best configuration is in the epoch 2 (loss = 0.1000). + epochs_trained = 0 + early_stop.on_train_begin() + for epoch in range(len(losses)): + epochs_trained += 1 + early_stop.model.set_weight_to_epoch(epoch=epoch) + early_stop.on_epoch_end(epoch, logs={'val_loss': losses[epoch]}) + if early_stop.model.stop_training: + break + # The best configuration is in epoch 2 (loss = 0.1000), + # and while patience = 2, we're restoring the best weights, + # so we end up at the epoch with the best weights, i.e. epoch 2 + self.assertEqual(early_stop.model.get_weights(), 2) + def test_RemoteMonitor(self): if requests is None: return @@ -534,11 +570,15 @@ class KerasCallbacksTest(test.TestCase): batch_size=BATCH_SIZE, validation_data=(x_test, y_test), callbacks=cbks, - epochs=1, + epochs=2, verbose=0) with open(filepath) as csvfile: - output = ' '.join(csvfile.readlines()) + list_lines = csvfile.readlines() + for line in list_lines: + assert line.count(sep) == 4 + assert len(list_lines) == 5 + output = ' '.join(list_lines) assert len(re.findall('epoch', output)) == 1 os.remove(filepath) @@ -1115,11 +1155,11 @@ class KerasCallbacksTest(test.TestCase): temp_dir = self.get_temp_dir() self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) - tb_cbk = keras.callbacks.TensorBoard(temp_dir) + tb_cbk = keras.callbacks.TensorBoard(temp_dir, update_freq='batch') tb_cbk.writer = FileWriterStub(temp_dir) for batch in range(5): - tb_cbk.on_batch_end(batch, {'acc': np.float32(batch)}) + tb_cbk.on_batch_end(batch, {'acc': batch}) self.assertEqual(tb_cbk.writer.batches_logged, [0, 1, 2, 3, 4]) self.assertEqual(tb_cbk.writer.summary_values, [0., 1., 2., 3., 4.]) self.assertEqual(tb_cbk.writer.summary_tags, ['batch_acc'] * 5) @@ -1147,14 +1187,17 @@ class KerasCallbacksTest(test.TestCase): temp_dir = self.get_temp_dir() self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) - tb_cbk = keras.callbacks.TensorBoard(temp_dir) + tb_cbk = keras.callbacks.TensorBoard(temp_dir, update_freq='batch') tb_cbk.writer = FileWriterStub(temp_dir) - tb_cbk.on_batch_end(0, {'acc': np.float32(5.0)}) - tb_cbk.on_epoch_end(0, {'acc': np.float32(10.0)}) + tb_cbk.on_batch_end(0, {'acc': 5.0}) batch_step, batch_summary = tb_cbk.writer.batch_summary self.assertEqual(batch_step, 0) self.assertEqual(batch_summary.value[0].simple_value, 5.0) + + tb_cbk = keras.callbacks.TensorBoard(temp_dir, update_freq='epoch') + tb_cbk.writer = FileWriterStub(temp_dir) + tb_cbk.on_epoch_end(0, {'acc': 10.0}) epoch_step, epoch_summary = tb_cbk.writer.epoch_summary self.assertEqual(epoch_step, 0) self.assertEqual(epoch_summary.value[0].simple_value, 10.0) @@ -1192,6 +1235,66 @@ class KerasCallbacksTest(test.TestCase): self.assertTrue(os.path.exists(temp_dir)) + def test_TensorBoard_update_freq(self): + + class FileWriterStub(object): + + def __init__(self, logdir, graph=None): + self.logdir = logdir + self.graph = graph + self.batch_summaries = [] + self.epoch_summaries = [] + + def add_summary(self, summary, step): + if 'batch_' in summary.value[0].tag: + self.batch_summaries.append((step, summary)) + elif 'epoch_' in summary.value[0].tag: + self.epoch_summaries.append((step, summary)) + + def flush(self): + pass + + def close(self): + pass + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) + + # Epoch mode + tb_cbk = keras.callbacks.TensorBoard(temp_dir, update_freq='epoch') + tb_cbk.writer = FileWriterStub(temp_dir) + + tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 1}) + self.assertEqual(tb_cbk.writer.batch_summaries, []) + tb_cbk.on_epoch_end(0, {'acc': 10.0, 'size': 1}) + self.assertEqual(len(tb_cbk.writer.epoch_summaries), 1) + + # Batch mode + tb_cbk = keras.callbacks.TensorBoard(temp_dir, update_freq='batch') + tb_cbk.writer = FileWriterStub(temp_dir) + + tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 1}) + self.assertEqual(len(tb_cbk.writer.batch_summaries), 1) + tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 1}) + self.assertEqual(len(tb_cbk.writer.batch_summaries), 2) + self.assertFalse(tb_cbk.writer.epoch_summaries) + + # Integer mode + tb_cbk = keras.callbacks.TensorBoard(temp_dir, update_freq=20) + tb_cbk.writer = FileWriterStub(temp_dir) + + tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 10}) + self.assertFalse(tb_cbk.writer.batch_summaries) + tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 10}) + self.assertEqual(len(tb_cbk.writer.batch_summaries), 1) + tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 10}) + self.assertEqual(len(tb_cbk.writer.batch_summaries), 1) + tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 10}) + self.assertEqual(len(tb_cbk.writer.batch_summaries), 2) + tb_cbk.on_batch_end(0, {'acc': 10.0, 'size': 10}) + self.assertEqual(len(tb_cbk.writer.batch_summaries), 2) + self.assertFalse(tb_cbk.writer.epoch_summaries) + def test_RemoteMonitorWithJsonPayload(self): if requests is None: self.skipTest('`requests` required to run this test') @@ -1226,6 +1329,7 @@ class KerasCallbacksTest(test.TestCase): def test_fit_generator_with_callback(self): class TestCallback(keras.callbacks.Callback): + def set_model(self, model): # Check the model operations for the optimizer operations that # the _make_train_function adds under a named scope for the |