aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/callbacks.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-30 10:34:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-30 10:38:30 -0700
commit538c7198fdba4ac9b71ba2ceb8c2fb0bb31d20e5 (patch)
tree29930329068e13c71dcfe02240fbb602764fcaca /tensorflow/python/keras/callbacks.py
parent1ece2e8e96be2eb39922951619ea99208df93284 (diff)
BEGIN_PUBLIC
Resubmission of Keras Tensorboard Callback - enable metrics logging in Eager END_PUBLIC Previously a test added in this CL was causing Windows cmake build to fail, I fixed the test. *** Reason for rollback *** Fixed test execution *** Original change description *** Automated rollback of commit d4cb01f242dc3ff0f7b0aae7284def46281755f2 PiperOrigin-RevId: 206606179
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r--tensorflow/python/keras/callbacks.py73
1 files changed, 53 insertions, 20 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index d1b9dc27bd..070d41147d 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -31,12 +31,14 @@ import time
import numpy as np
import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine.training_utils import standardize_input_data
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as tf_summary
@@ -716,6 +718,15 @@ class TensorBoard(Callback):
`embeddings_layer_names`. Numpy array (if the model has a single
input) or list of Numpy arrays (if the model has multiple inputs).
Learn [more about embeddings](https://www.tensorflow.org/programmers_guide/embedding)
+
+ Raises:
+ ValueError: If histogram_freq is set and no validation data is provided.
+
+ @compatbility(eager)
+ Using `Tensorboard` callback will work while eager execution is enabled,
+ however outputting histogram summaries of weights and gradients is not
+ supported, and thus `histogram_freq` will be ignored.
+ @end_compatibility
"""
# pylint: enable=line-too-long
@@ -734,6 +745,11 @@ class TensorBoard(Callback):
super(TensorBoard, self).__init__()
self.log_dir = log_dir
self.histogram_freq = histogram_freq
+ if self.histogram_freq and context.executing_eagerly():
+ logging.warning(
+ UserWarning('Weight and gradient histograms not supported for eager'
+ 'execution, setting `histogram_freq` to `0`.'))
+ self.histogram_freq = 0
self.merged = None
self.write_graph = write_graph
self.write_grads = write_grads
@@ -741,18 +757,22 @@ class TensorBoard(Callback):
self.batch_size = batch_size
self._current_batch = 0
self._total_batches_seen = 0
- # abstracted writer class to be able to stub for testing
- self._writer_class = tf_summary.FileWriter
self.embeddings_freq = embeddings_freq
self.embeddings_layer_names = embeddings_layer_names
self.embeddings_metadata = embeddings_metadata
self.embeddings_data = embeddings_data
- def set_model(self, model):
- """Sets Keras model and creates summary ops."""
+ def _init_writer(self):
+ """Sets file writer."""
+ if context.executing_eagerly():
+ self.writer = summary_ops_v2.create_file_writer(self.log_dir)
+ elif self.write_graph:
+ self.writer = tf_summary.FileWriter(self.log_dir, K.get_session().graph)
+ else:
+ self.writer = tf_summary.FileWriter(self.log_dir)
- self.model = model
- self.sess = K.get_session()
+ def _make_histogram_ops(self, model):
+ """Defines histogram ops when histogram_freq > 0."""
# only make histogram summary op if it hasn't already been made
if self.histogram_freq and self.merged is None:
for layer in self.model.layers:
@@ -793,8 +813,10 @@ class TensorBoard(Callback):
def is_indexed_slices(grad):
return type(grad).__name__ == 'IndexedSlices'
- grads = [grad.values if is_indexed_slices(grad) else grad
- for grad in grads]
+ grads = [
+ grad.values if is_indexed_slices(grad) else grad
+ for grad in grads
+ ]
tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
if hasattr(layer, 'output'):
@@ -803,12 +825,16 @@ class TensorBoard(Callback):
tf_summary.histogram('{}_out_{}'.format(layer.name, i), output)
else:
tf_summary.histogram('{}_out'.format(layer.name), layer.output)
- self.merged = tf_summary.merge_all()
- if self.write_graph:
- self.writer = self._writer_class(self.log_dir, self.sess.graph)
- else:
- self.writer = self._writer_class(self.log_dir)
+ def set_model(self, model):
+ """Sets Keras model and creates summary ops."""
+
+ self.model = model
+ self._init_writer()
+ # histogram summaries only enabled in graph mode
+ if not context.executing_eagerly():
+ self._make_histogram_ops(model)
+ self.merged = tf_summary.merge_all()
# If both embedding_freq and embeddings_data are available, we will
# visualize embeddings.
@@ -894,17 +920,24 @@ class TensorBoard(Callback):
"""
logs = logs or {}
- for name, value in logs.items():
- summary = tf_summary.Summary()
- summary_value = summary.value.add()
- summary_value.simple_value = value.item()
- summary_value.tag = name
- self.writer.add_summary(summary, step)
+ if context.executing_eagerly():
+ # use v2 summary ops
+ with self.writer.as_default(), summary_ops_v2.always_record_summaries():
+ for name, value in logs.items():
+ summary_ops_v2.scalar(name, value.item(), step=step)
+ else:
+ # use FileWriter from v1 summary
+ for name, value in logs.items():
+ summary = tf_summary.Summary()
+ summary_value = summary.value.add()
+ summary_value.simple_value = value.item()
+ summary_value.tag = name
+ self.writer.add_summary(summary, step)
self.writer.flush()
def on_train_begin(self, logs=None):
"""Checks if histogram summaries can be run."""
-
+ # will never be set when in eager
if self.histogram_freq:
if 'validation_steps' in self.params:
self._validation_batches = self.params['validation_steps']