aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-09-26 20:27:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 20:33:41 -0700
commitde2bcdc7ad149419e270e1443b63581163d75d5d (patch)
tree203bbff75133c9c2ff26b2950a669365cba2ee56 /tensorflow/python/keras
parent0d5c68e30f4637329fa233df506d7b97802a5e9b (diff)
Add Mirrored distribution strategy support for new metrics with Keras and Estimator
Add support for stateful metrics in model to estimator PiperOrigin-RevId: 214714322
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/engine/training.py6
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py51
-rw-r--r--tensorflow/python/keras/metrics.py16
-rw-r--r--tensorflow/python/keras/models.py9
4 files changed, 60 insertions, 22 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index ade8a4b32d..46bffd7068 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -647,12 +647,6 @@ class Model(Network):
skip_target_indices=skip_target_indices,
sample_weights=self.sample_weights)
- # If using distribution strategy and stateful_metrics, raise an error
- # since we currently don't support stateful metrics.
- if self._distribution_strategy is not None and self.stateful_metric_names:
- raise NotImplementedError('Stateful metrics are not supported with '
- 'DistributionStrategy.')
-
# Prepare gradient updates and state updates.
self.total_loss = total_loss
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 8b434ca444..1b64f904d5 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -26,6 +26,7 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import distributed_training_utils
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
@@ -153,6 +154,9 @@ def fit_loop(
assert steps_per_epoch is not None
for epoch in range(initial_epoch, epochs):
+ # Reset stateful metrics
+ for m in model.stateful_metric_functions:
+ m.reset_states()
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
for step_index in range(steps_per_epoch):
@@ -171,8 +175,9 @@ def fit_loop(
if not isinstance(outs, list):
outs = [outs]
- outs = _aggregate_metrics_across_towers(
- current_strategy.num_towers, out_labels, outs)
+ outs = _aggregate_metrics_across_towers(current_strategy.num_towers,
+ out_labels,
+ model.stateful_metric_names, outs)
for l, o in zip(out_labels, outs):
batch_logs[l] = o
callbacks.on_batch_end(step_index, batch_logs)
@@ -437,6 +442,13 @@ def test_loop(model, iterator, verbose=0, steps=None):
else:
ins = dataset_inputs + dataset_targets
+ for m in model.stateful_metric_functions:
+ m.reset_states()
+ stateful_metric_indices = [
+ i for i, name in enumerate(model.metrics_names)
+ if str(name) in model.stateful_metric_names
+ ]
+
outs = []
if verbose == 1:
progbar = Progbar(target=steps)
@@ -452,12 +464,16 @@ def test_loop(model, iterator, verbose=0, steps=None):
for step in range(steps):
batch_outs = distributed_test_function(ins)
batch_outs = _aggregate_metrics_across_towers(
- current_strategy.num_towers, model.metrics_names, batch_outs)
+ current_strategy.num_towers, model.metrics_names,
+ model.stateful_metric_names, batch_outs)
if isinstance(batch_outs, list):
if step == 0:
outs = [0.] * len(batch_outs)
for i, batch_out in enumerate(batch_outs):
- outs[i] += batch_out
+ if i in stateful_metric_indices:
+ outs[i] = batch_out
+ else:
+ outs[i] += batch_out
else:
if step == 0:
outs.append(0.)
@@ -465,7 +481,8 @@ def test_loop(model, iterator, verbose=0, steps=None):
if verbose >= 1:
progbar.update(step + 1)
for i in range(len(outs)):
- outs[i] /= steps
+ if i not in stateful_metric_indices:
+ outs[i] /= steps
if len(outs) == 1:
return outs[0]
@@ -816,10 +833,10 @@ def _clone_and_build_model(model, inputs=None, targets=None):
cloned_model.compile(
optimizer,
model.loss,
- metrics=model.metrics,
+ metrics=metrics_module.clone_metrics(model.metrics),
loss_weights=model.loss_weights,
sample_weight_mode=model.sample_weight_mode,
- weighted_metrics=model.weighted_metrics,
+ weighted_metrics=metrics_module.clone_metrics(model.weighted_metrics),
target_tensors=targets)
return cloned_model
@@ -834,8 +851,9 @@ def clone_model_on_towers(
model._make_callback_model()
-def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
- """Aggregate metrics values across all towers.
+def _aggregate_metrics_across_towers(num_devices, out_labels,
+ stateful_metric_names, outs):
+ """Aggregates stateless metrics values across towers.
When using `MirroredStrategy`, the number of towers is equal to the
number of devices over which training is distributed. This may not always be
@@ -844,6 +862,7 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
Args:
num_devices: Number of devices over which the model is being distributed.
out_labels: The list of metric names passed to `compile`.
+ stateful_metric_names: List of stateful metric names on the model.
outs: The output from all the towers.
Returns:
@@ -858,10 +877,16 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
# Each label in `out_labels` corresponds to one set of metrics. The
# number of metric values corresponds to the number of devices. We
# currently take the mean of the values.
- for _ in out_labels[1:]:
- m = np.mean(outs[current_index:current_index + num_devices])
- merged_output.append(m)
- current_index += num_devices
+ for metric_name in out_labels[1:]:
+ if metric_name in stateful_metric_names:
+ # For stateful metrics, we get one aggregated result value.
+ merged_output.append(outs[current_index])
+ current_index += 1
+ else:
+ m = np.mean(outs[current_index:current_index + num_devices])
+ merged_output.append(m)
+ current_index += num_devices
+
return merged_output
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index e64241e5cf..f4e8419eb0 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -71,6 +71,22 @@ def check_is_tensor_or_operation(x, name):
name, x))
+def clone_metric(metric):
+ """Returns a clone of the metric if stateful, otherwise returns it as is."""
+ if isinstance(metric, Metric):
+ return metric.__class__.from_config(metric.get_config())
+ return metric
+
+
+def clone_metrics(metrics):
+ """Clones the given metric list/dict."""
+ if metrics is None:
+ return None
+ if isinstance(metrics, dict):
+ return {key: clone_metric(value) for key, value in metrics.items()}
+ return [clone_metric(metric) for metric in metrics]
+
+
def update_state_wrapper(update_state_fn):
"""Decorator to wrap metric `update_state()` with `add_update()`.
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index 41c5e3cccf..b04b4df257 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import saving
from tensorflow.python.keras.engine import sequential
@@ -290,7 +291,9 @@ def _in_place_subclassed_model_reset(model):
if isinstance(value, Layer):
attributes_cache[name] = value
assert value in model._layers
- elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'):
+ elif isinstance(
+ value, (list, tuple)) and name not in ('layers', '_layers',
+ 'stateful_metric_functions'):
# Handle case: list/tuple of layers (also tracked by the Network API).
if value and all(isinstance(val, Layer) for val in value):
raise ValueError('We do not support the use of list-of-layers '
@@ -466,10 +469,10 @@ def clone_and_build_model(
clone.compile(
optimizer,
model.loss,
- metrics=model.metrics,
+ metrics=metrics_module.clone_metrics(model.metrics),
loss_weights=model.loss_weights,
sample_weight_mode=model.sample_weight_mode,
- weighted_metrics=model.weighted_metrics,
+ weighted_metrics=metrics_module.clone_metrics(model.weighted_metrics),
target_tensors=target_tensors)
return clone