aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/callbacks.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/callbacks.py')
-rw-r--r--tensorflow/python/keras/callbacks.py16
1 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index befe82f4ec..6dfbbf3694 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -360,7 +360,10 @@ class BaseLogger(Callback):
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
- self.seen += batch_size
+ # In case of distribution strategy we can potentially run multiple steps
+ # at the same time, we should account for that in the `seen` calculation.
+ num_steps = logs.get('num_steps', 1)
+ self.seen += batch_size * num_steps
for k, v in logs.items():
if k in self.stateful_metrics:
@@ -448,10 +451,13 @@ class ProgbarLogger(Callback):
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
+ # In case of distribution strategy we can potentially run multiple steps
+ # at the same time, we should account for that in the `seen` calculation.
+ num_steps = logs.get('num_steps', 1)
if self.use_steps:
- self.seen += 1
+ self.seen += num_steps
else:
- self.seen += batch_size
+ self.seen += batch_size * num_steps
for k in self.params['metrics']:
if k in logs:
@@ -1068,7 +1074,7 @@ class TensorBoard(Callback):
logs = logs or {}
batch_logs = {('batch_' + k): v
for k, v in logs.items()
- if k not in ['batch', 'size']}
+ if k not in ['batch', 'size', 'num_steps']}
self._write_custom_summaries(self._total_batches_seen, batch_logs)
self._total_batches_seen += 1
@@ -1092,7 +1098,7 @@ class TensorBoard(Callback):
# batch number as Tensorboard summaries
logs = {('epoch_' + k): v
for k, v in logs.items()
- if k not in ['batch', 'size']}
+ if k not in ['batch', 'size', 'num_steps']}
self._write_custom_summaries(epoch, logs)
# pop the histogram summary op after each epoch