diff options
Diffstat (limited to 'tensorflow/core/kernels/boosted_trees/training_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/boosted_trees/training_ops.cc | 85 |
1 files changed, 82 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/boosted_trees/training_ops.cc b/tensorflow/core/kernels/boosted_trees/training_ops.cc index a14fd4a133..973cdec13a 100644 --- a/tensorflow/core/kernels/boosted_trees/training_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/training_ops.cc @@ -16,11 +16,13 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/boosted_trees/resources.h" +#include "tensorflow/core/kernels/boosted_trees/tree_helper.h" namespace tensorflow { namespace { constexpr float kLayerByLayerTreeWeight = 1.0; +constexpr float kMinDeltaForCenterBias = 0.01; // TODO(nponomareva, youngheek): consider using vector. struct SplitCandidate { @@ -89,7 +91,8 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel { // Find best splits for each active node. std::map<int32, SplitCandidate> best_splits; - FindBestSplitsPerNode(context, node_ids_list, gains_list, &best_splits); + FindBestSplitsPerNode(context, node_ids_list, gains_list, feature_ids, + &best_splits); int32 current_tree = UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource); @@ -193,6 +196,7 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel { void FindBestSplitsPerNode( OpKernelContext* const context, const OpInputList& node_ids_list, const OpInputList& gains_list, + const TTypes<const int32>::Vec& feature_ids, std::map<int32, SplitCandidate>* best_split_per_node) { // Find best split per node going through every feature candidate. for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) { @@ -211,8 +215,18 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel { candidate.candidate_idx = candidate_idx; candidate.gain = gain; - if (best_split_it == best_split_per_node->end() || - gain > best_split_it->second.gain) { + if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() && + GainsAreEqual(gain, best_split_it->second.gain))) { + const auto best_candidate = (*best_split_per_node)[node_id]; + const int32 best_feature_id = feature_ids(best_candidate.feature_idx); + const int32 feature_id = feature_ids(candidate.feature_idx); + VLOG(2) << "Breaking ties on feature ids and buckets"; + // Breaking ties deterministically. + if (feature_id < best_feature_id) { + (*best_split_per_node)[node_id] = candidate; + } + } else if (best_split_it == best_split_per_node->end() || + GainIsLarger(gain, best_split_it->second.gain)) { (*best_split_per_node)[node_id] = candidate; } } @@ -227,4 +241,69 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsemble").Device(DEVICE_CPU), BoostedTreesUpdateEnsembleOp); +class BoostedTreesCenterBiasOp : public OpKernel { + public: + explicit BoostedTreesCenterBiasOp(OpKernelConstruction* const context) + : OpKernel(context) {} + + void Compute(OpKernelContext* const context) override { + // Get decision tree ensemble. + BoostedTreesEnsembleResource* ensemble_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &ensemble_resource)); + core::ScopedUnref unref_me(ensemble_resource); + mutex_lock l(*ensemble_resource->get_mutex()); + // Increase the ensemble stamp. + ensemble_resource->set_stamp(ensemble_resource->stamp() + 1); + + // Read means of hessians and gradients + const Tensor* mean_gradients_t; + OP_REQUIRES_OK(context, + context->input("mean_gradients", &mean_gradients_t)); + + const Tensor* mean_hessians_t; + OP_REQUIRES_OK(context, context->input("mean_hessians", &mean_hessians_t)); + + // Get the regularization options. + const Tensor* l1_t; + OP_REQUIRES_OK(context, context->input("l1", &l1_t)); + const auto l1 = l1_t->scalar<float>()(); + const Tensor* l2_t; + OP_REQUIRES_OK(context, context->input("l2", &l2_t)); + const auto l2 = l2_t->scalar<float>()(); + + // For now, assume 1-dimensional weight on leaves. + float logits; + float unused_gain; + + // TODO(nponomareva): change this when supporting multiclass. + const float gradients_mean = mean_gradients_t->flat<float>()(0); + const float hessians_mean = mean_hessians_t->flat<float>()(0); + CalculateWeightsAndGains(gradients_mean, hessians_mean, l1, l2, &logits, + &unused_gain); + + float current_bias = 0.0; + bool continue_centering = true; + if (ensemble_resource->num_trees() == 0) { + ensemble_resource->AddNewTreeWithLogits(kLayerByLayerTreeWeight, logits); + current_bias = logits; + } else { + current_bias = ensemble_resource->node_value(0, 0); + continue_centering = + std::abs(logits / current_bias) > kMinDeltaForCenterBias; + current_bias += logits; + ensemble_resource->set_node_value(0, 0, current_bias); + } + + Tensor* continue_centering_t = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output("continue_centering", TensorShape({}), + &continue_centering_t)); + // Check if we need to continue centering bias. + continue_centering_t->scalar<bool>()() = continue_centering; + } +}; +REGISTER_KERNEL_BUILDER(Name("BoostedTreesCenterBias").Device(DEVICE_CPU), + BoostedTreesCenterBiasOp); + } // namespace tensorflow |