aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-05 14:32:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-05 14:36:01 -0800
commitd58eabfbe3570dd47ae3d1e3d5520c3dbbaca3c8 (patch)
tree6b1245b59b4dca22b95b0c5ac96ebd90607ba8db
parentc483e7b63913fb35817e1ba0dd6dc0d200cf5061 (diff)
Output variance over tree predictions for classifications.
PiperOrigin-RevId: 180976319
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py3
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py16
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest_test.py2
3 files changed, 9 insertions, 12 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index cddd62851b..a998ac1e11 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -237,8 +237,7 @@ def get_model_fn(params,
if params.inference_tree_paths:
model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
- if params.regression:
- model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
+ model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
return model_ops
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index eb938763f1..3650b5d52f 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -478,8 +478,7 @@ class RandomForestGraphs(object):
**inference_args: Keyword arguments to pass through to each tree.
Returns:
- A tuple of (probabilities, tree_paths, variance), where variance
- is the variance over all the trees for regression problems only.
+ A tuple of (probabilities, tree_paths, variance).
Raises:
NotImplementedError: If trying to use feature bagging with sparse
@@ -513,13 +512,12 @@ class RandomForestGraphs(object):
self.params.num_trees,
name='probabilities')
tree_paths = array_ops.stack(paths, axis=1)
- regression_variance = None
- if self.params.regression:
- expected_squares = math_ops.div(
- math_ops.reduce_sum(all_predict * all_predict, 1),
- self.params.num_trees)
- regression_variance = math_ops.maximum(
- 0., expected_squares - average_values * average_values)
+
+ expected_squares = math_ops.div(
+ math_ops.reduce_sum(all_predict * all_predict, 1),
+ self.params.num_trees)
+ regression_variance = math_ops.maximum(
+ 0., expected_squares - average_values * average_values)
return average_values, tree_paths, regression_variance
def average_size(self):
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index 113dfb85d3..bbe627b157 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -108,7 +108,7 @@ class TensorForestTest(test_util.TensorFlowTestCase):
probs, paths, var = graph_builder.inference_graph(input_data)
self.assertTrue(isinstance(probs, ops.Tensor))
self.assertTrue(isinstance(paths, ops.Tensor))
- self.assertIsNone(var)
+ self.assertTrue(isinstance(var, ops.Tensor))
def testTrainingConstructionClassificationSparse(self):
input_data = sparse_tensor.SparseTensor(