diff options
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r-- | tensorflow/python/keras/callbacks.py | 266 |
1 files changed, 215 insertions, 51 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 00a9c479fb..d1b9dc27bd 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -24,17 +24,23 @@ from collections import Iterable from collections import OrderedDict import csv import json +import math import os import time import numpy as np import six +from tensorflow.python.framework import dtypes from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine.training_utils import standardize_input_data from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.ops import array_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as tf_summary +from tensorflow.python.training import saver from tensorflow.python.util.tf_export import tf_export @@ -696,7 +702,9 @@ class TensorBoard(Callback): write_images: whether to write model weights to visualize as image in TensorBoard. embeddings_freq: frequency (in epochs) at which selected embedding - layers will be saved. + layers will be saved. If set to 0, embeddings won't be computed. + Data to be visualized in TensorBoard's Embedding tab must be passed + as `embeddings_data`. embeddings_layer_names: a list of names of layers to keep eye on. If None or empty list all the embedding layer will be watched. embeddings_metadata: a dictionary which maps layer name to a file name @@ -704,6 +712,10 @@ class TensorBoard(Callback): [details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional) about metadata files format. In case if the same metadata file is used for all embedding layers, string can be passed. + embeddings_data: data to be embedded at layers specified in + `embeddings_layer_names`. Numpy array (if the model has a single + input) or list of Numpy arrays (if the model has multiple inputs). + Learn [more about embeddings](https://www.tensorflow.org/programmers_guide/embedding) """ # pylint: enable=line-too-long @@ -714,7 +726,11 @@ class TensorBoard(Callback): batch_size=32, write_graph=True, write_grads=False, - write_images=False): + write_images=False, + embeddings_freq=0, + embeddings_layer_names=None, + embeddings_metadata=None, + embeddings_data=None): super(TensorBoard, self).__init__() self.log_dir = log_dir self.histogram_freq = histogram_freq @@ -723,10 +739,21 @@ class TensorBoard(Callback): self.write_grads = write_grads self.write_images = write_images self.batch_size = batch_size + self._current_batch = 0 + self._total_batches_seen = 0 + # abstracted writer class to be able to stub for testing + self._writer_class = tf_summary.FileWriter + self.embeddings_freq = embeddings_freq + self.embeddings_layer_names = embeddings_layer_names + self.embeddings_metadata = embeddings_metadata + self.embeddings_data = embeddings_data def set_model(self, model): + """Sets Keras model and creates summary ops.""" + self.model = model self.sess = K.get_session() + # only make histogram summary op if it hasn't already been made if self.histogram_freq and self.merged is None: for layer in self.model.layers: for weight in layer.weights: @@ -771,69 +798,206 @@ class TensorBoard(Callback): tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads) if hasattr(layer, 'output'): - tf_summary.histogram('{}_out'.format(layer.name), layer.output) + if isinstance(layer.output, list): + for i, output in enumerate(layer.output): + tf_summary.histogram('{}_out_{}'.format(layer.name, i), output) + else: + tf_summary.histogram('{}_out'.format(layer.name), layer.output) self.merged = tf_summary.merge_all() if self.write_graph: - self.writer = tf_summary.FileWriter(self.log_dir, self.sess.graph) + self.writer = self._writer_class(self.log_dir, self.sess.graph) else: - self.writer = tf_summary.FileWriter(self.log_dir) - - def on_epoch_end(self, epoch, logs=None): - logs = logs or {} + self.writer = self._writer_class(self.log_dir) + + # If both embedding_freq and embeddings_data are available, we will + # visualize embeddings. + if self.embeddings_freq and self.embeddings_data is not None: + self.embeddings_data = standardize_input_data(self.embeddings_data, + model.input_names) + + # If embedding_layer_names are not provided, get all of the embedding + # layers from the model. + embeddings_layer_names = self.embeddings_layer_names + if not embeddings_layer_names: + embeddings_layer_names = [ + layer.name + for layer in self.model.layers + if type(layer).__name__ == 'Embedding' + ] + + self.assign_embeddings = [] + embeddings_vars = {} + + self.batch_id = batch_id = array_ops.placeholder(dtypes.int32) + self.step = step = array_ops.placeholder(dtypes.int32) - if not self.validation_data and self.histogram_freq: - raise ValueError('If printing histograms, validation_data must be ' - 'provided, and cannot be a generator.') - if self.validation_data and self.histogram_freq: - if epoch % self.histogram_freq == 0: - - val_data = self.validation_data - tensors = ( - self.model.inputs + self.model.targets + self.model.sample_weights) - - if self.model.uses_learning_phase: - tensors += [K.learning_phase()] + for layer in self.model.layers: + if layer.name in embeddings_layer_names: + embedding_input = self.model.get_layer(layer.name).output + embedding_size = np.prod(embedding_input.shape[1:]) + embedding_input = array_ops.reshape(embedding_input, + (step, int(embedding_size))) + shape = (self.embeddings_data[0].shape[0], int(embedding_size)) + embedding = variables.Variable( + array_ops.zeros(shape), name=layer.name + '_embedding') + embeddings_vars[layer.name] = embedding + batch = state_ops.assign(embedding[batch_id:batch_id + step], + embedding_input) + self.assign_embeddings.append(batch) + + self.saver = saver.Saver(list(embeddings_vars.values())) + + # Create embeddings_metadata dictionary + if isinstance(self.embeddings_metadata, str): + embeddings_metadata = { + layer_name: self.embeddings_metadata + for layer_name in embeddings_vars.keys() + } + else: + # If embedding_metadata is already a dictionary + embeddings_metadata = self.embeddings_metadata + + try: + from tensorboard.plugins import projector + except ImportError: + raise ImportError('Failed to import TensorBoard. Please make sure that ' + 'TensorBoard integration is complete."') + + # TODO(psv): Add integration tests to test embedding visualization + # with TensorBoard callback. We are unable to write a unit test for this + # because TensorBoard dependency assumes TensorFlow package is installed. + config = projector.ProjectorConfig() + for layer_name, tensor in embeddings_vars.items(): + embedding = config.embeddings.add() + embedding.tensor_name = tensor.name + + if (embeddings_metadata is not None and + layer_name in embeddings_metadata): + embedding.metadata_path = embeddings_metadata[layer_name] + + projector.visualize_embeddings(self.writer, config) + + def _fetch_callback(self, summary): + self.writer.add_summary( + summary, + self._epoch + self._current_val_batch / self._validation_batches) + self._current_val_batch += 1 + + def _write_custom_summaries(self, step, logs=None): + """Writes metrics out as custom scalar summaries. - assert len(val_data) == len(tensors) - val_size = val_data[0].shape[0] - i = 0 - while i < val_size: - step = min(self.batch_size, val_size - i) - batch_val = [] - batch_val.append(val_data[0][i:i + step] - if val_data[0] is not None else None) - batch_val.append(val_data[1][i:i + step] - if val_data[1] is not None else None) - batch_val.append(val_data[2][i:i + step] - if val_data[2] is not None else None) - if self.model.uses_learning_phase: - # do not slice the learning phase - batch_val = [x[i:i + step] if x is not None else None - for x in val_data[:-1]] - batch_val.append(val_data[-1]) - else: - batch_val = [x[i:i + step] if x is not None else None - for x in val_data] - feed_dict = {} - for key, val in zip(tensors, batch_val): - if val is not None: - feed_dict[key] = val - result = self.sess.run([self.merged], feed_dict=feed_dict) - summary_str = result[0] - self.writer.add_summary(summary_str, epoch) - i += self.batch_size + Arguments: + step: the global step to use for Tensorboard. + logs: dict. Keys are scalar summary names, values are + NumPy scalars. + """ + logs = logs or {} for name, value in logs.items(): - if name in ['batch', 'size']: - continue summary = tf_summary.Summary() summary_value = summary.value.add() summary_value.simple_value = value.item() summary_value.tag = name - self.writer.add_summary(summary, epoch) + self.writer.add_summary(summary, step) self.writer.flush() + def on_train_begin(self, logs=None): + """Checks if histogram summaries can be run.""" + + if self.histogram_freq: + if 'validation_steps' in self.params: + self._validation_batches = self.params['validation_steps'] + elif self.validation_data: + self._validation_batches = math.ceil( + self.validation_data[0].shape[0] / self.batch_size) + else: + raise ValueError('If printing histograms, validation data must be ' + 'provided.') + if self._validation_batches == 0: + raise ValueError( + 'If printing histograms, validation data must have length > 0.') + + def on_batch_end(self, batch, logs=None): + """Writes scalar summaries for metrics on every training batch.""" + # Don't output batch_size and batch number as Tensorboard summaries + logs = logs or {} + batch_logs = {('batch_' + k): v + for k, v in logs.items() + if k not in ['batch', 'size']} + self._write_custom_summaries(self._total_batches_seen, batch_logs) + self._total_batches_seen += 1 + + def on_epoch_begin(self, epoch, logs=None): + """Add histogram op to Model test_function callbacks, reset batch count.""" + + # check if histogram summary should be run for this epoch + if self.histogram_freq and epoch % self.histogram_freq == 0: + self._epoch = epoch + self._current_val_batch = 0 + # add the histogram summary op if it should run this epoch + if self.merged not in self.model.test_function.fetches: + self.model.test_function.fetches.append(self.merged) + self.model.test_function.fetch_callbacks[ + self.merged] = self._fetch_callback + + def on_epoch_end(self, epoch, logs=None): + """Checks if summary ops should run next epoch, logs scalar summaries.""" + + # don't output batch_size and + # batch number as Tensorboard summaries + logs = {('epoch_' + k): v + for k, v in logs.items() + if k not in ['batch', 'size']} + self._write_custom_summaries(epoch, logs) + + # pop the histogram summary op after each epoch + if self.histogram_freq: + if self.merged in self.model.test_function.fetches: + self.model.test_function.fetches.remove(self.merged) + if self.merged in self.model.test_function.fetch_callbacks: + self.model.test_function.fetch_callbacks.pop(self.merged) + + if self.embeddings_data is None and self.embeddings_freq: + raise ValueError('To visualize embeddings, embeddings_data must ' + 'be provided.') + + if self.embeddings_freq and self.embeddings_data is not None: + if epoch % self.embeddings_freq == 0: + # We need a second forward-pass here because we're passing + # the `embeddings_data` explicitly. This design allows to pass + # arbitrary data as `embeddings_data` and results from the fact + # that we need to know the size of the `tf.Variable`s which + # hold the embeddings in `set_model`. At this point, however, + # the `validation_data` is not yet set. + + embeddings_data = self.embeddings_data + n_samples = embeddings_data[0].shape[0] + i = 0 + while i < n_samples: + step = min(self.batch_size, n_samples - i) + batch = slice(i, i + step) + + if isinstance(self.model.input, list): + feed_dict = { + model_input: embeddings_data[idx][batch] + for idx, model_input in enumerate(self.model.input) + } + else: + feed_dict = {self.model.input: embeddings_data[0][batch]} + + feed_dict.update({self.batch_id: i, self.step: step}) + + if self.model.uses_learning_phase: + feed_dict[K.learning_phase()] = False + + self.sess.run(self.assign_embeddings, feed_dict=feed_dict) + self.saver.save(self.sess, + os.path.join(self.log_dir, 'keras_embedding.ckpt'), + epoch) + + i += self.batch_size + def on_train_end(self, logs=None): self.writer.close() |