aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/estimator_batch/estimator.py')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 59a78515c6..38fa8c3834 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -22,6 +22,7 @@ from tensorflow.contrib.boosted_trees.estimator_batch import model
from tensorflow.contrib.boosted_trees.python.utils import losses
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
+from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.ops import math_ops
@@ -354,3 +355,45 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
+
+
+class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
+ """An estimator using gradient boosted decision trees."""
+
+ def __init__(self,
+ learner_config,
+ examples_per_layer,
+ head,
+ num_trees=None,
+ feature_columns=None,
+ weight_column_name=None,
+ model_dir=None,
+ config=None,
+ label_keys=None,
+ feature_engineering_fn=None,
+ logits_modifier_function=None,
+ center_bias=True,
+ output_leaf_index=False):
+
+ def _model_fn(features, labels, mode, config):
+ return model.model_builder(
+ features=features,
+ labels=labels,
+ mode=mode,
+ config=config,
+ params={
+ 'head': head,
+ 'feature_columns': feature_columns,
+ 'learner_config': learner_config,
+ 'num_trees': num_trees,
+ 'weight_column_name': weight_column_name,
+ 'examples_per_layer': examples_per_layer,
+ 'center_bias': center_bias,
+ 'logits_modifier_function': logits_modifier_function,
+ 'use_core_libs': True,
+ 'output_leaf_index': output_leaf_index,
+ },
+ output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
+
+ super(CoreGradientBoostedDecisionTreeEstimator, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)