diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-09 14:25:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 14:32:09 -0700 |
commit | 35caff957424a60bd7d7e4e92a1ec87f617781c6 (patch) | |
tree | decad6ad4e0a6807e63ee3a43a751f03012f3d66 | |
parent | fa1542234857acf56af6e7f0dbe8d2084a18fa00 (diff) |
Export feature importance for oblivious tree nodes.
PiperOrigin-RevId: 216422334
-rw-r--r-- | tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 48f12a64f9..a3df272e69 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -196,6 +196,10 @@ def convert_to_universal_format(dtec, sorted_feature_names, matching_id = categorical_test.value.add() matching_id.int64_value = split.feature_id node.custom_left_child_test.Pack(categorical_test) + elif (node_type == "oblivious_dense_float_binary_split" or + node_type == "oblivious_categorical_id_binary_split"): + raise ValueError("Universal tree format doesn't support oblivious " + "trees") else: raise ValueError("Unexpected node type %s" % node_type) node.left_child_id.value = split.left_id @@ -229,6 +233,13 @@ def _get_feature_importances(dtec, feature_names, num_dense_floats, split = tree_node.categorical_id_binary_split split_column = feature_names[split.feature_column + num_dense_floats + num_sparse_float] + elif node_type == "oblivious_dense_float_binary_split": + split = tree_node.oblivious_dense_float_binary_split + split_column = feature_names[split.feature_column] + elif node_type == "oblivious_categorical_id_binary_split": + split = tree_node.oblivious_categorical_id_binary_split + split_column = feature_names[split.feature_column + num_dense_floats + + num_sparse_float] elif node_type == "categorical_id_set_membership_binary_split": split = tree_node.categorical_id_set_membership_binary_split split_column = feature_names[split.feature_column + num_dense_floats + |