aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc')
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc9
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 44a8ffaf4b..04e32267cc 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -422,6 +422,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
GradientStats(*gradients_t, *hessians_t, bucket_idx);
}
present_gradient_stats *= normalizer_ratio;
+ GradientStats not_present =
+ root_gradient_stats - present_gradient_stats;
+ // If there was (almost) no sparsity, fix the default direction to LEFT.
+ bool fixed_default_direction = not_present.IsAlmostZero();
GradientStats left_gradient_stats;
for (int64 element_idx = start_index; element_idx < end_index;
@@ -441,6 +445,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
// backward pass gradients.
GradientStats right_gradient_stats =
present_gradient_stats - left_gradient_stats;
+
{
NodeStats left_stats_default_left =
ComputeNodeStats(root_gradient_stats - right_gradient_stats);
@@ -457,7 +462,9 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
best_dimension_idx = dimension_id;
}
}
- {
+ // Consider calculating the default direction only when there were
+ // enough missing examples.
+ if (!fixed_default_direction) {
NodeStats left_stats_default_right =
ComputeNodeStats(left_gradient_stats);
NodeStats right_stats_default_right =