aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/callbacks.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r--tensorflow/python/keras/callbacks.py266
1 files changed, 215 insertions, 51 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 00a9c479fb..d1b9dc27bd 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -24,17 +24,23 @@ from collections import Iterable
from collections import OrderedDict
import csv
import json
+import math
import os
import time
import numpy as np
import six
+from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.engine.training_utils import standardize_input_data
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as tf_summary
+from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import tf_export
@@ -696,7 +702,9 @@ class TensorBoard(Callback):
write_images: whether to write model weights to visualize as
image in TensorBoard.
embeddings_freq: frequency (in epochs) at which selected embedding
- layers will be saved.
+ layers will be saved. If set to 0, embeddings won't be computed.
+ Data to be visualized in TensorBoard's Embedding tab must be passed
+ as `embeddings_data`.
embeddings_layer_names: a list of names of layers to keep eye on. If
None or empty list all the embedding layer will be watched.
embeddings_metadata: a dictionary which maps layer name to a file name
@@ -704,6 +712,10 @@ class TensorBoard(Callback):
[details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
about metadata files format. In case if the same metadata file is
used for all embedding layers, string can be passed.
+ embeddings_data: data to be embedded at layers specified in
+ `embeddings_layer_names`. Numpy array (if the model has a single
+ input) or list of Numpy arrays (if the model has multiple inputs).
+ Learn [more about embeddings](https://www.tensorflow.org/programmers_guide/embedding)
"""
# pylint: enable=line-too-long
@@ -714,7 +726,11 @@ class TensorBoard(Callback):
batch_size=32,
write_graph=True,
write_grads=False,
- write_images=False):
+ write_images=False,
+ embeddings_freq=0,
+ embeddings_layer_names=None,
+ embeddings_metadata=None,
+ embeddings_data=None):
super(TensorBoard, self).__init__()
self.log_dir = log_dir
self.histogram_freq = histogram_freq
@@ -723,10 +739,21 @@ class TensorBoard(Callback):
self.write_grads = write_grads
self.write_images = write_images
self.batch_size = batch_size
+ self._current_batch = 0
+ self._total_batches_seen = 0
+ # abstracted writer class to be able to stub for testing
+ self._writer_class = tf_summary.FileWriter
+ self.embeddings_freq = embeddings_freq
+ self.embeddings_layer_names = embeddings_layer_names
+ self.embeddings_metadata = embeddings_metadata
+ self.embeddings_data = embeddings_data
def set_model(self, model):
+ """Sets Keras model and creates summary ops."""
+
self.model = model
self.sess = K.get_session()
+ # only make histogram summary op if it hasn't already been made
if self.histogram_freq and self.merged is None:
for layer in self.model.layers:
for weight in layer.weights:
@@ -771,69 +798,206 @@ class TensorBoard(Callback):
tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
if hasattr(layer, 'output'):
- tf_summary.histogram('{}_out'.format(layer.name), layer.output)
+ if isinstance(layer.output, list):
+ for i, output in enumerate(layer.output):
+ tf_summary.histogram('{}_out_{}'.format(layer.name, i), output)
+ else:
+ tf_summary.histogram('{}_out'.format(layer.name), layer.output)
self.merged = tf_summary.merge_all()
if self.write_graph:
- self.writer = tf_summary.FileWriter(self.log_dir, self.sess.graph)
+ self.writer = self._writer_class(self.log_dir, self.sess.graph)
else:
- self.writer = tf_summary.FileWriter(self.log_dir)
-
- def on_epoch_end(self, epoch, logs=None):
- logs = logs or {}
+ self.writer = self._writer_class(self.log_dir)
+
+ # If both embedding_freq and embeddings_data are available, we will
+ # visualize embeddings.
+ if self.embeddings_freq and self.embeddings_data is not None:
+ self.embeddings_data = standardize_input_data(self.embeddings_data,
+ model.input_names)
+
+ # If embedding_layer_names are not provided, get all of the embedding
+ # layers from the model.
+ embeddings_layer_names = self.embeddings_layer_names
+ if not embeddings_layer_names:
+ embeddings_layer_names = [
+ layer.name
+ for layer in self.model.layers
+ if type(layer).__name__ == 'Embedding'
+ ]
+
+ self.assign_embeddings = []
+ embeddings_vars = {}
+
+ self.batch_id = batch_id = array_ops.placeholder(dtypes.int32)
+ self.step = step = array_ops.placeholder(dtypes.int32)
- if not self.validation_data and self.histogram_freq:
- raise ValueError('If printing histograms, validation_data must be '
- 'provided, and cannot be a generator.')
- if self.validation_data and self.histogram_freq:
- if epoch % self.histogram_freq == 0:
-
- val_data = self.validation_data
- tensors = (
- self.model.inputs + self.model.targets + self.model.sample_weights)
-
- if self.model.uses_learning_phase:
- tensors += [K.learning_phase()]
+ for layer in self.model.layers:
+ if layer.name in embeddings_layer_names:
+ embedding_input = self.model.get_layer(layer.name).output
+ embedding_size = np.prod(embedding_input.shape[1:])
+ embedding_input = array_ops.reshape(embedding_input,
+ (step, int(embedding_size)))
+ shape = (self.embeddings_data[0].shape[0], int(embedding_size))
+ embedding = variables.Variable(
+ array_ops.zeros(shape), name=layer.name + '_embedding')
+ embeddings_vars[layer.name] = embedding
+ batch = state_ops.assign(embedding[batch_id:batch_id + step],
+ embedding_input)
+ self.assign_embeddings.append(batch)
+
+ self.saver = saver.Saver(list(embeddings_vars.values()))
+
+ # Create embeddings_metadata dictionary
+ if isinstance(self.embeddings_metadata, str):
+ embeddings_metadata = {
+ layer_name: self.embeddings_metadata
+ for layer_name in embeddings_vars.keys()
+ }
+ else:
+ # If embedding_metadata is already a dictionary
+ embeddings_metadata = self.embeddings_metadata
+
+ try:
+ from tensorboard.plugins import projector
+ except ImportError:
+ raise ImportError('Failed to import TensorBoard. Please make sure that '
+ 'TensorBoard integration is complete."')
+
+ # TODO(psv): Add integration tests to test embedding visualization
+ # with TensorBoard callback. We are unable to write a unit test for this
+ # because TensorBoard dependency assumes TensorFlow package is installed.
+ config = projector.ProjectorConfig()
+ for layer_name, tensor in embeddings_vars.items():
+ embedding = config.embeddings.add()
+ embedding.tensor_name = tensor.name
+
+ if (embeddings_metadata is not None and
+ layer_name in embeddings_metadata):
+ embedding.metadata_path = embeddings_metadata[layer_name]
+
+ projector.visualize_embeddings(self.writer, config)
+
+ def _fetch_callback(self, summary):
+ self.writer.add_summary(
+ summary,
+ self._epoch + self._current_val_batch / self._validation_batches)
+ self._current_val_batch += 1
+
+ def _write_custom_summaries(self, step, logs=None):
+ """Writes metrics out as custom scalar summaries.
- assert len(val_data) == len(tensors)
- val_size = val_data[0].shape[0]
- i = 0
- while i < val_size:
- step = min(self.batch_size, val_size - i)
- batch_val = []
- batch_val.append(val_data[0][i:i + step]
- if val_data[0] is not None else None)
- batch_val.append(val_data[1][i:i + step]
- if val_data[1] is not None else None)
- batch_val.append(val_data[2][i:i + step]
- if val_data[2] is not None else None)
- if self.model.uses_learning_phase:
- # do not slice the learning phase
- batch_val = [x[i:i + step] if x is not None else None
- for x in val_data[:-1]]
- batch_val.append(val_data[-1])
- else:
- batch_val = [x[i:i + step] if x is not None else None
- for x in val_data]
- feed_dict = {}
- for key, val in zip(tensors, batch_val):
- if val is not None:
- feed_dict[key] = val
- result = self.sess.run([self.merged], feed_dict=feed_dict)
- summary_str = result[0]
- self.writer.add_summary(summary_str, epoch)
- i += self.batch_size
+ Arguments:
+ step: the global step to use for Tensorboard.
+ logs: dict. Keys are scalar summary names, values are
+ NumPy scalars.
+ """
+ logs = logs or {}
for name, value in logs.items():
- if name in ['batch', 'size']:
- continue
summary = tf_summary.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value.item()
summary_value.tag = name
- self.writer.add_summary(summary, epoch)
+ self.writer.add_summary(summary, step)
self.writer.flush()
+ def on_train_begin(self, logs=None):
+ """Checks if histogram summaries can be run."""
+
+ if self.histogram_freq:
+ if 'validation_steps' in self.params:
+ self._validation_batches = self.params['validation_steps']
+ elif self.validation_data:
+ self._validation_batches = math.ceil(
+ self.validation_data[0].shape[0] / self.batch_size)
+ else:
+ raise ValueError('If printing histograms, validation data must be '
+ 'provided.')
+ if self._validation_batches == 0:
+ raise ValueError(
+ 'If printing histograms, validation data must have length > 0.')
+
+ def on_batch_end(self, batch, logs=None):
+ """Writes scalar summaries for metrics on every training batch."""
+ # Don't output batch_size and batch number as Tensorboard summaries
+ logs = logs or {}
+ batch_logs = {('batch_' + k): v
+ for k, v in logs.items()
+ if k not in ['batch', 'size']}
+ self._write_custom_summaries(self._total_batches_seen, batch_logs)
+ self._total_batches_seen += 1
+
+ def on_epoch_begin(self, epoch, logs=None):
+ """Add histogram op to Model test_function callbacks, reset batch count."""
+
+ # check if histogram summary should be run for this epoch
+ if self.histogram_freq and epoch % self.histogram_freq == 0:
+ self._epoch = epoch
+ self._current_val_batch = 0
+ # add the histogram summary op if it should run this epoch
+ if self.merged not in self.model.test_function.fetches:
+ self.model.test_function.fetches.append(self.merged)
+ self.model.test_function.fetch_callbacks[
+ self.merged] = self._fetch_callback
+
+ def on_epoch_end(self, epoch, logs=None):
+ """Checks if summary ops should run next epoch, logs scalar summaries."""
+
+ # don't output batch_size and
+ # batch number as Tensorboard summaries
+ logs = {('epoch_' + k): v
+ for k, v in logs.items()
+ if k not in ['batch', 'size']}
+ self._write_custom_summaries(epoch, logs)
+
+ # pop the histogram summary op after each epoch
+ if self.histogram_freq:
+ if self.merged in self.model.test_function.fetches:
+ self.model.test_function.fetches.remove(self.merged)
+ if self.merged in self.model.test_function.fetch_callbacks:
+ self.model.test_function.fetch_callbacks.pop(self.merged)
+
+ if self.embeddings_data is None and self.embeddings_freq:
+ raise ValueError('To visualize embeddings, embeddings_data must '
+ 'be provided.')
+
+ if self.embeddings_freq and self.embeddings_data is not None:
+ if epoch % self.embeddings_freq == 0:
+ # We need a second forward-pass here because we're passing
+ # the `embeddings_data` explicitly. This design allows to pass
+ # arbitrary data as `embeddings_data` and results from the fact
+ # that we need to know the size of the `tf.Variable`s which
+ # hold the embeddings in `set_model`. At this point, however,
+ # the `validation_data` is not yet set.
+
+ embeddings_data = self.embeddings_data
+ n_samples = embeddings_data[0].shape[0]
+ i = 0
+ while i < n_samples:
+ step = min(self.batch_size, n_samples - i)
+ batch = slice(i, i + step)
+
+ if isinstance(self.model.input, list):
+ feed_dict = {
+ model_input: embeddings_data[idx][batch]
+ for idx, model_input in enumerate(self.model.input)
+ }
+ else:
+ feed_dict = {self.model.input: embeddings_data[0][batch]}
+
+ feed_dict.update({self.batch_id: i, self.step: step})
+
+ if self.model.uses_learning_phase:
+ feed_dict[K.learning_phase()] = False
+
+ self.sess.run(self.assign_embeddings, feed_dict=feed_dict)
+ self.saver.save(self.sess,
+ os.path.join(self.log_dir, 'keras_embedding.ckpt'),
+ epoch)
+
+ i += self.batch_size
+
def on_train_end(self, logs=None):
self.writer.close()