aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/callbacks.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-10 17:39:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 19:36:00 -0700
commit4e78d0d41fb8d5d40d088d3ba2cdc531059733e8 (patch)
treec5eb4b9cae0bcfa2d19c97f107aad178dbee84a2 /tensorflow/python/keras/callbacks.py
parent4498d9782f413bf146ff3e3ad69ea8d40e3cd0b7 (diff)
For Tensorboard callback, enable histogram summaries to be computed for
validation data supplied by a generator PiperOrigin-RevId: 204043732
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r--tensorflow/python/keras/callbacks.py34
1 files changed, 25 insertions, 9 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index d01c0cd2e2..5d66db232a 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -753,6 +753,7 @@ class TensorBoard(Callback):
self.model = model
self.sess = K.get_session()
+ # only make histogram summary op if it hasn't already been made
if self.histogram_freq and self.merged is None:
for layer in self.model.layers:
for weight in layer.weights:
@@ -807,20 +808,34 @@ class TensorBoard(Callback):
def _fetch_callback(self, summary):
self.writer.add_summary(
- summary, self._epoch + self._current_batch / self._batches_per_epoch)
- self._current_batch += 1
+ summary,
+ self._epoch + self._current_val_batch / self._validation_batches)
+ self._current_val_batch += 1
+
+ def on_train_begin(self, logs=None):
+ """Checks if histogram summaries can be run."""
+
+ if self.histogram_freq:
+ if 'validation_steps' in self.params:
+ self._validation_batches = self.params['validation_steps']
+ elif self.validation_data:
+ self._validation_batches = math.ceil(
+ self.validation_data[0].shape[0] / self.batch_size)
+ else:
+ raise ValueError('If printing histograms, validation data must be '
+ 'provided.')
+ if self._validation_batches == 0:
+ raise ValueError(
+ 'If printing histograms, validation data must have length > 0.')
def on_epoch_begin(self, epoch, logs=None):
"""Add histogram op to Model test_function callbacks, reset batch count."""
- if not self.validation_data and self.histogram_freq:
- raise ValueError('If printing histograms, validation_data must be '
- 'provided, and cannot be a generator.')
+ # check if histogram summary should be run for this epoch
if self.histogram_freq and epoch % self.histogram_freq == 0:
self._epoch = epoch
- self._current_batch = 0
- self._batches_per_epoch = math.ceil(
- self.validation_data[0].shape[0] / self.batch_size)
+ self._current_val_batch = 0
+ # add the histogram summary op if it should run this epoch
if self.merged not in self.model.test_function.fetches:
self.model.test_function.fetches.append(self.merged)
self.model.test_function.fetch_callbacks[
@@ -831,7 +846,8 @@ class TensorBoard(Callback):
logs = logs or {}
- if self.histogram_freq and self.histogram_freq > 1:
+ # pop the histogram summary op after each epoch
+ if self.histogram_freq:
if self.merged in self.model.test_function.fetches:
self.model.test_function.fetches.remove(self.merged)
if self.merged in self.model.test_function.fetch_callbacks: