aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/callbacks.py
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-07-15 11:28:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-15 11:32:18 -0700
commite6ce9ea5a156873c5b927e99d8935e32122538b9 (patch)
treed86b172c9586ee0736807b25c41c2d40c92e6bd8 /tensorflow/python/keras/callbacks.py
parentfe7d1d9447a31562acb26aad7a9ffca60686c38a (diff)
Partial update of tf.keras to the Keras 2.2.0 API.
Changes included are: - Embedding visualization is added to TensorBoard callback (from older Keras API.) - Fix: learning phase info being left out in multi-input models (from older Keras API.) - Fix: Tensorboard callback only supports logging Embeddings layer weights - Fix: Tensorboard callback with layer with multiple outputs PiperOrigin-RevId: 204659796
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r--tensorflow/python/keras/callbacks.py137
1 files changed, 134 insertions, 3 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 53d907a2cc..0857a3279f 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -31,11 +31,16 @@ import time
import numpy as np
import six
+from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.engine.training_utils import standardize_input_data
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as tf_summary
+from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import tf_export
@@ -697,7 +702,9 @@ class TensorBoard(Callback):
write_images: whether to write model weights to visualize as
image in TensorBoard.
embeddings_freq: frequency (in epochs) at which selected embedding
- layers will be saved.
+ layers will be saved. If set to 0, embeddings won't be computed.
+ Data to be visualized in TensorBoard's Embedding tab must be passed
+ as `embeddings_data`.
embeddings_layer_names: a list of names of layers to keep eye on. If
None or empty list all the embedding layer will be watched.
embeddings_metadata: a dictionary which maps layer name to a file name
@@ -705,6 +712,10 @@ class TensorBoard(Callback):
[details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
about metadata files format. In case if the same metadata file is
used for all embedding layers, string can be passed.
+ embeddings_data: data to be embedded at layers specified in
+ `embeddings_layer_names`. Numpy array (if the model has a single
+ input) or list of Numpy arrays (if the model has multiple inputs).
+ Learn [more about embeddings](https://www.tensorflow.org/programmers_guide/embedding)
"""
# pylint: enable=line-too-long
@@ -715,7 +726,11 @@ class TensorBoard(Callback):
batch_size=32,
write_graph=True,
write_grads=False,
- write_images=False):
+ write_images=False,
+ embeddings_freq=0,
+ embeddings_layer_names=None,
+ embeddings_metadata=None,
+ embeddings_data=None):
super(TensorBoard, self).__init__()
self.log_dir = log_dir
self.histogram_freq = histogram_freq
@@ -727,6 +742,10 @@ class TensorBoard(Callback):
self._current_batch = 0
# abstracted writer class to be able to stub for testing
self._writer_class = tf_summary.FileWriter
+ self.embeddings_freq = embeddings_freq
+ self.embeddings_layer_names = embeddings_layer_names
+ self.embeddings_metadata = embeddings_metadata
+ self.embeddings_data = embeddings_data
def set_model(self, model):
"""Sets Keras model and creates summary ops."""
@@ -778,7 +797,11 @@ class TensorBoard(Callback):
tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
if hasattr(layer, 'output'):
- tf_summary.histogram('{}_out'.format(layer.name), layer.output)
+ if isinstance(layer.output, list):
+ for i, output in enumerate(layer.output):
+ tf_summary.histogram('{}_out_{}'.format(layer.name, i), output)
+ else:
+ tf_summary.histogram('{}_out'.format(layer.name), layer.output)
self.merged = tf_summary.merge_all()
if self.write_graph:
@@ -786,6 +809,74 @@ class TensorBoard(Callback):
else:
self.writer = self._writer_class(self.log_dir)
+ # If both embedding_freq and embeddings_data are available, we will
+ # visualize embeddings.
+ if self.embeddings_freq and self.embeddings_data is not None:
+ self.embeddings_data = standardize_input_data(self.embeddings_data,
+ model.input_names)
+
+ # If embedding_layer_names are not provided, get all of the embedding
+ # layers from the model.
+ embeddings_layer_names = self.embeddings_layer_names
+ if not embeddings_layer_names:
+ embeddings_layer_names = [
+ layer.name
+ for layer in self.model.layers
+ if type(layer).__name__ == 'Embedding'
+ ]
+
+ self.assign_embeddings = []
+ embeddings_vars = {}
+
+ self.batch_id = batch_id = array_ops.placeholder(dtypes.int32)
+ self.step = step = array_ops.placeholder(dtypes.int32)
+
+ for layer in self.model.layers:
+ if layer.name in embeddings_layer_names:
+ embedding_input = self.model.get_layer(layer.name).output
+ embedding_size = np.prod(embedding_input.shape[1:])
+ embedding_input = array_ops.reshape(embedding_input,
+ (step, int(embedding_size)))
+ shape = (self.embeddings_data[0].shape[0], int(embedding_size))
+ embedding = variables.Variable(
+ array_ops.zeros(shape), name=layer.name + '_embedding')
+ embeddings_vars[layer.name] = embedding
+ batch = state_ops.assign(embedding[batch_id:batch_id + step],
+ embedding_input)
+ self.assign_embeddings.append(batch)
+
+ self.saver = saver.Saver(list(embeddings_vars.values()))
+
+ # Create embeddings_metadata dictionary
+ if isinstance(self.embeddings_metadata, str):
+ embeddings_metadata = {
+ layer_name: self.embeddings_metadata
+ for layer_name in embeddings_vars.keys()
+ }
+ else:
+ # If embedding_metadata is already a dictionary
+ embeddings_metadata = self.embeddings_metadata
+
+ try:
+ from tensorboard.plugins import projector
+ except ImportError:
+ raise ImportError('Failed to import TensorBoard. Please make sure that '
+ 'TensorBoard integration is complete."')
+
+ # TODO(psv): Add integration tests to test embedding visualization
+ # with TensorBoard callback. We are unable to write a unit test for this
+ # because TensorBoard dependency assumes TensorFlow package is installed.
+ config = projector.ProjectorConfig()
+ for layer_name, tensor in embeddings_vars.items():
+ embedding = config.embeddings.add()
+ embedding.tensor_name = tensor.name
+
+ if (embeddings_metadata is not None and
+ layer_name in embeddings_metadata):
+ embedding.metadata_path = embeddings_metadata[layer_name]
+
+ projector.visualize_embeddings(self.writer, config)
+
def _fetch_callback(self, summary):
self.writer.add_summary(
summary,
@@ -833,6 +924,46 @@ class TensorBoard(Callback):
if self.merged in self.model.test_function.fetch_callbacks:
self.model.test_function.fetch_callbacks.pop(self.merged)
+ if self.embeddings_data is None and self.embeddings_freq:
+ raise ValueError('To visualize embeddings, embeddings_data must '
+ 'be provided.')
+
+ if self.embeddings_freq and self.embeddings_data is not None:
+ if epoch % self.embeddings_freq == 0:
+ # We need a second forward-pass here because we're passing
+ # the `embeddings_data` explicitly. This design allows to pass
+ # arbitrary data as `embeddings_data` and results from the fact
+ # that we need to know the size of the `tf.Variable`s which
+ # hold the embeddings in `set_model`. At this point, however,
+ # the `validation_data` is not yet set.
+
+ embeddings_data = self.embeddings_data
+ n_samples = embeddings_data[0].shape[0]
+ i = 0
+ while i < n_samples:
+ step = min(self.batch_size, n_samples - i)
+ batch = slice(i, i + step)
+
+ if isinstance(self.model.input, list):
+ feed_dict = {
+ model_input: embeddings_data[idx][batch]
+ for idx, model_input in enumerate(self.model.input)
+ }
+ else:
+ feed_dict = {self.model.input: embeddings_data[0][batch]}
+
+ feed_dict.update({self.batch_id: i, self.step: step})
+
+ if self.model.uses_learning_phase:
+ feed_dict[K.learning_phase()] = False
+
+ self.sess.run(self.assign_embeddings, feed_dict=feed_dict)
+ self.saver.save(self.sess,
+ os.path.join(self.log_dir, 'keras_embedding.ckpt'),
+ epoch)
+
+ i += self.batch_size
+
for name, value in logs.items():
if name in ['batch', 'size']:
continue