aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-17 10:48:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-17 11:05:15 -0800
commitb499ead59653ffd97038445a6c32310bd2455392 (patch)
tree308b463940770b710e74ca6ddc0e659e4e9bac9a
parent7ac49f72682327a25c342bd5e4d07e9d80dedb2d (diff)
Default to more fertile nodes in TensorForest.
Change: 139479566
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py15
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest_test.py4
2 files changed, 4 insertions, 15 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index c20279bc50..9d4e97cc7d 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -118,19 +118,8 @@ class ForestHParams(object):
self.num_splits_to_consider or
max(10, int(math.ceil(math.sqrt(self.num_features)))))
- # max_fertile_nodes doesn't effect performance, only training speed.
- # We therefore set it primarily based upon space considerations.
- # Each fertile node takes up num_splits_to_consider times as much
- # as space as a non-fertile node. We want the fertile nodes to in
- # total only take up as much space as the non-fertile nodes, so
- num_fertile = int(math.ceil(self.max_nodes / self.num_splits_to_consider))
- # But always use at least 1000 accumulate slots.
- num_fertile = max(num_fertile, 1000)
- self.max_fertile_nodes = self.max_fertile_nodes or num_fertile
- # But it also never needs to be larger than the number of leaves,
- # which is max_nodes / 2.
- self.max_fertile_nodes = min(self.max_fertile_nodes,
- int(math.ceil(self.max_nodes / 2.0)))
+ self.max_fertile_nodes = (self.max_fertile_nodes or
+ int(math.ceil(self.max_nodes / 2.0)))
# We have num_splits_to_consider slots to fill, and we want to spend
# approximately split_after_samples samples initializing them.
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index 986d6ba1ea..75b00aa990 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -50,8 +50,8 @@ class TensorForestTest(test_util.TensorFlowTestCase):
num_features=1000).fill()
# sqrt(1000) = 31.63...
self.assertEquals(32, hparams.num_splits_to_consider)
- # 1000000 / 32 = 31250
- self.assertEquals(31250, hparams.max_fertile_nodes)
+ # 1000000 / 2 = 500000
+ self.assertEquals(500000, hparams.max_fertile_nodes)
# floor(31.63 / 25) = 1
self.assertEquals(1, hparams.split_initializations_per_input)