aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-27 08:06:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-27 08:09:24 -0700
commited2b2dbd1dbbb5c0e41f0bc6bb77219f0bfa2c96 (patch)
tree3f0f4bb0d89a29a016cdb0150a76ca275836f5f7 /tensorflow/contrib/boosted_trees
parentdc437d53a395438070739c3d509fa0c21b3bffbb (diff)
Adding core estimator for ranking.
PiperOrigin-RevId: 206318440
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py129
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py30
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py73
3 files changed, 210 insertions, 22 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 38fa8c3834..2df879f924 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -26,6 +26,12 @@ from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.ops import math_ops
+# ================== Old estimator interface===================================
+# The estimators below were designed for old feature columns and old estimator
+# interface. They can be used with new feature columns and losses by setting
+# use_core_libs = True.
+
+
class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
"""An estimator using gradient boosted decision trees."""
@@ -356,9 +362,16 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
config=config,
feature_engineering_fn=feature_engineering_fn)
+# ================== New Estimator interface===================================
+# The estimators below use new core Estimator interface and must be used with
+# new feature columns and heads.
+
class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
- """An estimator using gradient boosted decision trees."""
+ """An estimator using gradient boosted decision trees.
+
+ Useful for training with user specified `Head`.
+ """
def __init__(self,
learner_config,
@@ -374,6 +387,36 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
logits_modifier_function=None,
center_bias=True,
output_leaf_index=False):
+ """Initializes a core version of GradientBoostedDecisionTreeEstimator.
+
+ Args:
+ learner_config: A config for the learner.
+ examples_per_layer: Number of examples to accumulate before growing a
+ layer. It can also be a function that computes the number of examples
+ based on the depth of the layer that's being built.
+ head: `Head` instance.
+ num_trees: An int, number of trees to build.
+ feature_columns: A list of feature columns.
+ weight_column_name: Name of the column for weights, or None if not
+ weighted.
+ model_dir: Directory for model exports, etc.
+ config: `RunConfig` object to configure the runtime settings.
+ label_keys: Optional list of strings with size `[n_classes]` defining the
+ label vocabulary. Only supported for `n_classes` > 2.
+ feature_engineering_fn: Feature engineering function. Takes features and
+ labels which are the output of `input_fn` and returns features and
+ labels which will be fed into the model.
+ logits_modifier_function: A modifier function for the logits.
+ center_bias: Whether a separate tree should be created for first fitting
+ the bias.
+ output_leaf_index: whether to output leaf indices along with predictions
+ during inference. The leaf node indexes are available in predictions
+ dict by the key 'leaf_index'. For example,
+ result_dict = classifier.predict(...)
+ for example_prediction_result in result_dict:
+ # access leaf index list by example_prediction_result["leaf_index"]
+ # which contains one leaf index per tree
+ """
def _model_fn(features, labels, mode, config):
return model.model_builder(
@@ -397,3 +440,87 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
super(CoreGradientBoostedDecisionTreeEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
+
+
+class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
+ """A ranking estimator using gradient boosted decision trees."""
+
+ def __init__(
+ self,
+ learner_config,
+ examples_per_layer,
+ head,
+ ranking_model_pair_keys,
+ num_trees=None,
+ feature_columns=None,
+ weight_column_name=None,
+ model_dir=None,
+ config=None,
+ label_keys=None,
+ logits_modifier_function=None,
+ center_bias=False,
+ output_leaf_index=False,
+ ):
+ """Initializes a GradientBoostedDecisionTreeRanker instance.
+
+ This is an estimator that can be trained off the pairwise data and can be
+ used for inference on non-paired data. This is essentially LambdaMart.
+ Args:
+ learner_config: A config for the learner.
+ examples_per_layer: Number of examples to accumulate before growing a
+ layer. It can also be a function that computes the number of examples
+ based on the depth of the layer that's being built.
+ head: `Head` instance.
+ ranking_model_pair_keys: Keys to distinguish between features
+ for left and right part of the training pairs for ranking. For example,
+ for an Example with features "a.f1" and "b.f1", the keys would be
+ ("a", "b").
+ num_trees: An int, number of trees to build.
+ feature_columns: A list of feature columns.
+ weight_column_name: Name of the column for weights, or None if not
+ weighted.
+ model_dir: Directory for model exports, etc.
+ config: `RunConfig` object to configure the runtime settings.
+ label_keys: Optional list of strings with size `[n_classes]` defining the
+ label vocabulary. Only supported for `n_classes` > 2.
+ logits_modifier_function: A modifier function for the logits.
+ center_bias: Whether a separate tree should be created for first fitting
+ the bias.
+ output_leaf_index: whether to output leaf indices along with predictions
+ during inference. The leaf node indexes are available in predictions
+ dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
+ [batch_size, num_trees].
+ For example,
+ result_iter = classifier.predict(...)
+ for result_dict in result_iter:
+ # access leaf index list by result_dict["leaf_index"]
+ # which contains one leaf index per tree
+
+ Raises:
+ ValueError: If learner_config is not valid.
+ """
+
+ def _model_fn(features, labels, mode, config):
+ return model.ranking_model_builder(
+ features=features,
+ labels=labels,
+ mode=mode,
+ config=config,
+ params={
+ 'head': head,
+ 'n_classes': 2,
+ 'feature_columns': feature_columns,
+ 'learner_config': learner_config,
+ 'num_trees': num_trees,
+ 'weight_column_name': weight_column_name,
+ 'examples_per_layer': examples_per_layer,
+ 'center_bias': center_bias,
+ 'logits_modifier_function': logits_modifier_function,
+ 'use_core_libs': True,
+ 'output_leaf_index': output_leaf_index,
+ 'ranking_model_pair_keys': ranking_model_pair_keys,
+ },
+ output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
+
+ super(CoreGradientBoostedDecisionTreeRanker, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index f787d3cdb8..9e9febbbef 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -203,7 +203,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
model.predict(input_fn=_infer_ranking_train_input_fn)
-class CoreGradientBoostedDecisionTreeEstimator(test_util.TensorFlowTestCase):
+class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase):
def testTrainEvaluateInferDoesNotThrowError(self):
head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
@@ -229,6 +229,34 @@ class CoreGradientBoostedDecisionTreeEstimator(test_util.TensorFlowTestCase):
est.evaluate(input_fn=_eval_input_fn, steps=1)
est.predict(input_fn=_eval_input_fn)
+ def testRankingDontThrowExceptionForForEstimator(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ est = estimator.CoreGradientBoostedDecisionTreeRanker(
+ head=head_fn,
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[
+ core_feature_column.numeric_column("f1"),
+ core_feature_column.numeric_column("f2")
+ ],
+ ranking_model_pair_keys=("a", "b"))
+
+ # Train for a few steps.
+ est.train(input_fn=_ranking_train_input_fn, steps=1000)
+ est.evaluate(input_fn=_ranking_train_input_fn, steps=1)
+ est.predict(input_fn=_infer_ranking_train_input_fn)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index dbfee16b9e..161cc42cb0 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -59,6 +59,8 @@ def model_builder(features,
* center_bias: Whether a separate tree should be created for first fitting
the bias.
config: `RunConfig` of the estimator.
+ output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
+ (new interface).
Returns:
A `ModelFnOps` object.
@@ -176,7 +178,12 @@ def model_builder(features,
return model_fn_ops
-def ranking_model_builder(features, labels, mode, params, config):
+def ranking_model_builder(features,
+ labels,
+ mode,
+ params,
+ config,
+ output_type=ModelBuilderOutputType.MODEL_FN_OPS):
"""Multi-machine batch gradient descent tree model for ranking.
Args:
@@ -200,6 +207,9 @@ def ranking_model_builder(features, labels, mode, params, config):
for an Example with features "a.f1" and "b.f1", the keys would be
("a", "b").
config: `RunConfig` of the estimator.
+ output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
+ (new interface).
+
Returns:
A `ModelFnOps` object.
@@ -327,31 +337,54 @@ def ranking_model_builder(features, labels, mode, params, config):
return update_op
create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
- if use_core_libs and callable(create_estimator_spec_op):
- model_fn_ops = head.create_estimator_spec(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_train_op_fn,
- logits=logits)
- model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops)
- else:
- model_fn_ops = head.create_model_fn_ops(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_train_op_fn,
- logits=logits)
- if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
- model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
- gbdt_batch.LEAF_INDEX]
+ training_hooks = []
if num_trees:
if center_bias:
num_trees += 1
+
finalized_trees, attempted_trees = (
gbdt_model_main.get_number_of_trees_tensor())
- model_fn_ops.training_hooks.append(
+ training_hooks.append(
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
finalized_trees))
+
+ if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
+ if use_core_libs and callable(create_estimator_spec_op):
+ model_fn_ops = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+ model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(
+ model_fn_ops)
+ else:
+ model_fn_ops = head.create_model_fn_ops(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+
+ if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
+ model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
+ gbdt_batch.LEAF_INDEX]
+
+ model_fn_ops.training_hooks.extend(training_hooks)
+ return model_fn_ops
+
+ elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC:
+ assert callable(create_estimator_spec_op)
+ estimator_spec = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+
+ estimator_spec = estimator_spec._replace(
+ training_hooks=training_hooks + list(estimator_spec.training_hooks))
+ return estimator_spec
+
return model_fn_ops