aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/callbacks.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r--tensorflow/python/keras/callbacks.py101
1 files changed, 79 insertions, 22 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 3d6000f223..4c12c83a4c 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -24,6 +24,7 @@ from collections import Iterable
from collections import OrderedDict
import copy
import csv
+import io
import json
import math
import os
@@ -606,24 +607,28 @@ class EarlyStopping(Callback):
"""Stop training when a monitored quantity has stopped improving.
Arguments:
- monitor: quantity to be monitored.
- min_delta: minimum change in the monitored quantity
+ monitor: Quantity to be monitored.
+ min_delta: Minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.
- patience: number of epochs with no improvement
+ patience: Number of epochs with no improvement
after which training will be stopped.
verbose: verbosity mode.
- mode: one of {auto, min, max}. In `min` mode,
+ mode: One of `{"auto", "min", "max"}`. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity.
- baseline: baseline value for the monitored quantity.
+ baseline: Baseline value for the monitored quantity.
Training will stop if the model doesn't show improvement over the
baseline.
+ restore_best_weights: Whether to restore model weights from
+ the epoch with the best value of the monitored quantity.
+ If False, the model weights obtained at the last step of
+ training are used.
"""
def __init__(self,
@@ -632,7 +637,8 @@ class EarlyStopping(Callback):
patience=0,
verbose=0,
mode='auto',
- baseline=None):
+ baseline=None,
+ restore_best_weights=False):
super(EarlyStopping, self).__init__()
self.monitor = monitor
@@ -642,6 +648,8 @@ class EarlyStopping(Callback):
self.min_delta = abs(min_delta)
self.wait = 0
self.stopped_epoch = 0
+ self.restore_best_weights = restore_best_weights
+ self.best_weights = None
if mode not in ['auto', 'min', 'max']:
logging.warning('EarlyStopping mode %s is unknown, '
@@ -673,25 +681,37 @@ class EarlyStopping(Callback):
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_epoch_end(self, epoch, logs=None):
- current = logs.get(self.monitor)
+ current = self.get_monitor_value(logs)
if current is None:
- logging.warning('Early stopping conditioned on metric `%s` '
- 'which is not available. Available metrics are: %s',
- self.monitor, ','.join(list(logs.keys())))
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
+ if self.restore_best_weights:
+ self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
+ if self.restore_best_weights:
+ if self.verbose > 0:
+ print('Restoring model weights from the end of the best epoch.')
+ self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0 and self.verbose > 0:
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
+ def get_monitor_value(self, logs):
+ logs = logs or {}
+ monitor_value = logs.get(self.monitor)
+ if monitor_value is None:
+ logging.warning('Early stopping conditioned on metric `%s` '
+ 'which is not available. Available metrics are: %s',
+ self.monitor, ','.join(list(logs.keys())))
+ return monitor_value
+
@tf_export('keras.callbacks.RemoteMonitor')
class RemoteMonitor(Callback):
@@ -839,6 +859,12 @@ class TensorBoard(Callback):
`embeddings_layer_names`. Numpy array (if the model has a single
input) or list of Numpy arrays (if the model has multiple inputs).
Learn [more about embeddings](https://www.tensorflow.org/programmers_guide/embedding)
+ update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
+ writes the losses and metrics to TensorBoard after each batch.
+ The same applies for `'epoch'`. If using an integer, let's say `1000`,
+ the callback will write the metrics and losses to TensorBoard every
+ 1000 samples. Note that writing too frequently to TensorBoard
+ can slow down your training.
Raises:
ValueError: If histogram_freq is set and no validation data is provided.
@@ -862,7 +888,8 @@ class TensorBoard(Callback):
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None,
- embeddings_data=None):
+ embeddings_data=None,
+ update_freq='epoch'):
super(TensorBoard, self).__init__()
self.log_dir = log_dir
self.histogram_freq = histogram_freq
@@ -882,6 +909,12 @@ class TensorBoard(Callback):
self.embeddings_layer_names = embeddings_layer_names
self.embeddings_metadata = embeddings_metadata
self.embeddings_data = embeddings_data
+ if update_freq == 'batch':
+ self.update_freq = 1
+ else:
+ self.update_freq = update_freq
+ self._samples_seen = 0
+ self._samples_seen_at_last_write = 0
def _init_writer(self):
"""Sets file writer."""
@@ -1045,13 +1078,17 @@ class TensorBoard(Callback):
# use v2 summary ops
with self.writer.as_default(), summary_ops_v2.always_record_summaries():
for name, value in logs.items():
- summary_ops_v2.scalar(name, value.item(), step=step)
+ if isinstance(value, np.ndarray):
+ value = value.item()
+ summary_ops_v2.scalar(name, value, step=step)
else:
# use FileWriter from v1 summary
for name, value in logs.items():
+ if isinstance(value, np.ndarray):
+ value = value.item()
summary = tf_summary.Summary()
summary_value = summary.value.add()
- summary_value.simple_value = value.item()
+ summary_value.simple_value = value
summary_value.tag = name
self.writer.add_summary(summary, step)
self.writer.flush()
@@ -1076,10 +1113,14 @@ class TensorBoard(Callback):
"""Writes scalar summaries for metrics on every training batch."""
# Don't output batch_size and batch number as Tensorboard summaries
logs = logs or {}
- batch_logs = {('batch_' + k): v
- for k, v in logs.items()
- if k not in ['batch', 'size', 'num_steps']}
- self._write_custom_summaries(self._total_batches_seen, batch_logs)
+ self._samples_seen += logs.get('size', 1)
+ samples_seen_since = self._samples_seen - self._samples_seen_at_last_write
+ if self.update_freq != 'epoch' and samples_seen_since >= self.update_freq:
+ batch_logs = {('batch_' + k): v
+ for k, v in logs.items()
+ if k not in ['batch', 'size', 'num_steps']}
+ self._write_custom_summaries(self._total_batches_seen, batch_logs)
+ self._samples_seen_at_last_write = self._samples_seen
self._total_batches_seen += 1
def on_epoch_begin(self, epoch, logs=None):
@@ -1103,7 +1144,11 @@ class TensorBoard(Callback):
logs = {('epoch_' + k): v
for k, v in logs.items()
if k not in ['batch', 'size', 'num_steps']}
- self._write_custom_summaries(epoch, logs)
+ if self.update_freq == 'epoch':
+ step = epoch
+ else:
+ step = self._samples_seen
+ self._write_custom_summaries(step, logs)
# pop the histogram summary op after each epoch
if self.histogram_freq:
@@ -1309,7 +1354,12 @@ class CSVLogger(Callback):
self.writer = None
self.keys = None
self.append_header = True
- self.file_flags = 'b' if six.PY2 and os.name == 'nt' else ''
+ if six.PY2:
+ self.file_flags = 'b'
+ self._open_args = {}
+ else:
+ self.file_flags = ''
+ self._open_args = {'newline': '\n'}
super(CSVLogger, self).__init__()
def on_train_begin(self, logs=None):
@@ -1317,9 +1367,12 @@ class CSVLogger(Callback):
if os.path.exists(self.filename):
with open(self.filename, 'r' + self.file_flags) as f:
self.append_header = not bool(len(f.readline()))
- self.csv_file = open(self.filename, 'a' + self.file_flags)
+ mode = 'a'
else:
- self.csv_file = open(self.filename, 'w' + self.file_flags)
+ mode = 'w'
+ self.csv_file = io.open(self.filename,
+ mode + self.file_flags,
+ **self._open_args)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
@@ -1345,9 +1398,13 @@ class CSVLogger(Callback):
class CustomDialect(csv.excel):
delimiter = self.sep
+ fieldnames = ['epoch'] + self.keys
+ if six.PY2:
+ fieldnames = [unicode(x) for x in fieldnames]
+
self.writer = csv.DictWriter(
self.csv_file,
- fieldnames=['epoch'] + self.keys,
+ fieldnames=fieldnames,
dialect=CustomDialect)
if self.append_header:
self.writer.writeheader()