diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/model_fn.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/model_fn.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index 8be9c72adf..44e6c7c52d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -23,7 +23,6 @@ import collections import six -from tensorflow.contrib import framework as contrib_framework from tensorflow.contrib.framework import get_graph_from_inputs from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import metric_key @@ -32,6 +31,7 @@ from tensorflow.python.estimator import model_fn as core_model_fn_lib from tensorflow.python.estimator.export import export_output as core_export_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging @@ -156,11 +156,11 @@ class ModelFnOps( else: if isinstance(predictions, dict): predictions = { - k: contrib_framework.convert_to_tensor_or_sparse_tensor(v) + k: sparse_tensor.convert_to_tensor_or_sparse_tensor(v) for k, v in six.iteritems(predictions) } else: - predictions = contrib_framework.convert_to_tensor_or_sparse_tensor( + predictions = sparse_tensor.convert_to_tensor_or_sparse_tensor( predictions) # Validate eval_metric_ops |