diff options
author | 2016-11-17 10:48:44 -0800 | |
---|---|---|
committer | 2016-11-17 11:05:15 -0800 | |
commit | b499ead59653ffd97038445a6c32310bd2455392 (patch) | |
tree | 308b463940770b710e74ca6ddc0e659e4e9bac9a | |
parent | 7ac49f72682327a25c342bd5e4d07e9d80dedb2d (diff) |
Default to more fertile nodes in TensorForest.
Change: 139479566
-rw-r--r-- | tensorflow/contrib/tensor_forest/python/tensor_forest.py | 15 | ||||
-rw-r--r-- | tensorflow/contrib/tensor_forest/python/tensor_forest_test.py | 4 |
2 files changed, 4 insertions, 15 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index c20279bc50..9d4e97cc7d 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -118,19 +118,8 @@ class ForestHParams(object): self.num_splits_to_consider or max(10, int(math.ceil(math.sqrt(self.num_features))))) - # max_fertile_nodes doesn't effect performance, only training speed. - # We therefore set it primarily based upon space considerations. - # Each fertile node takes up num_splits_to_consider times as much - # as space as a non-fertile node. We want the fertile nodes to in - # total only take up as much space as the non-fertile nodes, so - num_fertile = int(math.ceil(self.max_nodes / self.num_splits_to_consider)) - # But always use at least 1000 accumulate slots. - num_fertile = max(num_fertile, 1000) - self.max_fertile_nodes = self.max_fertile_nodes or num_fertile - # But it also never needs to be larger than the number of leaves, - # which is max_nodes / 2. - self.max_fertile_nodes = min(self.max_fertile_nodes, - int(math.ceil(self.max_nodes / 2.0))) + self.max_fertile_nodes = (self.max_fertile_nodes or + int(math.ceil(self.max_nodes / 2.0))) # We have num_splits_to_consider slots to fill, and we want to spend # approximately split_after_samples samples initializing them. diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py index 986d6ba1ea..75b00aa990 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py @@ -50,8 +50,8 @@ class TensorForestTest(test_util.TensorFlowTestCase): num_features=1000).fill() # sqrt(1000) = 31.63... self.assertEquals(32, hparams.num_splits_to_consider) - # 1000000 / 32 = 31250 - self.assertEquals(31250, hparams.max_fertile_nodes) + # 1000000 / 2 = 500000 + self.assertEquals(500000, hparams.max_fertile_nodes) # floor(31.63 / 25) = 1 self.assertEquals(1, hparams.split_initializations_per_input) |