aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/python/tensor_forest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python/tensor_forest.py')
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py34
1 files changed, 22 insertions, 12 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index 7a35a70bbe..6f62cd11a9 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -295,7 +295,7 @@ def get_epoch_variable():
# A simple container to hold the training variables for a single tree.
-class TreeTrainingVariables(object):
+class TreeVariables(object):
"""Stores tf.Variables for training a single random tree.
Uses tf.get_variable to get tree-specific names so that this can be used
@@ -303,7 +303,7 @@ class TreeTrainingVariables(object):
then relies on restoring that model to evaluate).
"""
- def __init__(self, params, tree_num, training):
+ def __init__(self, params, tree_num, training, tree_config='', tree_stat=''):
if (not hasattr(params, 'params_proto') or
not isinstance(params.params_proto,
_params_proto.TensorForestParams)):
@@ -315,27 +315,28 @@ class TreeTrainingVariables(object):
# TODO(gilberth): Manually shard this to be able to fit it on
# multiple machines.
self.stats = stats_ops.fertile_stats_variable(
- params, '', self.get_tree_name('stats', tree_num))
+ params, tree_stat, self.get_tree_name('stats', tree_num))
self.tree = model_ops.tree_variable(
- params, '', self.stats, self.get_tree_name('tree', tree_num))
+ params, tree_config, self.stats, self.get_tree_name('tree', tree_num))
def get_tree_name(self, name, num):
return '{0}-{1}'.format(name, num)
-class ForestTrainingVariables(object):
+class ForestVariables(object):
"""A container for a forests training data, consisting of multiple trees.
- Instantiates a TreeTrainingVariables object for each tree. We override the
+ Instantiates a TreeVariables object for each tree. We override the
__getitem__ and __setitem__ function so that usage looks like this:
- forest_variables = ForestTrainingVariables(params)
+ forest_variables = ForestVariables(params)
... forest_variables.tree ...
"""
def __init__(self, params, device_assigner, training=True,
- tree_variables_class=TreeTrainingVariables):
+ tree_variables_class=TreeVariables,
+ tree_configs=None, tree_stats=None):
self.variables = []
# Set up some scalar variables to run through the device assigner, then
# we can use those to colocate everything related to a tree.
@@ -347,7 +348,13 @@ class ForestTrainingVariables(object):
for i in range(params.num_trees):
with ops.device(self.device_dummies[i].device):
- self.variables.append(tree_variables_class(params, i, training))
+ kwargs = {}
+ if tree_configs is not None:
+ kwargs.update(dict(tree_config=tree_configs[i]))
+ if tree_stats is not None:
+ kwargs.update(dict(tree_stat=tree_stats[i]))
+ self.variables.append(tree_variables_class(
+ params, i, training, **kwargs))
def __setitem__(self, t, val):
self.variables[t] = val
@@ -361,9 +368,11 @@ class RandomForestGraphs(object):
def __init__(self,
params,
+ tree_configs=None,
+ tree_stats=None,
device_assigner=None,
variables=None,
- tree_variables_class=TreeTrainingVariables,
+ tree_variables_class=TreeVariables,
tree_graphs=None,
training=True):
self.params = params
@@ -371,9 +380,10 @@ class RandomForestGraphs(object):
device_assigner or framework_variables.VariableDeviceChooser())
logging.info('Constructing forest with params = ')
logging.info(self.params.__dict__)
- self.variables = variables or ForestTrainingVariables(
+ self.variables = variables or ForestVariables(
self.params, device_assigner=self.device_assigner, training=training,
- tree_variables_class=tree_variables_class)
+ tree_variables_class=tree_variables_class,
+ tree_configs=tree_configs, tree_stats=tree_stats)
tree_graph_class = tree_graphs or RandomTreeGraphs
self.trees = [
tree_graph_class(self.variables[i], self.params, i)