diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-10 17:39:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-10 19:36:00 -0700 |
commit | 4e78d0d41fb8d5d40d088d3ba2cdc531059733e8 (patch) | |
tree | c5eb4b9cae0bcfa2d19c97f107aad178dbee84a2 | |
parent | 4498d9782f413bf146ff3e3ad69ea8d40e3cd0b7 (diff) |
For Tensorboard callback, enable histogram summaries to be computed for
validation data supplied by a generator
PiperOrigin-RevId: 204043732
-rw-r--r-- | tensorflow/python/keras/callbacks.py | 34 | ||||
-rw-r--r-- | tensorflow/python/keras/callbacks_test.py | 65 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_arrays.py | 24 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_eager.py | 14 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_generator.py | 20 |
5 files changed, 116 insertions, 41 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index d01c0cd2e2..5d66db232a 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -753,6 +753,7 @@ class TensorBoard(Callback): self.model = model self.sess = K.get_session() + # only make histogram summary op if it hasn't already been made if self.histogram_freq and self.merged is None: for layer in self.model.layers: for weight in layer.weights: @@ -807,20 +808,34 @@ class TensorBoard(Callback): def _fetch_callback(self, summary): self.writer.add_summary( - summary, self._epoch + self._current_batch / self._batches_per_epoch) - self._current_batch += 1 + summary, + self._epoch + self._current_val_batch / self._validation_batches) + self._current_val_batch += 1 + + def on_train_begin(self, logs=None): + """Checks if histogram summaries can be run.""" + + if self.histogram_freq: + if 'validation_steps' in self.params: + self._validation_batches = self.params['validation_steps'] + elif self.validation_data: + self._validation_batches = math.ceil( + self.validation_data[0].shape[0] / self.batch_size) + else: + raise ValueError('If printing histograms, validation data must be ' + 'provided.') + if self._validation_batches == 0: + raise ValueError( + 'If printing histograms, validation data must have length > 0.') def on_epoch_begin(self, epoch, logs=None): """Add histogram op to Model test_function callbacks, reset batch count.""" - if not self.validation_data and self.histogram_freq: - raise ValueError('If printing histograms, validation_data must be ' - 'provided, and cannot be a generator.') + # check if histogram summary should be run for this epoch if self.histogram_freq and epoch % self.histogram_freq == 0: self._epoch = epoch - self._current_batch = 0 - self._batches_per_epoch = math.ceil( - self.validation_data[0].shape[0] / self.batch_size) + self._current_val_batch = 0 + # add the histogram summary op if it should run this epoch if self.merged not in self.model.test_function.fetches: self.model.test_function.fetches.append(self.merged) self.model.test_function.fetch_callbacks[ @@ -831,7 +846,8 @@ class TensorBoard(Callback): logs = logs or {} - if self.histogram_freq and self.histogram_freq > 1: + # pop the histogram summary op after each epoch + if self.histogram_freq: if self.merged in self.model.test_function.fetches: self.model.test_function.fetches.remove(self.merged) if self.merged in self.model.test_function.fetch_callbacks: diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 4a5772f402..244d48591c 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -889,21 +889,6 @@ class KerasCallbacksTest(test.TestCase): for cb in cbs: cb.on_train_end() - # fit generator with validation data generator should raise ValueError if - # histogram_freq > 0 - cbs = callbacks_factory(histogram_freq=1) - with self.assertRaises(ValueError): - model.fit_generator( - data_generator(True), - len(x_train), - epochs=2, - validation_data=data_generator(False), - validation_steps=1, - callbacks=cbs) - - for cb in cbs: - cb.on_train_end() - # Make sure file writer cache is clear to avoid failures during cleanup. writer_cache.FileWriterCache.clear() @@ -1052,6 +1037,56 @@ class KerasCallbacksTest(test.TestCase): self.assertAllEqual(tsb.writer.steps_seen, [0, 0.5, 1, 1.5, 2, 2.5]) + def test_Tensorboard_histogram_summaries_with_generator(self): + np.random.seed(1337) + tmpdir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, tmpdir) + + def generator(): + x = np.random.randn(10, 100).astype(np.float32) + y = np.random.randn(10, 10).astype(np.float32) + while True: + yield x, y + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(10, input_dim=100, activation='relu')) + model.add(keras.layers.Dense(10, activation='softmax')) + model.compile( + loss='categorical_crossentropy', + optimizer='sgd', + metrics=['accuracy']) + tsb = keras.callbacks.TensorBoard( + log_dir=tmpdir, + histogram_freq=1, + write_images=True, + write_grads=True, + batch_size=5) + cbks = [tsb] + + # fit with validation generator + model.fit_generator( + generator(), + steps_per_epoch=2, + epochs=2, + validation_data=generator(), + validation_steps=2, + callbacks=cbks, + verbose=0) + + with self.assertRaises(ValueError): + # fit with validation generator but no + # validation_steps + model.fit_generator( + generator(), + steps_per_epoch=2, + epochs=2, + validation_data=generator(), + callbacks=cbks, + verbose=0) + + self.assertTrue(os.path.exists(tmpdir)) + @unittest.skipIf( os.name == 'nt', 'use_multiprocessing=True does not work on windows properly.') diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py index e82f5c0332..adefffab11 100644 --- a/tensorflow/python/keras/engine/training_arrays.py +++ b/tensorflow/python/keras/engine/training_arrays.py @@ -124,12 +124,10 @@ def fit_loop(model, callback_metrics = copy.copy(out_labels) + [ 'val_' + n for n in out_labels ] - if callbacks is not None and any( - [isinstance(callback, cbks.TensorBoard) for callback in callbacks]): - # need to create the test_function before start of the first epoch - # because TensorBoard callback on_epoch_begin adds summary to the - # list of fetches of the test_function - model._make_test_function() + # need to create the test_function before start of the first epoch + # because TensorBoard callback on_epoch_begin adds summary to the + # list of fetches of the test_function + model._make_test_function() else: callback_metrics = copy.copy(out_labels) @@ -162,7 +160,7 @@ def fit_loop(model, callbacks.set_model(callback_model) - callbacks.set_params({ + callback_params = { 'batch_size': batch_size, 'epochs': epochs, 'steps': steps_per_epoch, @@ -170,11 +168,17 @@ def fit_loop(model, 'verbose': verbose, 'do_validation': do_validation, 'metrics': callback_metrics or [], - }) - callbacks.on_train_begin() - callback_model.stop_training = False + } + if validation_steps: + callback_params.update({'validation_steps': validation_steps}) + callbacks.set_params(callback_params) + for cbk in callbacks: cbk.validation_data = val_ins + # validation_data must be set before on_train_begin() is called + # so that TensorboardCallback can validate its input + callbacks.on_train_begin() + callback_model.stop_training = False # To prevent a slowdown, we find beforehand the arrays that need conversion. feed = model._feed_inputs + model._feed_targets + model._feed_sample_weights diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py index e8838cd3bc..c78684c9f4 100644 --- a/tensorflow/python/keras/engine/training_eager.py +++ b/tensorflow/python/keras/engine/training_eager.py @@ -989,7 +989,7 @@ def fit_loop(model, callbacks.set_model(callback_model) - callbacks.set_params({ + callback_params = { 'batch_size': batch_size, 'epochs': epochs, 'steps': steps_per_epoch, @@ -997,9 +997,11 @@ def fit_loop(model, 'verbose': verbose, 'do_validation': do_validation, 'metrics': callback_metrics or [], - }) - callbacks.on_train_begin() - callback_model.stop_training = False + } + if validation_steps: + callback_params.update({'validation_steps': validation_steps}) + callbacks.set_params(callback_params) + for cbk in callbacks: if not val_inputs: cbk.validation_data = [] @@ -1009,6 +1011,10 @@ def fit_loop(model, cbk.validation_data = val_inputs + val_targets + val_sample_weights else: cbk.validation_data = val_inputs + val_targets + # validation_data must be set before on_train_begin() is called + # so that TensorboardCallback can validate its input + callbacks.on_train_begin() + callback_model.stop_training = False for epoch in range(initial_epoch, epochs): callbacks.on_epoch_begin(epoch) diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py index d81b384f0e..432cf2bddd 100644 --- a/tensorflow/python/keras/engine/training_generator.py +++ b/tensorflow/python/keras/engine/training_generator.py @@ -96,14 +96,25 @@ def fit_generator(model, else: callback_model = model callbacks.set_model(callback_model) - callbacks.set_params({ + + callback_params = { 'epochs': epochs, 'steps': steps_per_epoch, 'verbose': verbose, 'do_validation': do_validation, 'metrics': callback_metrics, - }) - callbacks.on_train_begin() + } + if do_validation: + # need to create the test_function before start of the first epoch + # because TensorBoard callback on_epoch_begin adds summary to the + # list of fetches of the test_function + model._make_test_function() + # determine the number of validation batches given a generator + if validation_steps: + callback_params.update({'validation_steps': validation_steps}) + elif isinstance(validation_data, Sequence): + callback_params.update({'validation_steps': len(validation_data)}) + callbacks.set_params(callback_params) enqueuer = None val_enqueuer = None @@ -149,6 +160,9 @@ def fit_generator(model, output_generator = generator callback_model.stop_training = False + # validation_data must be set before on_train_begin() is called + # so that TensorboardCallback can validate its input + callbacks.on_train_begin() # Construct epoch logs. epoch_logs = {} while epoch < epochs: |