aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-21 14:43:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 14:47:19 -0700
commite3108ea446b8b07d6a4aaca9667aff6ff5151a51 (patch)
tree981cf13de2bc5e957f8fb165c8bac347cbc74a29 /tensorflow/contrib/boosted_trees
parentd0caa5a700dd36b7ac92be2722deaca9a4e23ef4 (diff)
Fix bias feature being selected for splitting. The previous logic was broken for cases where all the examples in the last partition just had missing values. In those cases, the range that was selected for the leaf previous to the last included the bias value for the last leaf.
PiperOrigin-RevId: 214046965
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc23
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py86
2 files changed, 98 insertions, 11 deletions
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index af7006bff2..8edb5d6c64 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -739,21 +739,22 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
- std::vector<int32> non_empty_partitions;
- for (int i = 0; i < partition_ids.size() - 1; ++i) {
+ partition_boundaries.push_back(0);
+ for (int i = 1; i < partition_ids.size(); ++i) {
// Make sure the input is sorted by partition_ids;
- CHECK_LE(partition_ids(i), partition_ids(i + 1));
- if (i == 0 || partition_ids(i) != partition_ids(i - 1)) {
+ OP_REQUIRES(context, partition_ids(i - 1) <= partition_ids(i),
+ errors::InvalidArgument("Partition IDs must be sorted."));
+ if (partition_ids(i) != partition_ids(i - 1)) {
partition_boundaries.push_back(i);
- // Some partitions might only have bias feature. We don't want to split
- // those so check that the partition has at least 2 features.
- if (partition_ids(i) == partition_ids(i + 1)) {
- non_empty_partitions.push_back(partition_boundaries.size() - 1);
- }
}
}
- if (partition_ids.size() > 0) {
- partition_boundaries.push_back(partition_ids.size());
+ std::vector<int32> non_empty_partitions;
+ partition_boundaries.push_back(partition_ids.size());
+ for (int i = 0; i < partition_boundaries.size() - 1; ++i) {
+ // We want to ignore partitions with only the bias term.
+ if (partition_boundaries[i + 1] - partition_boundaries[i] >= 2) {
+ non_empty_partitions.push_back(i);
+ }
}
int num_elements = non_empty_partitions.size();
Tensor* output_partition_ids_t = nullptr;
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
index 94ea7bc2eb..c050c2ed7f 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
@@ -577,6 +577,92 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(gains), 0)
self.assertEqual(len(splits), 0)
+ def testLastOneEmpty(self):
+ with self.cached_session() as sess:
+ # The data looks like the following:
+ # Example | Gradients | Partition | Feature ID |
+ # i0 | (0.2, 0.12) | 0 | 1,2 |
+ # i1 | (-0.5, 0.07) | 0 | |
+ # i2 | (1.2, 0.2) | 0 | 2 |
+ # i3 | (4.0, 0.13) | 1 | |
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ partition_ids = [0, 0, 0, 1]
+ indices = [[0, 0], [0, 1], [2, 0]]
+ values = array_ops.constant([1, 2, 2], dtype=dtypes.int64)
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ class_id = -1
+
+ split_handler = categorical_split_handler.EqualitySplitHandler(
+ l1_regularization=0.1,
+ l2_regularization=1,
+ tree_complexity_regularization=0,
+ min_node_weight=0,
+ sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]),
+ feature_column_group_id=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ init_stamp_token=0)
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_1]):
+ are_splits_ready, partitions, gains, splits = (
+ split_handler.make_splits(0, 1, class_id))
+ are_splits_ready, partitions, gains, splits = (
+ sess.run([are_splits_ready, partitions, gains, splits]))
+ self.assertTrue(are_splits_ready)
+ self.assertAllEqual([0], partitions)
+
+ # Check the split on partition 0.
+ # -(0.2 + 1.2 - 0.1) / (0.12 + 0.2 + 1)
+ expected_left_weight = -0.9848484848484846
+
+ # (0.2 + 1.2 - 0.1) ** 2 / (0.12 + 0.2 + 1)
+ expected_left_gain = 1.2803030303030298
+
+ # -(-0.5 + 0.1) / (0.07 + 1)
+ expected_right_weight = 0.37383177570093457
+
+ # (-0.5 + 0.1) ** 2 / (0.07 + 1)
+ expected_right_gain = 0.14953271028037385
+
+ # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
+ expected_bias_gain = 0.46043165467625885
+
+ split_info = split_info_pb2.SplitInfo()
+ split_info.ParseFromString(splits[0])
+ left_child = split_info.left_child.vector
+ right_child = split_info.right_child.vector
+ split_node = split_info.split_node.categorical_id_binary_split
+
+ self.assertEqual(0, split_node.feature_column)
+
+ self.assertEqual(2, split_node.feature_id)
+
+ self.assertAllClose(
+ expected_left_gain + expected_right_gain - expected_bias_gain, gains[0],
+ 0.00001)
+
+ self.assertAllClose([expected_left_weight], left_child.value, 0.00001)
+
+ self.assertAllClose([expected_right_weight], right_child.value, 0.00001)
+
if __name__ == "__main__":
googletest.main()