aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-28 18:05:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 18:13:42 -0700
commitc4099e6ee8ba3846f2b7e70445806bc3055c5624 (patch)
tree930b9c6c49304383cc1c528899140500be750bb0 /tensorflow/contrib/boosted_trees
parent6eabd59b16c8eb873d7dc5bb8c5fe55677290844 (diff)
Added support for categorical features.
Ops are now interconnected to support oblivious decision trees. PiperOrigin-RevId: 210642692
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc195
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py7
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py111
-rw-r--r--tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc22
-rw-r--r--tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc3
-rw-r--r--tensorflow/contrib/boosted_trees/proto/tree_config.proto12
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py9
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py9
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py157
9 files changed, 504 insertions, 21 deletions
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index d0fd39fa30..3b28ed77f3 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -739,6 +739,11 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
context->input("bias_feature_id", &bias_feature_id_t));
int64 bias_feature_id = bias_feature_id_t->scalar<int64>()();
+ const Tensor* weak_learner_type_t;
+ OP_REQUIRES_OK(context,
+ context->input("weak_learner_type", &weak_learner_type_t));
+ const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
+
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
std::vector<int32> non_empty_partitions;
@@ -767,20 +772,63 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
tensorflow::TTypes<int32>::Vec output_partition_ids =
output_partition_ids_t->vec<int32>();
+ // For a normal tree, we output a split per partition. For an oblivious
+ // tree, we output one split for all partitions of the layer.
+ int size_output = num_elements;
+ if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE &&
+ num_elements > 0) {
+ size_output = 1;
+ }
+
Tensor* gains_t = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output("gains", TensorShape({num_elements}),
- &gains_t));
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "gains", TensorShape({size_output}), &gains_t));
tensorflow::TTypes<float>::Vec gains = gains_t->vec<float>();
Tensor* output_splits_t = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- "split_infos", TensorShape({num_elements}),
- &output_splits_t));
+ OP_REQUIRES_OK(context, context->allocate_output("split_infos",
+ TensorShape({size_output}),
+ &output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
+ if (num_elements == 0) {
+ return;
+ }
SplitBuilderState state(context);
+ switch (weak_learner_type) {
+ case LearnerConfig::NORMAL_DECISION_TREE: {
+ ComputeNormalDecisionTree(
+ context, &state, normalizer_ratio, num_elements,
+ partition_boundaries, non_empty_partitions, bias_feature_id,
+ partition_ids, feature_ids, gradients_t, hessians_t,
+ &output_partition_ids, &gains, &output_splits);
+ break;
+ }
+ case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
+ ComputeObliviousDecisionTree(
+ context, &state, normalizer_ratio, num_elements,
+ partition_boundaries, non_empty_partitions, bias_feature_id,
+ partition_ids, feature_ids, gradients_t, hessians_t,
+ &output_partition_ids, &gains, &output_splits);
+ break;
+ }
+ }
+ }
+
+ private:
+ void ComputeNormalDecisionTree(
+ OpKernelContext* const context, SplitBuilderState* state,
+ const float normalizer_ratio, const int num_elements,
+ const std::vector<int32>& partition_boundaries,
+ const std::vector<int32>& non_empty_partitions,
+ const int64 bias_feature_id,
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
+ const tensorflow::TTypes<int64>::ConstMatrix& feature_ids,
+ const Tensor* gradients_t, const Tensor* hessians_t,
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
+ tensorflow::TTypes<float>::Vec* gains,
+ tensorflow::TTypes<string>::Vec* output_splits) {
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
float best_gain = std::numeric_limits<float>::lowest();
int start_index = partition_boundaries[non_empty_partitions[root_idx]];
@@ -790,7 +838,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
errors::InvalidArgument("Bias feature ID missing."));
GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index);
root_gradient_stats *= normalizer_ratio;
- NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
+ NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats);
int32 best_feature_idx = 0;
NodeStats best_right_node_stats(0);
NodeStats best_left_node_stats(0);
@@ -801,8 +849,8 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
left_gradient_stats *= normalizer_ratio;
GradientStats right_gradient_stats =
root_gradient_stats - left_gradient_stats;
- NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
- NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
+ NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats);
+ NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats);
if (left_stats.gain + right_stats.gain > best_gain) {
best_gain = left_stats.gain + right_stats.gain;
best_left_node_stats = left_stats;
@@ -813,18 +861,133 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
SplitInfo split_info;
auto* equality_split = split_info.mutable_split_node()
->mutable_categorical_id_binary_split();
- equality_split->set_feature_column(state.feature_column_group_id());
+ equality_split->set_feature_column(state->feature_column_group_id());
equality_split->set_feature_id(feature_ids(best_feature_idx, 0));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- state.FillLeaf(best_left_node_stats, left_child);
- state.FillLeaf(best_right_node_stats, right_child);
- split_info.SerializeToString(&output_splits(root_idx));
- gains(root_idx) =
- best_gain - root_stats.gain - state.tree_complexity_regularization();
- output_partition_ids(root_idx) = partition_ids(start_index);
+ state->FillLeaf(best_left_node_stats, left_child);
+ state->FillLeaf(best_right_node_stats, right_child);
+ split_info.SerializeToString(&(*output_splits)(root_idx));
+ (*gains)(root_idx) =
+ best_gain - root_stats.gain - state->tree_complexity_regularization();
+ (*output_partition_ids)(root_idx) = partition_ids(start_index);
}
}
+
+ void ComputeObliviousDecisionTree(
+ OpKernelContext* const context, SplitBuilderState* state,
+ const float normalizer_ratio, const int num_elements,
+ const std::vector<int32>& partition_boundaries,
+ const std::vector<int32>& non_empty_partitions,
+ const int64 bias_feature_id,
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
+ const tensorflow::TTypes<int64>::ConstMatrix& feature_ids,
+ const Tensor* gradients_t, const Tensor* hessians_t,
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
+ tensorflow::TTypes<float>::Vec* gains,
+ tensorflow::TTypes<string>::Vec* output_splits) {
+ // Holds the root stats per each node to be split.
+ std::vector<GradientStats> current_layer_stats;
+ current_layer_stats.reserve(num_elements);
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ const int start_index = partition_boundaries[root_idx];
+ // First feature ID in each partition should be the bias feature.
+ OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id,
+ errors::InvalidArgument("Bias feature ID missing."));
+ GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index);
+ root_gradient_stats *= normalizer_ratio;
+ current_layer_stats.push_back(root_gradient_stats);
+ }
+ float best_gain = std::numeric_limits<float>::lowest();
+ int64 best_feature_id = 0;
+ std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0));
+ int64 current_feature_id = std::numeric_limits<int64>::max();
+ int64 last_feature_id = -1;
+ // Find the lowest feature id, this is going to be the first feature id to
+ // try.
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ const int start_index = partition_boundaries[root_idx];
+ if (feature_ids(start_index + 1, 0) < current_feature_id) {
+ current_feature_id = feature_ids(start_index + 1, 0);
+ }
+ }
+ // Indexes offsets for each of the partitions that can be used to access
+ // gradients of a partition for a current feature we consider. Start at one
+ // beacuse the zero index is for the bias.
+ std::vector<int> current_layer_offsets(num_elements, 1);
+ // The idea is to try every feature id in increasing order. In each
+ // iteration we calculate the gain of the layer using the current feature id
+ // as split value, and we also obtain the following feature id to try.
+ while (current_feature_id > last_feature_id) {
+ last_feature_id = current_feature_id;
+ int64 next_feature_id = -1;
+ // Left gradient stats per node.
+ std::vector<GradientStats> left_gradient_stats(num_elements);
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ int idx =
+ current_layer_offsets[root_idx] + partition_boundaries[root_idx];
+ const int end_index = partition_boundaries[root_idx + 1];
+ if (idx < end_index && feature_ids(idx, 0) == current_feature_id) {
+ GradientStats g(*gradients_t, *hessians_t, idx);
+ g *= normalizer_ratio;
+ left_gradient_stats[root_idx] = g;
+ current_layer_offsets[root_idx]++;
+ idx++;
+ }
+ if (idx < end_index &&
+ (feature_ids(idx, 0) < next_feature_id || next_feature_id == -1)) {
+ next_feature_id = feature_ids(idx, 0);
+ }
+ }
+ float gain_of_split = 0.0;
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ GradientStats right_gradient_stats =
+ current_layer_stats[root_idx] - left_gradient_stats[root_idx];
+ NodeStats left_stat =
+ state->ComputeNodeStats(left_gradient_stats[root_idx]);
+ NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats);
+ gain_of_split += left_stat.gain + right_stat.gain;
+ current_left_node_stats[root_idx] = left_stat;
+ current_right_node_stats[root_idx] = right_stat;
+ }
+ if (gain_of_split > best_gain) {
+ best_gain = gain_of_split;
+ best_left_node_stats = current_left_node_stats;
+ best_right_node_stats = current_right_node_stats;
+ best_feature_id = current_feature_id;
+ }
+ current_feature_id = next_feature_id;
+ }
+
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain;
+ }
+ best_gain -= num_elements * state->tree_complexity_regularization();
+
+ ObliviousSplitInfo oblivious_split_info;
+ auto* equality_split =
+ oblivious_split_info.mutable_split_node()
+ ->mutable_oblivious_categorical_id_binary_split();
+ equality_split->set_feature_column(state->feature_column_group_id());
+ equality_split->set_feature_id(best_feature_id);
+ (*gains)(0) = best_gain;
+
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ auto* left_child = oblivious_split_info.add_children();
+ auto* right_child = oblivious_split_info.add_children();
+
+ state->FillLeaf(best_left_node_stats[root_idx], left_child);
+ state->FillLeaf(best_right_node_stats[root_idx], right_child);
+
+ const int start_index = partition_boundaries[root_idx];
+ (*output_partition_ids)(root_idx) = partition_ids(start_index);
+ oblivious_split_info.add_children_parent_id(partition_ids(start_index));
+ }
+ oblivious_split_info.SerializeToString(&(*output_splits)(0));
+ }
};
REGISTER_KERNEL_BUILDER(
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
index efe29216c2..e6407174b1 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
+from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops
from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops
from tensorflow.python.framework import constant_op
@@ -46,6 +47,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
multiclass_strategy,
init_stamp_token=0,
loss_uses_sum_reduction=False,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE,
name=None):
"""Initialize the internal state for this split handler.
@@ -66,6 +68,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
stamped objects.
loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
SUM or MEAN reduction was used for the loss.
+ weak_learner_type: Specifies the type of weak learner to use.
name: An optional handler name.
"""
super(EqualitySplitHandler, self).__init__(
@@ -85,6 +88,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
hessian_shape,
name="StatsAccumulator/{}".format(self._name))
self._sparse_int_column = sparse_int_column
+ self._weak_learner_type = weak_learner_type
def update_stats(self, stamp_token, example_partition_ids, gradients,
hessians, empty_gradients, empty_hessians, weights,
@@ -197,7 +201,8 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
tree_complexity_regularization=self._tree_complexity_regularization,
min_node_weight=self._min_node_weight,
bias_feature_id=_BIAS_FEATURE_ID,
- multiclass_strategy=self._multiclass_strategy))
+ multiclass_strategy=self._multiclass_strategy,
+ weak_learner_type=self._weak_learner_type))
# There are no warm-up rounds needed in the equality column handler. So we
# always return ready.
are_splits_ready = constant_op.constant(True)
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
index ef253e7cec..d9f03c3840 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
@@ -169,6 +169,117 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
+ def testObliviousFeatureSplitGeneration(self):
+ with self.test_session() as sess:
+ # The data looks like the following:
+ # Example | Gradients | Partition | Feature ID |
+ # i0 | (0.2, 0.12) | 1 | 1 |
+ # i1 | (-0.5, 0.07) | 1 | 2 |
+ # i2 | (1.2, 0.2) | 1 | 1 |
+ # i3 | (4.0, 0.13) | 2 | 2 |
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ partition_ids = [1, 1, 1, 2]
+ indices = [[0, 0], [1, 0], [2, 0], [3, 0]]
+ values = array_ops.constant([1, 2, 1, 2], dtype=dtypes.int64)
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ class_id = -1
+
+ split_handler = categorical_split_handler.EqualitySplitHandler(
+ l1_regularization=0.1,
+ l2_regularization=1,
+ tree_complexity_regularization=0,
+ min_node_weight=0,
+ sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]),
+ feature_column_group_id=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ init_stamp_token=0,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ update_2 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+
+ with ops.control_dependencies([update_1, update_2]):
+ are_splits_ready, partitions, gains, splits = (
+ split_handler.make_splits(0, 1, class_id))
+ are_splits_ready, partitions, gains, splits = (
+ sess.run([are_splits_ready, partitions, gains, splits]))
+ self.assertTrue(are_splits_ready)
+ self.assertAllEqual([1, 2], partitions)
+
+ # For partition 1.
+ # -(0.2 + 1.2 - 0.1) / (0.12 + 0.2 + 1)
+ expected_left_weight1 = -0.9848484848484846
+ # (0.2 + 1.2 - 0.1) ** 2 / (0.12 + 0.2 + 1)
+ expected_left_gain1 = 1.2803030303030298
+
+ # -(-0.5 + 0.1) / (0.07 + 1)
+ expected_right_weight1 = 0.37383177570093457
+
+ # (-0.5 + 0.1) ** 2 / (0.07 + 1)
+ expected_right_gain1 = 0.14953271028037385
+
+ # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
+ expected_bias_gain1 = 0.46043165467625885
+
+ split_info = split_info_pb2.ObliviousSplitInfo()
+ split_info.ParseFromString(splits[0])
+ # Children of partition 1.
+ left_child = split_info.children[0].vector
+ right_child = split_info.children[1].vector
+ split_node = split_info.split_node.oblivious_categorical_id_binary_split
+
+ self.assertEqual(0, split_node.feature_column)
+ self.assertEqual(1, split_node.feature_id)
+ self.assertAllClose([expected_left_weight1], left_child.value, 0.00001)
+ self.assertAllClose([expected_right_weight1], right_child.value, 0.00001)
+
+ # For partition2.
+ expected_left_weight2 = 0
+ expected_left_gain2 = 0
+ # -(4 - 0.1) / (0.13 + 1)
+ expected_right_weight2 = -3.4513274336283186
+ # (4 - 0.1) ** 2 / (0.13 + 1)
+ expected_right_gain2 = 13.460176991150442
+ # (4 - 0.1) ** 2 / (0.13 + 1)
+ expected_bias_gain2 = 13.460176991150442
+
+ # Children of partition 2.
+ left_child = split_info.children[2].vector
+ right_child = split_info.children[3].vector
+ self.assertAllClose([expected_left_weight2], left_child.value, 0.00001)
+ self.assertAllClose([expected_right_weight2], right_child.value, 0.00001)
+
+ self.assertAllClose(
+ expected_left_gain1 + expected_right_gain1 - expected_bias_gain1 +
+ expected_left_gain2 + expected_right_gain2 - expected_bias_gain2,
+ gains[0], 0.00001)
+
def testGenerateFeatureSplitCandidatesSumReduction(self):
with self.test_session() as sess:
# The data looks like the following:
diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
index 3ed6c5c04d..64921faf81 100644
--- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
+++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
@@ -111,6 +111,18 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config,
node_id++;
break;
}
+ case TreeNode::kObliviousCategoricalIdBinarySplit: {
+ const auto& split =
+ current_node.oblivious_categorical_id_binary_split();
+ oblivious_leaf_idx <<= 1;
+ const auto& features =
+ example.sparse_int_features[split.feature_column()];
+ if (features.find(split.feature_id()) == features.end()) {
+ oblivious_leaf_idx++;
+ }
+ node_id++;
+ break;
+ }
case TreeNode::NODE_NOT_SET: {
LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString();
break;
@@ -181,6 +193,11 @@ void DecisionTree::LinkChildren(const std::vector<int32>& children,
<< "Not implemented for the ObliviousDenseFloatBinarySplit case.";
break;
}
+ case TreeNode::kObliviousCategoricalIdBinarySplit: {
+ LOG(QFATAL)
+ << "Not implemented for the ObliviousCategoricalIdBinarySplit case.";
+ break;
+ }
case TreeNode::NODE_NOT_SET: {
LOG(QFATAL) << "A non-set node cannot have children.";
break;
@@ -220,6 +237,11 @@ std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
<< "Not implemented for the ObliviousDenseFloatBinarySplit case.";
return {};
}
+ case TreeNode::kObliviousCategoricalIdBinarySplit: {
+ LOG(QFATAL)
+ << "Not implemented for the ObliviousCategoricalIdBinarySplit case.";
+ break;
+ }
case TreeNode::NODE_NOT_SET: {
return {};
}
diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
index 9b68a9de96..f1e12a028a 100644
--- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
@@ -179,6 +179,7 @@ REGISTER_OP("BuildCategoricalEqualitySplits")
.Input("tree_complexity_regularization: float")
.Input("min_node_weight: float")
.Input("multiclass_strategy: int32")
+ .Input("weak_learner_type: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
@@ -224,6 +225,8 @@ min_node_weight: A scalar, minimum sum of example hessian needed in a child.
be considered.
multiclass_strategy: A scalar, specifying the multiclass handling strategy.
See LearnerConfig.MultiClassStrategy for valid values.
+weak_learner_type: A scalar, specifying the weak learner type to use.
+ See LearnerConfig.WeakLearnerType for valid values.
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
for.
gains: A rank 1 tensor, for the computed gain for the created splits.
diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto
index 500909bf2a..520b4f8b11 100644
--- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto
+++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto
@@ -16,6 +16,7 @@ message TreeNode {
CategoricalIdSetMembershipBinarySplit
categorical_id_set_membership_binary_split = 6;
ObliviousDenseFloatBinarySplit oblivious_dense_float_binary_split = 7;
+ ObliviousCategoricalIdBinarySplit oblivious_categorical_id_binary_split = 8;
}
TreeNodeMetadata node_metadata = 777;
}
@@ -116,6 +117,17 @@ message ObliviousDenseFloatBinarySplit {
// leaves.
}
+// Split rule for categorical features with a single feature Id in the oblivious
+// case.
+message ObliviousCategoricalIdBinarySplit {
+ // Categorical feature column and Id describing the rule feature == Id.
+ int32 feature_column = 1;
+ int64 feature_id = 2;
+ // We don't store children ids, because either the next node represents the
+ // whole next layer of the tree or starting with the next node we only have
+ // leaves.
+}
+
// DecisionTreeConfig describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
index 5e62bad672..74917f7cde 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
@@ -541,7 +541,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
feature_column_group_id=0,
bias_feature_id=-1,
class_id=-1,
- multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -637,7 +638,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
feature_column_group_id=0,
bias_feature_id=-1,
class_id=-1,
- multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN))
+ multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -674,7 +676,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
feature_column_group_id=0,
bias_feature_id=-1,
class_id=-1,
- multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = (sess.run([partitions, gains, splits]))
self.assertEqual(0, len(partitions))
self.assertEqual(0, len(gains))
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 97743ba255..b008c6e534 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -762,7 +762,8 @@ class GradientBoostedDecisionTreeModel(object):
hessian_shape=self._hessian_shape,
multiclass_strategy=strategy_tensor,
init_stamp_token=init_stamp_token,
- loss_uses_sum_reduction=loss_uses_sum_reduction))
+ loss_uses_sum_reduction=loss_uses_sum_reduction,
+ weak_learner_type=weak_learner_type))
fc_name_idx += 1
# Create ensemble stats variables.
@@ -1063,6 +1064,12 @@ class GradientBoostedDecisionTreeModel(object):
# Grow the ensemble given the current candidates.
sizes = array_ops.unstack(split_sizes)
partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0))
+ # When using the oblivious decision tree as weak learner, it produces
+ # one gain and one split per handler and not number of partitions.
+ if self._learner_config.weak_learner_type == (
+ learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE):
+ sizes = len(training_state.handlers)
+
gains_list = list(array_ops.split(gains, sizes, axis=0))
split_info_list = list(array_ops.split(split_infos, sizes, axis=0))
return training_ops.grow_tree_ensemble(
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
index f7867d882d..73e41bc457 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from google.protobuf import text_format
from tensorflow.contrib import layers
+from tensorflow.contrib import learn
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
from tensorflow.contrib.boosted_trees.python.ops import model_ops
@@ -314,6 +315,162 @@ class GbdtTest(test_util.TensorFlowTestCase):
}"""
self.assertProtoEquals(expected_tree, output.trees[0])
+ def testObliviousDecisionTreeAsWeakLearner(self):
+ with self.test_session():
+ ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.learning_rate_tuner.fixed.learning_rate = 1
+ learner_config.regularization.l1 = 0
+ learner_config.regularization.l2 = 0
+ learner_config.constraints.max_tree_depth = 2
+ learner_config.constraints.min_node_weight = 0
+ learner_config.weak_learner_type = (
+ learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ learner_config.pruning_mode = learner_pb2.LearnerConfig.PRE_PRUNE
+ learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
+ features = {}
+ features["dense_float"] = array_ops.constant([[-2], [-1], [1], [2]],
+ dtypes.float32)
+
+ gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=True,
+ num_ps_replicas=0,
+ center_bias=False,
+ ensemble_handle=ensemble_handle,
+ examples_per_layer=1,
+ learner_config=learner_config,
+ logits_dimension=1,
+ features=features)
+
+ predictions_dict = gbdt_model.predict(learn.ModeKeys.TRAIN)
+ predictions = predictions_dict["predictions"]
+ labels = array_ops.constant([[-2], [-1], [1], [2]], dtypes.float32)
+ weights = array_ops.ones([4, 1], dtypes.float32)
+
+ train_op = gbdt_model.train(
+ loss=math_ops.reduce_mean(
+ _squared_loss(labels, weights, predictions)),
+ predictions_dict=predictions_dict,
+ labels=labels)
+ variables.global_variables_initializer().run()
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # On first run, expect no splits to be chosen because the quantile
+ # buckets will not be ready.
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 0)
+ self.assertEquals(len(output.tree_weights), 0)
+ self.assertEquals(stamp_token.eval(), 1)
+
+ # Second run.
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 1)
+ self.assertAllClose(output.tree_weights, [1])
+ self.assertEquals(stamp_token.eval(), 2)
+ expected_tree = """
+ nodes {
+ oblivious_dense_float_binary_split {
+ threshold: -1.0
+ }
+ node_metadata {
+ gain: 4.5
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -1.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 1.5
+ }
+ }
+ }"""
+ self.assertProtoEquals(expected_tree, output.trees[0])
+ # Third run.
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 1)
+ self.assertAllClose(output.tree_weights, [1])
+ self.assertEquals(stamp_token.eval(), 3)
+ expected_tree = """
+ nodes {
+ oblivious_dense_float_binary_split {
+ threshold: -1.0
+ }
+ node_metadata {
+ gain: 4.5
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ threshold: -2.0
+ }
+ node_metadata {
+ gain: 0.25
+ original_oblivious_leaves {
+ vector {
+ value: -1.5
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: 1.5
+ }
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -2.0
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -1.0
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 1.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 1.5
+ }
+ }
+ }"""
+ self.assertProtoEquals(expected_tree, output.trees[0])
+
def testTrainFnChiefSparseAndDense(self):
"""Tests the train function with sparse and dense features."""
with self.test_session() as sess: