diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-28 18:05:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 18:13:42 -0700 |
commit | c4099e6ee8ba3846f2b7e70445806bc3055c5624 (patch) | |
tree | 930b9c6c49304383cc1c528899140500be750bb0 /tensorflow/contrib/boosted_trees | |
parent | 6eabd59b16c8eb873d7dc5bb8c5fe55677290844 (diff) |
Added support for categorical features.
Ops are now interconnected to support oblivious decision trees.
PiperOrigin-RevId: 210642692
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
9 files changed, 504 insertions, 21 deletions
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index d0fd39fa30..3b28ed77f3 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -739,6 +739,11 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { context->input("bias_feature_id", &bias_feature_id_t)); int64 bias_feature_id = bias_feature_id_t->scalar<int64>()(); + const Tensor* weak_learner_type_t; + OP_REQUIRES_OK(context, + context->input("weak_learner_type", &weak_learner_type_t)); + const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()(); + // Find the number of unique partitions before we allocate the output. std::vector<int32> partition_boundaries; std::vector<int32> non_empty_partitions; @@ -767,20 +772,63 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { tensorflow::TTypes<int32>::Vec output_partition_ids = output_partition_ids_t->vec<int32>(); + // For a normal tree, we output a split per partition. For an oblivious + // tree, we output one split for all partitions of the layer. + int size_output = num_elements; + if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE && + num_elements > 0) { + size_output = 1; + } + Tensor* gains_t = nullptr; - OP_REQUIRES_OK( - context, context->allocate_output("gains", TensorShape({num_elements}), - &gains_t)); + OP_REQUIRES_OK(context, context->allocate_output( + "gains", TensorShape({size_output}), &gains_t)); tensorflow::TTypes<float>::Vec gains = gains_t->vec<float>(); Tensor* output_splits_t = nullptr; - OP_REQUIRES_OK(context, context->allocate_output( - "split_infos", TensorShape({num_elements}), - &output_splits_t)); + OP_REQUIRES_OK(context, context->allocate_output("split_infos", + TensorShape({size_output}), + &output_splits_t)); tensorflow::TTypes<string>::Vec output_splits = output_splits_t->vec<string>(); + if (num_elements == 0) { + return; + } SplitBuilderState state(context); + switch (weak_learner_type) { + case LearnerConfig::NORMAL_DECISION_TREE: { + ComputeNormalDecisionTree( + context, &state, normalizer_ratio, num_elements, + partition_boundaries, non_empty_partitions, bias_feature_id, + partition_ids, feature_ids, gradients_t, hessians_t, + &output_partition_ids, &gains, &output_splits); + break; + } + case LearnerConfig::OBLIVIOUS_DECISION_TREE: { + ComputeObliviousDecisionTree( + context, &state, normalizer_ratio, num_elements, + partition_boundaries, non_empty_partitions, bias_feature_id, + partition_ids, feature_ids, gradients_t, hessians_t, + &output_partition_ids, &gains, &output_splits); + break; + } + } + } + + private: + void ComputeNormalDecisionTree( + OpKernelContext* const context, SplitBuilderState* state, + const float normalizer_ratio, const int num_elements, + const std::vector<int32>& partition_boundaries, + const std::vector<int32>& non_empty_partitions, + const int64 bias_feature_id, + const tensorflow::TTypes<int32>::ConstVec& partition_ids, + const tensorflow::TTypes<int64>::ConstMatrix& feature_ids, + const Tensor* gradients_t, const Tensor* hessians_t, + tensorflow::TTypes<int32>::Vec* output_partition_ids, + tensorflow::TTypes<float>::Vec* gains, + tensorflow::TTypes<string>::Vec* output_splits) { for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits<float>::lowest(); int start_index = partition_boundaries[non_empty_partitions[root_idx]]; @@ -790,7 +838,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { errors::InvalidArgument("Bias feature ID missing.")); GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index); root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats); int32 best_feature_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -801,8 +849,8 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { left_gradient_stats *= normalizer_ratio; GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); - NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); + NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats); + NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -813,18 +861,133 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { SplitInfo split_info; auto* equality_split = split_info.mutable_split_node() ->mutable_categorical_id_binary_split(); - equality_split->set_feature_column(state.feature_column_group_id()); + equality_split->set_feature_column(state->feature_column_group_id()); equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - state.FillLeaf(best_left_node_stats, left_child); - state.FillLeaf(best_right_node_stats, right_child); - split_info.SerializeToString(&output_splits(root_idx)); - gains(root_idx) = - best_gain - root_stats.gain - state.tree_complexity_regularization(); - output_partition_ids(root_idx) = partition_ids(start_index); + state->FillLeaf(best_left_node_stats, left_child); + state->FillLeaf(best_right_node_stats, right_child); + split_info.SerializeToString(&(*output_splits)(root_idx)); + (*gains)(root_idx) = + best_gain - root_stats.gain - state->tree_complexity_regularization(); + (*output_partition_ids)(root_idx) = partition_ids(start_index); } } + + void ComputeObliviousDecisionTree( + OpKernelContext* const context, SplitBuilderState* state, + const float normalizer_ratio, const int num_elements, + const std::vector<int32>& partition_boundaries, + const std::vector<int32>& non_empty_partitions, + const int64 bias_feature_id, + const tensorflow::TTypes<int32>::ConstVec& partition_ids, + const tensorflow::TTypes<int64>::ConstMatrix& feature_ids, + const Tensor* gradients_t, const Tensor* hessians_t, + tensorflow::TTypes<int32>::Vec* output_partition_ids, + tensorflow::TTypes<float>::Vec* gains, + tensorflow::TTypes<string>::Vec* output_splits) { + // Holds the root stats per each node to be split. + std::vector<GradientStats> current_layer_stats; + current_layer_stats.reserve(num_elements); + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + const int start_index = partition_boundaries[root_idx]; + // First feature ID in each partition should be the bias feature. + OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id, + errors::InvalidArgument("Bias feature ID missing.")); + GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index); + root_gradient_stats *= normalizer_ratio; + current_layer_stats.push_back(root_gradient_stats); + } + float best_gain = std::numeric_limits<float>::lowest(); + int64 best_feature_id = 0; + std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0)); + std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0)); + std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0)); + std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0)); + int64 current_feature_id = std::numeric_limits<int64>::max(); + int64 last_feature_id = -1; + // Find the lowest feature id, this is going to be the first feature id to + // try. + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + const int start_index = partition_boundaries[root_idx]; + if (feature_ids(start_index + 1, 0) < current_feature_id) { + current_feature_id = feature_ids(start_index + 1, 0); + } + } + // Indexes offsets for each of the partitions that can be used to access + // gradients of a partition for a current feature we consider. Start at one + // beacuse the zero index is for the bias. + std::vector<int> current_layer_offsets(num_elements, 1); + // The idea is to try every feature id in increasing order. In each + // iteration we calculate the gain of the layer using the current feature id + // as split value, and we also obtain the following feature id to try. + while (current_feature_id > last_feature_id) { + last_feature_id = current_feature_id; + int64 next_feature_id = -1; + // Left gradient stats per node. + std::vector<GradientStats> left_gradient_stats(num_elements); + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + int idx = + current_layer_offsets[root_idx] + partition_boundaries[root_idx]; + const int end_index = partition_boundaries[root_idx + 1]; + if (idx < end_index && feature_ids(idx, 0) == current_feature_id) { + GradientStats g(*gradients_t, *hessians_t, idx); + g *= normalizer_ratio; + left_gradient_stats[root_idx] = g; + current_layer_offsets[root_idx]++; + idx++; + } + if (idx < end_index && + (feature_ids(idx, 0) < next_feature_id || next_feature_id == -1)) { + next_feature_id = feature_ids(idx, 0); + } + } + float gain_of_split = 0.0; + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + GradientStats right_gradient_stats = + current_layer_stats[root_idx] - left_gradient_stats[root_idx]; + NodeStats left_stat = + state->ComputeNodeStats(left_gradient_stats[root_idx]); + NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats); + gain_of_split += left_stat.gain + right_stat.gain; + current_left_node_stats[root_idx] = left_stat; + current_right_node_stats[root_idx] = right_stat; + } + if (gain_of_split > best_gain) { + best_gain = gain_of_split; + best_left_node_stats = current_left_node_stats; + best_right_node_stats = current_right_node_stats; + best_feature_id = current_feature_id; + } + current_feature_id = next_feature_id; + } + + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain; + } + best_gain -= num_elements * state->tree_complexity_regularization(); + + ObliviousSplitInfo oblivious_split_info; + auto* equality_split = + oblivious_split_info.mutable_split_node() + ->mutable_oblivious_categorical_id_binary_split(); + equality_split->set_feature_column(state->feature_column_group_id()); + equality_split->set_feature_id(best_feature_id); + (*gains)(0) = best_gain; + + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + auto* left_child = oblivious_split_info.add_children(); + auto* right_child = oblivious_split_info.add_children(); + + state->FillLeaf(best_left_node_stats[root_idx], left_child); + state->FillLeaf(best_right_node_stats[root_idx], right_child); + + const int start_index = partition_boundaries[root_idx]; + (*output_partition_ids)(root_idx) = partition_ids(start_index); + oblivious_split_info.add_children_parent_id(partition_ids(start_index)); + } + oblivious_split_info.SerializeToString(&(*output_splits)(0)); + } }; REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index efe29216c2..e6407174b1 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler +from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops from tensorflow.python.framework import constant_op @@ -46,6 +47,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): multiclass_strategy, init_stamp_token=0, loss_uses_sum_reduction=False, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE, name=None): """Initialize the internal state for this split handler. @@ -66,6 +68,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): stamped objects. loss_uses_sum_reduction: A scalar boolean tensor that specifies whether SUM or MEAN reduction was used for the loss. + weak_learner_type: Specifies the type of weak learner to use. name: An optional handler name. """ super(EqualitySplitHandler, self).__init__( @@ -85,6 +88,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): hessian_shape, name="StatsAccumulator/{}".format(self._name)) self._sparse_int_column = sparse_int_column + self._weak_learner_type = weak_learner_type def update_stats(self, stamp_token, example_partition_ids, gradients, hessians, empty_gradients, empty_hessians, weights, @@ -197,7 +201,8 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): tree_complexity_regularization=self._tree_complexity_regularization, min_node_weight=self._min_node_weight, bias_feature_id=_BIAS_FEATURE_ID, - multiclass_strategy=self._multiclass_strategy)) + multiclass_strategy=self._multiclass_strategy, + weak_learner_type=self._weak_learner_type)) # There are no warm-up rounds needed in the equality column handler. So we # always return ready. are_splits_ready = constant_op.constant(True) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py index ef253e7cec..d9f03c3840 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py @@ -169,6 +169,117 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(1, split_node.feature_id) + def testObliviousFeatureSplitGeneration(self): + with self.test_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Feature ID | + # i0 | (0.2, 0.12) | 1 | 1 | + # i1 | (-0.5, 0.07) | 1 | 2 | + # i2 | (1.2, 0.2) | 1 | 1 | + # i3 | (4.0, 0.13) | 2 | 2 | + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = [1, 1, 1, 2] + indices = [[0, 0], [1, 0], [2, 0], [3, 0]] + values = array_ops.constant([1, 2, 1, 2], dtype=dtypes.int64) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = categorical_split_handler.EqualitySplitHandler( + l1_regularization=0.1, + l2_regularization=1, + tree_complexity_regularization=0, + min_node_weight=0, + sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]), + feature_column_group_id=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + init_stamp_token=0, + weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + update_2 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + + with ops.control_dependencies([update_1, update_2]): + are_splits_ready, partitions, gains, splits = ( + split_handler.make_splits(0, 1, class_id)) + are_splits_ready, partitions, gains, splits = ( + sess.run([are_splits_ready, partitions, gains, splits])) + self.assertTrue(are_splits_ready) + self.assertAllEqual([1, 2], partitions) + + # For partition 1. + # -(0.2 + 1.2 - 0.1) / (0.12 + 0.2 + 1) + expected_left_weight1 = -0.9848484848484846 + # (0.2 + 1.2 - 0.1) ** 2 / (0.12 + 0.2 + 1) + expected_left_gain1 = 1.2803030303030298 + + # -(-0.5 + 0.1) / (0.07 + 1) + expected_right_weight1 = 0.37383177570093457 + + # (-0.5 + 0.1) ** 2 / (0.07 + 1) + expected_right_gain1 = 0.14953271028037385 + + # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1) + expected_bias_gain1 = 0.46043165467625885 + + split_info = split_info_pb2.ObliviousSplitInfo() + split_info.ParseFromString(splits[0]) + # Children of partition 1. + left_child = split_info.children[0].vector + right_child = split_info.children[1].vector + split_node = split_info.split_node.oblivious_categorical_id_binary_split + + self.assertEqual(0, split_node.feature_column) + self.assertEqual(1, split_node.feature_id) + self.assertAllClose([expected_left_weight1], left_child.value, 0.00001) + self.assertAllClose([expected_right_weight1], right_child.value, 0.00001) + + # For partition2. + expected_left_weight2 = 0 + expected_left_gain2 = 0 + # -(4 - 0.1) / (0.13 + 1) + expected_right_weight2 = -3.4513274336283186 + # (4 - 0.1) ** 2 / (0.13 + 1) + expected_right_gain2 = 13.460176991150442 + # (4 - 0.1) ** 2 / (0.13 + 1) + expected_bias_gain2 = 13.460176991150442 + + # Children of partition 2. + left_child = split_info.children[2].vector + right_child = split_info.children[3].vector + self.assertAllClose([expected_left_weight2], left_child.value, 0.00001) + self.assertAllClose([expected_right_weight2], right_child.value, 0.00001) + + self.assertAllClose( + expected_left_gain1 + expected_right_gain1 - expected_bias_gain1 + + expected_left_gain2 + expected_right_gain2 - expected_bias_gain2, + gains[0], 0.00001) + def testGenerateFeatureSplitCandidatesSumReduction(self): with self.test_session() as sess: # The data looks like the following: diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc index 3ed6c5c04d..64921faf81 100644 --- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc @@ -111,6 +111,18 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config, node_id++; break; } + case TreeNode::kObliviousCategoricalIdBinarySplit: { + const auto& split = + current_node.oblivious_categorical_id_binary_split(); + oblivious_leaf_idx <<= 1; + const auto& features = + example.sparse_int_features[split.feature_column()]; + if (features.find(split.feature_id()) == features.end()) { + oblivious_leaf_idx++; + } + node_id++; + break; + } case TreeNode::NODE_NOT_SET: { LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString(); break; @@ -181,6 +193,11 @@ void DecisionTree::LinkChildren(const std::vector<int32>& children, << "Not implemented for the ObliviousDenseFloatBinarySplit case."; break; } + case TreeNode::kObliviousCategoricalIdBinarySplit: { + LOG(QFATAL) + << "Not implemented for the ObliviousCategoricalIdBinarySplit case."; + break; + } case TreeNode::NODE_NOT_SET: { LOG(QFATAL) << "A non-set node cannot have children."; break; @@ -220,6 +237,11 @@ std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) { << "Not implemented for the ObliviousDenseFloatBinarySplit case."; return {}; } + case TreeNode::kObliviousCategoricalIdBinarySplit: { + LOG(QFATAL) + << "Not implemented for the ObliviousCategoricalIdBinarySplit case."; + break; + } case TreeNode::NODE_NOT_SET: { return {}; } diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index 9b68a9de96..f1e12a028a 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -179,6 +179,7 @@ REGISTER_OP("BuildCategoricalEqualitySplits") .Input("tree_complexity_regularization: float") .Input("min_node_weight: float") .Input("multiclass_strategy: int32") + .Input("weak_learner_type: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -224,6 +225,8 @@ min_node_weight: A scalar, minimum sum of example hessian needed in a child. be considered. multiclass_strategy: A scalar, specifying the multiclass handling strategy. See LearnerConfig.MultiClassStrategy for valid values. +weak_learner_type: A scalar, specifying the weak learner type to use. + See LearnerConfig.WeakLearnerType for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto index 500909bf2a..520b4f8b11 100644 --- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto +++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto @@ -16,6 +16,7 @@ message TreeNode { CategoricalIdSetMembershipBinarySplit categorical_id_set_membership_binary_split = 6; ObliviousDenseFloatBinarySplit oblivious_dense_float_binary_split = 7; + ObliviousCategoricalIdBinarySplit oblivious_categorical_id_binary_split = 8; } TreeNodeMetadata node_metadata = 777; } @@ -116,6 +117,17 @@ message ObliviousDenseFloatBinarySplit { // leaves. } +// Split rule for categorical features with a single feature Id in the oblivious +// case. +message ObliviousCategoricalIdBinarySplit { + // Categorical feature column and Id describing the rule feature == Id. + int32 feature_column = 1; + int64 feature_id = 2; + // We don't store children ids, because either the next node represents the + // whole next layer of the tree or starting with the next node we only have + // leaves. +} + // DecisionTreeConfig describes a list of connected nodes. // Node 0 must be the root and can carry any payload including a leaf // in the case of representing the bias. diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index 5e62bad672..74917f7cde 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -541,7 +541,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): feature_column_group_id=0, bias_feature_id=-1, class_id=-1, - multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) self.assertAllEqual([0, 1], partitions) @@ -637,7 +638,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): feature_column_group_id=0, bias_feature_id=-1, class_id=-1, - multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN)) + multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) self.assertAllEqual([0, 1], partitions) @@ -674,7 +676,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): feature_column_group_id=0, bias_feature_id=-1, class_id=-1, - multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = (sess.run([partitions, gains, splits])) self.assertEqual(0, len(partitions)) self.assertEqual(0, len(gains)) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 97743ba255..b008c6e534 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -762,7 +762,8 @@ class GradientBoostedDecisionTreeModel(object): hessian_shape=self._hessian_shape, multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token, - loss_uses_sum_reduction=loss_uses_sum_reduction)) + loss_uses_sum_reduction=loss_uses_sum_reduction, + weak_learner_type=weak_learner_type)) fc_name_idx += 1 # Create ensemble stats variables. @@ -1063,6 +1064,12 @@ class GradientBoostedDecisionTreeModel(object): # Grow the ensemble given the current candidates. sizes = array_ops.unstack(split_sizes) partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0)) + # When using the oblivious decision tree as weak learner, it produces + # one gain and one split per handler and not number of partitions. + if self._learner_config.weak_learner_type == ( + learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE): + sizes = len(training_state.handlers) + gains_list = list(array_ops.split(gains, sizes, axis=0)) split_info_list = list(array_ops.split(split_infos, sizes, axis=0)) return training_ops.grow_tree_ensemble( diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index f7867d882d..73e41bc457 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from google.protobuf import text_format from tensorflow.contrib import layers +from tensorflow.contrib import learn from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import tree_config_pb2 from tensorflow.contrib.boosted_trees.python.ops import model_ops @@ -314,6 +315,162 @@ class GbdtTest(test_util.TensorFlowTestCase): }""" self.assertProtoEquals(expected_tree, output.trees[0]) + def testObliviousDecisionTreeAsWeakLearner(self): + with self.test_session(): + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.learning_rate_tuner.fixed.learning_rate = 1 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 2 + learner_config.constraints.min_node_weight = 0 + learner_config.weak_learner_type = ( + learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE) + learner_config.pruning_mode = learner_pb2.LearnerConfig.PRE_PRUNE + learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER + features = {} + features["dense_float"] = array_ops.constant([[-2], [-1], [1], [2]], + dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions_dict = gbdt_model.predict(learn.ModeKeys.TRAIN) + predictions = predictions_dict["predictions"] + labels = array_ops.constant([[-2], [-1], [1], [2]], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 1) + + # Second run. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [1]) + self.assertEquals(stamp_token.eval(), 2) + expected_tree = """ + nodes { + oblivious_dense_float_binary_split { + threshold: -1.0 + } + node_metadata { + gain: 4.5 + original_oblivious_leaves { + } + } + } + nodes { + leaf { + vector { + value: -1.5 + } + } + } + nodes { + leaf { + vector { + value: 1.5 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + # Third run. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [1]) + self.assertEquals(stamp_token.eval(), 3) + expected_tree = """ + nodes { + oblivious_dense_float_binary_split { + threshold: -1.0 + } + node_metadata { + gain: 4.5 + original_oblivious_leaves { + } + } + } + nodes { + oblivious_dense_float_binary_split { + threshold: -2.0 + } + node_metadata { + gain: 0.25 + original_oblivious_leaves { + vector { + value: -1.5 + } + } + original_oblivious_leaves { + vector { + value: 1.5 + } + } + } + } + nodes { + leaf { + vector { + value: -2.0 + } + } + } + nodes { + leaf { + vector { + value: -1.0 + } + } + } + nodes { + leaf { + vector { + value: 1.5 + } + } + } + nodes { + leaf { + vector { + value: 1.5 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + def testTrainFnChiefSparseAndDense(self): """Tests the train function with sparse and dense features.""" with self.test_session() as sess: |