aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/predictor/predictor_factories.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/predictor/predictor_factories.py')
-rw-r--r--tensorflow/contrib/predictor/predictor_factories.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py
index 9485187c5d..04b5d5bdf1 100644
--- a/tensorflow/contrib/predictor/predictor_factories.py
+++ b/tensorflow/contrib/predictor/predictor_factories.py
@@ -21,6 +21,8 @@ from __future__ import print_function
from tensorflow.contrib.predictor import contrib_estimator_predictor
from tensorflow.contrib.predictor import core_estimator_predictor
from tensorflow.contrib.predictor import saved_model_predictor
+
+from tensorflow.contrib.learn.python.learn.estimators import estimator as contrib_estimator
from tensorflow.python.estimator import estimator as core_estimator
@@ -85,7 +87,7 @@ def from_estimator(estimator,
TypeError: if `estimator` is a contrib `Estimator` instead of a core
`Estimator`.
"""
- if isinstance(estimator, estimator.Estimator):
+ if isinstance(estimator, contrib_estimator.Estimator):
raise TypeError('Espected estimator to be of type '
'tf.python.estimator.Estimator, but got type '
'tf.contrib.learn.Estimator. You likely want to call '