diff options
author | Pavithra Vijay <psv@google.com> | 2018-09-26 20:27:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 20:33:41 -0700 |
commit | de2bcdc7ad149419e270e1443b63581163d75d5d (patch) | |
tree | 203bbff75133c9c2ff26b2950a669365cba2ee56 /tensorflow/contrib/tpu | |
parent | 0d5c68e30f4637329fa233df506d7b97802a5e9b (diff) |
Add Mirrored distribution strategy support for new metrics with Keras and Estimator
Add support for stateful metrics in model to estimator
PiperOrigin-RevId: 214714322
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_support.py | 31 |
1 files changed, 8 insertions, 23 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 448676c95e..956d0142a3 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -325,18 +325,6 @@ def _replicated_optimizer(opt): return KerasCrossShardOptimizer(opt) -def _clone_metrics(metrics): - """Returns a copy of metrics. A copy is created for stateful metrics.""" - if metrics is None: - return None - with variable_scope.variable_scope( - 'metrics', reuse=variable_scope.AUTO_REUSE): - return [ - m.__class__.from_config(m.get_config()) if isinstance( - m, metrics_module.Metric) else m for m in metrics - ] - - def _clone_optimizer(optimizer, config=None): """Returns a cloned optimizer with the provided optimizer.config or config.""" if not isinstance(optimizer, keras_optimizers.Optimizer): @@ -963,8 +951,9 @@ class TPUFunction(object): optimizer=_replicated_optimizer(self._cloned_optimizer), loss=self.model.loss, loss_weights=self.model.loss_weights, - metrics=_clone_metrics(self.model.metrics), - weighted_metrics=_clone_metrics(self.model.weighted_metrics), + metrics=metrics_module.clone_metrics(self.model.metrics), + weighted_metrics=metrics_module.clone_metrics( + self.model.weighted_metrics), target_tensors=tpu_targets, ) @@ -1364,13 +1353,9 @@ class KerasTPUModel(models.Model): raise ValueError('target_tensors is not supported for TPU execution.') self._cpu_model.compile( - _clone_optimizer(optimizer), - loss, - _clone_metrics(metrics), - loss_weights, - sample_weight_mode, - _clone_metrics(weighted_metrics), - target_tensors, + _clone_optimizer(optimizer), loss, + metrics_module.clone_metrics(metrics), loss_weights, sample_weight_mode, + metrics_module.clone_metrics(weighted_metrics), target_tensors, **kwargs) super(KerasTPUModel, self).compile(optimizer, loss, metrics, loss_weights, @@ -2126,10 +2111,10 @@ def tpu_model(model, strategy=None): cpu_model.compile( _clone_optimizer(model.optimizer, optimizer_config), model.loss, - _clone_metrics(model.metrics), + metrics_module.clone_metrics(model.metrics), model.loss_weights, model.sample_weight_mode, - _clone_metrics(model.weighted_metrics), + metrics_module.clone_metrics(model.weighted_metrics), ) if model_weights: |