diff options
author | Pavithra Vijay <psv@google.com> | 2018-08-10 15:31:00 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-10 15:40:37 -0700 |
commit | 306dc604991c7f5fa45622a8c8236f59039d2c6a (patch) | |
tree | 759651696318dc34da847e2bd55401ab98fe3a9c /tensorflow/python/keras/callbacks.py | |
parent | 190015477a88457198548f829e9d814686cdf365 (diff) |
Add support for the new metrics in Keras.
- Add support for stateful metrics, weighted metrics, metric masking in eager mode.
- Updated masking logic for loss and metrics. (similar to weights in cl/207311700)
- Add weighted_metrics to save and load model.
- Add Categorical accuracy metric.
- Migrating #21071
PiperOrigin-RevId: 208278596
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r-- | tensorflow/python/keras/callbacks.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 3e112b3132..f2feeb85a1 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -94,9 +94,14 @@ def configure_callbacks(callbacks, # Add additional callbacks model.history = History() - callbacks = [BaseLogger()] + (callbacks or []) + [model.history] + stateful_metric_names = None + if hasattr(model, 'stateful_metric_names'): + stateful_metric_names = model.stateful_metric_names + callbacks = [BaseLogger(stateful_metrics=stateful_metric_names) + ] + (callbacks or []) + [model.history] if verbose: - callbacks.append(ProgbarLogger(count_mode)) + callbacks.append( + ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names)) callback_list = CallbackList(callbacks) # Set callback model @@ -110,6 +115,8 @@ def configure_callbacks(callbacks, # Set callback parameters callback_metrics = [] + # When we have deferred build scenario with iterator input, we will compile + # when we standardize first batch of data. if model._is_compiled: # pylint: disable=protected-access callback_metrics = copy.copy(model.metrics_names) if do_validation: |