diff options
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/eval_metrics.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/random_forest.py | 10 |
2 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index b070413c15..c99f9b7c12 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -30,6 +30,8 @@ from tensorflow.python.ops import nn INFERENCE_PROB_NAME = prediction_key.PredictionKey.CLASSES INFERENCE_PRED_NAME = prediction_key.PredictionKey.PROBABILITIES +FEATURE_IMPORTANCE_NAME = 'global_feature_importance' + def _top_k_generator(k): def _top_k(probabilities, targets): diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index f55602d8b8..12a4ca2c0e 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -28,7 +28,6 @@ from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging @@ -130,6 +129,10 @@ def get_model_fn(params, inference[eval_metrics.INFERENCE_PRED_NAME] = math_ops.argmax( inference[eval_metrics.INFERENCE_PROB_NAME], 1) + if report_feature_importances: + inference[eval_metrics.FEATURE_IMPORTANCE_NAME] = ( + graph_builder.feature_importances()) + # labels might be None if we're doing prediction (which brings up the # question of why we force everything to adhere to a single model_fn). loss_deps = [] @@ -149,10 +152,7 @@ def get_model_fn(params, with ops.control_dependencies(loss_deps): training_loss = graph_builder.training_loss( features, labels, name=LOSS_NAME) - if report_feature_importances and mode == model_fn_lib.ModeKeys.EVAL: - training_loss = logging_ops.Print(training_loss, - [graph_builder.feature_importances()], - summarize=1000) + # Put weights back in if weights is not None: features[weights_name] = weights |