diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-21 14:43:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 14:47:19 -0700 |
commit | e3108ea446b8b07d6a4aaca9667aff6ff5151a51 (patch) | |
tree | 981cf13de2bc5e957f8fb165c8bac347cbc74a29 /tensorflow/contrib/boosted_trees | |
parent | d0caa5a700dd36b7ac92be2722deaca9a4e23ef4 (diff) |
Fix bias feature being selected for splitting. The previous logic was broken for cases where all the examples in the last partition just had missing values. In those cases, the range that was selected for the leaf previous to the last included the bias value for the last leaf.
PiperOrigin-RevId: 214046965
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r-- | tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc | 23 | ||||
-rw-r--r-- | tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py | 86 |
2 files changed, 98 insertions, 11 deletions
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index af7006bff2..8edb5d6c64 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -739,21 +739,22 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { // Find the number of unique partitions before we allocate the output. std::vector<int32> partition_boundaries; - std::vector<int32> non_empty_partitions; - for (int i = 0; i < partition_ids.size() - 1; ++i) { + partition_boundaries.push_back(0); + for (int i = 1; i < partition_ids.size(); ++i) { // Make sure the input is sorted by partition_ids; - CHECK_LE(partition_ids(i), partition_ids(i + 1)); - if (i == 0 || partition_ids(i) != partition_ids(i - 1)) { + OP_REQUIRES(context, partition_ids(i - 1) <= partition_ids(i), + errors::InvalidArgument("Partition IDs must be sorted.")); + if (partition_ids(i) != partition_ids(i - 1)) { partition_boundaries.push_back(i); - // Some partitions might only have bias feature. We don't want to split - // those so check that the partition has at least 2 features. - if (partition_ids(i) == partition_ids(i + 1)) { - non_empty_partitions.push_back(partition_boundaries.size() - 1); - } } } - if (partition_ids.size() > 0) { - partition_boundaries.push_back(partition_ids.size()); + std::vector<int32> non_empty_partitions; + partition_boundaries.push_back(partition_ids.size()); + for (int i = 0; i < partition_boundaries.size() - 1; ++i) { + // We want to ignore partitions with only the bias term. + if (partition_boundaries[i + 1] - partition_boundaries[i] >= 2) { + non_empty_partitions.push_back(i); + } } int num_elements = non_empty_partitions.size(); Tensor* output_partition_ids_t = nullptr; diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py index 94ea7bc2eb..c050c2ed7f 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py @@ -577,6 +577,92 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(gains), 0) self.assertEqual(len(splits), 0) + def testLastOneEmpty(self): + with self.cached_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Feature ID | + # i0 | (0.2, 0.12) | 0 | 1,2 | + # i1 | (-0.5, 0.07) | 0 | | + # i2 | (1.2, 0.2) | 0 | 2 | + # i3 | (4.0, 0.13) | 1 | | + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = [0, 0, 0, 1] + indices = [[0, 0], [0, 1], [2, 0]] + values = array_ops.constant([1, 2, 2], dtype=dtypes.int64) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = categorical_split_handler.EqualitySplitHandler( + l1_regularization=0.1, + l2_regularization=1, + tree_complexity_regularization=0, + min_node_weight=0, + sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]), + feature_column_group_id=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + init_stamp_token=0) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready, partitions, gains, splits = ( + split_handler.make_splits(0, 1, class_id)) + are_splits_ready, partitions, gains, splits = ( + sess.run([are_splits_ready, partitions, gains, splits])) + self.assertTrue(are_splits_ready) + self.assertAllEqual([0], partitions) + + # Check the split on partition 0. + # -(0.2 + 1.2 - 0.1) / (0.12 + 0.2 + 1) + expected_left_weight = -0.9848484848484846 + + # (0.2 + 1.2 - 0.1) ** 2 / (0.12 + 0.2 + 1) + expected_left_gain = 1.2803030303030298 + + # -(-0.5 + 0.1) / (0.07 + 1) + expected_right_weight = 0.37383177570093457 + + # (-0.5 + 0.1) ** 2 / (0.07 + 1) + expected_right_gain = 0.14953271028037385 + + # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1) + expected_bias_gain = 0.46043165467625885 + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.categorical_id_binary_split + + self.assertEqual(0, split_node.feature_column) + + self.assertEqual(2, split_node.feature_id) + + self.assertAllClose( + expected_left_gain + expected_right_gain - expected_bias_gain, gains[0], + 0.00001) + + self.assertAllClose([expected_left_weight], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight], right_child.value, 0.00001) + if __name__ == "__main__": googletest.main() |