aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/estimator_batch/estimator.py')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py128
1 files changed, 128 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 9c36c30221..38fa8c3834 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -22,6 +22,7 @@ from tensorflow.contrib.boosted_trees.estimator_batch import model
from tensorflow.contrib.boosted_trees.python.utils import losses
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
+from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.ops import math_ops
@@ -269,3 +270,130 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
+
+
+class GradientBoostedDecisionTreeRanker(estimator.Estimator):
+ """A ranking estimator using gradient boosted decision trees."""
+
+ def __init__(
+ self,
+ learner_config,
+ examples_per_layer,
+ head,
+ ranking_model_pair_keys,
+ num_trees=None,
+ feature_columns=None,
+ weight_column_name=None,
+ model_dir=None,
+ config=None,
+ label_keys=None,
+ feature_engineering_fn=None,
+ logits_modifier_function=None,
+ center_bias=False,
+ use_core_libs=False,
+ output_leaf_index=False,
+ ):
+ """Initializes a GradientBoostedDecisionTreeRanker instance.
+
+ This is an estimator that can be trained off the pairwise data and can be
+ used for inference on non-paired data. This is essentially LambdaMart.
+ Args:
+ learner_config: A config for the learner.
+ examples_per_layer: Number of examples to accumulate before growing a
+ layer. It can also be a function that computes the number of examples
+ based on the depth of the layer that's being built.
+ head: `Head` instance.
+ ranking_model_pair_keys: Keys to distinguish between features
+ for left and right part of the training pairs for ranking. For example,
+ for an Example with features "a.f1" and "b.f1", the keys would be
+ ("a", "b").
+ num_trees: An int, number of trees to build.
+ feature_columns: A list of feature columns.
+ weight_column_name: Name of the column for weights, or None if not
+ weighted.
+ model_dir: Directory for model exports, etc.
+ config: `RunConfig` object to configure the runtime settings.
+ label_keys: Optional list of strings with size `[n_classes]` defining the
+ label vocabulary. Only supported for `n_classes` > 2.
+ feature_engineering_fn: Feature engineering function. Takes features and
+ labels which are the output of `input_fn` and returns features and
+ labels which will be fed into the model.
+ logits_modifier_function: A modifier function for the logits.
+ center_bias: Whether a separate tree should be created for first fitting
+ the bias.
+ use_core_libs: Whether feature columns and loss are from the core (as
+ opposed to contrib) version of tensorflow.
+ output_leaf_index: whether to output leaf indices along with predictions
+ during inference. The leaf node indexes are available in predictions
+ dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
+ [batch_size, num_trees].
+ For example,
+ result_iter = classifier.predict(...)
+ for result_dict in result_iter:
+ # access leaf index list by result_dict["leaf_index"]
+ # which contains one leaf index per tree
+
+ Raises:
+ ValueError: If learner_config is not valid.
+ """
+ super(GradientBoostedDecisionTreeRanker, self).__init__(
+ model_fn=model.ranking_model_builder,
+ params={
+ 'head': head,
+ 'n_classes': 2,
+ 'feature_columns': feature_columns,
+ 'learner_config': learner_config,
+ 'num_trees': num_trees,
+ 'weight_column_name': weight_column_name,
+ 'examples_per_layer': examples_per_layer,
+ 'center_bias': center_bias,
+ 'logits_modifier_function': logits_modifier_function,
+ 'use_core_libs': use_core_libs,
+ 'output_leaf_index': output_leaf_index,
+ 'ranking_model_pair_keys': ranking_model_pair_keys,
+ },
+ model_dir=model_dir,
+ config=config,
+ feature_engineering_fn=feature_engineering_fn)
+
+
+class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
+ """An estimator using gradient boosted decision trees."""
+
+ def __init__(self,
+ learner_config,
+ examples_per_layer,
+ head,
+ num_trees=None,
+ feature_columns=None,
+ weight_column_name=None,
+ model_dir=None,
+ config=None,
+ label_keys=None,
+ feature_engineering_fn=None,
+ logits_modifier_function=None,
+ center_bias=True,
+ output_leaf_index=False):
+
+ def _model_fn(features, labels, mode, config):
+ return model.model_builder(
+ features=features,
+ labels=labels,
+ mode=mode,
+ config=config,
+ params={
+ 'head': head,
+ 'feature_columns': feature_columns,
+ 'learner_config': learner_config,
+ 'num_trees': num_trees,
+ 'weight_column_name': weight_column_name,
+ 'examples_per_layer': examples_per_layer,
+ 'center_bias': center_bias,
+ 'logits_modifier_function': logits_modifier_function,
+ 'use_core_libs': True,
+ 'output_leaf_index': output_leaf_index,
+ },
+ output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
+
+ super(CoreGradientBoostedDecisionTreeEstimator, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)