aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-25 09:12:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 09:17:42 -0700
commit32140ae87fd86398ac4fa45cb67bd2f29a93090d (patch)
treef7a4d7701a64defd1f0d30a0460a2ce3ea571695 /tensorflow/python/kernel_tests
parent588787ff7572208285cb471c76f4f8c83ad9d7ec (diff)
Boosted trees: Adding categorical split support to prediction ops.
PiperOrigin-RevId: 214448656
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py134
1 files changed, 134 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
index 467e33ec87..7cdc67f83f 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
@@ -445,6 +445,78 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
# change= 0.1(1.14+7.0-7.0)
self.assertAllClose([[1], [0.114]], logits_updates)
+ def testCategoricalSplits(self):
+ """Tests the training prediction work for categorical splits."""
+ with self.cached_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ categorical_split {
+ feature_id: 1
+ value: 2
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ categorical_split {
+ feature_id: 0
+ value: 13
+ left_id: 3
+ right_id: 4
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ is_finalized: true
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_0_values = [13, 1, 3]
+ feature_1_values = [2, 2, 1]
+
+ # No previous cached values.
+ cached_tree_ids = [0, 0, 0]
+ cached_node_ids = [0, 0, 0]
+
+ # Grow tree ensemble.
+ predict_op = boosted_trees_ops.training_predict(
+ tree_ensemble_handle,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+ self.assertAllClose([0, 0, 0], new_tree_ids)
+ self.assertAllClose([3, 4, 2], new_node_ids)
+ self.assertAllClose([[5.], [6.], [7.]], logits_updates)
+
def testCachedPredictionFromTheSameTreeWithPostPrunedNodes(self):
"""Tests that prediction based on previous node in the tree works."""
with self.cached_session() as session:
@@ -924,6 +996,68 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
logits = session.run(predict_op)
self.assertAllClose(expected_logits, logits)
+ def testCategoricalSplits(self):
+ """Tests the predictions work for categorical splits."""
+ with self.cached_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ categorical_split {
+ feature_id: 1
+ value: 2
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ categorical_split {
+ feature_id: 0
+ value: 13
+ left_id: 3
+ right_id: 4
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_0_values = [13, 1, 3]
+ feature_1_values = [2, 2, 1]
+
+ expected_logits = [[5.], [6.], [7.]]
+
+ # Prediction should work fine.
+ predict_op = boosted_trees_ops.predict(
+ tree_ensemble_handle,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits = session.run(predict_op)
+ self.assertAllClose(expected_logits, logits)
+
class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
"""Tests feature contribs ops for model understanding."""