aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/model_fn.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/model_fn.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
index 8be9c72adf..44e6c7c52d 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
@@ -23,7 +23,6 @@ import collections
import six
-from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib.framework import get_graph_from_inputs
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import metric_key
@@ -32,6 +31,7 @@ from tensorflow.python.estimator import model_fn as core_model_fn_lib
from tensorflow.python.estimator.export import export_output as core_export_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
@@ -156,11 +156,11 @@ class ModelFnOps(
else:
if isinstance(predictions, dict):
predictions = {
- k: contrib_framework.convert_to_tensor_or_sparse_tensor(v)
+ k: sparse_tensor.convert_to_tensor_or_sparse_tensor(v)
for k, v in six.iteritems(predictions)
}
else:
- predictions = contrib_framework.convert_to_tensor_or_sparse_tensor(
+ predictions = sparse_tensor.convert_to_tensor_or_sparse_tensor(
predictions)
# Validate eval_metric_ops