aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-08-06 16:07:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 16:25:08 -0700
commitb8b5866d82ce7adbb34acccb8e6392fb8a130886 (patch)
tree06efa0b1f570a8029f28266da1b6b3bf09a21e70
parent9b4aa35e068974fbdc4822841ac04c0f44686610 (diff)
Raise an error when using stateful metrics with DistributionStrategy.
PiperOrigin-RevId: 207625731
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py35
-rw-r--r--tensorflow/python/keras/engine/training.py41
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py73
3 files changed, 94 insertions, 55 deletions
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index fbdb376fcc..600ccc1425 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -246,6 +246,32 @@ class TestWithDistributionStrategy(test.TestCase):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
validation_data=dataset, validation_steps=2)
+ def test_raise_error_for_stateful_metrics(self):
+
+ class ExampleStatefulMetric(keras.layers.Layer):
+
+ def __init__(self, name='true_positives', **kwargs):
+ super(ExampleStatefulMetric, self).__init__(name=name, **kwargs)
+ self.stateful = True
+
+ def __call__(self, y_true, y_pred):
+ return y_pred - y_true
+
+ with self.test_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.001)
+ loss = 'mse'
+ metrics = ['mae', ExampleStatefulMetric()]
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
+ '/device:GPU:0'])
+ with self.assertRaisesRegexp(
+ NotImplementedError, 'Stateful metrics are not supported with '
+ 'DistributionStrategy.'):
+ model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+
def test_unsupported_features(self):
with self.test_session():
x = keras.layers.Input(shape=(3,), name='input')
@@ -268,8 +294,9 @@ class TestWithDistributionStrategy(test.TestCase):
# Test with validation split
with self.assertRaisesRegexp(
- ValueError, '`validation_split` argument is not supported '
- 'when input `x` is a dataset or a dataset iterator'):
+ ValueError, '`validation_split` argument is not '
+ 'supported when input `x` is a dataset or a '
+ 'dataset iterator.+'):
model.fit(dataset,
epochs=1, steps_per_epoch=2, verbose=0,
validation_split=0.5, validation_steps=2)
@@ -277,8 +304,8 @@ class TestWithDistributionStrategy(test.TestCase):
# Test with sample weight.
sample_weight = np.random.random((10,))
with self.assertRaisesRegexp(
- ValueError, 'sample_weight is currently not supported when using '
- 'DistributionStrategy.'):
+ NotImplementedError, 'sample_weight is currently not supported when '
+ 'using DistributionStrategy.'):
model.fit(
dataset,
epochs=1,
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 3db0b4c8ad..2cdd00a48d 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -201,16 +201,17 @@ class Model(Network):
# DistributionStrategy.
if distribute and not isinstance(
optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
- raise ValueError('Only TF native optimizers are supported with '
- 'DistributionStrategy.')
+ raise NotImplementedError('Only TF native optimizers are supported with '
+ 'DistributionStrategy.')
if distribute and context.executing_eagerly():
- raise ValueError('DistributionStrategy is not supported in Eager mode.')
+ raise NotImplementedError('DistributionStrategy is not supported in '
+ 'Eager mode.')
if distribute and sample_weight_mode:
- raise ValueError('sample_weight_mode is not supported with '
- 'DistributionStrategy.')
+ raise NotImplementedError('sample_weight_mode is not supported with '
+ 'DistributionStrategy.')
if distribute and weighted_metrics:
- raise ValueError('weighted_metrics is not supported with '
- 'DistributionStrategy.')
+ raise NotImplementedError('weighted_metrics is not supported with '
+ 'DistributionStrategy.')
if distribute and target_tensors:
raise ValueError('target_tensors is not supported with '
'DistributionStrategy.')
@@ -245,6 +246,12 @@ class Model(Network):
with self._distribution_strategy.scope():
first_replicated_model = self._distribution_strategy.unwrap(
self._grouped_model)[0]
+ # If the specified metrics in `compile` are stateful, raise an error
+ # since we currently don't support stateful metrics.
+ if first_replicated_model.stateful_metric_names:
+ raise NotImplementedError('Stateful metrics are not supported with '
+ 'DistributionStrategy.')
+
# We initialize the callback model with the first replicated model.
self._replicated_model = DistributedCallbackModel(first_replicated_model)
self._replicated_model.set_original_model(self)
@@ -665,11 +672,11 @@ class Model(Network):
RuntimeError: If the model was never compiled.
"""
if sample_weight is not None and sample_weight.all():
- raise ValueError('sample_weight is currently not supported when using '
- 'DistributionStrategy.')
+ raise NotImplementedError('sample_weight is currently not supported when '
+ 'using DistributionStrategy.')
if class_weight:
- raise ValueError('class_weight is currently not supported when using '
- 'DistributionStrategy.')
+ raise NotImplementedError('class_weight is currently not supported when '
+ 'using DistributionStrategy.')
# TODO(anjalisridhar): Can we use the iterator and getnext op cache?
# We require users to pass Datasets since we distribute the dataset across
@@ -1653,8 +1660,8 @@ class Model(Network):
ValueError: In case of invalid user-provided arguments.
"""
if self._distribution_strategy:
- raise ValueError('`train_on_batch` is not supported for models '
- 'compiled with DistributionStrategy.')
+ raise NotImplementedError('`train_on_batch` is not supported for models '
+ 'compiled with DistributionStrategy.')
# Validate and standardize user data.
x, y, sample_weights = self._standardize_user_data(
x, y, sample_weight=sample_weight, class_weight=class_weight)
@@ -1712,8 +1719,8 @@ class Model(Network):
ValueError: In case of invalid user-provided arguments.
"""
if self._distribution_strategy:
- raise ValueError('`test_on_batch` is not supported for models '
- 'compiled with DistributionStrategy.')
+ raise NotImplementedError('`test_on_batch` is not supported for models '
+ 'compiled with DistributionStrategy.')
# Validate and standardize user data.
x, y, sample_weights = self._standardize_user_data(
x, y, sample_weight=sample_weight)
@@ -1752,8 +1759,8 @@ class Model(Network):
expectations of the model.
"""
if self._distribution_strategy:
- raise ValueError('`predict_on_batch` is not supported for models '
- 'compiled with DistributionStrategy.')
+ raise NotImplementedError('`predict_on_batch` is not supported for '
+ 'models compiled with DistributionStrategy.')
# Validate and standardize user data.
inputs, _, _ = self._standardize_user_data(x)
if context.executing_eagerly():
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 75e466d593..5fa6c3c47d 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -183,9 +183,7 @@ def fit_loop(
if steps_per_epoch is not None:
epoch_logs = {}
for step_index in range(steps_per_epoch):
- batch_logs = {}
- batch_logs['batch'] = step_index
- batch_logs['size'] = 1
+ batch_logs = {'batch': step_index, 'size': 1}
callbacks.on_batch_begin(step_index, batch_logs)
try:
outs = distributed_train_function(ins)
@@ -200,21 +198,8 @@ def fit_loop(
if not isinstance(outs, list):
outs = [outs]
- # TODO(anjalisridhar): Temporary workaround for aggregating metrics
- # across towers. Replace with the new metrics module eventually.
- merged_output = []
- # The first output is the total loss.
- merged_output.append(outs[0])
- current_index = 1
- num_devices = len(current_strategy._devices)
- # Each label in `out_labels` corresponds to one set of metrics. The
- # number of metric values corresponds to the number of devices. We
- # currently take the mean of the values.
- for _ in out_labels[1:]:
- m = np.mean(outs[current_index:current_index + num_devices])
- merged_output.append(m)
- current_index += num_devices
-
+ outs = _aggregate_metrics_across_towers(
+ len(current_strategy._devices), out_labels, outs)
for l, o in zip(out_labels, outs):
batch_logs[l] = o
callbacks.on_batch_end(step_index, batch_logs)
@@ -302,16 +287,6 @@ def test_loop(model, inputs, targets, verbose=0, steps=None):
else:
ins = dataset_inputs + dataset_targets
- if hasattr(model, 'metrics'):
- for m in model.stateful_metric_functions:
- m.reset_states()
- stateful_metric_indices = [
- i for i, name in enumerate(model.metrics_names)
- if str(name) in model.stateful_metric_names
- ]
- else:
- stateful_metric_indices = []
-
outs = []
if verbose == 1:
progbar = Progbar(target=steps)
@@ -326,15 +301,14 @@ def test_loop(model, inputs, targets, verbose=0, steps=None):
if steps is not None:
for step in range(steps):
batch_outs = distributed_test_function(ins)
+ batch_outs = _aggregate_metrics_across_towers(
+ len(current_strategy._devices), model.metrics_names, batch_outs)
if isinstance(batch_outs, list):
if step == 0:
for _ in enumerate(batch_outs):
outs.append(0.)
for i, batch_out in enumerate(batch_outs):
- if i in stateful_metric_indices:
- outs[i] = batch_out
- else:
- outs[i] += batch_out
+ outs[i] += batch_out
else:
if step == 0:
outs.append(0.)
@@ -342,8 +316,8 @@ def test_loop(model, inputs, targets, verbose=0, steps=None):
if verbose == 1:
progbar.update(step + 1)
for i in range(len(outs)):
- if i not in stateful_metric_indices:
- outs[i] /= steps
+ outs[i] /= steps
+
if len(outs) == 1:
return outs[0]
return outs
@@ -453,3 +427,34 @@ def clone_and_build_model(model):
sample_weight_mode=model.sample_weight_mode,
weighted_metrics=model.weighted_metrics)
return cloned_model
+
+
+def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
+ """Aggregate metrics values across all towers.
+
+ When using `MirroredStrategy`, the number of towers is equal to the
+ number of devices over which training is distributed. This may not always be
+ the case.
+
+ Args:
+ num_devices: Number of devices over which the model is being distributed.
+ out_labels: The list of metric names passed to `compile`.
+ outs: The output from all the towers.
+
+ Returns:
+ The average value of each metric across the towers.
+ """
+ # TODO(anjalisridhar): Temporary workaround for aggregating metrics
+ # across towers. Replace with the new metrics module eventually.
+ merged_output = []
+ # The first output is the total loss.
+ merged_output.append(outs[0])
+ current_index = 1
+ # Each label in `out_labels` corresponds to one set of metrics. The
+ # number of metric values corresponds to the number of devices. We
+ # currently take the mean of the values.
+ for _ in out_labels[1:]:
+ m = np.mean(outs[current_index:current_index + num_devices])
+ merged_output.append(m)
+ current_index += num_devices
+ return merged_output