From 237c6ccae40005e3b6199731c45e1c9f5cd86c5f Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Wed, 19 Sep 2018 15:21:21 -0700 Subject: Create a steps_per_run variable to be updated correctly in the fit loop to make sure we run fit for the right number of steps. PiperOrigin-RevId: 213706042 --- .../python/keras/engine/training_distributed.py | 29 ++++++++++++++++------ 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index d133595793..05b40c66e3 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -293,11 +293,16 @@ def _experimental_fit_loop( for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors): initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype) + if steps_per_epoch is None: + raise ValueError('steps_per_epoch should be specified in the fit call.') + steps_per_run_var = K.variable( + value=min(steps_per_epoch, current_strategy.steps_per_run), + dtype='int32', + name='steps_per_run_var') + with current_strategy.scope(): - # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on - # steps_per_epoch and number of epochs. ctx = current_strategy.run_steps_on_dataset( - step_fn, iterator, iterations=current_strategy.steps_per_run, + step_fn, iterator, iterations=steps_per_run_var, initial_loop_values=initial_loop_values) train_op = ctx.run_op @@ -310,8 +315,6 @@ def _experimental_fit_loop( distributed_training_utils.set_weights( current_strategy, distributed_model, orig_model_weights) - assert steps_per_epoch is not None - # TODO(sourabhbajaj): Convert this into a proper validation function if callbacks: raise NotImplementedError( @@ -327,17 +330,28 @@ def _experimental_fit_loop( steps_per_epoch=steps_per_epoch, verbose=verbose) # TODO(priyag, sourabhbajaj): Add callbacks support for per step callback - # TODO(priyag, sourabhbajaj): Fix the number of steps run with steps_per_run # TODO(priyag, sourabhbajaj): Add validation. + + # Calculate the steps each time on the device. + steps_to_run = [current_strategy.steps_per_run] * ( + steps_per_epoch // current_strategy.steps_per_run) + if steps_per_epoch % current_strategy.steps_per_run: + steps_to_run.append(steps_per_epoch % current_strategy.steps_per_run) + callbacks.on_train_begin() for epoch in range(initial_epoch, epochs): callbacks.on_epoch_begin(epoch) epoch_logs = {} - for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run): + step_index = 0 + prev_step_count = None + for step_count in steps_to_run: # TODO(sourabhbajaj): Replace size with a combination of steps_per_run # and batch_size batch_logs = {'batch': step_index, 'size': 1} callbacks.on_batch_begin(step_index, batch_logs) + if prev_step_count is None or step_count != prev_step_count: + steps_per_run_var.load(step_count, K.get_session()) + prev_step_count = step_count try: _, outputs = K.get_session().run([train_op, output_tensors]) except errors.OutOfRangeError: @@ -350,6 +364,7 @@ def _experimental_fit_loop( batch_logs.update(outputs) callbacks.on_batch_end(step_index, batch_logs) + step_index = step_index + step_count if callbacks.model.stop_training: break -- cgit v1.2.3