diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-14 19:06:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-14 19:10:15 -0700 |
commit | 15c1afadd13ec0eba6bc2b70d073a1769e45b679 (patch) | |
tree | 209a6d40f4000f1cebac62a7f6ce53598792b78e /tensorflow/contrib/boosted_trees | |
parent | 94ba1c4f0eccd234b4e0e5b504ddf1803067f1bc (diff) |
Create a tf.constant for the weak_learner_type that's shared across all the
handlers to avoid duplicate constant creation.
PiperOrigin-RevId: 208755852
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r-- | tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index d0d1249bd6..20ff48c360 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -672,6 +672,8 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config.constraints.min_node_weight, dtypes.float32) loss_uses_sum_reduction = self._loss_reduction == losses.Reduction.SUM loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction) + weak_learner_type = constant_op.constant( + self._learner_config.weak_learner_type) epsilon = 0.01 num_quantiles = 100 strategy_tensor = constant_op.constant(strategy) @@ -696,6 +698,7 @@ class GradientBoostedDecisionTreeModel(object): multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token, loss_uses_sum_reduction=loss_uses_sum_reduction, + weak_learner_type=weak_learner_type, )) fc_name_idx += 1 |