aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/boosted_trees/training_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/boosted_trees/training_ops.cc')
-rw-r--r--tensorflow/core/kernels/boosted_trees/training_ops.cc85
1 files changed, 82 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/boosted_trees/training_ops.cc b/tensorflow/core/kernels/boosted_trees/training_ops.cc
index a14fd4a133..973cdec13a 100644
--- a/tensorflow/core/kernels/boosted_trees/training_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/training_ops.cc
@@ -16,11 +16,13 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/boosted_trees/resources.h"
+#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
namespace tensorflow {
namespace {
constexpr float kLayerByLayerTreeWeight = 1.0;
+constexpr float kMinDeltaForCenterBias = 0.01;
// TODO(nponomareva, youngheek): consider using vector.
struct SplitCandidate {
@@ -89,7 +91,8 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
// Find best splits for each active node.
std::map<int32, SplitCandidate> best_splits;
- FindBestSplitsPerNode(context, node_ids_list, gains_list, &best_splits);
+ FindBestSplitsPerNode(context, node_ids_list, gains_list, feature_ids,
+ &best_splits);
int32 current_tree =
UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
@@ -193,6 +196,7 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
void FindBestSplitsPerNode(
OpKernelContext* const context, const OpInputList& node_ids_list,
const OpInputList& gains_list,
+ const TTypes<const int32>::Vec& feature_ids,
std::map<int32, SplitCandidate>* best_split_per_node) {
// Find best split per node going through every feature candidate.
for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
@@ -211,8 +215,18 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
candidate.candidate_idx = candidate_idx;
candidate.gain = gain;
- if (best_split_it == best_split_per_node->end() ||
- gain > best_split_it->second.gain) {
+ if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
+ GainsAreEqual(gain, best_split_it->second.gain))) {
+ const auto best_candidate = (*best_split_per_node)[node_id];
+ const int32 best_feature_id = feature_ids(best_candidate.feature_idx);
+ const int32 feature_id = feature_ids(candidate.feature_idx);
+ VLOG(2) << "Breaking ties on feature ids and buckets";
+ // Breaking ties deterministically.
+ if (feature_id < best_feature_id) {
+ (*best_split_per_node)[node_id] = candidate;
+ }
+ } else if (best_split_it == best_split_per_node->end() ||
+ GainIsLarger(gain, best_split_it->second.gain)) {
(*best_split_per_node)[node_id] = candidate;
}
}
@@ -227,4 +241,69 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsemble").Device(DEVICE_CPU),
BoostedTreesUpdateEnsembleOp);
+class BoostedTreesCenterBiasOp : public OpKernel {
+ public:
+ explicit BoostedTreesCenterBiasOp(OpKernelConstruction* const context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* const context) override {
+ // Get decision tree ensemble.
+ BoostedTreesEnsembleResource* ensemble_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &ensemble_resource));
+ core::ScopedUnref unref_me(ensemble_resource);
+ mutex_lock l(*ensemble_resource->get_mutex());
+ // Increase the ensemble stamp.
+ ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
+
+ // Read means of hessians and gradients
+ const Tensor* mean_gradients_t;
+ OP_REQUIRES_OK(context,
+ context->input("mean_gradients", &mean_gradients_t));
+
+ const Tensor* mean_hessians_t;
+ OP_REQUIRES_OK(context, context->input("mean_hessians", &mean_hessians_t));
+
+ // Get the regularization options.
+ const Tensor* l1_t;
+ OP_REQUIRES_OK(context, context->input("l1", &l1_t));
+ const auto l1 = l1_t->scalar<float>()();
+ const Tensor* l2_t;
+ OP_REQUIRES_OK(context, context->input("l2", &l2_t));
+ const auto l2 = l2_t->scalar<float>()();
+
+ // For now, assume 1-dimensional weight on leaves.
+ float logits;
+ float unused_gain;
+
+ // TODO(nponomareva): change this when supporting multiclass.
+ const float gradients_mean = mean_gradients_t->flat<float>()(0);
+ const float hessians_mean = mean_hessians_t->flat<float>()(0);
+ CalculateWeightsAndGains(gradients_mean, hessians_mean, l1, l2, &logits,
+ &unused_gain);
+
+ float current_bias = 0.0;
+ bool continue_centering = true;
+ if (ensemble_resource->num_trees() == 0) {
+ ensemble_resource->AddNewTreeWithLogits(kLayerByLayerTreeWeight, logits);
+ current_bias = logits;
+ } else {
+ current_bias = ensemble_resource->node_value(0, 0);
+ continue_centering =
+ std::abs(logits / current_bias) > kMinDeltaForCenterBias;
+ current_bias += logits;
+ ensemble_resource->set_node_value(0, 0, current_bias);
+ }
+
+ Tensor* continue_centering_t = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output("continue_centering", TensorShape({}),
+ &continue_centering_t));
+ // Check if we need to continue centering bias.
+ continue_centering_t->scalar<bool>()() = continue_centering;
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesCenterBias").Device(DEVICE_CPU),
+ BoostedTreesCenterBiasOp);
+
} // namespace tensorflow