diff options
Diffstat (limited to 'tensorflow/python/estimator/estimator.py')
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 351fcb6423..2f1212d5a2 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -207,7 +207,8 @@ class Estimator(object): else: self._session_config = self._config.session_config - self._device_fn = _get_replica_device_setter(self._config) + self._device_fn = self._config.device_fn or \ + _get_replica_device_setter(self._config) if model_fn is None: raise ValueError('model_fn must be provided to Estimator.') @@ -716,7 +717,7 @@ class Estimator(object): batch_length = batch_length or value.shape[0] if value.shape[0] != batch_length: raise ValueError('Batch length of predictions should be same. %s has ' - 'different batch length then others.' % key) + 'different batch length than others.' % key) return batch_length def _extract_keys(self, predictions, predict_keys): |