aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training_distributed.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/training_distributed.py')
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py73
1 files changed, 39 insertions, 34 deletions
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 75e466d593..5fa6c3c47d 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -183,9 +183,7 @@ def fit_loop(
if steps_per_epoch is not None:
epoch_logs = {}
for step_index in range(steps_per_epoch):
- batch_logs = {}
- batch_logs['batch'] = step_index
- batch_logs['size'] = 1
+ batch_logs = {'batch': step_index, 'size': 1}
callbacks.on_batch_begin(step_index, batch_logs)
try:
outs = distributed_train_function(ins)
@@ -200,21 +198,8 @@ def fit_loop(
if not isinstance(outs, list):
outs = [outs]
- # TODO(anjalisridhar): Temporary workaround for aggregating metrics
- # across towers. Replace with the new metrics module eventually.
- merged_output = []
- # The first output is the total loss.
- merged_output.append(outs[0])
- current_index = 1
- num_devices = len(current_strategy._devices)
- # Each label in `out_labels` corresponds to one set of metrics. The
- # number of metric values corresponds to the number of devices. We
- # currently take the mean of the values.
- for _ in out_labels[1:]:
- m = np.mean(outs[current_index:current_index + num_devices])
- merged_output.append(m)
- current_index += num_devices
-
+ outs = _aggregate_metrics_across_towers(
+ len(current_strategy._devices), out_labels, outs)
for l, o in zip(out_labels, outs):
batch_logs[l] = o
callbacks.on_batch_end(step_index, batch_logs)
@@ -302,16 +287,6 @@ def test_loop(model, inputs, targets, verbose=0, steps=None):
else:
ins = dataset_inputs + dataset_targets
- if hasattr(model, 'metrics'):
- for m in model.stateful_metric_functions:
- m.reset_states()
- stateful_metric_indices = [
- i for i, name in enumerate(model.metrics_names)
- if str(name) in model.stateful_metric_names
- ]
- else:
- stateful_metric_indices = []
-
outs = []
if verbose == 1:
progbar = Progbar(target=steps)
@@ -326,15 +301,14 @@ def test_loop(model, inputs, targets, verbose=0, steps=None):
if steps is not None:
for step in range(steps):
batch_outs = distributed_test_function(ins)
+ batch_outs = _aggregate_metrics_across_towers(
+ len(current_strategy._devices), model.metrics_names, batch_outs)
if isinstance(batch_outs, list):
if step == 0:
for _ in enumerate(batch_outs):
outs.append(0.)
for i, batch_out in enumerate(batch_outs):
- if i in stateful_metric_indices:
- outs[i] = batch_out
- else:
- outs[i] += batch_out
+ outs[i] += batch_out
else:
if step == 0:
outs.append(0.)
@@ -342,8 +316,8 @@ def test_loop(model, inputs, targets, verbose=0, steps=None):
if verbose == 1:
progbar.update(step + 1)
for i in range(len(outs)):
- if i not in stateful_metric_indices:
- outs[i] /= steps
+ outs[i] /= steps
+
if len(outs) == 1:
return outs[0]
return outs
@@ -453,3 +427,34 @@ def clone_and_build_model(model):
sample_weight_mode=model.sample_weight_mode,
weighted_metrics=model.weighted_metrics)
return cloned_model
+
+
+def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
+ """Aggregate metrics values across all towers.
+
+ When using `MirroredStrategy`, the number of towers is equal to the
+ number of devices over which training is distributed. This may not always be
+ the case.
+
+ Args:
+ num_devices: Number of devices over which the model is being distributed.
+ out_labels: The list of metric names passed to `compile`.
+ outs: The output from all the towers.
+
+ Returns:
+ The average value of each metric across the towers.
+ """
+ # TODO(anjalisridhar): Temporary workaround for aggregating metrics
+ # across towers. Replace with the new metrics module eventually.
+ merged_output = []
+ # The first output is the total loss.
+ merged_output.append(outs[0])
+ current_index = 1
+ # Each label in `out_labels` corresponds to one set of metrics. The
+ # number of metric values corresponds to the number of devices. We
+ # currently take the mean of the values.
+ for _ in out_labels[1:]:
+ m = np.mean(outs[current_index:current_index + num_devices])
+ merged_output.append(m)
+ current_index += num_devices
+ return merged_output