diff options
Diffstat (limited to 'tensorflow/core/kernels/boosted_trees/stats_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/boosted_trees/stats_ops.cc | 41 |
1 files changed, 2 insertions, 39 deletions
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc index 48afd3fbf3..64ec1caa9c 100644 --- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc @@ -17,13 +17,10 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/boosted_trees/tree_helper.h" namespace tensorflow { -namespace { -const float kEps = 1e-15; -} // namespace - class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel { public: explicit BoostedTreesCalculateBestGainsPerFeatureOp( @@ -139,7 +136,7 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel { total_hess - cum_hess_bucket, l1, l2, &contrib_for_right, &gain_for_right); - if (gain_for_left + gain_for_right > best_gain) { + if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) { best_gain = gain_for_left + gain_for_right; best_bucket = bucket; best_contrib_for_left = contrib_for_left; @@ -200,40 +197,6 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel { } private: - void CalculateWeightsAndGains(const float g, const float h, const float l1, - const float l2, float* weight, float* gain) { - // - // The formula for weight is -(g+l1*sgn(w))/(H+l2), for gain it is - // (g+l1*sgn(w))^2/(h+l2). - // This is because for each leaf we optimize - // 1/2(h+l2)*w^2+g*w+l1*abs(w) - float g_with_l1 = g; - // Apply L1 regularization. - // 1) Assume w>0 => w=-(g+l1)/(h+l2)=> g+l1 < 0 => g < -l1 - // 2) Assume w<0 => w=-(g-l1)/(h+l2)=> g-l1 > 0 => g > l1 - // For g from (-l1, l1), thus there is no solution => set to 0. - if (l1 > 0) { - if (g > l1) { - g_with_l1 -= l1; - } else if (g < -l1) { - g_with_l1 += l1; - } else { - *weight = 0.0; - *gain = 0.0; - return; - } - } - // Apply L2 regularization. - if (h + l2 <= kEps) { - // Avoid division by 0 or infinitesimal. - *weight = 0; - *gain = 0; - } else { - *weight = -g_with_l1 / (h + l2); - *gain = -g_with_l1 * (*weight); - } - } - int max_splits_; int num_features_; }; |