diff options
Diffstat (limited to 'tensorflow/python/keras/engine/training_distributed.py')
-rw-r--r-- | tensorflow/python/keras/engine/training_distributed.py | 57 |
1 files changed, 43 insertions, 14 deletions
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index e440e02bfb..939732cd67 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -70,7 +70,8 @@ def fit_loop( # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. if current_strategy.__class__.__name__ == 'TPUStrategy': return _experimental_fit_loop( - model, iterator, epochs, initial_epoch, steps_per_epoch) + model, iterator, epochs, verbose, callbacks, initial_epoch, + steps_per_epoch) clone_model_on_towers( model, current_strategy, make_callback_model=True) @@ -201,6 +202,8 @@ def _experimental_fit_loop( model, iterator, epochs=100, + verbose=1, + callbacks=None, initial_epoch=0, steps_per_epoch=None): """fit function when using TPU DistributionStrategy for training. @@ -209,6 +212,8 @@ def _experimental_fit_loop( model: Keras Model instance. iterator: Iterator that returns inputs and targets epochs: Number of times to iterate over the data + verbose: Verbosity mode, 0, 1 or 2 + callbacks: List of callbacks to be called during training initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) @@ -225,7 +230,6 @@ def _experimental_fit_loop( # TODO(priyag): Add validation that shapes are fully defined for TPU case. - # TODO(priyag, sourabhbajaj): This should be moved into a callback instead. K.get_session().run(current_strategy.initialize()) def _per_device_train_function(model): @@ -298,19 +302,35 @@ def _experimental_fit_loop( assert steps_per_epoch is not None - # TODO(priyag, sourabhbajaj): Add callbacks support. + # TODO(sourabhbajaj): Convert this into a proper validation function + if callbacks: + raise NotImplementedError( + 'Callbacks are not supported with TPUStrategy right now.') + + callbacks = cbks.configure_callbacks( + callbacks, + model, + do_validation=False, + val_inputs=None, + val_targets=None, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + verbose=verbose) + # TODO(priyag, sourabhbajaj): Add callbacks support for per step callback + # TODO(priyag, sourabhbajaj): Fix the number of steps run with steps_per_run # TODO(priyag, sourabhbajaj): Add validation. + callbacks.on_train_begin() for epoch in range(initial_epoch, epochs): - for step_index in range( - 0, steps_per_epoch, current_strategy.steps_per_run): + callbacks.on_epoch_begin(epoch) + epoch_logs = {} + for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run): + # TODO(sourabhbajaj): Add the size parameter in batch_logs once callbacks + # are fixed as we need to replace size with a combination of steps_per_run + # and batch_size + batch_logs = {'batch': step_index} + callbacks.on_batch_begin(step_index, batch_logs) try: - _, outs = K.get_session().run([train_op, output_tensors]) - # TODO(priyag, sourabhbajaj): Remove this logging in favor of proper - # summaries through callbacks. - print('Epoch: {}, step_index: {}, loss: {}'.format( - epoch, step_index, outs['loss'])) - for label, out in outs.items(): - print(label, ': ', out) + _, outputs = K.get_session().run([train_op, output_tensors]) except errors.OutOfRangeError: logging.warning('Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your dataset ' @@ -319,6 +339,16 @@ def _experimental_fit_loop( steps_per_epoch * epochs) break + batch_logs.update(outputs) + callbacks.on_batch_end(step_index, batch_logs) + if callbacks.model.stop_training: + break + + callbacks.on_epoch_end(epoch, epoch_logs) + if callbacks.model.stop_training: + break + callbacks.on_train_end() + # Copy the weights back from the replicated model to the original model. with current_strategy.scope(): updated_weights = current_strategy.unwrap( @@ -326,8 +356,7 @@ def _experimental_fit_loop( model.set_weights(updated_weights) K.get_session().run(current_strategy.finalize()) - - # TODO(priyag, sourabhbajaj): Return history. + return model.history def test_loop(model, iterator, verbose=0, steps=None): |