diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/estimator.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/estimator.py | 70 |
1 files changed, 47 insertions, 23 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 91d900395b..2ec5a0659a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -330,8 +330,8 @@ class BaseEstimator( # Features and labels TensorSignature objects. # TODO(wicke): Rename these to something more descriptive - self._features_info = None - self._labels_info = None + self._features_info = {} + self._labels_info = {} self._graph = None @@ -641,28 +641,29 @@ class BaseEstimator( return tensor_signature.create_example_parser_from_signatures( self._features_info, examples_batch) - def _check_inputs(self, features, labels): - if self._features_info is not None: - logging.debug('Given features: %s, required signatures: %s.', - str(features), str(self._features_info)) - if not tensor_signature.tensors_compatible(features, self._features_info): - raise ValueError('Features are incompatible with given information. ' + def _check_inputs(self, features, labels, mode): + if mode in self._features_info: + logging.debug('Given features for mode %s: %s, required signatures: %s.', + mode, str(features), str(self._features_info[mode])) + + if not tensor_signature.tensors_compatible(features, self._features_info[mode]): + raise ValueError('Features for mode %s are incompatible with given information. ' 'Given features: %s, required signatures: %s.' % - (str(features), str(self._features_info))) + (mode, str(features), str(self._features_info[mode]))) else: - self._features_info = tensor_signature.create_signatures(features) - logging.debug('Setting feature info to %s.', str(self._features_info)) + self._features_info[mode] = tensor_signature.create_signatures(features) + logging.debug('Setting feature info for mode %s to %s.', mode, str(self._features_info[mode])) if labels is not None: - if self._labels_info is not None: + if mode in self._labels_info: logging.debug('Given labels: %s, required signatures: %s.', str(labels), str(self._labels_info)) - if not tensor_signature.tensors_compatible(labels, self._labels_info): - raise ValueError('Labels are incompatible with given information. ' + if not tensor_signature.tensors_compatible(labels, self._labels_info[mode]): + raise ValueError('Labels for mode %s are incompatible with given information. ' 'Given labels: %s, required signatures: %s.' % - (str(labels), str(self._labels_info))) + (mode, str(labels), str(self._labels_info[mode]))) else: - self._labels_info = tensor_signature.create_signatures(labels) - logging.debug('Setting labels info to %s', str(self._labels_info)) + self._labels_info[mode] = tensor_signature.create_signatures(labels) + logging.debug('Setting labels info for mode %s to %s', mode, str(self._labels_info[mode])) def _train_model(self, input_fn, @@ -699,8 +700,7 @@ class BaseEstimator( random_seed.set_random_seed(self._config.tf_random_seed) global_step = contrib_framework.create_global_step(g) features, labels = input_fn() - self._check_inputs(features, labels) - + self._check_inputs(features, labels, model_fn_lib.ModeKeys.TRAIN) # The default return type of _get_train_ops is ModelFnOps. But there are # some subclasses of tf.contrib.learn.Estimator which override this # method and use the legacy signature, namely _get_train_ops returns a @@ -800,8 +800,7 @@ class BaseEstimator( random_seed.set_random_seed(self._config.tf_random_seed) global_step = contrib_framework.create_global_step(g) features, labels = input_fn() - self._check_inputs(features, labels) - + self._check_inputs(features, labels, model_fn_lib.ModeKeys.EVAL) # The default return type of _get_eval_ops is ModelFnOps. But there are # some subclasses of tf.contrib.learn.Estimator which override this # method and use the legacy signature, namely _get_eval_ops returns an @@ -835,6 +834,29 @@ class BaseEstimator( return result[0] return result + def _set_infer_mode_feature_signature(self, features): + for mode in list(self._features_info.keys()): + if tensor_signature.tensors_compatible(features, self._features_info[mode]): + self._features_info[model_fn_lib.ModeKeys.INFER] = self._features_info[mode] + if mode in self._labels_info: + self._labels_info[model_fn_lib.ModeKeys.INFER] = ( + self._labels_info[mode]) + else: + self._labels_info[model_fn_lib.ModeKeys.INFER] = None + break + + if model_fn_lib.ModeKeys.INFER not in self._features_info: + logging.warning('Features for mode %s are incompatible with neither train mode nor eval mode.' + ' Given features: %s' % (model_fn_lib.ModeKeys.INFER, str(features))) + for mode in list(self._features_info.keys()): + logging.warning('Whereas %s mode signatures: %s' % (mode, str(self._features_info[mode]))) + self._check_inputs(features, None, model_fn_lib.ModeKeys.INFER) + if model_fn_lib.ModeKeys.TRAIN in self._labels_info: + logging.warning('Setting labels info for mode infer equal to that of labels info for train mode') + self._labels_info[model_fn_lib.ModeKeys.INFER] = self._labels_info[model_fn_lib.ModeKeys.TRAIN] + else: + self._labels_info[model_fn_lib.ModeKeys.INFER] = {} + def _infer_model( self, input_fn, feed_fn=None, outputs=None, as_iterable=True): # Check that model has been trained. @@ -1134,8 +1156,10 @@ class Estimator(BaseEstimator): Returns: `ModelFnOps` object. """ + + self._set_infer_mode_feature_signature(features) labels = tensor_signature.create_placeholders_from_signatures( - self._labels_info) + self._labels_info[model_fn_lib.ModeKeys.INFER]) return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.INFER) @experimental @@ -1239,7 +1263,7 @@ class Estimator(BaseEstimator): return export_dir -# For time of deprecation x,y from Estimator allow direct access. +# For time of deprecation x,y from Estimator allow direct access # pylint: disable=protected-access class SKCompat(sklearn.BaseEstimator): """Scikit learn wrapper for TensorFlow Learn Estimator.""" |