aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-09-17 18:54:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 18:59:06 -0700
commit1bd2804869355a7cd0cbfbe9e6aab7591b8a20de (patch)
tree4f839e5c958d23905f3a9ac5cbea75a46dc09b3d /tensorflow/contrib/tpu
parentf2a577888be8368121fe7ce16d4b72f91f53be60 (diff)
Add Keras TPU support for the new metrics.
PiperOrigin-RevId: 213378552
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py15
1 files changed, 13 insertions, 2 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 776b9bff0f..bf445256b6 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -76,6 +76,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers as keras_optimizers
from tensorflow.python.keras.engine import base_layer
@@ -293,6 +294,16 @@ def _replicated_optimizer(opt):
return KerasCrossShardOptimizer(opt)
+def clone_metrics(metrics):
+ """Returns a copy of metrics. A copy is created for stateful metrics."""
+ if metrics is None:
+ return None
+ return [
+ m.__class__.from_config(m.get_config())
+ if isinstance(m, metrics_module.Metric) else m for m in metrics
+ ]
+
+
class TPURewriteContext(object):
"""Prepare the environment for a Keras model during `tpu.rewrite`.
@@ -811,8 +822,8 @@ class TPUFunction(object):
optimizer=_replicated_optimizer(cloned_optimizer),
loss=self.model.loss,
loss_weights=self.model.loss_weights,
- metrics=self.model.metrics,
- weighted_metrics=self.model.weighted_metrics,
+ metrics=clone_metrics(self.model.metrics),
+ weighted_metrics=clone_metrics(self.model.weighted_metrics),
target_tensors=tpu_targets,
)