diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-20 10:23:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-20 10:26:25 -0700 |
commit | 4921064dd535d84aa031f8116e583b151dd46e97 (patch) | |
tree | d25482ee5ffc2e79a64a509523ff7078a75b0510 | |
parent | c023f46956f8a867d0dc77f1ee742564a3622e68 (diff) |
Update Keras TensorBoard callback to log metrics at the batch-level
PiperOrigin-RevId: 205416192
-rw-r--r-- | tensorflow/python/keras/callbacks.py | 46 | ||||
-rw-r--r-- | tensorflow/python/keras/callbacks_test.py | 68 |
2 files changed, 103 insertions, 11 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 0857a3279f..d1b9dc27bd 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -740,6 +740,7 @@ class TensorBoard(Callback): self.write_images = write_images self.batch_size = batch_size self._current_batch = 0 + self._total_batches_seen = 0 # abstracted writer class to be able to stub for testing self._writer_class = tf_summary.FileWriter self.embeddings_freq = embeddings_freq @@ -883,6 +884,24 @@ class TensorBoard(Callback): self._epoch + self._current_val_batch / self._validation_batches) self._current_val_batch += 1 + def _write_custom_summaries(self, step, logs=None): + """Writes metrics out as custom scalar summaries. + + Arguments: + step: the global step to use for Tensorboard. + logs: dict. Keys are scalar summary names, values are + NumPy scalars. + + """ + logs = logs or {} + for name, value in logs.items(): + summary = tf_summary.Summary() + summary_value = summary.value.add() + summary_value.simple_value = value.item() + summary_value.tag = name + self.writer.add_summary(summary, step) + self.writer.flush() + def on_train_begin(self, logs=None): """Checks if histogram summaries can be run.""" @@ -899,6 +918,16 @@ class TensorBoard(Callback): raise ValueError( 'If printing histograms, validation data must have length > 0.') + def on_batch_end(self, batch, logs=None): + """Writes scalar summaries for metrics on every training batch.""" + # Don't output batch_size and batch number as Tensorboard summaries + logs = logs or {} + batch_logs = {('batch_' + k): v + for k, v in logs.items() + if k not in ['batch', 'size']} + self._write_custom_summaries(self._total_batches_seen, batch_logs) + self._total_batches_seen += 1 + def on_epoch_begin(self, epoch, logs=None): """Add histogram op to Model test_function callbacks, reset batch count.""" @@ -915,7 +944,12 @@ class TensorBoard(Callback): def on_epoch_end(self, epoch, logs=None): """Checks if summary ops should run next epoch, logs scalar summaries.""" - logs = logs or {} + # don't output batch_size and + # batch number as Tensorboard summaries + logs = {('epoch_' + k): v + for k, v in logs.items() + if k not in ['batch', 'size']} + self._write_custom_summaries(epoch, logs) # pop the histogram summary op after each epoch if self.histogram_freq: @@ -964,16 +998,6 @@ class TensorBoard(Callback): i += self.batch_size - for name, value in logs.items(): - if name in ['batch', 'size']: - continue - summary = tf_summary.Summary() - summary_value = summary.value.add() - summary_value.simple_value = value.item() - summary_value.tag = name - self.writer.add_summary(summary, epoch) - self.writer.flush() - def on_train_end(self, logs=None): self.writer.close() diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 45598cafd3..7d830078ce 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -1096,6 +1096,74 @@ class KerasCallbacksTest(test.TestCase): assert os.path.exists(temp_dir) + def test_Tensorboard_batch_logging(self): + + class FileWriterStub(object): + + def __init__(self, logdir, graph=None): + self.logdir = logdir + self.graph = graph + self.batches_logged = [] + self.summary_values = [] + self.summary_tags = [] + + def add_summary(self, summary, step): + self.summary_values.append(summary.value[0].simple_value) + self.summary_tags.append(summary.value[0].tag) + self.batches_logged.append(step) + + def flush(self): + pass + + def close(self): + pass + + logdir = 'fake_dir' + + # log every batch + tb_cbk = keras.callbacks.TensorBoard(logdir) + tb_cbk.writer = FileWriterStub(logdir) + + for batch in range(5): + tb_cbk.on_batch_end(batch, {'acc': np.float32(batch)}) + self.assertEqual(tb_cbk.writer.batches_logged, [0, 1, 2, 3, 4]) + self.assertEqual(tb_cbk.writer.summary_values, [0., 1., 2., 3., 4.]) + self.assertEqual(tb_cbk.writer.summary_tags, ['batch_acc'] * 5) + + def test_Tensorboard_epoch_and_batch_logging(self): + + class FileWriterStub(object): + + def __init__(self, logdir, graph=None): + self.logdir = logdir + self.graph = graph + + def add_summary(self, summary, step): + if 'batch_' in summary.value[0].tag: + self.batch_summary = (step, summary) + elif 'epoch_' in summary.value[0].tag: + self.epoch_summary = (step, summary) + + def flush(self): + pass + + def close(self): + pass + + logdir = 'fake_dir' + + tb_cbk = keras.callbacks.TensorBoard(logdir) + tb_cbk.writer = FileWriterStub(logdir) + + tb_cbk.on_batch_end(0, {'acc': np.float32(5.0)}) + tb_cbk.on_epoch_end(0, {'acc': np.float32(10.0)}) + batch_step, batch_summary = tb_cbk.writer.batch_summary + self.assertEqual(batch_step, 0) + self.assertEqual(batch_summary.value[0].simple_value, 5.0) + epoch_step, epoch_summary = tb_cbk.writer.epoch_summary + self.assertEqual(epoch_step, 0) + self.assertEqual(epoch_summary.value[0].simple_value, 10.0) + def test_RemoteMonitorWithJsonPayload(self): if requests is None: self.skipTest('`requests` required to run this test') |