aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-23 14:51:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-23 15:10:56 -0700
commit32d121be1b105ef44fd5d4b421b78eb74dc94870 (patch)
treefe08311d30eb6a4229e1eeaa5dff2a30f1e3c47f /tensorflow/contrib/boosted_trees
parent3baa7b63edf7890b5489cf2085a79598f13af2c6 (diff)
Adding core interface to a contrib version
PiperOrigin-RevId: 205728990
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py43
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py29
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py74
3 files changed, 124 insertions, 22 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 59a78515c6..38fa8c3834 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -22,6 +22,7 @@ from tensorflow.contrib.boosted_trees.estimator_batch import model
from tensorflow.contrib.boosted_trees.python.utils import losses
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
+from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.ops import math_ops
@@ -354,3 +355,45 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
+
+
+class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
+ """An estimator using gradient boosted decision trees."""
+
+ def __init__(self,
+ learner_config,
+ examples_per_layer,
+ head,
+ num_trees=None,
+ feature_columns=None,
+ weight_column_name=None,
+ model_dir=None,
+ config=None,
+ label_keys=None,
+ feature_engineering_fn=None,
+ logits_modifier_function=None,
+ center_bias=True,
+ output_leaf_index=False):
+
+ def _model_fn(features, labels, mode, config):
+ return model.model_builder(
+ features=features,
+ labels=labels,
+ mode=mode,
+ config=config,
+ params={
+ 'head': head,
+ 'feature_columns': feature_columns,
+ 'learner_config': learner_config,
+ 'num_trees': num_trees,
+ 'weight_column_name': weight_column_name,
+ 'examples_per_layer': examples_per_layer,
+ 'center_bias': center_bias,
+ 'logits_modifier_function': logits_modifier_function,
+ 'use_core_libs': True,
+ 'output_leaf_index': output_leaf_index,
+ },
+ output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
+
+ super(CoreGradientBoostedDecisionTreeEstimator, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index 2c2dcb039d..f787d3cdb8 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -182,7 +182,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
config = run_config.RunConfig()
head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
- loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
model = estimator.GradientBoostedDecisionTreeRanker(
head=head_fn,
@@ -203,5 +203,32 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
model.predict(input_fn=_infer_ranking_train_input_fn)
+class CoreGradientBoostedDecisionTreeEstimator(test_util.TensorFlowTestCase):
+
+ def testTrainEvaluateInferDoesNotThrowError(self):
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ est = estimator.CoreGradientBoostedDecisionTreeEstimator(
+ head=head_fn,
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[core_feature_column.numeric_column("x")])
+
+ # Train for a few steps.
+ est.train(input_fn=_train_input_fn, steps=1000)
+ est.evaluate(input_fn=_eval_input_fn, steps=1)
+ est.predict(input_fn=_eval_input_fn)
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 0e8a56e6e9..2fbe72951a 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -29,7 +29,17 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import training_util
-def model_builder(features, labels, mode, params, config):
+class ModelBuilderOutputType(object):
+ MODEL_FN_OPS = 0
+ ESTIMATOR_SPEC = 1
+
+
+def model_builder(features,
+ labels,
+ mode,
+ params,
+ config,
+ output_type=ModelBuilderOutputType.MODEL_FN_OPS):
"""Multi-machine batch gradient descent tree model.
Args:
@@ -115,31 +125,53 @@ def model_builder(features, labels, mode, params, config):
return update_op
create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
- if use_core_libs and callable(create_estimator_spec_op):
- model_fn_ops = head.create_estimator_spec(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_train_op_fn,
- logits=logits)
- model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops)
- else:
- model_fn_ops = head.create_model_fn_ops(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_train_op_fn,
- logits=logits)
- if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
- model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
- gbdt_batch.LEAF_INDEX]
+
if num_trees:
if center_bias:
num_trees += 1
finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
- model_fn_ops.training_hooks.append(
+ training_hooks = [
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees))
+ finalized_trees)
+ ]
+
+ if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
+ if use_core_libs and callable(create_estimator_spec_op):
+ model_fn_ops = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+ model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(
+ model_fn_ops)
+ else:
+ model_fn_ops = head.create_model_fn_ops(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+
+ if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
+ model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
+ gbdt_batch.LEAF_INDEX]
+
+ model_fn_ops.training_hooks.extend(training_hooks)
+ return model_fn_ops
+ elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC:
+ assert callable(create_estimator_spec_op)
+ estimator_spec = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+
+ estimator_spec = estimator_spec._replace(
+ training_hooks=training_hooks + list(estimator_spec.training_hooks))
+ return estimator_spec
+
return model_fn_ops