diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-10 17:39:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-10 19:36:00 -0700 |
commit | 4e78d0d41fb8d5d40d088d3ba2cdc531059733e8 (patch) | |
tree | c5eb4b9cae0bcfa2d19c97f107aad178dbee84a2 /tensorflow/python/keras/callbacks.py | |
parent | 4498d9782f413bf146ff3e3ad69ea8d40e3cd0b7 (diff) |
For Tensorboard callback, enable histogram summaries to be computed for
validation data supplied by a generator
PiperOrigin-RevId: 204043732
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r-- | tensorflow/python/keras/callbacks.py | 34 |
1 files changed, 25 insertions, 9 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index d01c0cd2e2..5d66db232a 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -753,6 +753,7 @@ class TensorBoard(Callback): self.model = model self.sess = K.get_session() + # only make histogram summary op if it hasn't already been made if self.histogram_freq and self.merged is None: for layer in self.model.layers: for weight in layer.weights: @@ -807,20 +808,34 @@ class TensorBoard(Callback): def _fetch_callback(self, summary): self.writer.add_summary( - summary, self._epoch + self._current_batch / self._batches_per_epoch) - self._current_batch += 1 + summary, + self._epoch + self._current_val_batch / self._validation_batches) + self._current_val_batch += 1 + + def on_train_begin(self, logs=None): + """Checks if histogram summaries can be run.""" + + if self.histogram_freq: + if 'validation_steps' in self.params: + self._validation_batches = self.params['validation_steps'] + elif self.validation_data: + self._validation_batches = math.ceil( + self.validation_data[0].shape[0] / self.batch_size) + else: + raise ValueError('If printing histograms, validation data must be ' + 'provided.') + if self._validation_batches == 0: + raise ValueError( + 'If printing histograms, validation data must have length > 0.') def on_epoch_begin(self, epoch, logs=None): """Add histogram op to Model test_function callbacks, reset batch count.""" - if not self.validation_data and self.histogram_freq: - raise ValueError('If printing histograms, validation_data must be ' - 'provided, and cannot be a generator.') + # check if histogram summary should be run for this epoch if self.histogram_freq and epoch % self.histogram_freq == 0: self._epoch = epoch - self._current_batch = 0 - self._batches_per_epoch = math.ceil( - self.validation_data[0].shape[0] / self.batch_size) + self._current_val_batch = 0 + # add the histogram summary op if it should run this epoch if self.merged not in self.model.test_function.fetches: self.model.test_function.fetches.append(self.merged) self.model.test_function.fetch_callbacks[ @@ -831,7 +846,8 @@ class TensorBoard(Callback): logs = logs or {} - if self.histogram_freq and self.histogram_freq > 1: + # pop the histogram summary op after each epoch + if self.histogram_freq: if self.merged in self.model.test_function.fetches: self.model.test_function.fetches.remove(self.merged) if self.merged in self.model.test_function.fetch_callbacks: |