aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-06 20:28:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 20:32:41 -0700
commita7e3047fea74a43174c063320fd0cb6bb6dcceb1 (patch)
tree68a144dec2b5bae2f539459324337acb1aea2a47 /tensorflow/contrib/boosted_trees
parentac8cf2ad5d01010b978c5b41c2fac22ee69a90c4 (diff)
Make num_quantiles configurable; update the epsilon value as well since epsilon controls the maximum number of quantiles generated.
PiperOrigin-RevId: 211914388
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py43
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py8
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py9
3 files changed, 43 insertions, 17 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 870ce2442b..4c7a538b38 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -52,7 +52,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
Args:
@@ -94,6 +95,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: If learner_config is not valid.
@@ -134,7 +136,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -159,7 +162,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
Args:
@@ -201,6 +205,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
head = head_lib.regression_head(
label_name=label_name,
@@ -224,7 +229,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -251,7 +257,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
Args:
@@ -289,6 +296,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
super(GradientBoostedDecisionTreeEstimator, self).__init__(
model_fn=model.model_builder,
@@ -303,7 +311,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -329,7 +338,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
center_bias=False,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
@@ -377,6 +387,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
+
Raises:
ValueError: If learner_config is not valid.
"""
@@ -395,7 +407,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -444,7 +457,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
- output_leaf_index=False):
+ output_leaf_index=False,
+ num_quantiles=100):
"""Initializes a core version of GradientBoostedDecisionTreeEstimator.
Args:
@@ -474,6 +488,7 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
for example_prediction_result in result_dict:
# access leaf index list by example_prediction_result["leaf_index"]
# which contains one leaf index per tree
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
def _model_fn(features, labels, mode, config):
@@ -493,7 +508,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
- 'override_global_step_value': None
+ 'override_global_step_value': None,
+ 'num_quantiles': num_quantiles,
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
@@ -517,7 +533,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
label_keys=None,
logits_modifier_function=None,
center_bias=False,
- output_leaf_index=False):
+ output_leaf_index=False,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
@@ -552,6 +569,7 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
for result_dict in result_iter:
# access leaf index list by result_dict["leaf_index"]
# which contains one leaf index per tree
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: If learner_config is not valid.
@@ -576,7 +594,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
- 'override_global_step_value': None
+ 'override_global_step_value': None,
+ 'num_quantiles': num_quantiles,
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 04b46c3483..a6e422847d 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -81,6 +81,7 @@ def model_builder(features,
logits_modifier_function = params["logits_modifier_function"]
output_leaf_index = params["output_leaf_index"]
override_global_step_value = params.get("override_global_step_value", None)
+ num_quantiles = params["num_quantiles"]
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -116,7 +117,8 @@ def model_builder(features,
logits_dimension=head.logits_dimension,
features=training_features,
use_core_columns=use_core_libs,
- output_leaf_index=output_leaf_index)
+ output_leaf_index=output_leaf_index,
+ num_quantiles=num_quantiles)
with ops.name_scope("gbdt", "gbdt_optimizer"):
predictions_dict = gbdt_model.predict(mode)
logits = predictions_dict["predictions"]
@@ -237,6 +239,7 @@ def ranking_model_builder(features,
output_leaf_index = params["output_leaf_index"]
ranking_model_pair_keys = params["ranking_model_pair_keys"]
override_global_step_value = params.get("override_global_step_value", None)
+ num_quantiles = params["num_quantiles"]
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -299,7 +302,8 @@ def ranking_model_builder(features,
logits_dimension=head.logits_dimension,
features=main_features,
use_core_columns=use_core_libs,
- output_leaf_index=output_leaf_index)
+ output_leaf_index=output_leaf_index,
+ num_quantiles=num_quantiles)
with ops.name_scope("gbdt", "gbdt_optimizer"):
# Logits for inference.
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index b008c6e534..c7eb2493a8 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -304,7 +304,8 @@ class GradientBoostedDecisionTreeModel(object):
feature_columns=None,
use_core_columns=False,
output_leaf_index=False,
- output_leaf_index_modes=None):
+ output_leaf_index_modes=None,
+ num_quantiles=100):
"""Construct a new GradientBoostedDecisionTreeModel function.
Args:
@@ -327,6 +328,7 @@ class GradientBoostedDecisionTreeModel(object):
output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which
dictates when leaf indices will be outputted. By default, leaf indices
are only outputted in INFER mode.
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: if inputs are not valid.
@@ -399,6 +401,7 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config = learner_config
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
+ self._num_quantiles = num_quantiles
self._max_tree_depth = variables.Variable(
initial_value=self._learner_config.constraints.max_tree_depth)
self._attempted_trees = variables.Variable(
@@ -689,8 +692,8 @@ class GradientBoostedDecisionTreeModel(object):
loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction)
weak_learner_type = constant_op.constant(
self._learner_config.weak_learner_type)
- epsilon = 0.01
- num_quantiles = 100
+ num_quantiles = self._num_quantiles
+ epsilon = 1.0 / num_quantiles
strategy_tensor = constant_op.constant(strategy)
with ops.device(self._get_replica_device_setter(worker_device)):
# Create handlers for dense float columns