aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/callbacks_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/callbacks_test.py')
-rw-r--r--tensorflow/python/keras/callbacks_test.py118
1 files changed, 111 insertions, 7 deletions
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 467bc4cdc4..bb85347033 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -313,6 +313,42 @@ class KerasCallbacksTest(test.TestCase):
hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20)
assert len(hist.epoch) >= patience
+ def test_EarlyStopping_final_weights_when_restoring_model_weights(self):
+
+ class DummyModel(object):
+
+ def __init__(self):
+ self.stop_training = False
+ self.weights = -1
+
+ def get_weights(self):
+ return self.weights
+
+ def set_weights(self, weights):
+ self.weights = weights
+
+ def set_weight_to_epoch(self, epoch):
+ self.weights = epoch
+
+ early_stop = keras.callbacks.EarlyStopping(monitor='val_loss',
+ patience=2,
+ restore_best_weights=True)
+ early_stop.model = DummyModel()
+ losses = [0.2, 0.15, 0.1, 0.11, 0.12]
+ # The best configuration is in the epoch 2 (loss = 0.1000).
+ epochs_trained = 0
+ early_stop.on_train_begin()
+ for epoch in range(len(losses)):
+ epochs_trained += 1
+ early_stop.model.set_weight_to_epoch(epoch=epoch)
+ early_stop.on_epoch_end(epoch, logs={'val_loss': losses[epoch]})
+ if early_stop.model.stop_training:
+ break
+ # The best configuration is in epoch 2 (loss = 0.1000),
+ # and while patience = 2, we're restoring the best weights,
+ # so we end up at the epoch with the best weights, i.e. epoch 2
+ self.assertEqual(early_stop.model.get_weights(), 2)
+
def test_RemoteMonitor(self):
if requests is None:
return
@@ -534,11 +570,15 @@ class KerasCallbacksTest(test.TestCase):
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
callbacks=cbks,
- epochs=1,
+ epochs=2,
verbose=0)
with open(filepath) as csvfile:
- output = ' '.join(csvfile.readlines())
+ list_lines = csvfile.readlines()
+ for line in list_lines:
+ assert line.count(sep) == 4
+ assert len(list_lines) == 5
+ output = ' '.join(list_lines)
assert len(re.findall('epoch', output)) == 1
os.remove(filepath)
@@ -1115,11 +1155,11 @@ class KerasCallbacksTest(test.TestCase):
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
- tb_cbk = keras.callbacks.TensorBoard(temp_dir)
+ tb_cbk = keras.callbacks.TensorBoard(temp_dir, update_freq='batch')
tb_cbk.writer = FileWriterStub(temp_dir)
for batch in range(5):
- tb_cbk.on_batch_end(batch, {'acc': np.float32(batch)})
+ tb_cbk.on_batch_end(batch, {'acc': batch})
self.assertEqual(tb_cbk.writer.batches_logged, [0, 1, 2, 3, 4])
self.assertEqual(tb_cbk.writer.summary_values, [0., 1., 2., 3., 4.])
self.assertEqual(tb_cbk.writer.summary_tags, ['batch_acc'] * 5)
@@ -1147,14 +1187,17 @@ class KerasCallbacksTest(test.TestCase):
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
- tb_cbk = keras.callbacks.TensorBoard(temp_dir)
+ tb_cbk = keras.callbacks.TensorBoard(temp_dir, update_freq='batch')
tb_cbk.writer = FileWriterStub(temp_dir)
- tb_cbk.on_batch_end(0, {'acc': np.float32(5.0)})
- tb_cbk.on_epoch_end(0, {'acc': np.float32(10.0)})
+ tb_cbk.on_batch_end(0, {'acc': 5.0})
batch_step, batch_summary = tb_cbk.writer.batch_summary
self.assertEqual(batch_step, 0)
self.assertEqual(batch_summary.value[0].simple_value, 5.0)
+
+ tb_cbk = keras.callbacks.TensorBoard(temp_dir, update_freq='epoch')
+ tb_cbk.writer = FileWriterStub(temp_dir)
+ tb_cbk.on_epoch_end(0, {'acc': 10.0})
epoch_step, epoch_summary = tb_cbk.writer.epoch_summary
self.assertEqual(epoch_step, 0)
self.assertEqual(epoch_summary.value[0].simple_value, 10.0)
@@ -1192,6 +1235,66 @@ class KerasCallbacksTest(test.TestCase):
self.assertTrue(os.path.exists(temp_dir))
+ def test_TensorBoard_update_freq(self):
+
+ class FileWriterStub(object):
+
+ def __init__(self, logdir, graph=None):
+ self.logdir = logdir
+ self.graph = graph
+ self.batch_summaries = []
+ self.epoch_summaries = []
+
+ def add_summary(self, summary, step):
+ if 'batch_' in summary.value[0].tag:
+ self.batch_summaries.append((step, summary))
+ elif 'epoch_' in summary.value[0].tag:
+ self.epoch_summaries.append((step, summary))
+
+ def flush(self):
+ pass
+
+ def close(self):
+ pass
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
+
+ # Epoch mode
+ tb_cbk = keras.callbacks.TensorBoard(temp_dir, update_freq='epoch')
+ tb_cbk.writer = FileWriterStub(temp_dir)
+
+ tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 1})
+ self.assertEqual(tb_cbk.writer.batch_summaries, [])
+ tb_cbk.on_epoch_end(0, {'acc': 10.0, 'size': 1})
+ self.assertEqual(len(tb_cbk.writer.epoch_summaries), 1)
+
+ # Batch mode
+ tb_cbk = keras.callbacks.TensorBoard(temp_dir, update_freq='batch')
+ tb_cbk.writer = FileWriterStub(temp_dir)
+
+ tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 1})
+ self.assertEqual(len(tb_cbk.writer.batch_summaries), 1)
+ tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 1})
+ self.assertEqual(len(tb_cbk.writer.batch_summaries), 2)
+ self.assertFalse(tb_cbk.writer.epoch_summaries)
+
+ # Integer mode
+ tb_cbk = keras.callbacks.TensorBoard(temp_dir, update_freq=20)
+ tb_cbk.writer = FileWriterStub(temp_dir)
+
+ tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 10})
+ self.assertFalse(tb_cbk.writer.batch_summaries)
+ tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 10})
+ self.assertEqual(len(tb_cbk.writer.batch_summaries), 1)
+ tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 10})
+ self.assertEqual(len(tb_cbk.writer.batch_summaries), 1)
+ tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 10})
+ self.assertEqual(len(tb_cbk.writer.batch_summaries), 2)
+ tb_cbk.on_batch_end(0, {'acc': 10.0, 'size': 10})
+ self.assertEqual(len(tb_cbk.writer.batch_summaries), 2)
+ self.assertFalse(tb_cbk.writer.epoch_summaries)
+
def test_RemoteMonitorWithJsonPayload(self):
if requests is None:
self.skipTest('`requests` required to run this test')
@@ -1226,6 +1329,7 @@ class KerasCallbacksTest(test.TestCase):
def test_fit_generator_with_callback(self):
class TestCallback(keras.callbacks.Callback):
+
def set_model(self, model):
# Check the model operations for the optimizer operations that
# the _make_train_function adds under a named scope for the