diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-14 13:46:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-14 13:50:00 -0700 |
commit | af827be63a9d8aff06a438ac9769a6ff6870c60d (patch) | |
tree | 8fd9715680d9953efe6a5d8177b05c8ba379431d /tensorflow/contrib/boosted_trees | |
parent | 0c98648e9a8722344bf8445ae46b1dff507b4859 (diff) |
First iteration of oblivious tree split handling for dense features.
PiperOrigin-RevId: 208705535
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
7 files changed, 376 insertions, 37 deletions
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 401bec84a2..d9e7a0f466 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -34,7 +34,9 @@ namespace tensorflow { +using boosted_trees::learner::LearnerConfig; using boosted_trees::learner::LearnerConfig_MultiClassStrategy; +using boosted_trees::learner::ObliviousSplitInfo; using boosted_trees::learner::SplitInfo; using boosted_trees::learner::stochastic::GradientStats; using boosted_trees::learner::stochastic::NodeStats; @@ -158,6 +160,11 @@ class BuildDenseInequalitySplitsOp : public OpKernel { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); + const Tensor* weak_learner_type_t; + OP_REQUIRES_OK(context, + context->input("weak_learner_type", &weak_learner_type_t)); + const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()(); + // Find the number of unique partitions before we allocate the output. std::vector<int32> partition_boundaries; partition_boundaries.push_back(0); @@ -188,20 +195,59 @@ class BuildDenseInequalitySplitsOp : public OpKernel { tensorflow::TTypes<int32>::Vec output_partition_ids = output_partition_ids_t->vec<int32>(); - Tensor* gains_t = nullptr; - OP_REQUIRES_OK( - context, context->allocate_output("gains", TensorShape({num_elements}), - &gains_t)); + // For a normal tree, we output a split per partition. For an oblivious + // tree, we output one split for all partitions of the layer + int32 size_output = num_elements; + if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE && + num_elements > 0) { + size_output = 1; + } + Tensor* gains_t = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + "gains", TensorShape({size_output}), &gains_t)); tensorflow::TTypes<float>::Vec gains = gains_t->vec<float>(); Tensor* output_splits_t = nullptr; - OP_REQUIRES_OK(context, context->allocate_output( - "split_infos", TensorShape({num_elements}), - &output_splits_t)); + OP_REQUIRES_OK(context, context->allocate_output("split_infos", + TensorShape({size_output}), + &output_splits_t)); tensorflow::TTypes<string>::Vec output_splits = output_splits_t->vec<string>(); + + if (num_elements == 0) { + return; + } SplitBuilderState state(context); + switch (weak_learner_type) { + case LearnerConfig::NORMAL_DECISION_TREE: { + ComputeNormalDecisionTree( + &state, normalizer_ratio, num_elements, partition_boundaries, + bucket_boundaries, partition_ids, bucket_ids, gradients_t, + hessians_t, &output_partition_ids, &gains, &output_splits); + break; + } + case LearnerConfig::OBLIVIOUS_DECISION_TREE: { + ComputeObliviousDecisionTree( + &state, normalizer_ratio, num_elements, partition_boundaries, + bucket_boundaries, partition_ids, bucket_ids, gradients_t, + hessians_t, &output_partition_ids, &gains, &output_splits); + break; + } + } + } + + private: + void ComputeNormalDecisionTree( + SplitBuilderState* state, const float normalizer_ratio, + const int num_elements, const std::vector<int32>& partition_boundaries, + const tensorflow::TTypes<float>::ConstVec& bucket_boundaries, + const tensorflow::TTypes<int32>::ConstVec& partition_ids, + const tensorflow::TTypes<int64>::ConstMatrix& bucket_ids, + const Tensor* gradients_t, const Tensor* hessians_t, + tensorflow::TTypes<int32>::Vec* output_partition_ids, + tensorflow::TTypes<float>::Vec* gains, + tensorflow::TTypes<string>::Vec* output_splits) { for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits<float>::lowest(); int start_index = partition_boundaries[root_idx]; @@ -213,7 +259,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel { GradientStats(*gradients_t, *hessians_t, bucket_idx); } root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats); int32 best_bucket_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -223,10 +269,10 @@ class BuildDenseInequalitySplitsOp : public OpKernel { GradientStats g(*gradients_t, *hessians_t, bucket_idx); g *= normalizer_ratio; left_gradient_stats += g; - NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); + NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats); GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); + NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -237,20 +283,124 @@ class BuildDenseInequalitySplitsOp : public OpKernel { SplitInfo split_info; auto* dense_split = split_info.mutable_split_node()->mutable_dense_float_binary_split(); - dense_split->set_feature_column(state.feature_column_group_id()); + dense_split->set_feature_column(state->feature_column_group_id()); dense_split->set_threshold( bucket_boundaries(bucket_ids(best_bucket_idx, 0))); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - state.FillLeaf(best_left_node_stats, left_child); - state.FillLeaf(best_right_node_stats, right_child); - split_info.SerializeToString(&output_splits(root_idx)); - gains(root_idx) = - best_gain - root_stats.gain - state.tree_complexity_regularization(); - output_partition_ids(root_idx) = partition_ids(start_index); + state->FillLeaf(best_left_node_stats, left_child); + state->FillLeaf(best_right_node_stats, right_child); + split_info.SerializeToString(&(*output_splits)(root_idx)); + (*gains)(root_idx) = + best_gain - root_stats.gain - state->tree_complexity_regularization(); + (*output_partition_ids)(root_idx) = partition_ids(start_index); + } + } + void ComputeObliviousDecisionTree( + SplitBuilderState* state, const float normalizer_ratio, + const int num_elements, const std::vector<int32>& partition_boundaries, + const tensorflow::TTypes<float>::ConstVec& bucket_boundaries, + const tensorflow::TTypes<int32>::ConstVec& partition_ids, + const tensorflow::TTypes<int64>::ConstMatrix& bucket_ids, + const Tensor* gradients_t, const Tensor* hessians_t, + tensorflow::TTypes<int32>::Vec* output_partition_ids, + tensorflow::TTypes<float>::Vec* gains, + tensorflow::TTypes<string>::Vec* output_splits) { + // Holds the root stats per each node to be split. + std::vector<GradientStats> current_layer_stats; + current_layer_stats.reserve(num_elements); + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + const int start_index = partition_boundaries[root_idx]; + const int end_index = partition_boundaries[root_idx + 1]; + GradientStats root_gradient_stats; + for (int64 bucket_idx = start_index; bucket_idx < end_index; + ++bucket_idx) { + root_gradient_stats += + GradientStats(*gradients_t, *hessians_t, bucket_idx); + } + root_gradient_stats *= normalizer_ratio; + current_layer_stats.push_back(root_gradient_stats); + } + + float best_gain = std::numeric_limits<float>::lowest(); + int64 best_bucket_idx = 0; + std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0)); + std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0)); + std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0)); + std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0)); + int64 current_bucket_id = 0; + int64 last_bucket_id = -1; + // Indexes offsets for each of the partitions that can be used to access + // gradients of a partition for a current bucket we consider. + std::vector<int> current_layer_offsets(num_elements, 0); + std::vector<GradientStats> left_gradient_stats(num_elements); + // The idea is to try every bucket id in increasing order. In each iteration + // we calculate the gain of the layer using the current bucket id as split + // value, and we also obtain the following bucket id to try. + while (current_bucket_id > last_bucket_id) { + last_bucket_id = current_bucket_id; + int64 next_bucket_id = -1; + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + int idx = + current_layer_offsets[root_idx] + partition_boundaries[root_idx]; + const int end_index = partition_boundaries[root_idx + 1]; + if (idx < end_index && bucket_ids(idx, 0) == current_bucket_id) { + GradientStats g(*gradients_t, *hessians_t, idx); + g *= normalizer_ratio; + left_gradient_stats[root_idx] += g; + current_layer_offsets[root_idx]++; + idx++; + } + if (idx < end_index && + (bucket_ids(idx, 0) < next_bucket_id || next_bucket_id == -1)) { + next_bucket_id = bucket_ids(idx, 0); + } + } + float gain_of_split = 0.0; + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + GradientStats right_gradient_stats = + current_layer_stats[root_idx] - left_gradient_stats[root_idx]; + NodeStats left_stat = + state->ComputeNodeStats(left_gradient_stats[root_idx]); + NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats); + gain_of_split += left_stat.gain + right_stat.gain; + current_left_node_stats[root_idx] = left_stat; + current_right_node_stats[root_idx] = right_stat; + } + if (gain_of_split > best_gain) { + best_gain = gain_of_split; + best_left_node_stats = current_left_node_stats; + best_right_node_stats = current_right_node_stats; + } + current_bucket_id = next_bucket_id; + } + + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain; + } + best_gain -= num_elements * state->tree_complexity_regularization(); + + ObliviousSplitInfo oblivious_split_info; + auto* oblivious_dense_split = oblivious_split_info.mutable_split_node() + ->mutable_dense_float_binary_split(); + oblivious_dense_split->set_feature_column(state->feature_column_group_id()); + oblivious_dense_split->set_threshold( + bucket_boundaries(bucket_ids(best_bucket_idx, 0))); + (*gains)(0) = best_gain; + + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + auto* left_children = oblivious_split_info.add_children_leaves(); + auto* right_children = oblivious_split_info.add_children_leaves(); + + state->FillLeaf(best_left_node_stats[root_idx], left_children); + state->FillLeaf(best_right_node_stats[root_idx], right_children); + + const int start_index = partition_boundaries[root_idx]; + (*output_partition_ids)(root_idx) = partition_ids(start_index); } + oblivious_split_info.SerializeToString(&(*output_splits)(0)); } }; REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU), diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 2559fe9913..f45010ec26 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -64,6 +64,7 @@ from __future__ import print_function import re from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler +from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops from tensorflow.contrib.boosted_trees.python.ops import quantile_ops @@ -171,6 +172,7 @@ class DenseSplitHandler(InequalitySplitHandler): multiclass_strategy, init_stamp_token=0, loss_uses_sum_reduction=False, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE, name=None): """Initialize the internal state for this split handler. @@ -192,6 +194,7 @@ class DenseSplitHandler(InequalitySplitHandler): stamped objects. loss_uses_sum_reduction: A scalar boolean tensor that specifies whether SUM or MEAN reduction was used for the loss. + weak_learner_type: Specifies the type of weak learner to use. name: An optional handler name. """ super(DenseSplitHandler, self).__init__( @@ -209,6 +212,7 @@ class DenseSplitHandler(InequalitySplitHandler): multiclass_strategy=multiclass_strategy, loss_uses_sum_reduction=loss_uses_sum_reduction) self._dense_float_column = dense_float_column + self._weak_learner_type = weak_learner_type # Register dense_make_stats_update function as an Op to the graph. g = ops.get_default_graph() dense_make_stats_update.add_to_graph(g) @@ -269,16 +273,17 @@ class DenseSplitHandler(InequalitySplitHandler): next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, - self._min_node_weight, self._loss_uses_sum_reduction)) - + self._min_node_weight, self._loss_uses_sum_reduction, + self._weak_learner_type)) return are_splits_ready, partition_ids, gains, split_infos -def _make_dense_split( - quantile_accumulator_handle, stats_accumulator_handle, stamp_token, - next_stamp_token, multiclass_strategy, class_id, feature_column_id, - l1_regularization, l2_regularization, tree_complexity_regularization, - min_node_weight, is_multi_dimentional, loss_uses_sum_reduction): +def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, + class_id, feature_column_id, l1_regularization, + l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional, + loss_uses_sum_reduction, weak_learner_type): """Function that builds splits for a dense feature column.""" # Get the bucket boundaries are_splits_ready, buckets = ( @@ -327,7 +332,8 @@ def _make_dense_split( l2_regularization=l2_regularization, tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, - multiclass_strategy=multiclass_strategy)) + multiclass_strategy=multiclass_strategy, + weak_learner_type=weak_learner_type)) return are_splits_ready, partition_ids, gains, split_infos @@ -507,7 +513,40 @@ def _make_sparse_split( return are_splits_ready, partition_ids, gains, split_infos -def _specialize_make_split(func, is_multi_dimentional): +def _specialize_make_split_dense(func, is_multi_dimentional): + """Builds a specialized version of the function.""" + + @function.Defun( + dtypes.resource, + dtypes.resource, + dtypes.int64, + dtypes.int64, + dtypes.int32, + dtypes.int32, + dtypes.int32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.bool, + dtypes.int32, + noinline=True) + def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight, loss_uses_sum_reduction, weak_learner_type): + """Function that builds splits for a sparse feature column.""" + return func(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, class_id, + feature_column_id, l1_regularization, l2_regularization, + tree_complexity_regularization, min_node_weight, + is_multi_dimentional, loss_uses_sum_reduction, + weak_learner_type) + + return f + + +def _specialize_make_split_sparse(func, is_multi_dimentional): """Builds a specialized version of the function.""" @function.Defun( @@ -537,15 +576,17 @@ def _specialize_make_split(func, is_multi_dimentional): return f -make_dense_split_scalar = _specialize_make_split(_make_dense_split, - is_multi_dimentional=False) -make_dense_split_tensor = _specialize_make_split(_make_dense_split, - is_multi_dimentional=True) -make_sparse_split_scalar = _specialize_make_split(_make_sparse_split, - is_multi_dimentional=False) -make_sparse_split_tensor = _specialize_make_split(_make_sparse_split, - is_multi_dimentional=True) +make_dense_split_scalar = _specialize_make_split_dense( + _make_dense_split, is_multi_dimentional=False) + +make_dense_split_tensor = _specialize_make_split_dense( + _make_dense_split, is_multi_dimentional=True) + +make_sparse_split_scalar = _specialize_make_split_sparse( + _make_sparse_split, is_multi_dimentional=False) +make_sparse_split_tensor = _specialize_make_split_sparse( + _make_sparse_split, is_multi_dimentional=True) @function.Defun( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 5d82c4cae5..6572f2f414 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -182,6 +182,133 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) + def testObliviousFeatureSplitGeneration(self): + with self.test_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Dense Quantile | + # i0 | (0.2, 0.12) | 0 | 2 | + # i1 | (-0.5, 0.07) | 0 | 2 | + # i2 | (1.2, 0.2) | 0 | 0 | + # i3 | (4.0, 0.13) | 1 | 1 | + dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52]) + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + class_id = -1 + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + split_handler = ordinal_split_handler.DenseSplitHandler( + l1_regularization=0.1, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., + epsilon=0.001, + num_quantiles=10, + feature_column_group_id=0, + dense_float_column=dense_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + + with ops.control_dependencies([are_splits_ready]): + update_2 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) + are_splits_ready, are_splits_ready2, partitions, gains, splits = ( + sess.run([ + are_splits_ready, are_splits_ready2, partitions, gains, splits + ])) + + # During the first iteration, inequality split handlers are not going to + # have any splits. Make sure that we return not_ready in that case. + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + + self.assertAllEqual([0, 1], partitions) + + oblivious_split_info = split_info_pb2.ObliviousSplitInfo() + oblivious_split_info.ParseFromString(splits[0]) + split_node = oblivious_split_info.split_node.dense_float_binary_split + + self.assertAllClose(0.3, split_node.threshold, 0.00001) + self.assertEqual(0, split_node.feature_column) + + # Check the split on partition 0. + # -(1.2 - 0.1) / (0.2 + 1) + expected_left_weight_0 = -0.9166666666666666 + + # expected_left_weight_0 * -(1.2 - 0.1) + expected_left_gain_0 = 1.008333333333333 + + # (-0.5 + 0.2 + 0.1) / (0.19 + 1) + expected_right_weight_0 = 0.1680672 + + # expected_right_weight_0 * -(-0.5 + 0.2 + 0.1)) + expected_right_gain_0 = 0.033613445378151252 + + # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1) + expected_bias_gain_0 = 0.46043165467625896 + + left_child = oblivious_split_info.children_leaves[0].vector + right_child = oblivious_split_info.children_leaves[1].vector + + self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001) + + # Check the split on partition 1. + expected_left_weight_1 = 0 + expected_left_gain_1 = 0 + # -(4 - 0.1) / (0.13 + 1) + expected_right_weight_1 = -3.4513274336283186 + # expected_right_weight_1 * -(4 - 0.1) + expected_right_gain_1 = 13.460176991150442 + # (-4 + 0.1) ** 2 / (0.13 + 1) + expected_bias_gain_1 = 13.460176991150442 + + left_child = oblivious_split_info.children_leaves[2].vector + right_child = oblivious_split_info.children_leaves[3].vector + + self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001) + + # The layer gain is the sum of the gains of each partition + layer_gain = ( + expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + ( + expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1) + self.assertAllClose(layer_gain, gains[0], 0.00001) + def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): with self.test_session() as sess: # The data looks like the following: diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index ca5c7f3d8c..9b68a9de96 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -36,6 +36,7 @@ REGISTER_OP("BuildDenseInequalitySplits") .Input("tree_complexity_regularization: float") .Input("min_node_weight: float") .Input("multiclass_strategy: int32") + .Input("weak_learner_type: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -84,6 +85,8 @@ min_node_weight: A scalar, minimum sum of example hessian needed in a child. be considered. multiclass_strategy: A scalar, specifying the multiclass handling strategy. See LearnerConfig.MultiClassStrategy for valid values. +weak_learner_type: A scalar, specifying the weak learner type to use. + See LearnerConfig.WeakLearnerType for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. diff --git a/tensorflow/contrib/boosted_trees/proto/learner.proto b/tensorflow/contrib/boosted_trees/proto/learner.proto index d84ba7438e..c49cb48cde 100644 --- a/tensorflow/contrib/boosted_trees/proto/learner.proto +++ b/tensorflow/contrib/boosted_trees/proto/learner.proto @@ -108,6 +108,11 @@ message LearnerConfig { DIAGONAL_HESSIAN = 3; } + enum WeakLearnerType { + NORMAL_DECISION_TREE = 0; + OBLIVIOUS_DECISION_TREE = 1; + } + // Number of classes. uint32 num_classes = 1; @@ -141,4 +146,7 @@ message LearnerConfig { // If you want to average the ensembles (for regularization), provide the // config below. AveragingConfig averaging_config = 11; + + // By default we use NORMAL_DECISION_TREE as weak learner. + WeakLearnerType weak_learner_type = 12; } diff --git a/tensorflow/contrib/boosted_trees/proto/split_info.proto b/tensorflow/contrib/boosted_trees/proto/split_info.proto index a300c24c8e..850340f5c2 100644 --- a/tensorflow/contrib/boosted_trees/proto/split_info.proto +++ b/tensorflow/contrib/boosted_trees/proto/split_info.proto @@ -17,3 +17,10 @@ message SplitInfo { // Right Leaf node. tensorflow.boosted_trees.trees.Leaf right_child = 3; } + +message ObliviousSplitInfo { + // The split node with the feature_column and threshold defined. + tensorflow.boosted_trees.trees.TreeNode split_node = 1; + // The new leaves of the tree. + repeated tensorflow.boosted_trees.trees.Leaf children_leaves = 2; +} diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index 5cd37ec67e..2589504762 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -59,7 +59,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): min_node_weight=0, class_id=-1, feature_column_group_id=0, - multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) self.assertAllEqual([0, 1], partitions) @@ -132,7 +133,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): min_node_weight=0, class_id=-1, feature_column_group_id=0, - multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN)) + multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) self.assertAllEqual([0, 1], partitions) @@ -171,7 +173,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): min_node_weight=0, class_id=-1, feature_column_group_id=0, - multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) # .assertEmpty doesn't exist on ubuntu-contrib self.assertEqual(0, len(partitions)) |