aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-10 17:39:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 19:36:00 -0700
commit4e78d0d41fb8d5d40d088d3ba2cdc531059733e8 (patch)
treec5eb4b9cae0bcfa2d19c97f107aad178dbee84a2
parent4498d9782f413bf146ff3e3ad69ea8d40e3cd0b7 (diff)
For Tensorboard callback, enable histogram summaries to be computed for
validation data supplied by a generator PiperOrigin-RevId: 204043732
-rw-r--r--tensorflow/python/keras/callbacks.py34
-rw-r--r--tensorflow/python/keras/callbacks_test.py65
-rw-r--r--tensorflow/python/keras/engine/training_arrays.py24
-rw-r--r--tensorflow/python/keras/engine/training_eager.py14
-rw-r--r--tensorflow/python/keras/engine/training_generator.py20
5 files changed, 116 insertions, 41 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index d01c0cd2e2..5d66db232a 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -753,6 +753,7 @@ class TensorBoard(Callback):
self.model = model
self.sess = K.get_session()
+ # only make histogram summary op if it hasn't already been made
if self.histogram_freq and self.merged is None:
for layer in self.model.layers:
for weight in layer.weights:
@@ -807,20 +808,34 @@ class TensorBoard(Callback):
def _fetch_callback(self, summary):
self.writer.add_summary(
- summary, self._epoch + self._current_batch / self._batches_per_epoch)
- self._current_batch += 1
+ summary,
+ self._epoch + self._current_val_batch / self._validation_batches)
+ self._current_val_batch += 1
+
+ def on_train_begin(self, logs=None):
+ """Checks if histogram summaries can be run."""
+
+ if self.histogram_freq:
+ if 'validation_steps' in self.params:
+ self._validation_batches = self.params['validation_steps']
+ elif self.validation_data:
+ self._validation_batches = math.ceil(
+ self.validation_data[0].shape[0] / self.batch_size)
+ else:
+ raise ValueError('If printing histograms, validation data must be '
+ 'provided.')
+ if self._validation_batches == 0:
+ raise ValueError(
+ 'If printing histograms, validation data must have length > 0.')
def on_epoch_begin(self, epoch, logs=None):
"""Add histogram op to Model test_function callbacks, reset batch count."""
- if not self.validation_data and self.histogram_freq:
- raise ValueError('If printing histograms, validation_data must be '
- 'provided, and cannot be a generator.')
+ # check if histogram summary should be run for this epoch
if self.histogram_freq and epoch % self.histogram_freq == 0:
self._epoch = epoch
- self._current_batch = 0
- self._batches_per_epoch = math.ceil(
- self.validation_data[0].shape[0] / self.batch_size)
+ self._current_val_batch = 0
+ # add the histogram summary op if it should run this epoch
if self.merged not in self.model.test_function.fetches:
self.model.test_function.fetches.append(self.merged)
self.model.test_function.fetch_callbacks[
@@ -831,7 +846,8 @@ class TensorBoard(Callback):
logs = logs or {}
- if self.histogram_freq and self.histogram_freq > 1:
+ # pop the histogram summary op after each epoch
+ if self.histogram_freq:
if self.merged in self.model.test_function.fetches:
self.model.test_function.fetches.remove(self.merged)
if self.merged in self.model.test_function.fetch_callbacks:
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 4a5772f402..244d48591c 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -889,21 +889,6 @@ class KerasCallbacksTest(test.TestCase):
for cb in cbs:
cb.on_train_end()
- # fit generator with validation data generator should raise ValueError if
- # histogram_freq > 0
- cbs = callbacks_factory(histogram_freq=1)
- with self.assertRaises(ValueError):
- model.fit_generator(
- data_generator(True),
- len(x_train),
- epochs=2,
- validation_data=data_generator(False),
- validation_steps=1,
- callbacks=cbs)
-
- for cb in cbs:
- cb.on_train_end()
-
# Make sure file writer cache is clear to avoid failures during cleanup.
writer_cache.FileWriterCache.clear()
@@ -1052,6 +1037,56 @@ class KerasCallbacksTest(test.TestCase):
self.assertAllEqual(tsb.writer.steps_seen, [0, 0.5, 1, 1.5, 2, 2.5])
+ def test_Tensorboard_histogram_summaries_with_generator(self):
+ np.random.seed(1337)
+ tmpdir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, tmpdir)
+
+ def generator():
+ x = np.random.randn(10, 100).astype(np.float32)
+ y = np.random.randn(10, 10).astype(np.float32)
+ while True:
+ yield x, y
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(10, input_dim=100, activation='relu'))
+ model.add(keras.layers.Dense(10, activation='softmax'))
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='sgd',
+ metrics=['accuracy'])
+ tsb = keras.callbacks.TensorBoard(
+ log_dir=tmpdir,
+ histogram_freq=1,
+ write_images=True,
+ write_grads=True,
+ batch_size=5)
+ cbks = [tsb]
+
+ # fit with validation generator
+ model.fit_generator(
+ generator(),
+ steps_per_epoch=2,
+ epochs=2,
+ validation_data=generator(),
+ validation_steps=2,
+ callbacks=cbks,
+ verbose=0)
+
+ with self.assertRaises(ValueError):
+ # fit with validation generator but no
+ # validation_steps
+ model.fit_generator(
+ generator(),
+ steps_per_epoch=2,
+ epochs=2,
+ validation_data=generator(),
+ callbacks=cbks,
+ verbose=0)
+
+ self.assertTrue(os.path.exists(tmpdir))
+
@unittest.skipIf(
os.name == 'nt',
'use_multiprocessing=True does not work on windows properly.')
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py
index e82f5c0332..adefffab11 100644
--- a/tensorflow/python/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/engine/training_arrays.py
@@ -124,12 +124,10 @@ def fit_loop(model,
callback_metrics = copy.copy(out_labels) + [
'val_' + n for n in out_labels
]
- if callbacks is not None and any(
- [isinstance(callback, cbks.TensorBoard) for callback in callbacks]):
- # need to create the test_function before start of the first epoch
- # because TensorBoard callback on_epoch_begin adds summary to the
- # list of fetches of the test_function
- model._make_test_function()
+ # need to create the test_function before start of the first epoch
+ # because TensorBoard callback on_epoch_begin adds summary to the
+ # list of fetches of the test_function
+ model._make_test_function()
else:
callback_metrics = copy.copy(out_labels)
@@ -162,7 +160,7 @@ def fit_loop(model,
callbacks.set_model(callback_model)
- callbacks.set_params({
+ callback_params = {
'batch_size': batch_size,
'epochs': epochs,
'steps': steps_per_epoch,
@@ -170,11 +168,17 @@ def fit_loop(model,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics or [],
- })
- callbacks.on_train_begin()
- callback_model.stop_training = False
+ }
+ if validation_steps:
+ callback_params.update({'validation_steps': validation_steps})
+ callbacks.set_params(callback_params)
+
for cbk in callbacks:
cbk.validation_data = val_ins
+ # validation_data must be set before on_train_begin() is called
+ # so that TensorboardCallback can validate its input
+ callbacks.on_train_begin()
+ callback_model.stop_training = False
# To prevent a slowdown, we find beforehand the arrays that need conversion.
feed = model._feed_inputs + model._feed_targets + model._feed_sample_weights
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index e8838cd3bc..c78684c9f4 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -989,7 +989,7 @@ def fit_loop(model,
callbacks.set_model(callback_model)
- callbacks.set_params({
+ callback_params = {
'batch_size': batch_size,
'epochs': epochs,
'steps': steps_per_epoch,
@@ -997,9 +997,11 @@ def fit_loop(model,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics or [],
- })
- callbacks.on_train_begin()
- callback_model.stop_training = False
+ }
+ if validation_steps:
+ callback_params.update({'validation_steps': validation_steps})
+ callbacks.set_params(callback_params)
+
for cbk in callbacks:
if not val_inputs:
cbk.validation_data = []
@@ -1009,6 +1011,10 @@ def fit_loop(model,
cbk.validation_data = val_inputs + val_targets + val_sample_weights
else:
cbk.validation_data = val_inputs + val_targets
+ # validation_data must be set before on_train_begin() is called
+ # so that TensorboardCallback can validate its input
+ callbacks.on_train_begin()
+ callback_model.stop_training = False
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py
index d81b384f0e..432cf2bddd 100644
--- a/tensorflow/python/keras/engine/training_generator.py
+++ b/tensorflow/python/keras/engine/training_generator.py
@@ -96,14 +96,25 @@ def fit_generator(model,
else:
callback_model = model
callbacks.set_model(callback_model)
- callbacks.set_params({
+
+ callback_params = {
'epochs': epochs,
'steps': steps_per_epoch,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics,
- })
- callbacks.on_train_begin()
+ }
+ if do_validation:
+ # need to create the test_function before start of the first epoch
+ # because TensorBoard callback on_epoch_begin adds summary to the
+ # list of fetches of the test_function
+ model._make_test_function()
+ # determine the number of validation batches given a generator
+ if validation_steps:
+ callback_params.update({'validation_steps': validation_steps})
+ elif isinstance(validation_data, Sequence):
+ callback_params.update({'validation_steps': len(validation_data)})
+ callbacks.set_params(callback_params)
enqueuer = None
val_enqueuer = None
@@ -149,6 +160,9 @@ def fit_generator(model,
output_generator = generator
callback_model.stop_training = False
+ # validation_data must be set before on_train_begin() is called
+ # so that TensorboardCallback can validate its input
+ callbacks.on_train_begin()
# Construct epoch logs.
epoch_logs = {}
while epoch < epochs: