diff options
author | 2018-09-17 18:54:34 -0700 | |
---|---|---|
committer | 2018-09-17 18:59:06 -0700 | |
commit | 1bd2804869355a7cd0cbfbe9e6aab7591b8a20de (patch) | |
tree | 4f839e5c958d23905f3a9ac5cbea75a46dc09b3d /tensorflow/contrib/tpu | |
parent | f2a577888be8368121fe7ce16d4b72f91f53be60 (diff) |
Add Keras TPU support for the new metrics.
PiperOrigin-RevId: 213378552
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_support.py | 15 |
1 files changed, 13 insertions, 2 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 776b9bff0f..bf445256b6 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -76,6 +76,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks as cbks +from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import models from tensorflow.python.keras import optimizers as keras_optimizers from tensorflow.python.keras.engine import base_layer @@ -293,6 +294,16 @@ def _replicated_optimizer(opt): return KerasCrossShardOptimizer(opt) +def clone_metrics(metrics): + """Returns a copy of metrics. A copy is created for stateful metrics.""" + if metrics is None: + return None + return [ + m.__class__.from_config(m.get_config()) + if isinstance(m, metrics_module.Metric) else m for m in metrics + ] + + class TPURewriteContext(object): """Prepare the environment for a Keras model during `tpu.rewrite`. @@ -811,8 +822,8 @@ class TPUFunction(object): optimizer=_replicated_optimizer(cloned_optimizer), loss=self.model.loss, loss_weights=self.model.loss_weights, - metrics=self.model.metrics, - weighted_metrics=self.model.weighted_metrics, + metrics=clone_metrics(self.model.metrics), + weighted_metrics=clone_metrics(self.model.weighted_metrics), target_tensors=tpu_targets, ) |