aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc')
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc184
1 files changed, 167 insertions, 17 deletions
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 401bec84a2..d9e7a0f466 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -34,7 +34,9 @@
namespace tensorflow {
+using boosted_trees::learner::LearnerConfig;
using boosted_trees::learner::LearnerConfig_MultiClassStrategy;
+using boosted_trees::learner::ObliviousSplitInfo;
using boosted_trees::learner::SplitInfo;
using boosted_trees::learner::stochastic::GradientStats;
using boosted_trees::learner::stochastic::NodeStats;
@@ -158,6 +160,11 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
+ const Tensor* weak_learner_type_t;
+ OP_REQUIRES_OK(context,
+ context->input("weak_learner_type", &weak_learner_type_t));
+ const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
+
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
partition_boundaries.push_back(0);
@@ -188,20 +195,59 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
tensorflow::TTypes<int32>::Vec output_partition_ids =
output_partition_ids_t->vec<int32>();
- Tensor* gains_t = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output("gains", TensorShape({num_elements}),
- &gains_t));
+ // For a normal tree, we output a split per partition. For an oblivious
+ // tree, we output one split for all partitions of the layer
+ int32 size_output = num_elements;
+ if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE &&
+ num_elements > 0) {
+ size_output = 1;
+ }
+ Tensor* gains_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "gains", TensorShape({size_output}), &gains_t));
tensorflow::TTypes<float>::Vec gains = gains_t->vec<float>();
Tensor* output_splits_t = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- "split_infos", TensorShape({num_elements}),
- &output_splits_t));
+ OP_REQUIRES_OK(context, context->allocate_output("split_infos",
+ TensorShape({size_output}),
+ &output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
+
+ if (num_elements == 0) {
+ return;
+ }
SplitBuilderState state(context);
+ switch (weak_learner_type) {
+ case LearnerConfig::NORMAL_DECISION_TREE: {
+ ComputeNormalDecisionTree(
+ &state, normalizer_ratio, num_elements, partition_boundaries,
+ bucket_boundaries, partition_ids, bucket_ids, gradients_t,
+ hessians_t, &output_partition_ids, &gains, &output_splits);
+ break;
+ }
+ case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
+ ComputeObliviousDecisionTree(
+ &state, normalizer_ratio, num_elements, partition_boundaries,
+ bucket_boundaries, partition_ids, bucket_ids, gradients_t,
+ hessians_t, &output_partition_ids, &gains, &output_splits);
+ break;
+ }
+ }
+ }
+
+ private:
+ void ComputeNormalDecisionTree(
+ SplitBuilderState* state, const float normalizer_ratio,
+ const int num_elements, const std::vector<int32>& partition_boundaries,
+ const tensorflow::TTypes<float>::ConstVec& bucket_boundaries,
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
+ const tensorflow::TTypes<int64>::ConstMatrix& bucket_ids,
+ const Tensor* gradients_t, const Tensor* hessians_t,
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
+ tensorflow::TTypes<float>::Vec* gains,
+ tensorflow::TTypes<string>::Vec* output_splits) {
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
float best_gain = std::numeric_limits<float>::lowest();
int start_index = partition_boundaries[root_idx];
@@ -213,7 +259,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
GradientStats(*gradients_t, *hessians_t, bucket_idx);
}
root_gradient_stats *= normalizer_ratio;
- NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
+ NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats);
int32 best_bucket_idx = 0;
NodeStats best_right_node_stats(0);
NodeStats best_left_node_stats(0);
@@ -223,10 +269,10 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
GradientStats g(*gradients_t, *hessians_t, bucket_idx);
g *= normalizer_ratio;
left_gradient_stats += g;
- NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
+ NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats);
GradientStats right_gradient_stats =
root_gradient_stats - left_gradient_stats;
- NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
+ NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats);
if (left_stats.gain + right_stats.gain > best_gain) {
best_gain = left_stats.gain + right_stats.gain;
best_left_node_stats = left_stats;
@@ -237,20 +283,124 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
SplitInfo split_info;
auto* dense_split =
split_info.mutable_split_node()->mutable_dense_float_binary_split();
- dense_split->set_feature_column(state.feature_column_group_id());
+ dense_split->set_feature_column(state->feature_column_group_id());
dense_split->set_threshold(
bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- state.FillLeaf(best_left_node_stats, left_child);
- state.FillLeaf(best_right_node_stats, right_child);
- split_info.SerializeToString(&output_splits(root_idx));
- gains(root_idx) =
- best_gain - root_stats.gain - state.tree_complexity_regularization();
- output_partition_ids(root_idx) = partition_ids(start_index);
+ state->FillLeaf(best_left_node_stats, left_child);
+ state->FillLeaf(best_right_node_stats, right_child);
+ split_info.SerializeToString(&(*output_splits)(root_idx));
+ (*gains)(root_idx) =
+ best_gain - root_stats.gain - state->tree_complexity_regularization();
+ (*output_partition_ids)(root_idx) = partition_ids(start_index);
+ }
+ }
+ void ComputeObliviousDecisionTree(
+ SplitBuilderState* state, const float normalizer_ratio,
+ const int num_elements, const std::vector<int32>& partition_boundaries,
+ const tensorflow::TTypes<float>::ConstVec& bucket_boundaries,
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
+ const tensorflow::TTypes<int64>::ConstMatrix& bucket_ids,
+ const Tensor* gradients_t, const Tensor* hessians_t,
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
+ tensorflow::TTypes<float>::Vec* gains,
+ tensorflow::TTypes<string>::Vec* output_splits) {
+ // Holds the root stats per each node to be split.
+ std::vector<GradientStats> current_layer_stats;
+ current_layer_stats.reserve(num_elements);
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ const int start_index = partition_boundaries[root_idx];
+ const int end_index = partition_boundaries[root_idx + 1];
+ GradientStats root_gradient_stats;
+ for (int64 bucket_idx = start_index; bucket_idx < end_index;
+ ++bucket_idx) {
+ root_gradient_stats +=
+ GradientStats(*gradients_t, *hessians_t, bucket_idx);
+ }
+ root_gradient_stats *= normalizer_ratio;
+ current_layer_stats.push_back(root_gradient_stats);
+ }
+
+ float best_gain = std::numeric_limits<float>::lowest();
+ int64 best_bucket_idx = 0;
+ std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0));
+ int64 current_bucket_id = 0;
+ int64 last_bucket_id = -1;
+ // Indexes offsets for each of the partitions that can be used to access
+ // gradients of a partition for a current bucket we consider.
+ std::vector<int> current_layer_offsets(num_elements, 0);
+ std::vector<GradientStats> left_gradient_stats(num_elements);
+ // The idea is to try every bucket id in increasing order. In each iteration
+ // we calculate the gain of the layer using the current bucket id as split
+ // value, and we also obtain the following bucket id to try.
+ while (current_bucket_id > last_bucket_id) {
+ last_bucket_id = current_bucket_id;
+ int64 next_bucket_id = -1;
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ int idx =
+ current_layer_offsets[root_idx] + partition_boundaries[root_idx];
+ const int end_index = partition_boundaries[root_idx + 1];
+ if (idx < end_index && bucket_ids(idx, 0) == current_bucket_id) {
+ GradientStats g(*gradients_t, *hessians_t, idx);
+ g *= normalizer_ratio;
+ left_gradient_stats[root_idx] += g;
+ current_layer_offsets[root_idx]++;
+ idx++;
+ }
+ if (idx < end_index &&
+ (bucket_ids(idx, 0) < next_bucket_id || next_bucket_id == -1)) {
+ next_bucket_id = bucket_ids(idx, 0);
+ }
+ }
+ float gain_of_split = 0.0;
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ GradientStats right_gradient_stats =
+ current_layer_stats[root_idx] - left_gradient_stats[root_idx];
+ NodeStats left_stat =
+ state->ComputeNodeStats(left_gradient_stats[root_idx]);
+ NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats);
+ gain_of_split += left_stat.gain + right_stat.gain;
+ current_left_node_stats[root_idx] = left_stat;
+ current_right_node_stats[root_idx] = right_stat;
+ }
+ if (gain_of_split > best_gain) {
+ best_gain = gain_of_split;
+ best_left_node_stats = current_left_node_stats;
+ best_right_node_stats = current_right_node_stats;
+ }
+ current_bucket_id = next_bucket_id;
+ }
+
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain;
+ }
+ best_gain -= num_elements * state->tree_complexity_regularization();
+
+ ObliviousSplitInfo oblivious_split_info;
+ auto* oblivious_dense_split = oblivious_split_info.mutable_split_node()
+ ->mutable_dense_float_binary_split();
+ oblivious_dense_split->set_feature_column(state->feature_column_group_id());
+ oblivious_dense_split->set_threshold(
+ bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
+ (*gains)(0) = best_gain;
+
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ auto* left_children = oblivious_split_info.add_children_leaves();
+ auto* right_children = oblivious_split_info.add_children_leaves();
+
+ state->FillLeaf(best_left_node_stats[root_idx], left_children);
+ state->FillLeaf(best_right_node_stats[root_idx], right_children);
+
+ const int start_index = partition_boundaries[root_idx];
+ (*output_partition_ids)(root_idx) = partition_ids(start_index);
}
+ oblivious_split_info.SerializeToString(&(*output_splits)(0));
}
};
REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU),