aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-03 14:09:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-03 14:17:14 -0700
commit2aac0e887ca27d9818607cd52f28044cb7673c70 (patch)
treed93dff30e8e0458894dcf197358ffa200ea77ce7 /tensorflow/core/kernels/boosted_trees
parentbdd84aa59d3bdedc42647711e401229f489c7d25 (diff)
- Adding ability to center bias as a first step of training gbdt
- Fixing non determinism in choosing a split when gains are the same. PiperOrigin-RevId: 203180755
Diffstat (limited to 'tensorflow/core/kernels/boosted_trees')
-rw-r--r--tensorflow/core/kernels/boosted_trees/BUILD7
-rw-r--r--tensorflow/core/kernels/boosted_trees/prediction_ops.cc28
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.cc23
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.h6
-rw-r--r--tensorflow/core/kernels/boosted_trees/stats_ops.cc41
-rw-r--r--tensorflow/core/kernels/boosted_trees/training_ops.cc85
-rw-r--r--tensorflow/core/kernels/boosted_trees/tree_helper.h69
7 files changed, 203 insertions, 56 deletions
diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD
index 0244f3cd8d..4910021c63 100644
--- a/tensorflow/core/kernels/boosted_trees/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/BUILD
@@ -45,6 +45,11 @@ cc_library(
],
)
+cc_library(
+ name = "tree_helper",
+ hdrs = ["tree_helper.h"],
+)
+
tf_kernel_library(
name = "resource_ops",
srcs = ["resource_ops.cc"],
@@ -61,6 +66,7 @@ tf_kernel_library(
name = "stats_ops",
srcs = ["stats_ops.cc"],
deps = [
+ ":tree_helper",
"//tensorflow/core:boosted_trees_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -72,6 +78,7 @@ tf_kernel_library(
srcs = ["training_ops.cc"],
deps = [
":resources",
+ ":tree_helper",
"//tensorflow/core:boosted_trees_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
index 2920132a27..b2efa06941 100644
--- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
@@ -104,8 +104,8 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
const int32 latest_tree = resource->num_trees() - 1;
if (latest_tree < 0) {
- // Ensemble was empty. Nothing changes.
- output_node_ids = cached_node_ids;
+ // Ensemble was empty. Output the very first node.
+ output_node_ids.setZero();
output_tree_ids = cached_tree_ids;
// All the predictions are zeros.
output_partial_logits.setZero();
@@ -120,16 +120,20 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
int32 node_id = cached_node_ids(i);
float partial_tree_logit = 0.0;
- // If the tree was pruned, returns the node id into which the
- // current_node_id was pruned, as well the correction of the cached
- // logit prediction.
- resource->GetPostPruneCorrection(tree_id, node_id, &node_id,
- &partial_tree_logit);
-
- // Logic in the loop adds the cached node value again if it is a leaf.
- // If it is not a leaf anymore we need to subtract the old node's
- // value. The following logic handles both of these cases.
- partial_tree_logit -= resource->node_value(tree_id, node_id);
+ if (node_id >= 0) {
+ // If the tree was pruned, returns the node id into which the
+ // current_node_id was pruned, as well the correction of the cached
+ // logit prediction.
+ resource->GetPostPruneCorrection(tree_id, node_id, &node_id,
+ &partial_tree_logit);
+ // Logic in the loop adds the cached node value again if it is a
+ // leaf. If it is not a leaf anymore we need to subtract the old
+ // node's value. The following logic handles both of these cases.
+ partial_tree_logit -= resource->node_value(tree_id, node_id);
+ } else {
+ // No cache exists, start from the very first node.
+ node_id = 0;
+ }
float partial_all_logit = 0.0;
while (true) {
if (resource->is_leaf(tree_id, node_id)) {
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc
index c410748c27..cc90bb2f45 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.cc
+++ b/tensorflow/core/kernels/boosted_trees/resources.cc
@@ -21,6 +21,10 @@ limitations under the License.
namespace tensorflow {
+namespace {
+constexpr float kLayerByLayerTreeWeight = 1.0;
+} // namespace
+
// Constructor.
BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
: tree_ensemble_(
@@ -78,6 +82,16 @@ float BoostedTreesEnsembleResource::node_value(const int32 tree_id,
}
}
+void BoostedTreesEnsembleResource::set_node_value(const int32 tree_id,
+ const int32 node_id,
+ const float logits) {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
+ auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
+ DCHECK(node->node_case() == boosted_trees::Node::kLeaf);
+ node->mutable_leaf()->set_scalar(logits);
+}
+
int32 BoostedTreesEnsembleResource::GetNumLayersGrown(
const int32 tree_id) const {
DCHECK_LT(tree_id, tree_ensemble_->trees_size());
@@ -204,9 +218,14 @@ void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
// Add a tree to the ensemble and returns a new tree_id.
int32 BoostedTreesEnsembleResource::AddNewTree(const float weight) {
+ return AddNewTreeWithLogits(weight, 0.0);
+}
+
+int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(const float weight,
+ const float logits) {
const int32 new_tree_id = tree_ensemble_->trees_size();
auto* node = tree_ensemble_->add_trees()->add_nodes();
- node->mutable_leaf()->set_scalar(0.0);
+ node->mutable_leaf()->set_scalar(logits);
tree_ensemble_->add_tree_weights(weight);
tree_ensemble_->add_tree_metadata();
@@ -225,7 +244,7 @@ void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
*right_node_id = *left_node_id + 1;
auto* left_node = tree->add_nodes();
auto* right_node = tree->add_nodes();
- if (node_id != 0) {
+ if (node_id != 0 || (node->has_leaf() && node->leaf().scalar() != 0)) {
// Save previous leaf value if it is not the first leaf in the tree.
node->mutable_metadata()->mutable_original_leaf()->Swap(
node->mutable_leaf());
diff --git a/tensorflow/core/kernels/boosted_trees/resources.h b/tensorflow/core/kernels/boosted_trees/resources.h
index df78d3f275..f961ed3814 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.h
+++ b/tensorflow/core/kernels/boosted_trees/resources.h
@@ -70,6 +70,9 @@ class BoostedTreesEnsembleResource : public StampedResource {
float node_value(const int32 tree_id, const int32 node_id) const;
+ void set_node_value(const int32 tree_id, const int32 node_id,
+ const float logits);
+
int32 GetNumLayersGrown(const int32 tree_id) const;
void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const;
@@ -99,6 +102,9 @@ class BoostedTreesEnsembleResource : public StampedResource {
// Add a tree to the ensemble and returns a new tree_id.
int32 AddNewTree(const float weight);
+ // Adds new tree with one node to the ensemble and sets node's value to logits
+ int32 AddNewTreeWithLogits(const float weight, const float logits);
+
// Grows the tree by adding a split and leaves.
void AddBucketizedSplitNode(const int32 tree_id, const int32 node_id,
const int32 feature_id, const int32 threshold,
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
index 48afd3fbf3..64ec1caa9c 100644
--- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
@@ -17,13 +17,10 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
namespace tensorflow {
-namespace {
-const float kEps = 1e-15;
-} // namespace
-
class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
public:
explicit BoostedTreesCalculateBestGainsPerFeatureOp(
@@ -139,7 +136,7 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
total_hess - cum_hess_bucket, l1, l2,
&contrib_for_right, &gain_for_right);
- if (gain_for_left + gain_for_right > best_gain) {
+ if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
best_gain = gain_for_left + gain_for_right;
best_bucket = bucket;
best_contrib_for_left = contrib_for_left;
@@ -200,40 +197,6 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
}
private:
- void CalculateWeightsAndGains(const float g, const float h, const float l1,
- const float l2, float* weight, float* gain) {
- //
- // The formula for weight is -(g+l1*sgn(w))/(H+l2), for gain it is
- // (g+l1*sgn(w))^2/(h+l2).
- // This is because for each leaf we optimize
- // 1/2(h+l2)*w^2+g*w+l1*abs(w)
- float g_with_l1 = g;
- // Apply L1 regularization.
- // 1) Assume w>0 => w=-(g+l1)/(h+l2)=> g+l1 < 0 => g < -l1
- // 2) Assume w<0 => w=-(g-l1)/(h+l2)=> g-l1 > 0 => g > l1
- // For g from (-l1, l1), thus there is no solution => set to 0.
- if (l1 > 0) {
- if (g > l1) {
- g_with_l1 -= l1;
- } else if (g < -l1) {
- g_with_l1 += l1;
- } else {
- *weight = 0.0;
- *gain = 0.0;
- return;
- }
- }
- // Apply L2 regularization.
- if (h + l2 <= kEps) {
- // Avoid division by 0 or infinitesimal.
- *weight = 0;
- *gain = 0;
- } else {
- *weight = -g_with_l1 / (h + l2);
- *gain = -g_with_l1 * (*weight);
- }
- }
-
int max_splits_;
int num_features_;
};
diff --git a/tensorflow/core/kernels/boosted_trees/training_ops.cc b/tensorflow/core/kernels/boosted_trees/training_ops.cc
index a14fd4a133..973cdec13a 100644
--- a/tensorflow/core/kernels/boosted_trees/training_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/training_ops.cc
@@ -16,11 +16,13 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/boosted_trees/resources.h"
+#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
namespace tensorflow {
namespace {
constexpr float kLayerByLayerTreeWeight = 1.0;
+constexpr float kMinDeltaForCenterBias = 0.01;
// TODO(nponomareva, youngheek): consider using vector.
struct SplitCandidate {
@@ -89,7 +91,8 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
// Find best splits for each active node.
std::map<int32, SplitCandidate> best_splits;
- FindBestSplitsPerNode(context, node_ids_list, gains_list, &best_splits);
+ FindBestSplitsPerNode(context, node_ids_list, gains_list, feature_ids,
+ &best_splits);
int32 current_tree =
UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
@@ -193,6 +196,7 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
void FindBestSplitsPerNode(
OpKernelContext* const context, const OpInputList& node_ids_list,
const OpInputList& gains_list,
+ const TTypes<const int32>::Vec& feature_ids,
std::map<int32, SplitCandidate>* best_split_per_node) {
// Find best split per node going through every feature candidate.
for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
@@ -211,8 +215,18 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
candidate.candidate_idx = candidate_idx;
candidate.gain = gain;
- if (best_split_it == best_split_per_node->end() ||
- gain > best_split_it->second.gain) {
+ if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
+ GainsAreEqual(gain, best_split_it->second.gain))) {
+ const auto best_candidate = (*best_split_per_node)[node_id];
+ const int32 best_feature_id = feature_ids(best_candidate.feature_idx);
+ const int32 feature_id = feature_ids(candidate.feature_idx);
+ VLOG(2) << "Breaking ties on feature ids and buckets";
+ // Breaking ties deterministically.
+ if (feature_id < best_feature_id) {
+ (*best_split_per_node)[node_id] = candidate;
+ }
+ } else if (best_split_it == best_split_per_node->end() ||
+ GainIsLarger(gain, best_split_it->second.gain)) {
(*best_split_per_node)[node_id] = candidate;
}
}
@@ -227,4 +241,69 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsemble").Device(DEVICE_CPU),
BoostedTreesUpdateEnsembleOp);
+class BoostedTreesCenterBiasOp : public OpKernel {
+ public:
+ explicit BoostedTreesCenterBiasOp(OpKernelConstruction* const context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* const context) override {
+ // Get decision tree ensemble.
+ BoostedTreesEnsembleResource* ensemble_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &ensemble_resource));
+ core::ScopedUnref unref_me(ensemble_resource);
+ mutex_lock l(*ensemble_resource->get_mutex());
+ // Increase the ensemble stamp.
+ ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
+
+ // Read means of hessians and gradients
+ const Tensor* mean_gradients_t;
+ OP_REQUIRES_OK(context,
+ context->input("mean_gradients", &mean_gradients_t));
+
+ const Tensor* mean_hessians_t;
+ OP_REQUIRES_OK(context, context->input("mean_hessians", &mean_hessians_t));
+
+ // Get the regularization options.
+ const Tensor* l1_t;
+ OP_REQUIRES_OK(context, context->input("l1", &l1_t));
+ const auto l1 = l1_t->scalar<float>()();
+ const Tensor* l2_t;
+ OP_REQUIRES_OK(context, context->input("l2", &l2_t));
+ const auto l2 = l2_t->scalar<float>()();
+
+ // For now, assume 1-dimensional weight on leaves.
+ float logits;
+ float unused_gain;
+
+ // TODO(nponomareva): change this when supporting multiclass.
+ const float gradients_mean = mean_gradients_t->flat<float>()(0);
+ const float hessians_mean = mean_hessians_t->flat<float>()(0);
+ CalculateWeightsAndGains(gradients_mean, hessians_mean, l1, l2, &logits,
+ &unused_gain);
+
+ float current_bias = 0.0;
+ bool continue_centering = true;
+ if (ensemble_resource->num_trees() == 0) {
+ ensemble_resource->AddNewTreeWithLogits(kLayerByLayerTreeWeight, logits);
+ current_bias = logits;
+ } else {
+ current_bias = ensemble_resource->node_value(0, 0);
+ continue_centering =
+ std::abs(logits / current_bias) > kMinDeltaForCenterBias;
+ current_bias += logits;
+ ensemble_resource->set_node_value(0, 0, current_bias);
+ }
+
+ Tensor* continue_centering_t = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output("continue_centering", TensorShape({}),
+ &continue_centering_t));
+ // Check if we need to continue centering bias.
+ continue_centering_t->scalar<bool>()() = continue_centering;
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesCenterBias").Device(DEVICE_CPU),
+ BoostedTreesCenterBiasOp);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/tree_helper.h b/tensorflow/core/kernels/boosted_trees/tree_helper.h
new file mode 100644
index 0000000000..8b18d9e5f8
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/tree_helper.h
@@ -0,0 +1,69 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_TREE_HELPER_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_TREE_HELPER_H_
+#include <cmath>
+
+namespace tensorflow {
+
+static bool GainsAreEqual(const float g1, const float g2) {
+ const float kTolerance = 1e-15;
+ return std::abs(g1 - g2) < kTolerance;
+}
+
+static bool GainIsLarger(const float g1, const float g2) {
+ const float kTolerance = 1e-15;
+ return g1 - g2 >= kTolerance;
+}
+
+static void CalculateWeightsAndGains(const float g, const float h,
+ const float l1, const float l2,
+ float* weight, float* gain) {
+ const float kEps = 1e-15;
+ // The formula for weight is -(g+l1*sgn(w))/(H+l2), for gain it is
+ // (g+l1*sgn(w))^2/(h+l2).
+ // This is because for each leaf we optimize
+ // 1/2(h+l2)*w^2+g*w+l1*abs(w)
+ float g_with_l1 = g;
+ // Apply L1 regularization.
+ // 1) Assume w>0 => w=-(g+l1)/(h+l2)=> g+l1 < 0 => g < -l1
+ // 2) Assume w<0 => w=-(g-l1)/(h+l2)=> g-l1 > 0 => g > l1
+ // For g from (-l1, l1), thus there is no solution => set to 0.
+ if (l1 > 0) {
+ if (g > l1) {
+ g_with_l1 -= l1;
+ } else if (g < -l1) {
+ g_with_l1 += l1;
+ } else {
+ *weight = 0.0;
+ *gain = 0.0;
+ return;
+ }
+ }
+ // Apply L2 regularization.
+ if (h + l2 <= kEps) {
+ // Avoid division by 0 or infinitesimal.
+ *weight = 0;
+ *gain = 0;
+ } else {
+ *weight = -g_with_l1 / (h + l2);
+ *gain = -g_with_l1 * (*weight);
+ }
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_TREE_HELPER_H_