diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-09-27 22:57:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 23:03:09 -0700 |
commit | 6ebe9baae06c06d0a70a424a55c78f5af07b49f7 (patch) | |
tree | 191f4e58bf50302b782f45e92b14d3369f6d42ea /tensorflow/python/keras | |
parent | d56c298f1ef14b5a738e1e0b7bbc66fcd736be3e (diff) |
Fix error that occurs when attempting to use TensorFlow optimizers with Keras and DistributionStrategy
PiperOrigin-RevId: 214890580
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 3 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_distributed.py | 341 |
2 files changed, 171 insertions, 173 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 46bffd7068..5091cac836 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -851,7 +851,8 @@ class Model(Network): # able to clone a Dataset on multiple workers we can remove this lambda. result = self._distribution_strategy.distribute_dataset(lambda: x) iterator = result.make_initializable_iterator() - K.get_session().run(iterator.initializer) + with self._distribution_strategy.scope(): + K.get_session().run(iterator.initializer) training_utils.validate_iterator_input(x, y, sample_weight, validation_split) diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 1b64f904d5..a6470458d2 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -112,100 +112,99 @@ def fit_loop( dataset_targets = distributed_training_utils.flatten_perdevice_values( current_strategy, targets) - # Create a train function that is composed of all the parameters above. - distributed_train_function = K.Function( - all_inputs, all_outputs, - updates=all_updates, - name='distributed_train_function', - **all_session_args) - - # We need to set sample_weights to None since there are sample weight - # placeholders that are created with default values. - sample_weights = [None for _ in range(len(model.outputs) * - current_strategy.num_towers)] - if model.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = dataset_inputs + dataset_targets + sample_weights + [1] - else: - ins = dataset_inputs + dataset_targets + # Create a train function that is composed of all the parameters above. + distributed_train_function = K.Function( + all_inputs, all_outputs, + updates=all_updates, + name='distributed_train_function', + **all_session_args) + + # We need to set sample_weights to None since there are sample weight + # placeholders that are created with default values. + sample_weights = [None for _ in range(len(model.outputs) * + current_strategy.num_towers)] + if model.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = dataset_inputs + dataset_targets + sample_weights + [1] + else: + ins = dataset_inputs + dataset_targets - do_validation = False - if validation_steps: - do_validation = True + do_validation = False + if validation_steps: + do_validation = True - # Copy the weights from the original model to each of the replicated models. - orig_model_weights = model.get_weights() - with current_strategy.scope(): + # Copy the weights from the original model to each of the replicated models. + orig_model_weights = model.get_weights() distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights( current_strategy, distributed_model, orig_model_weights) - callbacks = cbks.configure_callbacks( - callbacks, - model, - do_validation=do_validation, - val_inputs=None, - val_targets=None, - epochs=epochs, - steps_per_epoch=steps_per_epoch, - verbose=verbose) - out_labels = model.metrics_names or [] - callbacks.on_train_begin() - - assert steps_per_epoch is not None - - for epoch in range(initial_epoch, epochs): - # Reset stateful metrics - for m in model.stateful_metric_functions: - m.reset_states() - callbacks.on_epoch_begin(epoch) - epoch_logs = {} - for step_index in range(steps_per_epoch): - batch_logs = {'batch': step_index, 'size': 1} - callbacks.on_batch_begin(step_index, batch_logs) - try: - outs = distributed_train_function(ins) - except errors.OutOfRangeError: - logging.warning('Your dataset iterator ran out of data; ' - 'interrupting training. Make sure that your dataset ' - 'can generate at least `steps_per_epoch * epochs` ' - 'batches (in this case, %d batches).' % - steps_per_epoch * epochs) - break - - if not isinstance(outs, list): - outs = [outs] - - outs = _aggregate_metrics_across_towers(current_strategy.num_towers, - out_labels, - model.stateful_metric_names, outs) - for l, o in zip(out_labels, outs): - batch_logs[l] = o - callbacks.on_batch_end(step_index, batch_logs) + callbacks = cbks.configure_callbacks( + callbacks, + model, + do_validation=do_validation, + val_inputs=None, + val_targets=None, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + verbose=verbose) + out_labels = model.metrics_names or [] + callbacks.on_train_begin() + + assert steps_per_epoch is not None + + for epoch in range(initial_epoch, epochs): + # Reset stateful metrics + for m in model.stateful_metric_functions: + m.reset_states() + callbacks.on_epoch_begin(epoch) + epoch_logs = {} + for step_index in range(steps_per_epoch): + batch_logs = {'batch': step_index, 'size': 1} + callbacks.on_batch_begin(step_index, batch_logs) + try: + outs = distributed_train_function(ins) + except errors.OutOfRangeError: + logging.warning('Your dataset iterator ran out of data; ' + 'interrupting training. Make sure that your dataset ' + 'can generate at least `steps_per_epoch * epochs` ' + 'batches (in this case, %d batches).' % + steps_per_epoch * epochs) + break + + if not isinstance(outs, list): + outs = [outs] + + outs = _aggregate_metrics_across_towers(current_strategy.num_towers, + out_labels, + model.stateful_metric_names, + outs) + for l, o in zip(out_labels, outs): + batch_logs[l] = o + callbacks.on_batch_end(step_index, batch_logs) + if callbacks.model.stop_training: + break + if do_validation: + val_outs = test_loop( + model, + val_iterator, + steps=validation_steps, + verbose=0) + if not isinstance(val_outs, list): + val_outs = [val_outs] + # Same labels assumed. + for l, o in zip(out_labels, val_outs): + epoch_logs['val_' + l] = o + + callbacks.on_epoch_end(epoch, epoch_logs) if callbacks.model.stop_training: break - if do_validation: - val_outs = test_loop( - model, - val_iterator, - steps=validation_steps, - verbose=0) - if not isinstance(val_outs, list): - val_outs = [val_outs] - # Same labels assumed. - for l, o in zip(out_labels, val_outs): - epoch_logs['val_' + l] = o + callbacks.on_train_end() - callbacks.on_epoch_end(epoch, epoch_logs) - if callbacks.model.stop_training: - break - callbacks.on_train_end() - - # Copy the weights back from the replicated model to the original model. - with current_strategy.scope(): + # Copy the weights back from the replicated model to the original model. updated_weights = current_strategy.unwrap( model._grouped_model)[0].get_weights() model.set_weights(updated_weights) - return model.history + return model.history def _experimental_fit_loop( @@ -427,66 +426,65 @@ def test_loop(model, iterator, verbose=0, steps=None): dataset_targets = distributed_training_utils.flatten_perdevice_values( current_strategy, targets) - distributed_test_function = K.Function( - all_inputs, all_outputs, - updates=all_updates, - name='distributed_test_function', - **all_session_args) - - # We need to set sample_weights to None since there are sample weight - # placeholders that are created with default values. - sample_weights = [None for _ in range(len(model.outputs) * - current_strategy.num_towers)] - if model.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = dataset_inputs + dataset_targets + sample_weights + [0] - else: - ins = dataset_inputs + dataset_targets + distributed_test_function = K.Function( + all_inputs, all_outputs, + updates=all_updates, + name='distributed_test_function', + **all_session_args) - for m in model.stateful_metric_functions: - m.reset_states() - stateful_metric_indices = [ - i for i, name in enumerate(model.metrics_names) - if str(name) in model.stateful_metric_names - ] + # We need to set sample_weights to None since there are sample weight + # placeholders that are created with default values. + sample_weights = [None for _ in range(len(model.outputs) * + current_strategy.num_towers)] + if model.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = dataset_inputs + dataset_targets + sample_weights + [0] + else: + ins = dataset_inputs + dataset_targets - outs = [] - if verbose == 1: - progbar = Progbar(target=steps) + for m in model.stateful_metric_functions: + m.reset_states() + stateful_metric_indices = [ + i for i, name in enumerate(model.metrics_names) + if str(name) in model.stateful_metric_names + ] - # Copy the weights from the original model to each of the replicated models. - orig_model_weights = model.get_weights() - with current_strategy.scope(): + outs = [] + if verbose == 1: + progbar = Progbar(target=steps) + + # Copy the weights from the original model to each of the replicated models. + orig_model_weights = model.get_weights() distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights( current_strategy, distributed_model, orig_model_weights) - assert steps is not None - for step in range(steps): - batch_outs = distributed_test_function(ins) - batch_outs = _aggregate_metrics_across_towers( - current_strategy.num_towers, model.metrics_names, - model.stateful_metric_names, batch_outs) - if isinstance(batch_outs, list): - if step == 0: - outs = [0.] * len(batch_outs) - for i, batch_out in enumerate(batch_outs): - if i in stateful_metric_indices: - outs[i] = batch_out - else: - outs[i] += batch_out - else: - if step == 0: - outs.append(0.) - outs[0] += batch_outs - if verbose >= 1: - progbar.update(step + 1) - for i in range(len(outs)): - if i not in stateful_metric_indices: - outs[i] /= steps + assert steps is not None + for step in range(steps): + batch_outs = distributed_test_function(ins) + batch_outs = _aggregate_metrics_across_towers( + current_strategy.num_towers, model.metrics_names, + model.stateful_metric_names, batch_outs) + if isinstance(batch_outs, list): + if step == 0: + outs = [0.] * len(batch_outs) + for i, batch_out in enumerate(batch_outs): + if i in stateful_metric_indices: + outs[i] = batch_out + else: + outs[i] += batch_out + else: + if step == 0: + outs.append(0.) + outs[0] += batch_outs + if verbose >= 1: + progbar.update(step + 1) + for i in range(len(outs)): + if i not in stateful_metric_indices: + outs[i] /= steps - if len(outs) == 1: - return outs[0] - return outs + if len(outs) == 1: + return outs[0] + return outs def _experimental_test_loop(model, iterator, verbose=0, steps=None): @@ -647,51 +645,50 @@ def predict_loop(model, iterator, verbose=0, steps=None): dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) - distributed_predict_function = K.Function( - all_inputs, all_outputs, - updates=all_updates, - name='distributed_predict_function', - **all_session_args) + distributed_predict_function = K.Function( + all_inputs, all_outputs, + updates=all_updates, + name='distributed_predict_function', + **all_session_args) - if model.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = dataset_inputs + [0] - else: - ins = dataset_inputs + if model.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = dataset_inputs + [0] + else: + ins = dataset_inputs - if verbose == 1: - progbar = Progbar(target=steps) + if verbose == 1: + progbar = Progbar(target=steps) - # Copy the weights from the original model to each of the replicated models. - orig_model_weights = model.get_weights() - with current_strategy.scope(): + # Copy the weights from the original model to each of the replicated models. + orig_model_weights = model.get_weights() distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights( current_strategy, distributed_model, orig_model_weights) - if steps is not None: - # Since we do not know how many samples we will see, we cannot pre-allocate - # the returned Numpy arrays. Instead, we store one array per batch seen - # and concatenate them upon returning. - unconcatenated_outs = [] - for step in range(steps): - batch_outs = distributed_predict_function(ins) - if not isinstance(batch_outs, list): - batch_outs = [batch_outs] - if step == 0: - for _ in batch_outs: - unconcatenated_outs.append([]) - # TODO(anjalisridhar): Should combine the outputs from multiple towers - # correctly here. - for i, batch_out in enumerate(batch_outs): - unconcatenated_outs[i].append(batch_out) - if verbose >= 1: - progbar.update(step + 1) - if len(unconcatenated_outs) == 1: - return np.concatenate(unconcatenated_outs[0], axis=0) - return [ - np.concatenate(unconcatenated_outs[i], axis=0) - for i in range(len(unconcatenated_outs)) - ] + if steps is not None: + # Since we do not know how many samples we will see, we cannot + # pre-allocate the returned Numpy arrays. Instead, we store one array per + # batch seen and concatenate them upon returning. + unconcatenated_outs = [] + for step in range(steps): + batch_outs = distributed_predict_function(ins) + if not isinstance(batch_outs, list): + batch_outs = [batch_outs] + if step == 0: + for _ in batch_outs: + unconcatenated_outs.append([]) + # TODO(anjalisridhar): Should combine the outputs from multiple towers + # correctly here. + for i, batch_out in enumerate(batch_outs): + unconcatenated_outs[i].append(batch_out) + if verbose >= 1: + progbar.update(step + 1) + if len(unconcatenated_outs) == 1: + return np.concatenate(unconcatenated_outs[0], axis=0) + return [ + np.concatenate(unconcatenated_outs[i], axis=0) + for i in range(len(unconcatenated_outs)) + ] def _experimental_predict_loop(model, iterator, verbose=0, steps=None): |