aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py')
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index d0d1249bd6..20ff48c360 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -672,6 +672,8 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config.constraints.min_node_weight, dtypes.float32)
loss_uses_sum_reduction = self._loss_reduction == losses.Reduction.SUM
loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction)
+ weak_learner_type = constant_op.constant(
+ self._learner_config.weak_learner_type)
epsilon = 0.01
num_quantiles = 100
strategy_tensor = constant_op.constant(strategy)
@@ -696,6 +698,7 @@ class GradientBoostedDecisionTreeModel(object):
multiclass_strategy=strategy_tensor,
init_stamp_token=init_stamp_token,
loss_uses_sum_reduction=loss_uses_sum_reduction,
+ weak_learner_type=weak_learner_type,
))
fc_name_idx += 1