diff options
Diffstat (limited to 'tensorflow/python/keras/engine/training_distributed.py')
-rw-r--r-- | tensorflow/python/keras/engine/training_distributed.py | 73 |
1 files changed, 39 insertions, 34 deletions
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 75e466d593..5fa6c3c47d 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -183,9 +183,7 @@ def fit_loop( if steps_per_epoch is not None: epoch_logs = {} for step_index in range(steps_per_epoch): - batch_logs = {} - batch_logs['batch'] = step_index - batch_logs['size'] = 1 + batch_logs = {'batch': step_index, 'size': 1} callbacks.on_batch_begin(step_index, batch_logs) try: outs = distributed_train_function(ins) @@ -200,21 +198,8 @@ def fit_loop( if not isinstance(outs, list): outs = [outs] - # TODO(anjalisridhar): Temporary workaround for aggregating metrics - # across towers. Replace with the new metrics module eventually. - merged_output = [] - # The first output is the total loss. - merged_output.append(outs[0]) - current_index = 1 - num_devices = len(current_strategy._devices) - # Each label in `out_labels` corresponds to one set of metrics. The - # number of metric values corresponds to the number of devices. We - # currently take the mean of the values. - for _ in out_labels[1:]: - m = np.mean(outs[current_index:current_index + num_devices]) - merged_output.append(m) - current_index += num_devices - + outs = _aggregate_metrics_across_towers( + len(current_strategy._devices), out_labels, outs) for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(step_index, batch_logs) @@ -302,16 +287,6 @@ def test_loop(model, inputs, targets, verbose=0, steps=None): else: ins = dataset_inputs + dataset_targets - if hasattr(model, 'metrics'): - for m in model.stateful_metric_functions: - m.reset_states() - stateful_metric_indices = [ - i for i, name in enumerate(model.metrics_names) - if str(name) in model.stateful_metric_names - ] - else: - stateful_metric_indices = [] - outs = [] if verbose == 1: progbar = Progbar(target=steps) @@ -326,15 +301,14 @@ def test_loop(model, inputs, targets, verbose=0, steps=None): if steps is not None: for step in range(steps): batch_outs = distributed_test_function(ins) + batch_outs = _aggregate_metrics_across_towers( + len(current_strategy._devices), model.metrics_names, batch_outs) if isinstance(batch_outs, list): if step == 0: for _ in enumerate(batch_outs): outs.append(0.) for i, batch_out in enumerate(batch_outs): - if i in stateful_metric_indices: - outs[i] = batch_out - else: - outs[i] += batch_out + outs[i] += batch_out else: if step == 0: outs.append(0.) @@ -342,8 +316,8 @@ def test_loop(model, inputs, targets, verbose=0, steps=None): if verbose == 1: progbar.update(step + 1) for i in range(len(outs)): - if i not in stateful_metric_indices: - outs[i] /= steps + outs[i] /= steps + if len(outs) == 1: return outs[0] return outs @@ -453,3 +427,34 @@ def clone_and_build_model(model): sample_weight_mode=model.sample_weight_mode, weighted_metrics=model.weighted_metrics) return cloned_model + + +def _aggregate_metrics_across_towers(num_devices, out_labels, outs): + """Aggregate metrics values across all towers. + + When using `MirroredStrategy`, the number of towers is equal to the + number of devices over which training is distributed. This may not always be + the case. + + Args: + num_devices: Number of devices over which the model is being distributed. + out_labels: The list of metric names passed to `compile`. + outs: The output from all the towers. + + Returns: + The average value of each metric across the towers. + """ + # TODO(anjalisridhar): Temporary workaround for aggregating metrics + # across towers. Replace with the new metrics module eventually. + merged_output = [] + # The first output is the total loss. + merged_output.append(outs[0]) + current_index = 1 + # Each label in `out_labels` corresponds to one set of metrics. The + # number of metric values corresponds to the number of devices. We + # currently take the mean of the values. + for _ in out_labels[1:]: + m = np.mean(outs[current_index:current_index + num_devices]) + merged_output.append(m) + current_index += num_devices + return merged_output |