aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-09-27 22:57:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 23:03:09 -0700
commit6ebe9baae06c06d0a70a424a55c78f5af07b49f7 (patch)
tree191f4e58bf50302b782f45e92b14d3369f6d42ea /tensorflow/python/keras
parentd56c298f1ef14b5a738e1e0b7bbc66fcd736be3e (diff)
Fix error that occurs when attempting to use TensorFlow optimizers with Keras and DistributionStrategy
PiperOrigin-RevId: 214890580
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/engine/training.py3
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py341
2 files changed, 171 insertions, 173 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 46bffd7068..5091cac836 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -851,7 +851,8 @@ class Model(Network):
# able to clone a Dataset on multiple workers we can remove this lambda.
result = self._distribution_strategy.distribute_dataset(lambda: x)
iterator = result.make_initializable_iterator()
- K.get_session().run(iterator.initializer)
+ with self._distribution_strategy.scope():
+ K.get_session().run(iterator.initializer)
training_utils.validate_iterator_input(x, y, sample_weight,
validation_split)
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 1b64f904d5..a6470458d2 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -112,100 +112,99 @@ def fit_loop(
dataset_targets = distributed_training_utils.flatten_perdevice_values(
current_strategy, targets)
- # Create a train function that is composed of all the parameters above.
- distributed_train_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_train_function',
- **all_session_args)
-
- # We need to set sample_weights to None since there are sample weight
- # placeholders that are created with default values.
- sample_weights = [None for _ in range(len(model.outputs) *
- current_strategy.num_towers)]
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + dataset_targets + sample_weights + [1]
- else:
- ins = dataset_inputs + dataset_targets
+ # Create a train function that is composed of all the parameters above.
+ distributed_train_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_train_function',
+ **all_session_args)
+
+ # We need to set sample_weights to None since there are sample weight
+ # placeholders that are created with default values.
+ sample_weights = [None for _ in range(len(model.outputs) *
+ current_strategy.num_towers)]
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + dataset_targets + sample_weights + [1]
+ else:
+ ins = dataset_inputs + dataset_targets
- do_validation = False
- if validation_steps:
- do_validation = True
+ do_validation = False
+ if validation_steps:
+ do_validation = True
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- callbacks = cbks.configure_callbacks(
- callbacks,
- model,
- do_validation=do_validation,
- val_inputs=None,
- val_targets=None,
- epochs=epochs,
- steps_per_epoch=steps_per_epoch,
- verbose=verbose)
- out_labels = model.metrics_names or []
- callbacks.on_train_begin()
-
- assert steps_per_epoch is not None
-
- for epoch in range(initial_epoch, epochs):
- # Reset stateful metrics
- for m in model.stateful_metric_functions:
- m.reset_states()
- callbacks.on_epoch_begin(epoch)
- epoch_logs = {}
- for step_index in range(steps_per_epoch):
- batch_logs = {'batch': step_index, 'size': 1}
- callbacks.on_batch_begin(step_index, batch_logs)
- try:
- outs = distributed_train_function(ins)
- except errors.OutOfRangeError:
- logging.warning('Your dataset iterator ran out of data; '
- 'interrupting training. Make sure that your dataset '
- 'can generate at least `steps_per_epoch * epochs` '
- 'batches (in this case, %d batches).' %
- steps_per_epoch * epochs)
- break
-
- if not isinstance(outs, list):
- outs = [outs]
-
- outs = _aggregate_metrics_across_towers(current_strategy.num_towers,
- out_labels,
- model.stateful_metric_names, outs)
- for l, o in zip(out_labels, outs):
- batch_logs[l] = o
- callbacks.on_batch_end(step_index, batch_logs)
+ callbacks = cbks.configure_callbacks(
+ callbacks,
+ model,
+ do_validation=do_validation,
+ val_inputs=None,
+ val_targets=None,
+ epochs=epochs,
+ steps_per_epoch=steps_per_epoch,
+ verbose=verbose)
+ out_labels = model.metrics_names or []
+ callbacks.on_train_begin()
+
+ assert steps_per_epoch is not None
+
+ for epoch in range(initial_epoch, epochs):
+ # Reset stateful metrics
+ for m in model.stateful_metric_functions:
+ m.reset_states()
+ callbacks.on_epoch_begin(epoch)
+ epoch_logs = {}
+ for step_index in range(steps_per_epoch):
+ batch_logs = {'batch': step_index, 'size': 1}
+ callbacks.on_batch_begin(step_index, batch_logs)
+ try:
+ outs = distributed_train_function(ins)
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your dataset '
+ 'can generate at least `steps_per_epoch * epochs` '
+ 'batches (in this case, %d batches).' %
+ steps_per_epoch * epochs)
+ break
+
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ outs = _aggregate_metrics_across_towers(current_strategy.num_towers,
+ out_labels,
+ model.stateful_metric_names,
+ outs)
+ for l, o in zip(out_labels, outs):
+ batch_logs[l] = o
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callbacks.model.stop_training:
+ break
+ if do_validation:
+ val_outs = test_loop(
+ model,
+ val_iterator,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(out_labels, val_outs):
+ epoch_logs['val_' + l] = o
+
+ callbacks.on_epoch_end(epoch, epoch_logs)
if callbacks.model.stop_training:
break
- if do_validation:
- val_outs = test_loop(
- model,
- val_iterator,
- steps=validation_steps,
- verbose=0)
- if not isinstance(val_outs, list):
- val_outs = [val_outs]
- # Same labels assumed.
- for l, o in zip(out_labels, val_outs):
- epoch_logs['val_' + l] = o
+ callbacks.on_train_end()
- callbacks.on_epoch_end(epoch, epoch_logs)
- if callbacks.model.stop_training:
- break
- callbacks.on_train_end()
-
- # Copy the weights back from the replicated model to the original model.
- with current_strategy.scope():
+ # Copy the weights back from the replicated model to the original model.
updated_weights = current_strategy.unwrap(
model._grouped_model)[0].get_weights()
model.set_weights(updated_weights)
- return model.history
+ return model.history
def _experimental_fit_loop(
@@ -427,66 +426,65 @@ def test_loop(model, iterator, verbose=0, steps=None):
dataset_targets = distributed_training_utils.flatten_perdevice_values(
current_strategy, targets)
- distributed_test_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_test_function',
- **all_session_args)
-
- # We need to set sample_weights to None since there are sample weight
- # placeholders that are created with default values.
- sample_weights = [None for _ in range(len(model.outputs) *
- current_strategy.num_towers)]
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + dataset_targets + sample_weights + [0]
- else:
- ins = dataset_inputs + dataset_targets
+ distributed_test_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_test_function',
+ **all_session_args)
- for m in model.stateful_metric_functions:
- m.reset_states()
- stateful_metric_indices = [
- i for i, name in enumerate(model.metrics_names)
- if str(name) in model.stateful_metric_names
- ]
+ # We need to set sample_weights to None since there are sample weight
+ # placeholders that are created with default values.
+ sample_weights = [None for _ in range(len(model.outputs) *
+ current_strategy.num_towers)]
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + dataset_targets + sample_weights + [0]
+ else:
+ ins = dataset_inputs + dataset_targets
- outs = []
- if verbose == 1:
- progbar = Progbar(target=steps)
+ for m in model.stateful_metric_functions:
+ m.reset_states()
+ stateful_metric_indices = [
+ i for i, name in enumerate(model.metrics_names)
+ if str(name) in model.stateful_metric_names
+ ]
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ outs = []
+ if verbose == 1:
+ progbar = Progbar(target=steps)
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- assert steps is not None
- for step in range(steps):
- batch_outs = distributed_test_function(ins)
- batch_outs = _aggregate_metrics_across_towers(
- current_strategy.num_towers, model.metrics_names,
- model.stateful_metric_names, batch_outs)
- if isinstance(batch_outs, list):
- if step == 0:
- outs = [0.] * len(batch_outs)
- for i, batch_out in enumerate(batch_outs):
- if i in stateful_metric_indices:
- outs[i] = batch_out
- else:
- outs[i] += batch_out
- else:
- if step == 0:
- outs.append(0.)
- outs[0] += batch_outs
- if verbose >= 1:
- progbar.update(step + 1)
- for i in range(len(outs)):
- if i not in stateful_metric_indices:
- outs[i] /= steps
+ assert steps is not None
+ for step in range(steps):
+ batch_outs = distributed_test_function(ins)
+ batch_outs = _aggregate_metrics_across_towers(
+ current_strategy.num_towers, model.metrics_names,
+ model.stateful_metric_names, batch_outs)
+ if isinstance(batch_outs, list):
+ if step == 0:
+ outs = [0.] * len(batch_outs)
+ for i, batch_out in enumerate(batch_outs):
+ if i in stateful_metric_indices:
+ outs[i] = batch_out
+ else:
+ outs[i] += batch_out
+ else:
+ if step == 0:
+ outs.append(0.)
+ outs[0] += batch_outs
+ if verbose >= 1:
+ progbar.update(step + 1)
+ for i in range(len(outs)):
+ if i not in stateful_metric_indices:
+ outs[i] /= steps
- if len(outs) == 1:
- return outs[0]
- return outs
+ if len(outs) == 1:
+ return outs[0]
+ return outs
def _experimental_test_loop(model, iterator, verbose=0, steps=None):
@@ -647,51 +645,50 @@ def predict_loop(model, iterator, verbose=0, steps=None):
dataset_inputs = distributed_training_utils.flatten_perdevice_values(
current_strategy, inputs)
- distributed_predict_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_predict_function',
- **all_session_args)
+ distributed_predict_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_predict_function',
+ **all_session_args)
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + [0]
- else:
- ins = dataset_inputs
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + [0]
+ else:
+ ins = dataset_inputs
- if verbose == 1:
- progbar = Progbar(target=steps)
+ if verbose == 1:
+ progbar = Progbar(target=steps)
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- if steps is not None:
- # Since we do not know how many samples we will see, we cannot pre-allocate
- # the returned Numpy arrays. Instead, we store one array per batch seen
- # and concatenate them upon returning.
- unconcatenated_outs = []
- for step in range(steps):
- batch_outs = distributed_predict_function(ins)
- if not isinstance(batch_outs, list):
- batch_outs = [batch_outs]
- if step == 0:
- for _ in batch_outs:
- unconcatenated_outs.append([])
- # TODO(anjalisridhar): Should combine the outputs from multiple towers
- # correctly here.
- for i, batch_out in enumerate(batch_outs):
- unconcatenated_outs[i].append(batch_out)
- if verbose >= 1:
- progbar.update(step + 1)
- if len(unconcatenated_outs) == 1:
- return np.concatenate(unconcatenated_outs[0], axis=0)
- return [
- np.concatenate(unconcatenated_outs[i], axis=0)
- for i in range(len(unconcatenated_outs))
- ]
+ if steps is not None:
+ # Since we do not know how many samples we will see, we cannot
+ # pre-allocate the returned Numpy arrays. Instead, we store one array per
+ # batch seen and concatenate them upon returning.
+ unconcatenated_outs = []
+ for step in range(steps):
+ batch_outs = distributed_predict_function(ins)
+ if not isinstance(batch_outs, list):
+ batch_outs = [batch_outs]
+ if step == 0:
+ for _ in batch_outs:
+ unconcatenated_outs.append([])
+ # TODO(anjalisridhar): Should combine the outputs from multiple towers
+ # correctly here.
+ for i, batch_out in enumerate(batch_outs):
+ unconcatenated_outs[i].append(batch_out)
+ if verbose >= 1:
+ progbar.update(step + 1)
+ if len(unconcatenated_outs) == 1:
+ return np.concatenate(unconcatenated_outs[0], axis=0)
+ return [
+ np.concatenate(unconcatenated_outs[i], axis=0)
+ for i in range(len(unconcatenated_outs))
+ ]
def _experimental_predict_loop(model, iterator, verbose=0, steps=None):