diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-25 09:12:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 09:17:42 -0700 |
commit | 32140ae87fd86398ac4fa45cb67bd2f29a93090d (patch) | |
tree | f7a4d7701a64defd1f0d30a0460a2ce3ea571695 /tensorflow/python/kernel_tests | |
parent | 588787ff7572208285cb471c76f4f8c83ad9d7ec (diff) |
Boosted trees: Adding categorical split support to prediction ops.
PiperOrigin-RevId: 214448656
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r-- | tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py | 134 |
1 files changed, 134 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py index 467e33ec87..7cdc67f83f 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py @@ -445,6 +445,78 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase): # change= 0.1(1.14+7.0-7.0) self.assertAllClose([[1], [0.114]], logits_updates) + def testCategoricalSplits(self): + """Tests the training prediction work for categorical splits.""" + with self.cached_session() as session: + tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() + text_format.Merge( + """ + trees { + nodes { + categorical_split { + feature_id: 1 + value: 2 + left_id: 1 + right_id: 2 + } + } + nodes { + categorical_split { + feature_id: 0 + value: 13 + left_id: 3 + right_id: 4 + } + } + nodes { + leaf { + scalar: 7.0 + } + } + nodes { + leaf { + scalar: 5.0 + } + } + nodes { + leaf { + scalar: 6.0 + } + } + } + tree_weights: 1.0 + tree_metadata { + is_finalized: true + } + """, tree_ensemble_config) + + # Create existing ensemble with one root split + tree_ensemble = boosted_trees_ops.TreeEnsemble( + 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) + tree_ensemble_handle = tree_ensemble.resource_handle + resources.initialize_resources(resources.shared_resources()).run() + + feature_0_values = [13, 1, 3] + feature_1_values = [2, 2, 1] + + # No previous cached values. + cached_tree_ids = [0, 0, 0] + cached_node_ids = [0, 0, 0] + + # Grow tree ensemble. + predict_op = boosted_trees_ops.training_predict( + tree_ensemble_handle, + cached_tree_ids=cached_tree_ids, + cached_node_ids=cached_node_ids, + bucketized_features=[feature_0_values, feature_1_values], + logits_dimension=1) + + logits_updates, new_tree_ids, new_node_ids = session.run(predict_op) + + self.assertAllClose([0, 0, 0], new_tree_ids) + self.assertAllClose([3, 4, 2], new_node_ids) + self.assertAllClose([[5.], [6.], [7.]], logits_updates) + def testCachedPredictionFromTheSameTreeWithPostPrunedNodes(self): """Tests that prediction based on previous node in the tree works.""" with self.cached_session() as session: @@ -924,6 +996,68 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): logits = session.run(predict_op) self.assertAllClose(expected_logits, logits) + def testCategoricalSplits(self): + """Tests the predictions work for categorical splits.""" + with self.cached_session() as session: + tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() + text_format.Merge( + """ + trees { + nodes { + categorical_split { + feature_id: 1 + value: 2 + left_id: 1 + right_id: 2 + } + } + nodes { + categorical_split { + feature_id: 0 + value: 13 + left_id: 3 + right_id: 4 + } + } + nodes { + leaf { + scalar: 7.0 + } + } + nodes { + leaf { + scalar: 5.0 + } + } + nodes { + leaf { + scalar: 6.0 + } + } + } + tree_weights: 1.0 + """, tree_ensemble_config) + + # Create existing ensemble with one root split + tree_ensemble = boosted_trees_ops.TreeEnsemble( + 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) + tree_ensemble_handle = tree_ensemble.resource_handle + resources.initialize_resources(resources.shared_resources()).run() + + feature_0_values = [13, 1, 3] + feature_1_values = [2, 2, 1] + + expected_logits = [[5.], [6.], [7.]] + + # Prediction should work fine. + predict_op = boosted_trees_ops.predict( + tree_ensemble_handle, + bucketized_features=[feature_0_values, feature_1_values], + logits_dimension=1) + + logits = session.run(predict_op) + self.assertAllClose(expected_logits, logits) + class FeatureContribsOpsTest(test_util.TensorFlowTestCase): """Tests feature contribs ops for model understanding.""" |