aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-09-26 20:27:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 20:33:41 -0700
commitde2bcdc7ad149419e270e1443b63581163d75d5d (patch)
tree203bbff75133c9c2ff26b2950a669365cba2ee56 /tensorflow/contrib/tpu
parent0d5c68e30f4637329fa233df506d7b97802a5e9b (diff)
Add Mirrored distribution strategy support for new metrics with Keras and Estimator
Add support for stateful metrics in model to estimator PiperOrigin-RevId: 214714322
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py31
1 files changed, 8 insertions, 23 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 448676c95e..956d0142a3 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -325,18 +325,6 @@ def _replicated_optimizer(opt):
return KerasCrossShardOptimizer(opt)
-def _clone_metrics(metrics):
- """Returns a copy of metrics. A copy is created for stateful metrics."""
- if metrics is None:
- return None
- with variable_scope.variable_scope(
- 'metrics', reuse=variable_scope.AUTO_REUSE):
- return [
- m.__class__.from_config(m.get_config()) if isinstance(
- m, metrics_module.Metric) else m for m in metrics
- ]
-
-
def _clone_optimizer(optimizer, config=None):
"""Returns a cloned optimizer with the provided optimizer.config or config."""
if not isinstance(optimizer, keras_optimizers.Optimizer):
@@ -963,8 +951,9 @@ class TPUFunction(object):
optimizer=_replicated_optimizer(self._cloned_optimizer),
loss=self.model.loss,
loss_weights=self.model.loss_weights,
- metrics=_clone_metrics(self.model.metrics),
- weighted_metrics=_clone_metrics(self.model.weighted_metrics),
+ metrics=metrics_module.clone_metrics(self.model.metrics),
+ weighted_metrics=metrics_module.clone_metrics(
+ self.model.weighted_metrics),
target_tensors=tpu_targets,
)
@@ -1364,13 +1353,9 @@ class KerasTPUModel(models.Model):
raise ValueError('target_tensors is not supported for TPU execution.')
self._cpu_model.compile(
- _clone_optimizer(optimizer),
- loss,
- _clone_metrics(metrics),
- loss_weights,
- sample_weight_mode,
- _clone_metrics(weighted_metrics),
- target_tensors,
+ _clone_optimizer(optimizer), loss,
+ metrics_module.clone_metrics(metrics), loss_weights, sample_weight_mode,
+ metrics_module.clone_metrics(weighted_metrics), target_tensors,
**kwargs)
super(KerasTPUModel, self).compile(optimizer, loss, metrics, loss_weights,
@@ -2126,10 +2111,10 @@ def tpu_model(model, strategy=None):
cpu_model.compile(
_clone_optimizer(model.optimizer, optimizer_config),
model.loss,
- _clone_metrics(model.metrics),
+ metrics_module.clone_metrics(model.metrics),
model.loss_weights,
model.sample_weight_mode,
- _clone_metrics(model.weighted_metrics),
+ metrics_module.clone_metrics(model.weighted_metrics),
)
if model_weights: