aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/estimator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/estimator.py')
-rw-r--r--tensorflow/python/estimator/estimator.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 351fcb6423..2f1212d5a2 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -207,7 +207,8 @@ class Estimator(object):
else:
self._session_config = self._config.session_config
- self._device_fn = _get_replica_device_setter(self._config)
+ self._device_fn = self._config.device_fn or \
+ _get_replica_device_setter(self._config)
if model_fn is None:
raise ValueError('model_fn must be provided to Estimator.')
@@ -716,7 +717,7 @@ class Estimator(object):
batch_length = batch_length or value.shape[0]
if value.shape[0] != batch_length:
raise ValueError('Batch length of predictions should be same. %s has '
- 'different batch length then others.' % key)
+ 'different batch length than others.' % key)
return batch_length
def _extract_keys(self, predictions, predict_keys):