aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar Sourabh Bajaj <sourabhbajaj@google.com>2018-09-20 09:39:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 09:44:01 -0700
commit06ad4ad47bef99d4a8f6856bbb121387e8edcfa5 (patch)
tree5931082a614006cc7f306d45af9b0b3ce362033d /tensorflow/python/keras
parent32047f490d0892056ae4e0214d2f049887fdcf35 (diff)
Callbacks should count the steps correctly in the multi step case
PiperOrigin-RevId: 213829360
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/callbacks.py16
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py10
2 files changed, 12 insertions, 14 deletions
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index befe82f4ec..6dfbbf3694 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -360,7 +360,10 @@ class BaseLogger(Callback):
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
- self.seen += batch_size
+ # In case of distribution strategy we can potentially run multiple steps
+ # at the same time, we should account for that in the `seen` calculation.
+ num_steps = logs.get('num_steps', 1)
+ self.seen += batch_size * num_steps
for k, v in logs.items():
if k in self.stateful_metrics:
@@ -448,10 +451,13 @@ class ProgbarLogger(Callback):
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
+ # In case of distribution strategy we can potentially run multiple steps
+ # at the same time, we should account for that in the `seen` calculation.
+ num_steps = logs.get('num_steps', 1)
if self.use_steps:
- self.seen += 1
+ self.seen += num_steps
else:
- self.seen += batch_size
+ self.seen += batch_size * num_steps
for k in self.params['metrics']:
if k in logs:
@@ -1068,7 +1074,7 @@ class TensorBoard(Callback):
logs = logs or {}
batch_logs = {('batch_' + k): v
for k, v in logs.items()
- if k not in ['batch', 'size']}
+ if k not in ['batch', 'size', 'num_steps']}
self._write_custom_summaries(self._total_batches_seen, batch_logs)
self._total_batches_seen += 1
@@ -1092,7 +1098,7 @@ class TensorBoard(Callback):
# batch number as Tensorboard summaries
logs = {('epoch_' + k): v
for k, v in logs.items()
- if k not in ['batch', 'size']}
+ if k not in ['batch', 'size', 'num_steps']}
self._write_custom_summaries(epoch, logs)
# pop the histogram summary op after each epoch
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 05b40c66e3..26c5ec4efc 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -314,12 +314,6 @@ def _experimental_fit_loop(
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
-
- # TODO(sourabhbajaj): Convert this into a proper validation function
- if callbacks:
- raise NotImplementedError(
- 'Callbacks are not supported with TPUStrategy right now.')
-
callbacks = cbks.configure_callbacks(
callbacks,
model,
@@ -345,9 +339,7 @@ def _experimental_fit_loop(
step_index = 0
prev_step_count = None
for step_count in steps_to_run:
- # TODO(sourabhbajaj): Replace size with a combination of steps_per_run
- # and batch_size
- batch_logs = {'batch': step_index, 'size': 1}
+ batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
callbacks.on_batch_begin(step_index, batch_logs)
if prev_step_count is None or step_count != prev_step_count:
steps_per_run_var.load(step_count, K.get_session())