diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-30 10:34:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-30 10:38:30 -0700 |
commit | 538c7198fdba4ac9b71ba2ceb8c2fb0bb31d20e5 (patch) | |
tree | 29930329068e13c71dcfe02240fbb602764fcaca /tensorflow/python/keras/callbacks.py | |
parent | 1ece2e8e96be2eb39922951619ea99208df93284 (diff) |
BEGIN_PUBLIC
Resubmission of Keras Tensorboard Callback - enable metrics logging in Eager
END_PUBLIC
Previously a test added in this CL was causing Windows cmake build to fail, I
fixed the test.
*** Reason for rollback ***
Fixed test execution
*** Original change description ***
Automated rollback of commit d4cb01f242dc3ff0f7b0aae7284def46281755f2
PiperOrigin-RevId: 206606179
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r-- | tensorflow/python/keras/callbacks.py | 73 |
1 files changed, 53 insertions, 20 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index d1b9dc27bd..070d41147d 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -31,12 +31,14 @@ import time import numpy as np import six +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.keras import backend as K from tensorflow.python.keras.engine.training_utils import standardize_input_data from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.ops import array_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as tf_summary @@ -716,6 +718,15 @@ class TensorBoard(Callback): `embeddings_layer_names`. Numpy array (if the model has a single input) or list of Numpy arrays (if the model has multiple inputs). Learn [more about embeddings](https://www.tensorflow.org/programmers_guide/embedding) + + Raises: + ValueError: If histogram_freq is set and no validation data is provided. + + @compatbility(eager) + Using `Tensorboard` callback will work while eager execution is enabled, + however outputting histogram summaries of weights and gradients is not + supported, and thus `histogram_freq` will be ignored. + @end_compatibility """ # pylint: enable=line-too-long @@ -734,6 +745,11 @@ class TensorBoard(Callback): super(TensorBoard, self).__init__() self.log_dir = log_dir self.histogram_freq = histogram_freq + if self.histogram_freq and context.executing_eagerly(): + logging.warning( + UserWarning('Weight and gradient histograms not supported for eager' + 'execution, setting `histogram_freq` to `0`.')) + self.histogram_freq = 0 self.merged = None self.write_graph = write_graph self.write_grads = write_grads @@ -741,18 +757,22 @@ class TensorBoard(Callback): self.batch_size = batch_size self._current_batch = 0 self._total_batches_seen = 0 - # abstracted writer class to be able to stub for testing - self._writer_class = tf_summary.FileWriter self.embeddings_freq = embeddings_freq self.embeddings_layer_names = embeddings_layer_names self.embeddings_metadata = embeddings_metadata self.embeddings_data = embeddings_data - def set_model(self, model): - """Sets Keras model and creates summary ops.""" + def _init_writer(self): + """Sets file writer.""" + if context.executing_eagerly(): + self.writer = summary_ops_v2.create_file_writer(self.log_dir) + elif self.write_graph: + self.writer = tf_summary.FileWriter(self.log_dir, K.get_session().graph) + else: + self.writer = tf_summary.FileWriter(self.log_dir) - self.model = model - self.sess = K.get_session() + def _make_histogram_ops(self, model): + """Defines histogram ops when histogram_freq > 0.""" # only make histogram summary op if it hasn't already been made if self.histogram_freq and self.merged is None: for layer in self.model.layers: @@ -793,8 +813,10 @@ class TensorBoard(Callback): def is_indexed_slices(grad): return type(grad).__name__ == 'IndexedSlices' - grads = [grad.values if is_indexed_slices(grad) else grad - for grad in grads] + grads = [ + grad.values if is_indexed_slices(grad) else grad + for grad in grads + ] tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads) if hasattr(layer, 'output'): @@ -803,12 +825,16 @@ class TensorBoard(Callback): tf_summary.histogram('{}_out_{}'.format(layer.name, i), output) else: tf_summary.histogram('{}_out'.format(layer.name), layer.output) - self.merged = tf_summary.merge_all() - if self.write_graph: - self.writer = self._writer_class(self.log_dir, self.sess.graph) - else: - self.writer = self._writer_class(self.log_dir) + def set_model(self, model): + """Sets Keras model and creates summary ops.""" + + self.model = model + self._init_writer() + # histogram summaries only enabled in graph mode + if not context.executing_eagerly(): + self._make_histogram_ops(model) + self.merged = tf_summary.merge_all() # If both embedding_freq and embeddings_data are available, we will # visualize embeddings. @@ -894,17 +920,24 @@ class TensorBoard(Callback): """ logs = logs or {} - for name, value in logs.items(): - summary = tf_summary.Summary() - summary_value = summary.value.add() - summary_value.simple_value = value.item() - summary_value.tag = name - self.writer.add_summary(summary, step) + if context.executing_eagerly(): + # use v2 summary ops + with self.writer.as_default(), summary_ops_v2.always_record_summaries(): + for name, value in logs.items(): + summary_ops_v2.scalar(name, value.item(), step=step) + else: + # use FileWriter from v1 summary + for name, value in logs.items(): + summary = tf_summary.Summary() + summary_value = summary.value.add() + summary_value.simple_value = value.item() + summary_value.tag = name + self.writer.add_summary(summary, step) self.writer.flush() def on_train_begin(self, logs=None): """Checks if histogram summaries can be run.""" - + # will never be set when in eager if self.histogram_freq: if 'validation_steps' in self.params: self._validation_batches = self.params['validation_steps'] |