diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-20 11:20:19 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-20 12:37:44 -0700 |
commit | 5c21d55000c03a90281880f0eb55b12fcaa528fe (patch) | |
tree | 0507a7e04b34e9e56adc1cb7ed666be292fcca60 | |
parent | 45938092a25872ee35d89d60cbf82b90aa4e08c9 (diff) |
Return TensorForest feature importances in inference dict instead of using tf.Print.
Change: 150659388
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/eval_metrics.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/random_forest.py | 10 |
2 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index b070413c15..c99f9b7c12 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -30,6 +30,8 @@ from tensorflow.python.ops import nn INFERENCE_PROB_NAME = prediction_key.PredictionKey.CLASSES INFERENCE_PRED_NAME = prediction_key.PredictionKey.PROBABILITIES +FEATURE_IMPORTANCE_NAME = 'global_feature_importance' + def _top_k_generator(k): def _top_k(probabilities, targets): diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index f55602d8b8..12a4ca2c0e 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -28,7 +28,6 @@ from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging @@ -130,6 +129,10 @@ def get_model_fn(params, inference[eval_metrics.INFERENCE_PRED_NAME] = math_ops.argmax( inference[eval_metrics.INFERENCE_PROB_NAME], 1) + if report_feature_importances: + inference[eval_metrics.FEATURE_IMPORTANCE_NAME] = ( + graph_builder.feature_importances()) + # labels might be None if we're doing prediction (which brings up the # question of why we force everything to adhere to a single model_fn). loss_deps = [] @@ -149,10 +152,7 @@ def get_model_fn(params, with ops.control_dependencies(loss_deps): training_loss = graph_builder.training_loss( features, labels, name=LOSS_NAME) - if report_feature_importances and mode == model_fn_lib.ModeKeys.EVAL: - training_loss = logging_ops.Print(training_loss, - [graph_builder.feature_importances()], - summarize=1000) + # Put weights back in if weights is not None: features[weights_name] = weights |