diff options
author | 2018-08-22 16:41:19 -0700 | |
---|---|---|
committer | 2018-08-22 16:53:32 -0700 | |
commit | 8101d3f9eeb923f9ea07cd495a748a34bdf1aee9 (patch) | |
tree | 05063a0b00724499fa327a39e4e01d205cb1ec94 | |
parent | 73c7768904554b5b2b6420556b52bfaf43453423 (diff) |
Turn off prefetching for `predict` calls since it is currently not
deterministic.
PiperOrigin-RevId: 209852547
-rw-r--r-- | tensorflow/contrib/distribute/python/keras_test.py | 3 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 13 |
2 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 287db6a88d..d39fd57294 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -565,8 +565,7 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase): dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) dataset_with = dataset_with.batch(32) strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0', - '/device:GPU:0'], - prefetch_on_device=False) + '/device:GPU:0']) model.compile(loss=keras.losses.mean_squared_error, optimizer=gradient_descent.GradientDescentOptimizer(0.5), diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 502635c408..85d25411b4 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -1727,6 +1727,13 @@ class Model(Network): if batch_size is None and steps is None: batch_size = 32 + # Turn off prefetching since this is currently not deterministic. Once + # b/112498930 is fixed we can turn it back on. + # `_prefetch_on_device` is currently a property of only `MirroredStrategy`. + if (self._distribution_strategy and + hasattr(self._distribution_strategy, '_prefetch_on_device')): + self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access + # Validate and standardize user data. x, _, _ = self._standardize_user_data( x, check_steps=True, steps_name='steps', steps=steps) @@ -1735,8 +1742,12 @@ class Model(Network): return training_eager.predict_loop( self, x, batch_size=batch_size, verbose=verbose, steps=steps) elif self._distribution_strategy: - return training_distributed.predict_loop( + results = training_distributed.predict_loop( self, x, verbose=verbose, steps=steps) + # Turn prefetching back on since we turned it off previously. + if hasattr(self._distribution_strategy, '_prefetch_on_device'): + self._distribution_strategy._prefetch_on_device = True # pylint: disable=protected-access + return results else: return training_arrays.predict_loop( self, x, batch_size=batch_size, verbose=verbose, steps=steps) |