aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-09 10:47:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 10:56:16 -0700
commit810f4f8d13de5a5e9ba4010addcce18f98002150 (patch)
tree32eb99c5b851b19e87a868f99a01766d45e1c9d4 /tensorflow/contrib/boosted_trees
parent27fb77281c34574306389f8b2c0ab36a38436100 (diff)
Adding ranking support (over paired data for train and eval and unpaired data for inference).
PiperOrigin-RevId: 203791296
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py85
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py47
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py183
-rw-r--r--tensorflow/contrib/boosted_trees/python/utils/losses.py67
4 files changed, 380 insertions, 2 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 9c36c30221..59a78515c6 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -269,3 +269,88 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
+
+
+class GradientBoostedDecisionTreeRanker(estimator.Estimator):
+ """A ranking estimator using gradient boosted decision trees."""
+
+ def __init__(
+ self,
+ learner_config,
+ examples_per_layer,
+ head,
+ ranking_model_pair_keys,
+ num_trees=None,
+ feature_columns=None,
+ weight_column_name=None,
+ model_dir=None,
+ config=None,
+ label_keys=None,
+ feature_engineering_fn=None,
+ logits_modifier_function=None,
+ center_bias=False,
+ use_core_libs=False,
+ output_leaf_index=False,
+ ):
+ """Initializes a GradientBoostedDecisionTreeRanker instance.
+
+ This is an estimator that can be trained off the pairwise data and can be
+ used for inference on non-paired data. This is essentially LambdaMart.
+ Args:
+ learner_config: A config for the learner.
+ examples_per_layer: Number of examples to accumulate before growing a
+ layer. It can also be a function that computes the number of examples
+ based on the depth of the layer that's being built.
+ head: `Head` instance.
+ ranking_model_pair_keys: Keys to distinguish between features
+ for left and right part of the training pairs for ranking. For example,
+ for an Example with features "a.f1" and "b.f1", the keys would be
+ ("a", "b").
+ num_trees: An int, number of trees to build.
+ feature_columns: A list of feature columns.
+ weight_column_name: Name of the column for weights, or None if not
+ weighted.
+ model_dir: Directory for model exports, etc.
+ config: `RunConfig` object to configure the runtime settings.
+ label_keys: Optional list of strings with size `[n_classes]` defining the
+ label vocabulary. Only supported for `n_classes` > 2.
+ feature_engineering_fn: Feature engineering function. Takes features and
+ labels which are the output of `input_fn` and returns features and
+ labels which will be fed into the model.
+ logits_modifier_function: A modifier function for the logits.
+ center_bias: Whether a separate tree should be created for first fitting
+ the bias.
+ use_core_libs: Whether feature columns and loss are from the core (as
+ opposed to contrib) version of tensorflow.
+ output_leaf_index: whether to output leaf indices along with predictions
+ during inference. The leaf node indexes are available in predictions
+ dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
+ [batch_size, num_trees].
+ For example,
+ result_iter = classifier.predict(...)
+ for result_dict in result_iter:
+ # access leaf index list by result_dict["leaf_index"]
+ # which contains one leaf index per tree
+
+ Raises:
+ ValueError: If learner_config is not valid.
+ """
+ super(GradientBoostedDecisionTreeRanker, self).__init__(
+ model_fn=model.ranking_model_builder,
+ params={
+ 'head': head,
+ 'n_classes': 2,
+ 'feature_columns': feature_columns,
+ 'learner_config': learner_config,
+ 'num_trees': num_trees,
+ 'weight_column_name': weight_column_name,
+ 'examples_per_layer': examples_per_layer,
+ 'center_bias': center_bias,
+ 'logits_modifier_function': logits_modifier_function,
+ 'use_core_libs': use_core_libs,
+ 'output_leaf_index': output_leaf_index,
+ 'ranking_model_pair_keys': ranking_model_pair_keys,
+ },
+ model_dir=model_dir,
+ config=config,
+ feature_engineering_fn=feature_engineering_fn)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index 75ef1b0500..2c2dcb039d 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -37,12 +37,31 @@ def _train_input_fn():
return features, label
+def _ranking_train_input_fn():
+ features = {
+ "a.f1": constant_op.constant([[3.], [0.3], [1.]]),
+ "a.f2": constant_op.constant([[0.1], [3.], [1.]]),
+ "b.f1": constant_op.constant([[13.], [0.4], [5.]]),
+ "b.f2": constant_op.constant([[1.], [3.], [0.01]]),
+ }
+ label = constant_op.constant([[0], [0], [1]], dtype=dtypes.int32)
+ return features, label
+
+
def _eval_input_fn():
features = {"x": constant_op.constant([[1.], [2.], [2.]])}
label = constant_op.constant([[0], [1], [1]], dtype=dtypes.int32)
return features, label
+def _infer_ranking_train_input_fn():
+ features = {
+ "f1": constant_op.constant([[3.], [2], [1.]]),
+ "f2": constant_op.constant([[0.1], [3.], [1.]])
+ }
+ return features, None
+
+
class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
def setUp(self):
@@ -155,6 +174,34 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
regressor.evaluate(input_fn=_eval_input_fn, steps=1)
regressor.export(self._export_dir_base)
+ def testRankingDontThrowExceptionForForEstimator(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+
+ model = estimator.GradientBoostedDecisionTreeRanker(
+ head=head_fn,
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ use_core_libs=True,
+ feature_columns=[
+ core_feature_column.numeric_column("f1"),
+ core_feature_column.numeric_column("f2")
+ ],
+ ranking_model_pair_keys=("a", "b"))
+
+ model.fit(input_fn=_ranking_train_input_fn, steps=1000)
+ model.evaluate(input_fn=_ranking_train_input_fn, steps=1)
+ model.predict(input_fn=_infer_ranking_train_input_fn)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 1ee8911989..0e8a56e6e9 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import copy
+from tensorflow.contrib import learn
from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils
from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
from tensorflow.contrib.boosted_trees.python.ops import model_ops
@@ -28,7 +29,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import training_util
-
def model_builder(features, labels, mode, params, config):
"""Multi-machine batch gradient descent tree model.
@@ -141,3 +141,184 @@ def model_builder(features, labels, mode, params, config):
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
finalized_trees))
return model_fn_ops
+
+
+def ranking_model_builder(features, labels, mode, params, config):
+ """Multi-machine batch gradient descent tree model for ranking.
+
+ Args:
+ features: `Tensor` or `dict` of `Tensor` objects.
+ labels: Labels used to train on.
+ mode: Mode we are in. (TRAIN/EVAL/INFER)
+ params: A dict of hyperparameters.
+ The following hyperparameters are expected:
+ * head: A `Head` instance.
+ * learner_config: A config for the learner.
+ * feature_columns: An iterable containing all the feature columns used by
+ the model.
+ * examples_per_layer: Number of examples to accumulate before growing a
+ layer. It can also be a function that computes the number of examples
+ based on the depth of the layer that's being built.
+ * weight_column_name: The name of weight column.
+ * center_bias: Whether a separate tree should be created for first fitting
+ the bias.
+ * ranking_model_pair_keys (Optional): Keys to distinguish between features
+ for left and right part of the training pairs for ranking. For example,
+ for an Example with features "a.f1" and "b.f1", the keys would be
+ ("a", "b").
+ config: `RunConfig` of the estimator.
+
+ Returns:
+ A `ModelFnOps` object.
+ Raises:
+ ValueError: if inputs are not valid.
+ """
+ head = params["head"]
+ learner_config = params["learner_config"]
+ examples_per_layer = params["examples_per_layer"]
+ feature_columns = params["feature_columns"]
+ weight_column_name = params["weight_column_name"]
+ num_trees = params["num_trees"]
+ use_core_libs = params["use_core_libs"]
+ logits_modifier_function = params["logits_modifier_function"]
+ output_leaf_index = params["output_leaf_index"]
+ ranking_model_pair_keys = params["ranking_model_pair_keys"]
+
+ if features is None:
+ raise ValueError("At least one feature must be specified.")
+
+ if config is None:
+ raise ValueError("Missing estimator RunConfig.")
+
+ center_bias = params["center_bias"]
+
+ if isinstance(features, ops.Tensor):
+ features = {features.name: features}
+
+ # Make a shallow copy of features to ensure downstream usage
+ # is unaffected by modifications in the model function.
+ training_features = copy.copy(features)
+ training_features.pop(weight_column_name, None)
+ global_step = training_util.get_global_step()
+ with ops.device(global_step.device):
+ ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config="", # Initialize an empty ensemble.
+ name="ensemble_model")
+
+ # Extract the features.
+ if mode == learn.ModeKeys.TRAIN or mode == learn.ModeKeys.EVAL:
+ # For ranking pairwise training, we extract two sets of features.
+ if len(ranking_model_pair_keys) != 2:
+ raise ValueError("You must provide keys for ranking.")
+ left_pair_key = ranking_model_pair_keys[0]
+ right_pair_key = ranking_model_pair_keys[1]
+ if left_pair_key is None or right_pair_key is None:
+ raise ValueError("Both pair keys should be provided for ranking.")
+
+ features_1 = {}
+ features_2 = {}
+ for name in training_features:
+ feature = training_features[name]
+ new_name = name[2:]
+ if name.startswith(left_pair_key + "."):
+ features_1[new_name] = feature
+ else:
+ assert name.startswith(right_pair_key + ".")
+ features_2[new_name] = feature
+
+ main_features = features_1
+ supplementary_features = features_2
+ else:
+ # For non-ranking or inference ranking, we have only 1 set of features.
+ main_features = training_features
+
+ # Create GBDT model.
+ gbdt_model_main = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=config.is_chief,
+ num_ps_replicas=config.num_ps_replicas,
+ ensemble_handle=ensemble_handle,
+ center_bias=center_bias,
+ examples_per_layer=examples_per_layer,
+ learner_config=learner_config,
+ feature_columns=feature_columns,
+ logits_dimension=head.logits_dimension,
+ features=main_features,
+ use_core_columns=use_core_libs,
+ output_leaf_index=output_leaf_index)
+
+ with ops.name_scope("gbdt", "gbdt_optimizer"):
+ # Logits for inference.
+ if mode == learn.ModeKeys.INFER:
+ predictions_dict = gbdt_model_main.predict(mode)
+ logits = predictions_dict[gbdt_batch.PREDICTIONS]
+ if logits_modifier_function:
+ logits = logits_modifier_function(logits, features, mode)
+ else:
+ gbdt_model_supplementary = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=config.is_chief,
+ num_ps_replicas=config.num_ps_replicas,
+ ensemble_handle=ensemble_handle,
+ center_bias=center_bias,
+ examples_per_layer=examples_per_layer,
+ learner_config=learner_config,
+ feature_columns=feature_columns,
+ logits_dimension=head.logits_dimension,
+ features=supplementary_features,
+ use_core_columns=use_core_libs,
+ output_leaf_index=output_leaf_index)
+
+ # Logits for train and eval.
+ if not supplementary_features:
+ raise ValueError("Features for ranking must be specified.")
+
+ predictions_dict_1 = gbdt_model_main.predict(mode)
+ predictions_1 = predictions_dict_1[gbdt_batch.PREDICTIONS]
+
+ predictions_dict_2 = gbdt_model_supplementary.predict(mode)
+ predictions_2 = predictions_dict_2[gbdt_batch.PREDICTIONS]
+
+ logits = predictions_1 - predictions_2
+ if logits_modifier_function:
+ logits = logits_modifier_function(logits, features, mode)
+
+ predictions_dict = predictions_dict_1
+ predictions_dict[gbdt_batch.PREDICTIONS] = logits
+
+ def _train_op_fn(loss):
+ """Returns the op to optimize the loss."""
+ update_op = gbdt_model_main.train(loss, predictions_dict, labels)
+ with ops.control_dependencies(
+ [update_op]), (ops.colocate_with(global_step)):
+ update_op = state_ops.assign_add(global_step, 1).op
+ return update_op
+
+ create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
+ if use_core_libs and callable(create_estimator_spec_op):
+ model_fn_ops = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+ model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops)
+ else:
+ model_fn_ops = head.create_model_fn_ops(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+
+ if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
+ model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
+ gbdt_batch.LEAF_INDEX]
+ if num_trees:
+ if center_bias:
+ num_trees += 1
+ finalized_trees, attempted_trees = (
+ gbdt_model_main.get_number_of_trees_tensor())
+ model_fn_ops.training_hooks.append(
+ trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
+ finalized_trees))
+ return model_fn_ops
diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py
index ab7ac2aba6..b5ebaf1999 100644
--- a/tensorflow/contrib/boosted_trees/python/utils/losses.py
+++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py
@@ -23,6 +23,12 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.ops.losses import losses
+
+
+def per_example_squared_hinge_loss(labels, weights, predictions):
+ loss = losses.hinge_loss(labels=labels, logits=predictions, weights=weights)
+ return math_ops.square(loss), control_flow_ops.no_op()
def per_example_logistic_loss(labels, weights, predictions):
@@ -126,7 +132,7 @@ def per_example_squared_loss(labels, weights, predictions):
def per_example_exp_loss(labels, weights, predictions, name=None, eps=0.1):
- """Exponential loss given labels, example weights and predictions.
+ """Trimmed exponential loss given labels, example weights and predictions.
Note that this is only for binary classification.
If logistic loss tries to make sure that the classifier is certain of its
@@ -211,3 +217,62 @@ def per_example_exp_loss(labels, weights, predictions, name=None, eps=0.1):
unweighted_loss = exp_with_logits(
name=name, eps=eps, labels=labels, logits=predictions)
return unweighted_loss * weights, control_flow_ops.no_op()
+
+
+def per_example_full_exp_loss(labels, weights, predictions, name=None):
+ """Full exponential loss given labels, example weights and predictions.
+
+ Note that this is only for binary classification.
+ The loss returns is exp(-targets*logits), where targets are converted to -1
+ and 1.
+
+ Args:
+ labels: Rank 2 (N, D) tensor of per-example labels.
+ weights: Rank 2 (N, 1) tensor of per-example weights.
+ predictions: Rank 2 (N, D) tensor of per-example predictions.
+ name: A name for the operation (optional).
+
+ Returns:
+ loss: A Rank 2 (N, 1) tensor of per-example exp loss
+ update_op: An update operation to update the loss's internal state.
+ """
+
+ def full_exp_with_logits(name, labels=None, logits=None):
+ """Computes exponential loss given `logits`.
+
+ Args:
+ name: A name for the operation (optional).
+ labels: A `Tensor` of the same type and shape as `logits`.
+ logits: A `Tensor` of type `float32` or `float64`.
+
+ Returns:
+ A `Tensor` of the same shape as `logits` with the componentwise
+ exponential losses.
+
+ Raises:
+ ValueError: If `logits` and `labels` do not have the same shape.
+ """
+ with ops.name_scope(name, "exp_loss", [logits, labels]) as name:
+ logits = ops.convert_to_tensor(logits, name="logits")
+ labels = ops.convert_to_tensor(labels, name="labels")
+ try:
+ labels.get_shape().merge_with(logits.get_shape())
+ except ValueError:
+ raise ValueError("logits and labels must have the same shape (%s vs %s)"
+ % (logits.get_shape(), labels.get_shape()))
+
+ # Default threshold of 0 to switch between classes
+ zeros = array_ops.zeros_like(logits, dtype=logits.dtype)
+ ones = array_ops.ones_like(logits, dtype=logits.dtype)
+ neg_ones = -array_ops.ones_like(logits, dtype=logits.dtype)
+
+ # Convert labels to 1 and -1
+ cond_labels = (labels > zeros)
+ labels_converted = array_ops.where(cond_labels, ones, neg_ones)
+
+ return math_ops.exp(-1.0 * logits * labels_converted)
+
+ labels = math_ops.to_float(labels)
+ unweighted_loss = full_exp_with_logits(
+ name=name, labels=labels, logits=predictions)
+ return unweighted_loss * weights, control_flow_ops.no_op()