aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training_distributed.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/training_distributed.py')
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py57
1 files changed, 43 insertions, 14 deletions
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index e440e02bfb..939732cd67 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -70,7 +70,8 @@ def fit_loop(
# TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
if current_strategy.__class__.__name__ == 'TPUStrategy':
return _experimental_fit_loop(
- model, iterator, epochs, initial_epoch, steps_per_epoch)
+ model, iterator, epochs, verbose, callbacks, initial_epoch,
+ steps_per_epoch)
clone_model_on_towers(
model, current_strategy, make_callback_model=True)
@@ -201,6 +202,8 @@ def _experimental_fit_loop(
model,
iterator,
epochs=100,
+ verbose=1,
+ callbacks=None,
initial_epoch=0,
steps_per_epoch=None):
"""fit function when using TPU DistributionStrategy for training.
@@ -209,6 +212,8 @@ def _experimental_fit_loop(
model: Keras Model instance.
iterator: Iterator that returns inputs and targets
epochs: Number of times to iterate over the data
+ verbose: Verbosity mode, 0, 1 or 2
+ callbacks: List of callbacks to be called during training
initial_epoch: Epoch at which to start training
(useful for resuming a previous training run)
steps_per_epoch: Total number of steps (batches of samples)
@@ -225,7 +230,6 @@ def _experimental_fit_loop(
# TODO(priyag): Add validation that shapes are fully defined for TPU case.
- # TODO(priyag, sourabhbajaj): This should be moved into a callback instead.
K.get_session().run(current_strategy.initialize())
def _per_device_train_function(model):
@@ -298,19 +302,35 @@ def _experimental_fit_loop(
assert steps_per_epoch is not None
- # TODO(priyag, sourabhbajaj): Add callbacks support.
+ # TODO(sourabhbajaj): Convert this into a proper validation function
+ if callbacks:
+ raise NotImplementedError(
+ 'Callbacks are not supported with TPUStrategy right now.')
+
+ callbacks = cbks.configure_callbacks(
+ callbacks,
+ model,
+ do_validation=False,
+ val_inputs=None,
+ val_targets=None,
+ epochs=epochs,
+ steps_per_epoch=steps_per_epoch,
+ verbose=verbose)
+ # TODO(priyag, sourabhbajaj): Add callbacks support for per step callback
+ # TODO(priyag, sourabhbajaj): Fix the number of steps run with steps_per_run
# TODO(priyag, sourabhbajaj): Add validation.
+ callbacks.on_train_begin()
for epoch in range(initial_epoch, epochs):
- for step_index in range(
- 0, steps_per_epoch, current_strategy.steps_per_run):
+ callbacks.on_epoch_begin(epoch)
+ epoch_logs = {}
+ for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run):
+ # TODO(sourabhbajaj): Add the size parameter in batch_logs once callbacks
+ # are fixed as we need to replace size with a combination of steps_per_run
+ # and batch_size
+ batch_logs = {'batch': step_index}
+ callbacks.on_batch_begin(step_index, batch_logs)
try:
- _, outs = K.get_session().run([train_op, output_tensors])
- # TODO(priyag, sourabhbajaj): Remove this logging in favor of proper
- # summaries through callbacks.
- print('Epoch: {}, step_index: {}, loss: {}'.format(
- epoch, step_index, outs['loss']))
- for label, out in outs.items():
- print(label, ': ', out)
+ _, outputs = K.get_session().run([train_op, output_tensors])
except errors.OutOfRangeError:
logging.warning('Your dataset iterator ran out of data; '
'interrupting training. Make sure that your dataset '
@@ -319,6 +339,16 @@ def _experimental_fit_loop(
steps_per_epoch * epochs)
break
+ batch_logs.update(outputs)
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callbacks.model.stop_training:
+ break
+
+ callbacks.on_epoch_end(epoch, epoch_logs)
+ if callbacks.model.stop_training:
+ break
+ callbacks.on_train_end()
+
# Copy the weights back from the replicated model to the original model.
with current_strategy.scope():
updated_weights = current_strategy.unwrap(
@@ -326,8 +356,7 @@ def _experimental_fit_loop(
model.set_weights(updated_weights)
K.get_session().run(current_strategy.finalize())
-
- # TODO(priyag, sourabhbajaj): Return history.
+ return model.history
def test_loop(model, iterator, verbose=0, steps=None):