diff options
author | 2018-08-06 16:07:37 -0700 | |
---|---|---|
committer | 2018-08-06 16:25:08 -0700 | |
commit | b8b5866d82ce7adbb34acccb8e6392fb8a130886 (patch) | |
tree | 06efa0b1f570a8029f28266da1b6b3bf09a21e70 | |
parent | 9b4aa35e068974fbdc4822841ac04c0f44686610 (diff) |
Raise an error when using stateful metrics with DistributionStrategy.
PiperOrigin-RevId: 207625731
-rw-r--r-- | tensorflow/contrib/distribute/python/keras_test.py | 35 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 41 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_distributed.py | 73 |
3 files changed, 94 insertions, 55 deletions
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index fbdb376fcc..600ccc1425 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -246,6 +246,32 @@ class TestWithDistributionStrategy(test.TestCase): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, validation_data=dataset, validation_steps=2) + def test_raise_error_for_stateful_metrics(self): + + class ExampleStatefulMetric(keras.layers.Layer): + + def __init__(self, name='true_positives', **kwargs): + super(ExampleStatefulMetric, self).__init__(name=name, **kwargs) + self.stateful = True + + def __call__(self, y_true, y_pred): + return y_pred - y_true + + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', ExampleStatefulMetric()] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + with self.assertRaisesRegexp( + NotImplementedError, 'Stateful metrics are not supported with ' + 'DistributionStrategy.'): + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + def test_unsupported_features(self): with self.test_session(): x = keras.layers.Input(shape=(3,), name='input') @@ -268,8 +294,9 @@ class TestWithDistributionStrategy(test.TestCase): # Test with validation split with self.assertRaisesRegexp( - ValueError, '`validation_split` argument is not supported ' - 'when input `x` is a dataset or a dataset iterator'): + ValueError, '`validation_split` argument is not ' + 'supported when input `x` is a dataset or a ' + 'dataset iterator.+'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, validation_split=0.5, validation_steps=2) @@ -277,8 +304,8 @@ class TestWithDistributionStrategy(test.TestCase): # Test with sample weight. sample_weight = np.random.random((10,)) with self.assertRaisesRegexp( - ValueError, 'sample_weight is currently not supported when using ' - 'DistributionStrategy.'): + NotImplementedError, 'sample_weight is currently not supported when ' + 'using DistributionStrategy.'): model.fit( dataset, epochs=1, diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 3db0b4c8ad..2cdd00a48d 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -201,16 +201,17 @@ class Model(Network): # DistributionStrategy. if distribute and not isinstance( optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)): - raise ValueError('Only TF native optimizers are supported with ' - 'DistributionStrategy.') + raise NotImplementedError('Only TF native optimizers are supported with ' + 'DistributionStrategy.') if distribute and context.executing_eagerly(): - raise ValueError('DistributionStrategy is not supported in Eager mode.') + raise NotImplementedError('DistributionStrategy is not supported in ' + 'Eager mode.') if distribute and sample_weight_mode: - raise ValueError('sample_weight_mode is not supported with ' - 'DistributionStrategy.') + raise NotImplementedError('sample_weight_mode is not supported with ' + 'DistributionStrategy.') if distribute and weighted_metrics: - raise ValueError('weighted_metrics is not supported with ' - 'DistributionStrategy.') + raise NotImplementedError('weighted_metrics is not supported with ' + 'DistributionStrategy.') if distribute and target_tensors: raise ValueError('target_tensors is not supported with ' 'DistributionStrategy.') @@ -245,6 +246,12 @@ class Model(Network): with self._distribution_strategy.scope(): first_replicated_model = self._distribution_strategy.unwrap( self._grouped_model)[0] + # If the specified metrics in `compile` are stateful, raise an error + # since we currently don't support stateful metrics. + if first_replicated_model.stateful_metric_names: + raise NotImplementedError('Stateful metrics are not supported with ' + 'DistributionStrategy.') + # We initialize the callback model with the first replicated model. self._replicated_model = DistributedCallbackModel(first_replicated_model) self._replicated_model.set_original_model(self) @@ -665,11 +672,11 @@ class Model(Network): RuntimeError: If the model was never compiled. """ if sample_weight is not None and sample_weight.all(): - raise ValueError('sample_weight is currently not supported when using ' - 'DistributionStrategy.') + raise NotImplementedError('sample_weight is currently not supported when ' + 'using DistributionStrategy.') if class_weight: - raise ValueError('class_weight is currently not supported when using ' - 'DistributionStrategy.') + raise NotImplementedError('class_weight is currently not supported when ' + 'using DistributionStrategy.') # TODO(anjalisridhar): Can we use the iterator and getnext op cache? # We require users to pass Datasets since we distribute the dataset across @@ -1653,8 +1660,8 @@ class Model(Network): ValueError: In case of invalid user-provided arguments. """ if self._distribution_strategy: - raise ValueError('`train_on_batch` is not supported for models ' - 'compiled with DistributionStrategy.') + raise NotImplementedError('`train_on_batch` is not supported for models ' + 'compiled with DistributionStrategy.') # Validate and standardize user data. x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight, class_weight=class_weight) @@ -1712,8 +1719,8 @@ class Model(Network): ValueError: In case of invalid user-provided arguments. """ if self._distribution_strategy: - raise ValueError('`test_on_batch` is not supported for models ' - 'compiled with DistributionStrategy.') + raise NotImplementedError('`test_on_batch` is not supported for models ' + 'compiled with DistributionStrategy.') # Validate and standardize user data. x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight) @@ -1752,8 +1759,8 @@ class Model(Network): expectations of the model. """ if self._distribution_strategy: - raise ValueError('`predict_on_batch` is not supported for models ' - 'compiled with DistributionStrategy.') + raise NotImplementedError('`predict_on_batch` is not supported for ' + 'models compiled with DistributionStrategy.') # Validate and standardize user data. inputs, _, _ = self._standardize_user_data(x) if context.executing_eagerly(): diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 75e466d593..5fa6c3c47d 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -183,9 +183,7 @@ def fit_loop( if steps_per_epoch is not None: epoch_logs = {} for step_index in range(steps_per_epoch): - batch_logs = {} - batch_logs['batch'] = step_index - batch_logs['size'] = 1 + batch_logs = {'batch': step_index, 'size': 1} callbacks.on_batch_begin(step_index, batch_logs) try: outs = distributed_train_function(ins) @@ -200,21 +198,8 @@ def fit_loop( if not isinstance(outs, list): outs = [outs] - # TODO(anjalisridhar): Temporary workaround for aggregating metrics - # across towers. Replace with the new metrics module eventually. - merged_output = [] - # The first output is the total loss. - merged_output.append(outs[0]) - current_index = 1 - num_devices = len(current_strategy._devices) - # Each label in `out_labels` corresponds to one set of metrics. The - # number of metric values corresponds to the number of devices. We - # currently take the mean of the values. - for _ in out_labels[1:]: - m = np.mean(outs[current_index:current_index + num_devices]) - merged_output.append(m) - current_index += num_devices - + outs = _aggregate_metrics_across_towers( + len(current_strategy._devices), out_labels, outs) for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(step_index, batch_logs) @@ -302,16 +287,6 @@ def test_loop(model, inputs, targets, verbose=0, steps=None): else: ins = dataset_inputs + dataset_targets - if hasattr(model, 'metrics'): - for m in model.stateful_metric_functions: - m.reset_states() - stateful_metric_indices = [ - i for i, name in enumerate(model.metrics_names) - if str(name) in model.stateful_metric_names - ] - else: - stateful_metric_indices = [] - outs = [] if verbose == 1: progbar = Progbar(target=steps) @@ -326,15 +301,14 @@ def test_loop(model, inputs, targets, verbose=0, steps=None): if steps is not None: for step in range(steps): batch_outs = distributed_test_function(ins) + batch_outs = _aggregate_metrics_across_towers( + len(current_strategy._devices), model.metrics_names, batch_outs) if isinstance(batch_outs, list): if step == 0: for _ in enumerate(batch_outs): outs.append(0.) for i, batch_out in enumerate(batch_outs): - if i in stateful_metric_indices: - outs[i] = batch_out - else: - outs[i] += batch_out + outs[i] += batch_out else: if step == 0: outs.append(0.) @@ -342,8 +316,8 @@ def test_loop(model, inputs, targets, verbose=0, steps=None): if verbose == 1: progbar.update(step + 1) for i in range(len(outs)): - if i not in stateful_metric_indices: - outs[i] /= steps + outs[i] /= steps + if len(outs) == 1: return outs[0] return outs @@ -453,3 +427,34 @@ def clone_and_build_model(model): sample_weight_mode=model.sample_weight_mode, weighted_metrics=model.weighted_metrics) return cloned_model + + +def _aggregate_metrics_across_towers(num_devices, out_labels, outs): + """Aggregate metrics values across all towers. + + When using `MirroredStrategy`, the number of towers is equal to the + number of devices over which training is distributed. This may not always be + the case. + + Args: + num_devices: Number of devices over which the model is being distributed. + out_labels: The list of metric names passed to `compile`. + outs: The output from all the towers. + + Returns: + The average value of each metric across the towers. + """ + # TODO(anjalisridhar): Temporary workaround for aggregating metrics + # across towers. Replace with the new metrics module eventually. + merged_output = [] + # The first output is the total loss. + merged_output.append(outs[0]) + current_index = 1 + # Each label in `out_labels` corresponds to one set of metrics. The + # number of metric values corresponds to the number of devices. We + # currently take the mean of the values. + for _ in out_labels[1:]: + m = np.mean(outs[current_index:current_index + num_devices]) + merged_output.append(m) + current_index += num_devices + return merged_output |