aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-08-22 16:41:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 16:53:32 -0700
commit8101d3f9eeb923f9ea07cd495a748a34bdf1aee9 (patch)
tree05063a0b00724499fa327a39e4e01d205cb1ec94
parent73c7768904554b5b2b6420556b52bfaf43453423 (diff)
Turn off prefetching for `predict` calls since it is currently not
deterministic. PiperOrigin-RevId: 209852547
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py3
-rw-r--r--tensorflow/python/keras/engine/training.py13
2 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 287db6a88d..d39fd57294 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -565,8 +565,7 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase):
dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
dataset_with = dataset_with.batch(32)
strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0',
- '/device:GPU:0'],
- prefetch_on_device=False)
+ '/device:GPU:0'])
model.compile(loss=keras.losses.mean_squared_error,
optimizer=gradient_descent.GradientDescentOptimizer(0.5),
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 502635c408..85d25411b4 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -1727,6 +1727,13 @@ class Model(Network):
if batch_size is None and steps is None:
batch_size = 32
+ # Turn off prefetching since this is currently not deterministic. Once
+ # b/112498930 is fixed we can turn it back on.
+ # `_prefetch_on_device` is currently a property of only `MirroredStrategy`.
+ if (self._distribution_strategy and
+ hasattr(self._distribution_strategy, '_prefetch_on_device')):
+ self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access
+
# Validate and standardize user data.
x, _, _ = self._standardize_user_data(
x, check_steps=True, steps_name='steps', steps=steps)
@@ -1735,8 +1742,12 @@ class Model(Network):
return training_eager.predict_loop(
self, x, batch_size=batch_size, verbose=verbose, steps=steps)
elif self._distribution_strategy:
- return training_distributed.predict_loop(
+ results = training_distributed.predict_loop(
self, x, verbose=verbose, steps=steps)
+ # Turn prefetching back on since we turned it off previously.
+ if hasattr(self._distribution_strategy, '_prefetch_on_device'):
+ self._distribution_strategy._prefetch_on_device = True # pylint: disable=protected-access
+ return results
else:
return training_arrays.predict_loop(
self, x, batch_size=batch_size, verbose=verbose, steps=steps)