aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-20 10:23:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 10:26:25 -0700
commit4921064dd535d84aa031f8116e583b151dd46e97 (patch)
treed25482ee5ffc2e79a64a509523ff7078a75b0510
parentc023f46956f8a867d0dc77f1ee742564a3622e68 (diff)
Update Keras TensorBoard callback to log metrics at the batch-level
PiperOrigin-RevId: 205416192
-rw-r--r--tensorflow/python/keras/callbacks.py46
-rw-r--r--tensorflow/python/keras/callbacks_test.py68
2 files changed, 103 insertions, 11 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 0857a3279f..d1b9dc27bd 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -740,6 +740,7 @@ class TensorBoard(Callback):
self.write_images = write_images
self.batch_size = batch_size
self._current_batch = 0
+ self._total_batches_seen = 0
# abstracted writer class to be able to stub for testing
self._writer_class = tf_summary.FileWriter
self.embeddings_freq = embeddings_freq
@@ -883,6 +884,24 @@ class TensorBoard(Callback):
self._epoch + self._current_val_batch / self._validation_batches)
self._current_val_batch += 1
+ def _write_custom_summaries(self, step, logs=None):
+ """Writes metrics out as custom scalar summaries.
+
+ Arguments:
+ step: the global step to use for Tensorboard.
+ logs: dict. Keys are scalar summary names, values are
+ NumPy scalars.
+
+ """
+ logs = logs or {}
+ for name, value in logs.items():
+ summary = tf_summary.Summary()
+ summary_value = summary.value.add()
+ summary_value.simple_value = value.item()
+ summary_value.tag = name
+ self.writer.add_summary(summary, step)
+ self.writer.flush()
+
def on_train_begin(self, logs=None):
"""Checks if histogram summaries can be run."""
@@ -899,6 +918,16 @@ class TensorBoard(Callback):
raise ValueError(
'If printing histograms, validation data must have length > 0.')
+ def on_batch_end(self, batch, logs=None):
+ """Writes scalar summaries for metrics on every training batch."""
+ # Don't output batch_size and batch number as Tensorboard summaries
+ logs = logs or {}
+ batch_logs = {('batch_' + k): v
+ for k, v in logs.items()
+ if k not in ['batch', 'size']}
+ self._write_custom_summaries(self._total_batches_seen, batch_logs)
+ self._total_batches_seen += 1
+
def on_epoch_begin(self, epoch, logs=None):
"""Add histogram op to Model test_function callbacks, reset batch count."""
@@ -915,7 +944,12 @@ class TensorBoard(Callback):
def on_epoch_end(self, epoch, logs=None):
"""Checks if summary ops should run next epoch, logs scalar summaries."""
- logs = logs or {}
+ # don't output batch_size and
+ # batch number as Tensorboard summaries
+ logs = {('epoch_' + k): v
+ for k, v in logs.items()
+ if k not in ['batch', 'size']}
+ self._write_custom_summaries(epoch, logs)
# pop the histogram summary op after each epoch
if self.histogram_freq:
@@ -964,16 +998,6 @@ class TensorBoard(Callback):
i += self.batch_size
- for name, value in logs.items():
- if name in ['batch', 'size']:
- continue
- summary = tf_summary.Summary()
- summary_value = summary.value.add()
- summary_value.simple_value = value.item()
- summary_value.tag = name
- self.writer.add_summary(summary, epoch)
- self.writer.flush()
-
def on_train_end(self, logs=None):
self.writer.close()
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 45598cafd3..7d830078ce 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -1096,6 +1096,74 @@ class KerasCallbacksTest(test.TestCase):
assert os.path.exists(temp_dir)
+ def test_Tensorboard_batch_logging(self):
+
+ class FileWriterStub(object):
+
+ def __init__(self, logdir, graph=None):
+ self.logdir = logdir
+ self.graph = graph
+ self.batches_logged = []
+ self.summary_values = []
+ self.summary_tags = []
+
+ def add_summary(self, summary, step):
+ self.summary_values.append(summary.value[0].simple_value)
+ self.summary_tags.append(summary.value[0].tag)
+ self.batches_logged.append(step)
+
+ def flush(self):
+ pass
+
+ def close(self):
+ pass
+
+ logdir = 'fake_dir'
+
+ # log every batch
+ tb_cbk = keras.callbacks.TensorBoard(logdir)
+ tb_cbk.writer = FileWriterStub(logdir)
+
+ for batch in range(5):
+ tb_cbk.on_batch_end(batch, {'acc': np.float32(batch)})
+ self.assertEqual(tb_cbk.writer.batches_logged, [0, 1, 2, 3, 4])
+ self.assertEqual(tb_cbk.writer.summary_values, [0., 1., 2., 3., 4.])
+ self.assertEqual(tb_cbk.writer.summary_tags, ['batch_acc'] * 5)
+
+ def test_Tensorboard_epoch_and_batch_logging(self):
+
+ class FileWriterStub(object):
+
+ def __init__(self, logdir, graph=None):
+ self.logdir = logdir
+ self.graph = graph
+
+ def add_summary(self, summary, step):
+ if 'batch_' in summary.value[0].tag:
+ self.batch_summary = (step, summary)
+ elif 'epoch_' in summary.value[0].tag:
+ self.epoch_summary = (step, summary)
+
+ def flush(self):
+ pass
+
+ def close(self):
+ pass
+
+ logdir = 'fake_dir'
+
+ tb_cbk = keras.callbacks.TensorBoard(logdir)
+ tb_cbk.writer = FileWriterStub(logdir)
+
+ tb_cbk.on_batch_end(0, {'acc': np.float32(5.0)})
+ tb_cbk.on_epoch_end(0, {'acc': np.float32(10.0)})
+ batch_step, batch_summary = tb_cbk.writer.batch_summary
+ self.assertEqual(batch_step, 0)
+ self.assertEqual(batch_summary.value[0].simple_value, 5.0)
+ epoch_step, epoch_summary = tb_cbk.writer.epoch_summary
+ self.assertEqual(epoch_step, 0)
+ self.assertEqual(epoch_summary.value[0].simple_value, 10.0)
+
def test_RemoteMonitorWithJsonPayload(self):
if requests is None:
self.skipTest('`requests` required to run this test')