aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/estimators/estimator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/estimator.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py70
1 files changed, 47 insertions, 23 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 91d900395b..2ec5a0659a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -330,8 +330,8 @@ class BaseEstimator(
# Features and labels TensorSignature objects.
# TODO(wicke): Rename these to something more descriptive
- self._features_info = None
- self._labels_info = None
+ self._features_info = {}
+ self._labels_info = {}
self._graph = None
@@ -641,28 +641,29 @@ class BaseEstimator(
return tensor_signature.create_example_parser_from_signatures(
self._features_info, examples_batch)
- def _check_inputs(self, features, labels):
- if self._features_info is not None:
- logging.debug('Given features: %s, required signatures: %s.',
- str(features), str(self._features_info))
- if not tensor_signature.tensors_compatible(features, self._features_info):
- raise ValueError('Features are incompatible with given information. '
+ def _check_inputs(self, features, labels, mode):
+ if mode in self._features_info:
+ logging.debug('Given features for mode %s: %s, required signatures: %s.',
+ mode, str(features), str(self._features_info[mode]))
+
+ if not tensor_signature.tensors_compatible(features, self._features_info[mode]):
+ raise ValueError('Features for mode %s are incompatible with given information. '
'Given features: %s, required signatures: %s.' %
- (str(features), str(self._features_info)))
+ (mode, str(features), str(self._features_info[mode])))
else:
- self._features_info = tensor_signature.create_signatures(features)
- logging.debug('Setting feature info to %s.', str(self._features_info))
+ self._features_info[mode] = tensor_signature.create_signatures(features)
+ logging.debug('Setting feature info for mode %s to %s.', mode, str(self._features_info[mode]))
if labels is not None:
- if self._labels_info is not None:
+ if mode in self._labels_info:
logging.debug('Given labels: %s, required signatures: %s.',
str(labels), str(self._labels_info))
- if not tensor_signature.tensors_compatible(labels, self._labels_info):
- raise ValueError('Labels are incompatible with given information. '
+ if not tensor_signature.tensors_compatible(labels, self._labels_info[mode]):
+ raise ValueError('Labels for mode %s are incompatible with given information. '
'Given labels: %s, required signatures: %s.' %
- (str(labels), str(self._labels_info)))
+ (mode, str(labels), str(self._labels_info[mode])))
else:
- self._labels_info = tensor_signature.create_signatures(labels)
- logging.debug('Setting labels info to %s', str(self._labels_info))
+ self._labels_info[mode] = tensor_signature.create_signatures(labels)
+ logging.debug('Setting labels info for mode %s to %s', mode, str(self._labels_info[mode]))
def _train_model(self,
input_fn,
@@ -699,8 +700,7 @@ class BaseEstimator(
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, labels = input_fn()
- self._check_inputs(features, labels)
-
+ self._check_inputs(features, labels, model_fn_lib.ModeKeys.TRAIN)
# The default return type of _get_train_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_train_ops returns a
@@ -800,8 +800,7 @@ class BaseEstimator(
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, labels = input_fn()
- self._check_inputs(features, labels)
-
+ self._check_inputs(features, labels, model_fn_lib.ModeKeys.EVAL)
# The default return type of _get_eval_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_eval_ops returns an
@@ -835,6 +834,29 @@ class BaseEstimator(
return result[0]
return result
+ def _set_infer_mode_feature_signature(self, features):
+ for mode in list(self._features_info.keys()):
+ if tensor_signature.tensors_compatible(features, self._features_info[mode]):
+ self._features_info[model_fn_lib.ModeKeys.INFER] = self._features_info[mode]
+ if mode in self._labels_info:
+ self._labels_info[model_fn_lib.ModeKeys.INFER] = (
+ self._labels_info[mode])
+ else:
+ self._labels_info[model_fn_lib.ModeKeys.INFER] = None
+ break
+
+ if model_fn_lib.ModeKeys.INFER not in self._features_info:
+ logging.warning('Features for mode %s are incompatible with neither train mode nor eval mode.'
+ ' Given features: %s' % (model_fn_lib.ModeKeys.INFER, str(features)))
+ for mode in list(self._features_info.keys()):
+ logging.warning('Whereas %s mode signatures: %s' % (mode, str(self._features_info[mode])))
+ self._check_inputs(features, None, model_fn_lib.ModeKeys.INFER)
+ if model_fn_lib.ModeKeys.TRAIN in self._labels_info:
+ logging.warning('Setting labels info for mode infer equal to that of labels info for train mode')
+ self._labels_info[model_fn_lib.ModeKeys.INFER] = self._labels_info[model_fn_lib.ModeKeys.TRAIN]
+ else:
+ self._labels_info[model_fn_lib.ModeKeys.INFER] = {}
+
def _infer_model(
self, input_fn, feed_fn=None, outputs=None, as_iterable=True):
# Check that model has been trained.
@@ -1134,8 +1156,10 @@ class Estimator(BaseEstimator):
Returns:
`ModelFnOps` object.
"""
+
+ self._set_infer_mode_feature_signature(features)
labels = tensor_signature.create_placeholders_from_signatures(
- self._labels_info)
+ self._labels_info[model_fn_lib.ModeKeys.INFER])
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.INFER)
@experimental
@@ -1239,7 +1263,7 @@ class Estimator(BaseEstimator):
return export_dir
-# For time of deprecation x,y from Estimator allow direct access.
+# For time of deprecation x,y from Estimator allow direct access
# pylint: disable=protected-access
class SKCompat(sklearn.BaseEstimator):
"""Scikit learn wrapper for TensorFlow Learn Estimator."""