diff options
Diffstat (limited to 'tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py')
-rw-r--r-- | tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index d0d1249bd6..20ff48c360 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -672,6 +672,8 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config.constraints.min_node_weight, dtypes.float32) loss_uses_sum_reduction = self._loss_reduction == losses.Reduction.SUM loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction) + weak_learner_type = constant_op.constant( + self._learner_config.weak_learner_type) epsilon = 0.01 num_quantiles = 100 strategy_tensor = constant_op.constant(strategy) @@ -696,6 +698,7 @@ class GradientBoostedDecisionTreeModel(object): multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token, loss_uses_sum_reduction=loss_uses_sum_reduction, + weak_learner_type=weak_learner_type, )) fc_name_idx += 1 |