aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-20 11:20:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-20 12:37:44 -0700
commit5c21d55000c03a90281880f0eb55b12fcaa528fe (patch)
tree0507a7e04b34e9e56adc1cb7ed666be292fcca60
parent45938092a25872ee35d89d60cbf82b90aa4e08c9 (diff)
Return TensorForest feature importances in inference dict instead of using tf.Print.
Change: 150659388
-rw-r--r--tensorflow/contrib/tensor_forest/client/eval_metrics.py2
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py10
2 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py
index b070413c15..c99f9b7c12 100644
--- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py
+++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py
@@ -30,6 +30,8 @@ from tensorflow.python.ops import nn
INFERENCE_PROB_NAME = prediction_key.PredictionKey.CLASSES
INFERENCE_PRED_NAME = prediction_key.PredictionKey.PROBABILITIES
+FEATURE_IMPORTANCE_NAME = 'global_feature_importance'
+
def _top_k_generator(k):
def _top_k(probabilities, targets):
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index f55602d8b8..12a4ca2c0e 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -28,7 +28,6 @@ from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
@@ -130,6 +129,10 @@ def get_model_fn(params,
inference[eval_metrics.INFERENCE_PRED_NAME] = math_ops.argmax(
inference[eval_metrics.INFERENCE_PROB_NAME], 1)
+ if report_feature_importances:
+ inference[eval_metrics.FEATURE_IMPORTANCE_NAME] = (
+ graph_builder.feature_importances())
+
# labels might be None if we're doing prediction (which brings up the
# question of why we force everything to adhere to a single model_fn).
loss_deps = []
@@ -149,10 +152,7 @@ def get_model_fn(params,
with ops.control_dependencies(loss_deps):
training_loss = graph_builder.training_loss(
features, labels, name=LOSS_NAME)
- if report_feature_importances and mode == model_fn_lib.ModeKeys.EVAL:
- training_loss = logging_ops.Print(training_loss,
- [graph_builder.feature_importances()],
- summarize=1000)
+
# Put weights back in
if weights is not None:
features[weights_name] = weights