aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-14 19:06:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-14 19:10:15 -0700
commit15c1afadd13ec0eba6bc2b70d073a1769e45b679 (patch)
tree209a6d40f4000f1cebac62a7f6ce53598792b78e /tensorflow/contrib/boosted_trees
parent94ba1c4f0eccd234b4e0e5b504ddf1803067f1bc (diff)
Create a tf.constant for the weak_learner_type that's shared across all the
handlers to avoid duplicate constant creation. PiperOrigin-RevId: 208755852
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index d0d1249bd6..20ff48c360 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -672,6 +672,8 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config.constraints.min_node_weight, dtypes.float32)
loss_uses_sum_reduction = self._loss_reduction == losses.Reduction.SUM
loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction)
+ weak_learner_type = constant_op.constant(
+ self._learner_config.weak_learner_type)
epsilon = 0.01
num_quantiles = 100
strategy_tensor = constant_op.constant(strategy)
@@ -696,6 +698,7 @@ class GradientBoostedDecisionTreeModel(object):
multiclass_strategy=strategy_tensor,
init_stamp_token=init_stamp_token,
loss_uses_sum_reduction=loss_uses_sum_reduction,
+ weak_learner_type=weak_learner_type,
))
fc_name_idx += 1