aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/callbacks.py
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-08-10 15:31:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 15:40:37 -0700
commit306dc604991c7f5fa45622a8c8236f59039d2c6a (patch)
tree759651696318dc34da847e2bd55401ab98fe3a9c /tensorflow/python/keras/callbacks.py
parent190015477a88457198548f829e9d814686cdf365 (diff)
Add support for the new metrics in Keras.
- Add support for stateful metrics, weighted metrics, metric masking in eager mode. - Updated masking logic for loss and metrics. (similar to weights in cl/207311700) - Add weighted_metrics to save and load model. - Add Categorical accuracy metric. - Migrating #21071 PiperOrigin-RevId: 208278596
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r--tensorflow/python/keras/callbacks.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 3e112b3132..f2feeb85a1 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -94,9 +94,14 @@ def configure_callbacks(callbacks,
# Add additional callbacks
model.history = History()
- callbacks = [BaseLogger()] + (callbacks or []) + [model.history]
+ stateful_metric_names = None
+ if hasattr(model, 'stateful_metric_names'):
+ stateful_metric_names = model.stateful_metric_names
+ callbacks = [BaseLogger(stateful_metrics=stateful_metric_names)
+ ] + (callbacks or []) + [model.history]
if verbose:
- callbacks.append(ProgbarLogger(count_mode))
+ callbacks.append(
+ ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names))
callback_list = CallbackList(callbacks)
# Set callback model
@@ -110,6 +115,8 @@ def configure_callbacks(callbacks,
# Set callback parameters
callback_metrics = []
+ # When we have deferred build scenario with iterator input, we will compile
+ # when we standardize first batch of data.
if model._is_compiled: # pylint: disable=protected-access
callback_metrics = copy.copy(model.metrics_names)
if do_validation: