aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py')
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py127
1 files changed, 127 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
index 5d82c4cae5..6572f2f414 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
@@ -182,6 +182,133 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
+ def testObliviousFeatureSplitGeneration(self):
+ with self.test_session() as sess:
+ # The data looks like the following:
+ # Example | Gradients | Partition | Dense Quantile |
+ # i0 | (0.2, 0.12) | 0 | 2 |
+ # i1 | (-0.5, 0.07) | 0 | 2 |
+ # i2 | (1.2, 0.2) | 0 | 0 |
+ # i3 | (4.0, 0.13) | 1 | 1 |
+ dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52])
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
+ class_id = -1
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ split_handler = ordinal_split_handler.DenseSplitHandler(
+ l1_regularization=0.1,
+ l2_regularization=1.,
+ tree_complexity_regularization=0.,
+ min_node_weight=0.,
+ epsilon=0.001,
+ num_quantiles=10,
+ feature_column_group_id=0,
+ dense_float_column=dense_column,
+ init_stamp_token=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_1]):
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
+
+ with ops.control_dependencies([are_splits_ready]):
+ update_2 = split_handler.update_stats_sync(
+ 1,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_2]):
+ are_splits_ready2, partitions, gains, splits = (
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
+ are_splits_ready, are_splits_ready2, partitions, gains, splits = (
+ sess.run([
+ are_splits_ready, are_splits_ready2, partitions, gains, splits
+ ]))
+
+ # During the first iteration, inequality split handlers are not going to
+ # have any splits. Make sure that we return not_ready in that case.
+ self.assertFalse(are_splits_ready)
+ self.assertTrue(are_splits_ready2)
+
+ self.assertAllEqual([0, 1], partitions)
+
+ oblivious_split_info = split_info_pb2.ObliviousSplitInfo()
+ oblivious_split_info.ParseFromString(splits[0])
+ split_node = oblivious_split_info.split_node.dense_float_binary_split
+
+ self.assertAllClose(0.3, split_node.threshold, 0.00001)
+ self.assertEqual(0, split_node.feature_column)
+
+ # Check the split on partition 0.
+ # -(1.2 - 0.1) / (0.2 + 1)
+ expected_left_weight_0 = -0.9166666666666666
+
+ # expected_left_weight_0 * -(1.2 - 0.1)
+ expected_left_gain_0 = 1.008333333333333
+
+ # (-0.5 + 0.2 + 0.1) / (0.19 + 1)
+ expected_right_weight_0 = 0.1680672
+
+ # expected_right_weight_0 * -(-0.5 + 0.2 + 0.1))
+ expected_right_gain_0 = 0.033613445378151252
+
+ # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
+ expected_bias_gain_0 = 0.46043165467625896
+
+ left_child = oblivious_split_info.children_leaves[0].vector
+ right_child = oblivious_split_info.children_leaves[1].vector
+
+ self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001)
+
+ self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001)
+
+ # Check the split on partition 1.
+ expected_left_weight_1 = 0
+ expected_left_gain_1 = 0
+ # -(4 - 0.1) / (0.13 + 1)
+ expected_right_weight_1 = -3.4513274336283186
+ # expected_right_weight_1 * -(4 - 0.1)
+ expected_right_gain_1 = 13.460176991150442
+ # (-4 + 0.1) ** 2 / (0.13 + 1)
+ expected_bias_gain_1 = 13.460176991150442
+
+ left_child = oblivious_split_info.children_leaves[2].vector
+ right_child = oblivious_split_info.children_leaves[3].vector
+
+ self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)
+
+ self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001)
+
+ # The layer gain is the sum of the gains of each partition
+ layer_gain = (
+ expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + (
+ expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1)
+ self.assertAllClose(layer_gain, gains[0], 0.00001)
+
def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
with self.test_session() as sess:
# The data looks like the following: