diff options
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r-- | tensorflow/python/keras/callbacks.py | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index befe82f4ec..6dfbbf3694 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -360,7 +360,10 @@ class BaseLogger(Callback): def on_batch_end(self, batch, logs=None): logs = logs or {} batch_size = logs.get('size', 0) - self.seen += batch_size + # In case of distribution strategy we can potentially run multiple steps + # at the same time, we should account for that in the `seen` calculation. + num_steps = logs.get('num_steps', 1) + self.seen += batch_size * num_steps for k, v in logs.items(): if k in self.stateful_metrics: @@ -448,10 +451,13 @@ class ProgbarLogger(Callback): def on_batch_end(self, batch, logs=None): logs = logs or {} batch_size = logs.get('size', 0) + # In case of distribution strategy we can potentially run multiple steps + # at the same time, we should account for that in the `seen` calculation. + num_steps = logs.get('num_steps', 1) if self.use_steps: - self.seen += 1 + self.seen += num_steps else: - self.seen += batch_size + self.seen += batch_size * num_steps for k in self.params['metrics']: if k in logs: @@ -1068,7 +1074,7 @@ class TensorBoard(Callback): logs = logs or {} batch_logs = {('batch_' + k): v for k, v in logs.items() - if k not in ['batch', 'size']} + if k not in ['batch', 'size', 'num_steps']} self._write_custom_summaries(self._total_batches_seen, batch_logs) self._total_batches_seen += 1 @@ -1092,7 +1098,7 @@ class TensorBoard(Callback): # batch number as Tensorboard summaries logs = {('epoch_' + k): v for k, v in logs.items() - if k not in ['batch', 'size']} + if k not in ['batch', 'size', 'num_steps']} self._write_custom_summaries(epoch, logs) # pop the histogram summary op after each epoch |