From 35caff957424a60bd7d7e4e92a1ec87f617781c6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Oct 2018 14:25:27 -0700 Subject: Export feature importance for oblivious tree nodes. PiperOrigin-RevId: 216422334 --- .../boosted_trees/estimator_batch/custom_export_strategy.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 48f12a64f9..a3df272e69 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -196,6 +196,10 @@ def convert_to_universal_format(dtec, sorted_feature_names, matching_id = categorical_test.value.add() matching_id.int64_value = split.feature_id node.custom_left_child_test.Pack(categorical_test) + elif (node_type == "oblivious_dense_float_binary_split" or + node_type == "oblivious_categorical_id_binary_split"): + raise ValueError("Universal tree format doesn't support oblivious " + "trees") else: raise ValueError("Unexpected node type %s" % node_type) node.left_child_id.value = split.left_id @@ -229,6 +233,13 @@ def _get_feature_importances(dtec, feature_names, num_dense_floats, split = tree_node.categorical_id_binary_split split_column = feature_names[split.feature_column + num_dense_floats + num_sparse_float] + elif node_type == "oblivious_dense_float_binary_split": + split = tree_node.oblivious_dense_float_binary_split + split_column = feature_names[split.feature_column] + elif node_type == "oblivious_categorical_id_binary_split": + split = tree_node.oblivious_categorical_id_binary_split + split_column = feature_names[split.feature_column + num_dense_floats + + num_sparse_float] elif node_type == "categorical_id_set_membership_binary_split": split = tree_node.categorical_id_set_membership_binary_split split_column = feature_names[split.feature_column + num_dense_floats + -- cgit v1.2.3