aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-14 13:46:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-14 13:50:00 -0700
commitaf827be63a9d8aff06a438ac9769a6ff6870c60d (patch)
tree8fd9715680d9953efe6a5d8177b05c8ba379431d /tensorflow/contrib/boosted_trees
parent0c98648e9a8722344bf8445ae46b1dff507b4859 (diff)
First iteration of oblivious tree split handling for dense features.
PiperOrigin-RevId: 208705535
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc184
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py75
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py127
-rw-r--r--tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc3
-rw-r--r--tensorflow/contrib/boosted_trees/proto/learner.proto8
-rw-r--r--tensorflow/contrib/boosted_trees/proto/split_info.proto7
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py9
7 files changed, 376 insertions, 37 deletions
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 401bec84a2..d9e7a0f466 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -34,7 +34,9 @@
namespace tensorflow {
+using boosted_trees::learner::LearnerConfig;
using boosted_trees::learner::LearnerConfig_MultiClassStrategy;
+using boosted_trees::learner::ObliviousSplitInfo;
using boosted_trees::learner::SplitInfo;
using boosted_trees::learner::stochastic::GradientStats;
using boosted_trees::learner::stochastic::NodeStats;
@@ -158,6 +160,11 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
+ const Tensor* weak_learner_type_t;
+ OP_REQUIRES_OK(context,
+ context->input("weak_learner_type", &weak_learner_type_t));
+ const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
+
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
partition_boundaries.push_back(0);
@@ -188,20 +195,59 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
tensorflow::TTypes<int32>::Vec output_partition_ids =
output_partition_ids_t->vec<int32>();
- Tensor* gains_t = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output("gains", TensorShape({num_elements}),
- &gains_t));
+ // For a normal tree, we output a split per partition. For an oblivious
+ // tree, we output one split for all partitions of the layer
+ int32 size_output = num_elements;
+ if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE &&
+ num_elements > 0) {
+ size_output = 1;
+ }
+ Tensor* gains_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "gains", TensorShape({size_output}), &gains_t));
tensorflow::TTypes<float>::Vec gains = gains_t->vec<float>();
Tensor* output_splits_t = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- "split_infos", TensorShape({num_elements}),
- &output_splits_t));
+ OP_REQUIRES_OK(context, context->allocate_output("split_infos",
+ TensorShape({size_output}),
+ &output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
+
+ if (num_elements == 0) {
+ return;
+ }
SplitBuilderState state(context);
+ switch (weak_learner_type) {
+ case LearnerConfig::NORMAL_DECISION_TREE: {
+ ComputeNormalDecisionTree(
+ &state, normalizer_ratio, num_elements, partition_boundaries,
+ bucket_boundaries, partition_ids, bucket_ids, gradients_t,
+ hessians_t, &output_partition_ids, &gains, &output_splits);
+ break;
+ }
+ case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
+ ComputeObliviousDecisionTree(
+ &state, normalizer_ratio, num_elements, partition_boundaries,
+ bucket_boundaries, partition_ids, bucket_ids, gradients_t,
+ hessians_t, &output_partition_ids, &gains, &output_splits);
+ break;
+ }
+ }
+ }
+
+ private:
+ void ComputeNormalDecisionTree(
+ SplitBuilderState* state, const float normalizer_ratio,
+ const int num_elements, const std::vector<int32>& partition_boundaries,
+ const tensorflow::TTypes<float>::ConstVec& bucket_boundaries,
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
+ const tensorflow::TTypes<int64>::ConstMatrix& bucket_ids,
+ const Tensor* gradients_t, const Tensor* hessians_t,
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
+ tensorflow::TTypes<float>::Vec* gains,
+ tensorflow::TTypes<string>::Vec* output_splits) {
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
float best_gain = std::numeric_limits<float>::lowest();
int start_index = partition_boundaries[root_idx];
@@ -213,7 +259,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
GradientStats(*gradients_t, *hessians_t, bucket_idx);
}
root_gradient_stats *= normalizer_ratio;
- NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
+ NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats);
int32 best_bucket_idx = 0;
NodeStats best_right_node_stats(0);
NodeStats best_left_node_stats(0);
@@ -223,10 +269,10 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
GradientStats g(*gradients_t, *hessians_t, bucket_idx);
g *= normalizer_ratio;
left_gradient_stats += g;
- NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
+ NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats);
GradientStats right_gradient_stats =
root_gradient_stats - left_gradient_stats;
- NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
+ NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats);
if (left_stats.gain + right_stats.gain > best_gain) {
best_gain = left_stats.gain + right_stats.gain;
best_left_node_stats = left_stats;
@@ -237,20 +283,124 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
SplitInfo split_info;
auto* dense_split =
split_info.mutable_split_node()->mutable_dense_float_binary_split();
- dense_split->set_feature_column(state.feature_column_group_id());
+ dense_split->set_feature_column(state->feature_column_group_id());
dense_split->set_threshold(
bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- state.FillLeaf(best_left_node_stats, left_child);
- state.FillLeaf(best_right_node_stats, right_child);
- split_info.SerializeToString(&output_splits(root_idx));
- gains(root_idx) =
- best_gain - root_stats.gain - state.tree_complexity_regularization();
- output_partition_ids(root_idx) = partition_ids(start_index);
+ state->FillLeaf(best_left_node_stats, left_child);
+ state->FillLeaf(best_right_node_stats, right_child);
+ split_info.SerializeToString(&(*output_splits)(root_idx));
+ (*gains)(root_idx) =
+ best_gain - root_stats.gain - state->tree_complexity_regularization();
+ (*output_partition_ids)(root_idx) = partition_ids(start_index);
+ }
+ }
+ void ComputeObliviousDecisionTree(
+ SplitBuilderState* state, const float normalizer_ratio,
+ const int num_elements, const std::vector<int32>& partition_boundaries,
+ const tensorflow::TTypes<float>::ConstVec& bucket_boundaries,
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
+ const tensorflow::TTypes<int64>::ConstMatrix& bucket_ids,
+ const Tensor* gradients_t, const Tensor* hessians_t,
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
+ tensorflow::TTypes<float>::Vec* gains,
+ tensorflow::TTypes<string>::Vec* output_splits) {
+ // Holds the root stats per each node to be split.
+ std::vector<GradientStats> current_layer_stats;
+ current_layer_stats.reserve(num_elements);
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ const int start_index = partition_boundaries[root_idx];
+ const int end_index = partition_boundaries[root_idx + 1];
+ GradientStats root_gradient_stats;
+ for (int64 bucket_idx = start_index; bucket_idx < end_index;
+ ++bucket_idx) {
+ root_gradient_stats +=
+ GradientStats(*gradients_t, *hessians_t, bucket_idx);
+ }
+ root_gradient_stats *= normalizer_ratio;
+ current_layer_stats.push_back(root_gradient_stats);
+ }
+
+ float best_gain = std::numeric_limits<float>::lowest();
+ int64 best_bucket_idx = 0;
+ std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0));
+ int64 current_bucket_id = 0;
+ int64 last_bucket_id = -1;
+ // Indexes offsets for each of the partitions that can be used to access
+ // gradients of a partition for a current bucket we consider.
+ std::vector<int> current_layer_offsets(num_elements, 0);
+ std::vector<GradientStats> left_gradient_stats(num_elements);
+ // The idea is to try every bucket id in increasing order. In each iteration
+ // we calculate the gain of the layer using the current bucket id as split
+ // value, and we also obtain the following bucket id to try.
+ while (current_bucket_id > last_bucket_id) {
+ last_bucket_id = current_bucket_id;
+ int64 next_bucket_id = -1;
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ int idx =
+ current_layer_offsets[root_idx] + partition_boundaries[root_idx];
+ const int end_index = partition_boundaries[root_idx + 1];
+ if (idx < end_index && bucket_ids(idx, 0) == current_bucket_id) {
+ GradientStats g(*gradients_t, *hessians_t, idx);
+ g *= normalizer_ratio;
+ left_gradient_stats[root_idx] += g;
+ current_layer_offsets[root_idx]++;
+ idx++;
+ }
+ if (idx < end_index &&
+ (bucket_ids(idx, 0) < next_bucket_id || next_bucket_id == -1)) {
+ next_bucket_id = bucket_ids(idx, 0);
+ }
+ }
+ float gain_of_split = 0.0;
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ GradientStats right_gradient_stats =
+ current_layer_stats[root_idx] - left_gradient_stats[root_idx];
+ NodeStats left_stat =
+ state->ComputeNodeStats(left_gradient_stats[root_idx]);
+ NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats);
+ gain_of_split += left_stat.gain + right_stat.gain;
+ current_left_node_stats[root_idx] = left_stat;
+ current_right_node_stats[root_idx] = right_stat;
+ }
+ if (gain_of_split > best_gain) {
+ best_gain = gain_of_split;
+ best_left_node_stats = current_left_node_stats;
+ best_right_node_stats = current_right_node_stats;
+ }
+ current_bucket_id = next_bucket_id;
+ }
+
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain;
+ }
+ best_gain -= num_elements * state->tree_complexity_regularization();
+
+ ObliviousSplitInfo oblivious_split_info;
+ auto* oblivious_dense_split = oblivious_split_info.mutable_split_node()
+ ->mutable_dense_float_binary_split();
+ oblivious_dense_split->set_feature_column(state->feature_column_group_id());
+ oblivious_dense_split->set_threshold(
+ bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
+ (*gains)(0) = best_gain;
+
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ auto* left_children = oblivious_split_info.add_children_leaves();
+ auto* right_children = oblivious_split_info.add_children_leaves();
+
+ state->FillLeaf(best_left_node_stats[root_idx], left_children);
+ state->FillLeaf(best_right_node_stats[root_idx], right_children);
+
+ const int start_index = partition_boundaries[root_idx];
+ (*output_partition_ids)(root_idx) = partition_ids(start_index);
}
+ oblivious_split_info.SerializeToString(&(*output_splits)(0));
}
};
REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU),
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
index 2559fe9913..f45010ec26 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
@@ -64,6 +64,7 @@ from __future__ import print_function
import re
from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
+from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops
from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops
from tensorflow.contrib.boosted_trees.python.ops import quantile_ops
@@ -171,6 +172,7 @@ class DenseSplitHandler(InequalitySplitHandler):
multiclass_strategy,
init_stamp_token=0,
loss_uses_sum_reduction=False,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE,
name=None):
"""Initialize the internal state for this split handler.
@@ -192,6 +194,7 @@ class DenseSplitHandler(InequalitySplitHandler):
stamped objects.
loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
SUM or MEAN reduction was used for the loss.
+ weak_learner_type: Specifies the type of weak learner to use.
name: An optional handler name.
"""
super(DenseSplitHandler, self).__init__(
@@ -209,6 +212,7 @@ class DenseSplitHandler(InequalitySplitHandler):
multiclass_strategy=multiclass_strategy,
loss_uses_sum_reduction=loss_uses_sum_reduction)
self._dense_float_column = dense_float_column
+ self._weak_learner_type = weak_learner_type
# Register dense_make_stats_update function as an Op to the graph.
g = ops.get_default_graph()
dense_make_stats_update.add_to_graph(g)
@@ -269,16 +273,17 @@ class DenseSplitHandler(InequalitySplitHandler):
next_stamp_token, self._multiclass_strategy, class_id,
self._feature_column_group_id, self._l1_regularization,
self._l2_regularization, self._tree_complexity_regularization,
- self._min_node_weight, self._loss_uses_sum_reduction))
-
+ self._min_node_weight, self._loss_uses_sum_reduction,
+ self._weak_learner_type))
return are_splits_ready, partition_ids, gains, split_infos
-def _make_dense_split(
- quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
- next_stamp_token, multiclass_strategy, class_id, feature_column_id,
- l1_regularization, l2_regularization, tree_complexity_regularization,
- min_node_weight, is_multi_dimentional, loss_uses_sum_reduction):
+def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle,
+ stamp_token, next_stamp_token, multiclass_strategy,
+ class_id, feature_column_id, l1_regularization,
+ l2_regularization, tree_complexity_regularization,
+ min_node_weight, is_multi_dimentional,
+ loss_uses_sum_reduction, weak_learner_type):
"""Function that builds splits for a dense feature column."""
# Get the bucket boundaries
are_splits_ready, buckets = (
@@ -327,7 +332,8 @@ def _make_dense_split(
l2_regularization=l2_regularization,
tree_complexity_regularization=tree_complexity_regularization,
min_node_weight=min_node_weight,
- multiclass_strategy=multiclass_strategy))
+ multiclass_strategy=multiclass_strategy,
+ weak_learner_type=weak_learner_type))
return are_splits_ready, partition_ids, gains, split_infos
@@ -507,7 +513,40 @@ def _make_sparse_split(
return are_splits_ready, partition_ids, gains, split_infos
-def _specialize_make_split(func, is_multi_dimentional):
+def _specialize_make_split_dense(func, is_multi_dimentional):
+ """Builds a specialized version of the function."""
+
+ @function.Defun(
+ dtypes.resource,
+ dtypes.resource,
+ dtypes.int64,
+ dtypes.int64,
+ dtypes.int32,
+ dtypes.int32,
+ dtypes.int32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.bool,
+ dtypes.int32,
+ noinline=True)
+ def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
+ next_stamp_token, multiclass_strategy, class_id, feature_column_id,
+ l1_regularization, l2_regularization, tree_complexity_regularization,
+ min_node_weight, loss_uses_sum_reduction, weak_learner_type):
+ """Function that builds splits for a sparse feature column."""
+ return func(quantile_accumulator_handle, stats_accumulator_handle,
+ stamp_token, next_stamp_token, multiclass_strategy, class_id,
+ feature_column_id, l1_regularization, l2_regularization,
+ tree_complexity_regularization, min_node_weight,
+ is_multi_dimentional, loss_uses_sum_reduction,
+ weak_learner_type)
+
+ return f
+
+
+def _specialize_make_split_sparse(func, is_multi_dimentional):
"""Builds a specialized version of the function."""
@function.Defun(
@@ -537,15 +576,17 @@ def _specialize_make_split(func, is_multi_dimentional):
return f
-make_dense_split_scalar = _specialize_make_split(_make_dense_split,
- is_multi_dimentional=False)
-make_dense_split_tensor = _specialize_make_split(_make_dense_split,
- is_multi_dimentional=True)
-make_sparse_split_scalar = _specialize_make_split(_make_sparse_split,
- is_multi_dimentional=False)
-make_sparse_split_tensor = _specialize_make_split(_make_sparse_split,
- is_multi_dimentional=True)
+make_dense_split_scalar = _specialize_make_split_dense(
+ _make_dense_split, is_multi_dimentional=False)
+
+make_dense_split_tensor = _specialize_make_split_dense(
+ _make_dense_split, is_multi_dimentional=True)
+
+make_sparse_split_scalar = _specialize_make_split_sparse(
+ _make_sparse_split, is_multi_dimentional=False)
+make_sparse_split_tensor = _specialize_make_split_sparse(
+ _make_sparse_split, is_multi_dimentional=True)
@function.Defun(
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
index 5d82c4cae5..6572f2f414 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
@@ -182,6 +182,133 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
+ def testObliviousFeatureSplitGeneration(self):
+ with self.test_session() as sess:
+ # The data looks like the following:
+ # Example | Gradients | Partition | Dense Quantile |
+ # i0 | (0.2, 0.12) | 0 | 2 |
+ # i1 | (-0.5, 0.07) | 0 | 2 |
+ # i2 | (1.2, 0.2) | 0 | 0 |
+ # i3 | (4.0, 0.13) | 1 | 1 |
+ dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52])
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
+ class_id = -1
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ split_handler = ordinal_split_handler.DenseSplitHandler(
+ l1_regularization=0.1,
+ l2_regularization=1.,
+ tree_complexity_regularization=0.,
+ min_node_weight=0.,
+ epsilon=0.001,
+ num_quantiles=10,
+ feature_column_group_id=0,
+ dense_float_column=dense_column,
+ init_stamp_token=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_1]):
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
+
+ with ops.control_dependencies([are_splits_ready]):
+ update_2 = split_handler.update_stats_sync(
+ 1,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_2]):
+ are_splits_ready2, partitions, gains, splits = (
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
+ are_splits_ready, are_splits_ready2, partitions, gains, splits = (
+ sess.run([
+ are_splits_ready, are_splits_ready2, partitions, gains, splits
+ ]))
+
+ # During the first iteration, inequality split handlers are not going to
+ # have any splits. Make sure that we return not_ready in that case.
+ self.assertFalse(are_splits_ready)
+ self.assertTrue(are_splits_ready2)
+
+ self.assertAllEqual([0, 1], partitions)
+
+ oblivious_split_info = split_info_pb2.ObliviousSplitInfo()
+ oblivious_split_info.ParseFromString(splits[0])
+ split_node = oblivious_split_info.split_node.dense_float_binary_split
+
+ self.assertAllClose(0.3, split_node.threshold, 0.00001)
+ self.assertEqual(0, split_node.feature_column)
+
+ # Check the split on partition 0.
+ # -(1.2 - 0.1) / (0.2 + 1)
+ expected_left_weight_0 = -0.9166666666666666
+
+ # expected_left_weight_0 * -(1.2 - 0.1)
+ expected_left_gain_0 = 1.008333333333333
+
+ # (-0.5 + 0.2 + 0.1) / (0.19 + 1)
+ expected_right_weight_0 = 0.1680672
+
+ # expected_right_weight_0 * -(-0.5 + 0.2 + 0.1))
+ expected_right_gain_0 = 0.033613445378151252
+
+ # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
+ expected_bias_gain_0 = 0.46043165467625896
+
+ left_child = oblivious_split_info.children_leaves[0].vector
+ right_child = oblivious_split_info.children_leaves[1].vector
+
+ self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001)
+
+ self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001)
+
+ # Check the split on partition 1.
+ expected_left_weight_1 = 0
+ expected_left_gain_1 = 0
+ # -(4 - 0.1) / (0.13 + 1)
+ expected_right_weight_1 = -3.4513274336283186
+ # expected_right_weight_1 * -(4 - 0.1)
+ expected_right_gain_1 = 13.460176991150442
+ # (-4 + 0.1) ** 2 / (0.13 + 1)
+ expected_bias_gain_1 = 13.460176991150442
+
+ left_child = oblivious_split_info.children_leaves[2].vector
+ right_child = oblivious_split_info.children_leaves[3].vector
+
+ self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)
+
+ self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001)
+
+ # The layer gain is the sum of the gains of each partition
+ layer_gain = (
+ expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + (
+ expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1)
+ self.assertAllClose(layer_gain, gains[0], 0.00001)
+
def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
with self.test_session() as sess:
# The data looks like the following:
diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
index ca5c7f3d8c..9b68a9de96 100644
--- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
@@ -36,6 +36,7 @@ REGISTER_OP("BuildDenseInequalitySplits")
.Input("tree_complexity_regularization: float")
.Input("min_node_weight: float")
.Input("multiclass_strategy: int32")
+ .Input("weak_learner_type: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
@@ -84,6 +85,8 @@ min_node_weight: A scalar, minimum sum of example hessian needed in a child.
be considered.
multiclass_strategy: A scalar, specifying the multiclass handling strategy.
See LearnerConfig.MultiClassStrategy for valid values.
+weak_learner_type: A scalar, specifying the weak learner type to use.
+ See LearnerConfig.WeakLearnerType for valid values.
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
for.
gains: A rank 1 tensor, for the computed gain for the created splits.
diff --git a/tensorflow/contrib/boosted_trees/proto/learner.proto b/tensorflow/contrib/boosted_trees/proto/learner.proto
index d84ba7438e..c49cb48cde 100644
--- a/tensorflow/contrib/boosted_trees/proto/learner.proto
+++ b/tensorflow/contrib/boosted_trees/proto/learner.proto
@@ -108,6 +108,11 @@ message LearnerConfig {
DIAGONAL_HESSIAN = 3;
}
+ enum WeakLearnerType {
+ NORMAL_DECISION_TREE = 0;
+ OBLIVIOUS_DECISION_TREE = 1;
+ }
+
// Number of classes.
uint32 num_classes = 1;
@@ -141,4 +146,7 @@ message LearnerConfig {
// If you want to average the ensembles (for regularization), provide the
// config below.
AveragingConfig averaging_config = 11;
+
+ // By default we use NORMAL_DECISION_TREE as weak learner.
+ WeakLearnerType weak_learner_type = 12;
}
diff --git a/tensorflow/contrib/boosted_trees/proto/split_info.proto b/tensorflow/contrib/boosted_trees/proto/split_info.proto
index a300c24c8e..850340f5c2 100644
--- a/tensorflow/contrib/boosted_trees/proto/split_info.proto
+++ b/tensorflow/contrib/boosted_trees/proto/split_info.proto
@@ -17,3 +17,10 @@ message SplitInfo {
// Right Leaf node.
tensorflow.boosted_trees.trees.Leaf right_child = 3;
}
+
+message ObliviousSplitInfo {
+ // The split node with the feature_column and threshold defined.
+ tensorflow.boosted_trees.trees.TreeNode split_node = 1;
+ // The new leaves of the tree.
+ repeated tensorflow.boosted_trees.trees.Leaf children_leaves = 2;
+}
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
index 5cd37ec67e..2589504762 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
@@ -59,7 +59,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
- multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -132,7 +133,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
- multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN))
+ multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -171,7 +173,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
- multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
# .assertEmpty doesn't exist on ubuntu-contrib
self.assertEqual(0, len(partitions))