From d58eabfbe3570dd47ae3d1e3d5520c3dbbaca3c8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 5 Jan 2018 14:32:31 -0800 Subject: Output variance over tree predictions for classifications. PiperOrigin-RevId: 180976319 --- tensorflow/contrib/tensor_forest/client/random_forest.py | 3 +-- tensorflow/contrib/tensor_forest/python/tensor_forest.py | 16 +++++++--------- .../contrib/tensor_forest/python/tensor_forest_test.py | 2 +- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index cddd62851b..a998ac1e11 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -237,8 +237,7 @@ def get_model_fn(params, if params.inference_tree_paths: model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths - if params.regression: - model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance + model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance return model_ops diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index eb938763f1..3650b5d52f 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -478,8 +478,7 @@ class RandomForestGraphs(object): **inference_args: Keyword arguments to pass through to each tree. Returns: - A tuple of (probabilities, tree_paths, variance), where variance - is the variance over all the trees for regression problems only. + A tuple of (probabilities, tree_paths, variance). Raises: NotImplementedError: If trying to use feature bagging with sparse @@ -513,13 +512,12 @@ class RandomForestGraphs(object): self.params.num_trees, name='probabilities') tree_paths = array_ops.stack(paths, axis=1) - regression_variance = None - if self.params.regression: - expected_squares = math_ops.div( - math_ops.reduce_sum(all_predict * all_predict, 1), - self.params.num_trees) - regression_variance = math_ops.maximum( - 0., expected_squares - average_values * average_values) + + expected_squares = math_ops.div( + math_ops.reduce_sum(all_predict * all_predict, 1), + self.params.num_trees) + regression_variance = math_ops.maximum( + 0., expected_squares - average_values * average_values) return average_values, tree_paths, regression_variance def average_size(self): diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py index 113dfb85d3..bbe627b157 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py @@ -108,7 +108,7 @@ class TensorForestTest(test_util.TensorFlowTestCase): probs, paths, var = graph_builder.inference_graph(input_data) self.assertTrue(isinstance(probs, ops.Tensor)) self.assertTrue(isinstance(paths, ops.Tensor)) - self.assertIsNone(var) + self.assertTrue(isinstance(var, ops.Tensor)) def testTrainingConstructionClassificationSparse(self): input_data = sparse_tensor.SparseTensor( -- cgit v1.2.3