diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-28 18:05:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 18:13:42 -0700 |
commit | c4099e6ee8ba3846f2b7e70445806bc3055c5624 (patch) | |
tree | 930b9c6c49304383cc1c528899140500be750bb0 /tensorflow/contrib/boosted_trees/python | |
parent | 6eabd59b16c8eb873d7dc5bb8c5fe55677290844 (diff) |
Added support for categorical features.
Ops are now interconnected to support oblivious decision trees.
PiperOrigin-RevId: 210642692
Diffstat (limited to 'tensorflow/contrib/boosted_trees/python')
3 files changed, 171 insertions, 4 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index 5e62bad672..74917f7cde 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -541,7 +541,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): feature_column_group_id=0, bias_feature_id=-1, class_id=-1, - multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) self.assertAllEqual([0, 1], partitions) @@ -637,7 +638,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): feature_column_group_id=0, bias_feature_id=-1, class_id=-1, - multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN)) + multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) self.assertAllEqual([0, 1], partitions) @@ -674,7 +676,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): feature_column_group_id=0, bias_feature_id=-1, class_id=-1, - multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = (sess.run([partitions, gains, splits])) self.assertEqual(0, len(partitions)) self.assertEqual(0, len(gains)) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 97743ba255..b008c6e534 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -762,7 +762,8 @@ class GradientBoostedDecisionTreeModel(object): hessian_shape=self._hessian_shape, multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token, - loss_uses_sum_reduction=loss_uses_sum_reduction)) + loss_uses_sum_reduction=loss_uses_sum_reduction, + weak_learner_type=weak_learner_type)) fc_name_idx += 1 # Create ensemble stats variables. @@ -1063,6 +1064,12 @@ class GradientBoostedDecisionTreeModel(object): # Grow the ensemble given the current candidates. sizes = array_ops.unstack(split_sizes) partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0)) + # When using the oblivious decision tree as weak learner, it produces + # one gain and one split per handler and not number of partitions. + if self._learner_config.weak_learner_type == ( + learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE): + sizes = len(training_state.handlers) + gains_list = list(array_ops.split(gains, sizes, axis=0)) split_info_list = list(array_ops.split(split_infos, sizes, axis=0)) return training_ops.grow_tree_ensemble( diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index f7867d882d..73e41bc457 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from google.protobuf import text_format from tensorflow.contrib import layers +from tensorflow.contrib import learn from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import tree_config_pb2 from tensorflow.contrib.boosted_trees.python.ops import model_ops @@ -314,6 +315,162 @@ class GbdtTest(test_util.TensorFlowTestCase): }""" self.assertProtoEquals(expected_tree, output.trees[0]) + def testObliviousDecisionTreeAsWeakLearner(self): + with self.test_session(): + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.learning_rate_tuner.fixed.learning_rate = 1 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 2 + learner_config.constraints.min_node_weight = 0 + learner_config.weak_learner_type = ( + learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE) + learner_config.pruning_mode = learner_pb2.LearnerConfig.PRE_PRUNE + learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER + features = {} + features["dense_float"] = array_ops.constant([[-2], [-1], [1], [2]], + dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions_dict = gbdt_model.predict(learn.ModeKeys.TRAIN) + predictions = predictions_dict["predictions"] + labels = array_ops.constant([[-2], [-1], [1], [2]], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 1) + + # Second run. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [1]) + self.assertEquals(stamp_token.eval(), 2) + expected_tree = """ + nodes { + oblivious_dense_float_binary_split { + threshold: -1.0 + } + node_metadata { + gain: 4.5 + original_oblivious_leaves { + } + } + } + nodes { + leaf { + vector { + value: -1.5 + } + } + } + nodes { + leaf { + vector { + value: 1.5 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + # Third run. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [1]) + self.assertEquals(stamp_token.eval(), 3) + expected_tree = """ + nodes { + oblivious_dense_float_binary_split { + threshold: -1.0 + } + node_metadata { + gain: 4.5 + original_oblivious_leaves { + } + } + } + nodes { + oblivious_dense_float_binary_split { + threshold: -2.0 + } + node_metadata { + gain: 0.25 + original_oblivious_leaves { + vector { + value: -1.5 + } + } + original_oblivious_leaves { + vector { + value: 1.5 + } + } + } + } + nodes { + leaf { + vector { + value: -2.0 + } + } + } + nodes { + leaf { + vector { + value: -1.0 + } + } + } + nodes { + leaf { + vector { + value: 1.5 + } + } + } + nodes { + leaf { + vector { + value: 1.5 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + def testTrainFnChiefSparseAndDense(self): """Tests the train function with sparse and dense features.""" with self.test_session() as sess: |