diff options
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r-- | tensorflow/python/keras/callbacks.py | 101 |
1 files changed, 79 insertions, 22 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 3d6000f223..4c12c83a4c 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -24,6 +24,7 @@ from collections import Iterable from collections import OrderedDict import copy import csv +import io import json import math import os @@ -606,24 +607,28 @@ class EarlyStopping(Callback): """Stop training when a monitored quantity has stopped improving. Arguments: - monitor: quantity to be monitored. - min_delta: minimum change in the monitored quantity + monitor: Quantity to be monitored. + min_delta: Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. - patience: number of epochs with no improvement + patience: Number of epochs with no improvement after which training will be stopped. verbose: verbosity mode. - mode: one of {auto, min, max}. In `min` mode, + mode: One of `{"auto", "min", "max"}`. In `min` mode, training will stop when the quantity monitored has stopped decreasing; in `max` mode it will stop when the quantity monitored has stopped increasing; in `auto` mode, the direction is automatically inferred from the name of the monitored quantity. - baseline: baseline value for the monitored quantity. + baseline: Baseline value for the monitored quantity. Training will stop if the model doesn't show improvement over the baseline. + restore_best_weights: Whether to restore model weights from + the epoch with the best value of the monitored quantity. + If False, the model weights obtained at the last step of + training are used. """ def __init__(self, @@ -632,7 +637,8 @@ class EarlyStopping(Callback): patience=0, verbose=0, mode='auto', - baseline=None): + baseline=None, + restore_best_weights=False): super(EarlyStopping, self).__init__() self.monitor = monitor @@ -642,6 +648,8 @@ class EarlyStopping(Callback): self.min_delta = abs(min_delta) self.wait = 0 self.stopped_epoch = 0 + self.restore_best_weights = restore_best_weights + self.best_weights = None if mode not in ['auto', 'min', 'max']: logging.warning('EarlyStopping mode %s is unknown, ' @@ -673,25 +681,37 @@ class EarlyStopping(Callback): self.best = np.Inf if self.monitor_op == np.less else -np.Inf def on_epoch_end(self, epoch, logs=None): - current = logs.get(self.monitor) + current = self.get_monitor_value(logs) if current is None: - logging.warning('Early stopping conditioned on metric `%s` ' - 'which is not available. Available metrics are: %s', - self.monitor, ','.join(list(logs.keys()))) return if self.monitor_op(current - self.min_delta, self.best): self.best = current self.wait = 0 + if self.restore_best_weights: + self.best_weights = self.model.get_weights() else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True + if self.restore_best_weights: + if self.verbose > 0: + print('Restoring model weights from the end of the best epoch.') + self.model.set_weights(self.best_weights) def on_train_end(self, logs=None): if self.stopped_epoch > 0 and self.verbose > 0: print('Epoch %05d: early stopping' % (self.stopped_epoch + 1)) + def get_monitor_value(self, logs): + logs = logs or {} + monitor_value = logs.get(self.monitor) + if monitor_value is None: + logging.warning('Early stopping conditioned on metric `%s` ' + 'which is not available. Available metrics are: %s', + self.monitor, ','.join(list(logs.keys()))) + return monitor_value + @tf_export('keras.callbacks.RemoteMonitor') class RemoteMonitor(Callback): @@ -839,6 +859,12 @@ class TensorBoard(Callback): `embeddings_layer_names`. Numpy array (if the model has a single input) or list of Numpy arrays (if the model has multiple inputs). Learn [more about embeddings](https://www.tensorflow.org/programmers_guide/embedding) + update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, + writes the losses and metrics to TensorBoard after each batch. + The same applies for `'epoch'`. If using an integer, let's say `1000`, + the callback will write the metrics and losses to TensorBoard every + 1000 samples. Note that writing too frequently to TensorBoard + can slow down your training. Raises: ValueError: If histogram_freq is set and no validation data is provided. @@ -862,7 +888,8 @@ class TensorBoard(Callback): embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None, - embeddings_data=None): + embeddings_data=None, + update_freq='epoch'): super(TensorBoard, self).__init__() self.log_dir = log_dir self.histogram_freq = histogram_freq @@ -882,6 +909,12 @@ class TensorBoard(Callback): self.embeddings_layer_names = embeddings_layer_names self.embeddings_metadata = embeddings_metadata self.embeddings_data = embeddings_data + if update_freq == 'batch': + self.update_freq = 1 + else: + self.update_freq = update_freq + self._samples_seen = 0 + self._samples_seen_at_last_write = 0 def _init_writer(self): """Sets file writer.""" @@ -1045,13 +1078,17 @@ class TensorBoard(Callback): # use v2 summary ops with self.writer.as_default(), summary_ops_v2.always_record_summaries(): for name, value in logs.items(): - summary_ops_v2.scalar(name, value.item(), step=step) + if isinstance(value, np.ndarray): + value = value.item() + summary_ops_v2.scalar(name, value, step=step) else: # use FileWriter from v1 summary for name, value in logs.items(): + if isinstance(value, np.ndarray): + value = value.item() summary = tf_summary.Summary() summary_value = summary.value.add() - summary_value.simple_value = value.item() + summary_value.simple_value = value summary_value.tag = name self.writer.add_summary(summary, step) self.writer.flush() @@ -1076,10 +1113,14 @@ class TensorBoard(Callback): """Writes scalar summaries for metrics on every training batch.""" # Don't output batch_size and batch number as Tensorboard summaries logs = logs or {} - batch_logs = {('batch_' + k): v - for k, v in logs.items() - if k not in ['batch', 'size', 'num_steps']} - self._write_custom_summaries(self._total_batches_seen, batch_logs) + self._samples_seen += logs.get('size', 1) + samples_seen_since = self._samples_seen - self._samples_seen_at_last_write + if self.update_freq != 'epoch' and samples_seen_since >= self.update_freq: + batch_logs = {('batch_' + k): v + for k, v in logs.items() + if k not in ['batch', 'size', 'num_steps']} + self._write_custom_summaries(self._total_batches_seen, batch_logs) + self._samples_seen_at_last_write = self._samples_seen self._total_batches_seen += 1 def on_epoch_begin(self, epoch, logs=None): @@ -1103,7 +1144,11 @@ class TensorBoard(Callback): logs = {('epoch_' + k): v for k, v in logs.items() if k not in ['batch', 'size', 'num_steps']} - self._write_custom_summaries(epoch, logs) + if self.update_freq == 'epoch': + step = epoch + else: + step = self._samples_seen + self._write_custom_summaries(step, logs) # pop the histogram summary op after each epoch if self.histogram_freq: @@ -1309,7 +1354,12 @@ class CSVLogger(Callback): self.writer = None self.keys = None self.append_header = True - self.file_flags = 'b' if six.PY2 and os.name == 'nt' else '' + if six.PY2: + self.file_flags = 'b' + self._open_args = {} + else: + self.file_flags = '' + self._open_args = {'newline': '\n'} super(CSVLogger, self).__init__() def on_train_begin(self, logs=None): @@ -1317,9 +1367,12 @@ class CSVLogger(Callback): if os.path.exists(self.filename): with open(self.filename, 'r' + self.file_flags) as f: self.append_header = not bool(len(f.readline())) - self.csv_file = open(self.filename, 'a' + self.file_flags) + mode = 'a' else: - self.csv_file = open(self.filename, 'w' + self.file_flags) + mode = 'w' + self.csv_file = io.open(self.filename, + mode + self.file_flags, + **self._open_args) def on_epoch_end(self, epoch, logs=None): logs = logs or {} @@ -1345,9 +1398,13 @@ class CSVLogger(Callback): class CustomDialect(csv.excel): delimiter = self.sep + fieldnames = ['epoch'] + self.keys + if six.PY2: + fieldnames = [unicode(x) for x in fieldnames] + self.writer = csv.DictWriter( self.csv_file, - fieldnames=['epoch'] + self.keys, + fieldnames=fieldnames, dialect=CustomDialect) if self.append_header: self.writer.writeheader() |