aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-09 14:25:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 14:32:09 -0700
commit35caff957424a60bd7d7e4e92a1ec87f617781c6 (patch)
treedecad6ad4e0a6807e63ee3a43a751f03012f3d66
parentfa1542234857acf56af6e7f0dbe8d2084a18fa00 (diff)
Export feature importance for oblivious tree nodes.
PiperOrigin-RevId: 216422334
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
index 48f12a64f9..a3df272e69 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
@@ -196,6 +196,10 @@ def convert_to_universal_format(dtec, sorted_feature_names,
matching_id = categorical_test.value.add()
matching_id.int64_value = split.feature_id
node.custom_left_child_test.Pack(categorical_test)
+ elif (node_type == "oblivious_dense_float_binary_split" or
+ node_type == "oblivious_categorical_id_binary_split"):
+ raise ValueError("Universal tree format doesn't support oblivious "
+ "trees")
else:
raise ValueError("Unexpected node type %s" % node_type)
node.left_child_id.value = split.left_id
@@ -229,6 +233,13 @@ def _get_feature_importances(dtec, feature_names, num_dense_floats,
split = tree_node.categorical_id_binary_split
split_column = feature_names[split.feature_column + num_dense_floats +
num_sparse_float]
+ elif node_type == "oblivious_dense_float_binary_split":
+ split = tree_node.oblivious_dense_float_binary_split
+ split_column = feature_names[split.feature_column]
+ elif node_type == "oblivious_categorical_id_binary_split":
+ split = tree_node.oblivious_categorical_id_binary_split
+ split_column = feature_names[split.feature_column + num_dense_floats +
+ num_sparse_float]
elif node_type == "categorical_id_set_membership_binary_split":
split = tree_node.categorical_id_set_membership_binary_split
split_column = feature_names[split.feature_column + num_dense_floats +