diff options
Diffstat (limited to 'tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py')
-rw-r--r-- | tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index b008c6e534..c7eb2493a8 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -304,7 +304,8 @@ class GradientBoostedDecisionTreeModel(object): feature_columns=None, use_core_columns=False, output_leaf_index=False, - output_leaf_index_modes=None): + output_leaf_index_modes=None, + num_quantiles=100): """Construct a new GradientBoostedDecisionTreeModel function. Args: @@ -327,6 +328,7 @@ class GradientBoostedDecisionTreeModel(object): output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which dictates when leaf indices will be outputted. By default, leaf indices are only outputted in INFER mode. + num_quantiles: Number of quantiles to build for numeric feature values. Raises: ValueError: if inputs are not valid. @@ -399,6 +401,7 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config = learner_config self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() + self._num_quantiles = num_quantiles self._max_tree_depth = variables.Variable( initial_value=self._learner_config.constraints.max_tree_depth) self._attempted_trees = variables.Variable( @@ -689,8 +692,8 @@ class GradientBoostedDecisionTreeModel(object): loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction) weak_learner_type = constant_op.constant( self._learner_config.weak_learner_type) - epsilon = 0.01 - num_quantiles = 100 + num_quantiles = self._num_quantiles + epsilon = 1.0 / num_quantiles strategy_tensor = constant_op.constant(strategy) with ops.device(self._get_replica_device_setter(worker_device)): # Create handlers for dense float columns |