aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar Sourabh Bajaj <sourabhbajaj@google.com>2018-09-19 15:21:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 15:25:00 -0700
commit237c6ccae40005e3b6199731c45e1c9f5cd86c5f (patch)
treeb4ee30fcae88ce63306191dbf0595e56710d511a /tensorflow/python/keras
parentc3014ec19e23e4aad7286b3fac6b25a5fb4a6326 (diff)
Create a steps_per_run variable to be updated correctly in the fit loop to make sure we run fit for the right number of steps.
PiperOrigin-RevId: 213706042
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py29
1 files changed, 22 insertions, 7 deletions
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index d133595793..05b40c66e3 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -293,11 +293,16 @@ def _experimental_fit_loop(
for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+ if steps_per_epoch is None:
+ raise ValueError('steps_per_epoch should be specified in the fit call.')
+ steps_per_run_var = K.variable(
+ value=min(steps_per_epoch, current_strategy.steps_per_run),
+ dtype='int32',
+ name='steps_per_run_var')
+
with current_strategy.scope():
- # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
- # steps_per_epoch and number of epochs.
ctx = current_strategy.run_steps_on_dataset(
- step_fn, iterator, iterations=current_strategy.steps_per_run,
+ step_fn, iterator, iterations=steps_per_run_var,
initial_loop_values=initial_loop_values)
train_op = ctx.run_op
@@ -310,8 +315,6 @@ def _experimental_fit_loop(
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- assert steps_per_epoch is not None
-
# TODO(sourabhbajaj): Convert this into a proper validation function
if callbacks:
raise NotImplementedError(
@@ -327,17 +330,28 @@ def _experimental_fit_loop(
steps_per_epoch=steps_per_epoch,
verbose=verbose)
# TODO(priyag, sourabhbajaj): Add callbacks support for per step callback
- # TODO(priyag, sourabhbajaj): Fix the number of steps run with steps_per_run
# TODO(priyag, sourabhbajaj): Add validation.
+
+ # Calculate the steps each time on the device.
+ steps_to_run = [current_strategy.steps_per_run] * (
+ steps_per_epoch // current_strategy.steps_per_run)
+ if steps_per_epoch % current_strategy.steps_per_run:
+ steps_to_run.append(steps_per_epoch % current_strategy.steps_per_run)
+
callbacks.on_train_begin()
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
- for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run):
+ step_index = 0
+ prev_step_count = None
+ for step_count in steps_to_run:
# TODO(sourabhbajaj): Replace size with a combination of steps_per_run
# and batch_size
batch_logs = {'batch': step_index, 'size': 1}
callbacks.on_batch_begin(step_index, batch_logs)
+ if prev_step_count is None or step_count != prev_step_count:
+ steps_per_run_var.load(step_count, K.get_session())
+ prev_step_count = step_count
try:
_, outputs = K.get_session().run([train_op, output_tensors])
except errors.OutOfRangeError:
@@ -350,6 +364,7 @@ def _experimental_fit_loop(
batch_logs.update(outputs)
callbacks.on_batch_end(step_index, batch_logs)
+ step_index = step_index + step_count
if callbacks.model.stop_training:
break