diff options
author | Pavithra Vijay <psv@google.com> | 2018-09-26 20:27:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 20:33:41 -0700 |
commit | de2bcdc7ad149419e270e1443b63581163d75d5d (patch) | |
tree | 203bbff75133c9c2ff26b2950a669365cba2ee56 /tensorflow/python/keras | |
parent | 0d5c68e30f4637329fa233df506d7b97802a5e9b (diff) |
Add Mirrored distribution strategy support for new metrics with Keras and Estimator
Add support for stateful metrics in model to estimator
PiperOrigin-RevId: 214714322
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 6 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_distributed.py | 51 | ||||
-rw-r--r-- | tensorflow/python/keras/metrics.py | 16 | ||||
-rw-r--r-- | tensorflow/python/keras/models.py | 9 |
4 files changed, 60 insertions, 22 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index ade8a4b32d..46bffd7068 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -647,12 +647,6 @@ class Model(Network): skip_target_indices=skip_target_indices, sample_weights=self.sample_weights) - # If using distribution strategy and stateful_metrics, raise an error - # since we currently don't support stateful metrics. - if self._distribution_strategy is not None and self.stateful_metric_names: - raise NotImplementedError('Stateful metrics are not supported with ' - 'DistributionStrategy.') - # Prepare gradient updates and state updates. self.total_loss = total_loss diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 8b434ca444..1b64f904d5 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -26,6 +26,7 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks as cbks from tensorflow.python.keras import optimizers from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope @@ -153,6 +154,9 @@ def fit_loop( assert steps_per_epoch is not None for epoch in range(initial_epoch, epochs): + # Reset stateful metrics + for m in model.stateful_metric_functions: + m.reset_states() callbacks.on_epoch_begin(epoch) epoch_logs = {} for step_index in range(steps_per_epoch): @@ -171,8 +175,9 @@ def fit_loop( if not isinstance(outs, list): outs = [outs] - outs = _aggregate_metrics_across_towers( - current_strategy.num_towers, out_labels, outs) + outs = _aggregate_metrics_across_towers(current_strategy.num_towers, + out_labels, + model.stateful_metric_names, outs) for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(step_index, batch_logs) @@ -437,6 +442,13 @@ def test_loop(model, iterator, verbose=0, steps=None): else: ins = dataset_inputs + dataset_targets + for m in model.stateful_metric_functions: + m.reset_states() + stateful_metric_indices = [ + i for i, name in enumerate(model.metrics_names) + if str(name) in model.stateful_metric_names + ] + outs = [] if verbose == 1: progbar = Progbar(target=steps) @@ -452,12 +464,16 @@ def test_loop(model, iterator, verbose=0, steps=None): for step in range(steps): batch_outs = distributed_test_function(ins) batch_outs = _aggregate_metrics_across_towers( - current_strategy.num_towers, model.metrics_names, batch_outs) + current_strategy.num_towers, model.metrics_names, + model.stateful_metric_names, batch_outs) if isinstance(batch_outs, list): if step == 0: outs = [0.] * len(batch_outs) for i, batch_out in enumerate(batch_outs): - outs[i] += batch_out + if i in stateful_metric_indices: + outs[i] = batch_out + else: + outs[i] += batch_out else: if step == 0: outs.append(0.) @@ -465,7 +481,8 @@ def test_loop(model, iterator, verbose=0, steps=None): if verbose >= 1: progbar.update(step + 1) for i in range(len(outs)): - outs[i] /= steps + if i not in stateful_metric_indices: + outs[i] /= steps if len(outs) == 1: return outs[0] @@ -816,10 +833,10 @@ def _clone_and_build_model(model, inputs=None, targets=None): cloned_model.compile( optimizer, model.loss, - metrics=model.metrics, + metrics=metrics_module.clone_metrics(model.metrics), loss_weights=model.loss_weights, sample_weight_mode=model.sample_weight_mode, - weighted_metrics=model.weighted_metrics, + weighted_metrics=metrics_module.clone_metrics(model.weighted_metrics), target_tensors=targets) return cloned_model @@ -834,8 +851,9 @@ def clone_model_on_towers( model._make_callback_model() -def _aggregate_metrics_across_towers(num_devices, out_labels, outs): - """Aggregate metrics values across all towers. +def _aggregate_metrics_across_towers(num_devices, out_labels, + stateful_metric_names, outs): + """Aggregates stateless metrics values across towers. When using `MirroredStrategy`, the number of towers is equal to the number of devices over which training is distributed. This may not always be @@ -844,6 +862,7 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs): Args: num_devices: Number of devices over which the model is being distributed. out_labels: The list of metric names passed to `compile`. + stateful_metric_names: List of stateful metric names on the model. outs: The output from all the towers. Returns: @@ -858,10 +877,16 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs): # Each label in `out_labels` corresponds to one set of metrics. The # number of metric values corresponds to the number of devices. We # currently take the mean of the values. - for _ in out_labels[1:]: - m = np.mean(outs[current_index:current_index + num_devices]) - merged_output.append(m) - current_index += num_devices + for metric_name in out_labels[1:]: + if metric_name in stateful_metric_names: + # For stateful metrics, we get one aggregated result value. + merged_output.append(outs[current_index]) + current_index += 1 + else: + m = np.mean(outs[current_index:current_index + num_devices]) + merged_output.append(m) + current_index += num_devices + return merged_output diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index e64241e5cf..f4e8419eb0 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -71,6 +71,22 @@ def check_is_tensor_or_operation(x, name): name, x)) +def clone_metric(metric): + """Returns a clone of the metric if stateful, otherwise returns it as is.""" + if isinstance(metric, Metric): + return metric.__class__.from_config(metric.get_config()) + return metric + + +def clone_metrics(metrics): + """Clones the given metric list/dict.""" + if metrics is None: + return None + if isinstance(metrics, dict): + return {key: clone_metric(value) for key, value in metrics.items()} + return [clone_metric(metric) for metric in metrics] + + def update_state_wrapper(update_state_fn): """Decorator to wrap metric `update_state()` with `add_update()`. diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index 41c5e3cccf..b04b4df257 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.keras import backend as K +from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import optimizers from tensorflow.python.keras.engine import saving from tensorflow.python.keras.engine import sequential @@ -290,7 +291,9 @@ def _in_place_subclassed_model_reset(model): if isinstance(value, Layer): attributes_cache[name] = value assert value in model._layers - elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'): + elif isinstance( + value, (list, tuple)) and name not in ('layers', '_layers', + 'stateful_metric_functions'): # Handle case: list/tuple of layers (also tracked by the Network API). if value and all(isinstance(val, Layer) for val in value): raise ValueError('We do not support the use of list-of-layers ' @@ -466,10 +469,10 @@ def clone_and_build_model( clone.compile( optimizer, model.loss, - metrics=model.metrics, + metrics=metrics_module.clone_metrics(model.metrics), loss_weights=model.loss_weights, sample_weight_mode=model.sample_weight_mode, - weighted_metrics=model.weighted_metrics, + weighted_metrics=metrics_module.clone_metrics(model.weighted_metrics), target_tensors=target_tensors) return clone |