aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py')
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index b008c6e534..c7eb2493a8 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -304,7 +304,8 @@ class GradientBoostedDecisionTreeModel(object):
feature_columns=None,
use_core_columns=False,
output_leaf_index=False,
- output_leaf_index_modes=None):
+ output_leaf_index_modes=None,
+ num_quantiles=100):
"""Construct a new GradientBoostedDecisionTreeModel function.
Args:
@@ -327,6 +328,7 @@ class GradientBoostedDecisionTreeModel(object):
output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which
dictates when leaf indices will be outputted. By default, leaf indices
are only outputted in INFER mode.
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: if inputs are not valid.
@@ -399,6 +401,7 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config = learner_config
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
+ self._num_quantiles = num_quantiles
self._max_tree_depth = variables.Variable(
initial_value=self._learner_config.constraints.max_tree_depth)
self._attempted_trees = variables.Variable(
@@ -689,8 +692,8 @@ class GradientBoostedDecisionTreeModel(object):
loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction)
weak_learner_type = constant_op.constant(
self._learner_config.weak_learner_type)
- epsilon = 0.01
- num_quantiles = 100
+ num_quantiles = self._num_quantiles
+ epsilon = 1.0 / num_quantiles
strategy_tensor = constant_op.constant(strategy)
with ops.device(self._get_replica_device_setter(worker_device)):
# Create handlers for dense float columns