aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py')
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 1ee7f2395e..e08b230f46 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -287,7 +287,8 @@ class GradientBoostedDecisionTreeModel(object):
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS,
feature_columns=None,
use_core_columns=False,
- output_leaf_index=False):
+ output_leaf_index=False,
+ output_leaf_index_modes=None):
"""Construct a new GradientBoostedDecisionTreeModel function.
Args:
@@ -307,6 +308,9 @@ class GradientBoostedDecisionTreeModel(object):
used.
output_leaf_index: A boolean variable indicating whether to output leaf
index into predictions dictionary.
+ output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which
+ dictates when leaf indices will be outputted. By default, leaf indices
+ are only outputted in INFER mode.
Raises:
ValueError: if inputs are not valid.
@@ -404,7 +408,16 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config.multi_class_strategy ==
learner_pb2.LearnerConfig.TREE_PER_CLASS and
learner_config.num_classes == 2)
+
+ if output_leaf_index_modes is None:
+ output_leaf_index_modes = [learn.ModeKeys.INFER]
+ elif not all(
+ mode in (learn.ModeKeys.TRAIN, learn.ModeKeys.EVAL,
+ learn.ModeKeys.INFER) for mode in output_leaf_index_modes):
+ raise ValueError("output_leaf_index_modes should only contain ModeKeys.")
+
self._output_leaf_index = output_leaf_index
+ self._output_leaf_index_modes = output_leaf_index_modes
def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode):
"""Runs prediction and returns a dictionary of the prediction results.
@@ -435,8 +448,7 @@ class GradientBoostedDecisionTreeModel(object):
# the right stamp.
with ops.control_dependencies(ensemble_stats):
leaf_index = None
- # Only used in infer (predict), not used in train and eval.
- if self._output_leaf_index and mode == learn.ModeKeys.INFER:
+ if self._output_leaf_index and mode in self._output_leaf_index_modes:
predictions, _, leaf_index = (
prediction_ops).gradient_trees_prediction_verbose(
ensemble_handle,
@@ -508,9 +520,6 @@ class GradientBoostedDecisionTreeModel(object):
if not input_deps:
raise ValueError("No input tensors for prediction.")
- if any(i.device != input_deps[0].device for i in input_deps):
- raise ValueError("All input tensors should be on the same device.")
-
# Get most current model stamp.
ensemble_stamp = model_ops.tree_ensemble_stamp_token(self._ensemble_handle)