aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest
diff options
context:
space:
mode:
authorGravatar Peng Yu <peng.yu@shopify.com>2018-02-16 10:59:14 -0500
committerGravatar Peng Yu <peng.yu@shopify.com>2018-05-21 13:34:57 -0400
commit049a8364211ca91d73e10b2002c18f10fe89b8b2 (patch)
tree119c835b9b2039b4c21f54a7293e44e1785e0f4d /tensorflow/contrib/tensor_forest
parentb84878e63e1166b5c8dc34a777f6ab6fc2517d74 (diff)
add inference support for tree and forest variables
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py31
1 files changed, 19 insertions, 12 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index 7a35a70bbe..0feca52b0f 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -295,7 +295,7 @@ def get_epoch_variable():
# A simple container to hold the training variables for a single tree.
-class TreeTrainingVariables(object):
+class TreeVariables(object):
"""Stores tf.Variables for training a single random tree.
Uses tf.get_variable to get tree-specific names so that this can be used
@@ -303,7 +303,7 @@ class TreeTrainingVariables(object):
then relies on restoring that model to evaluate).
"""
- def __init__(self, params, tree_num, training):
+ def __init__(self, params, tree_num, training, tree_config='', tree_stat=''):
if (not hasattr(params, 'params_proto') or
not isinstance(params.params_proto,
_params_proto.TensorForestParams)):
@@ -315,27 +315,27 @@ class TreeTrainingVariables(object):
# TODO(gilberth): Manually shard this to be able to fit it on
# multiple machines.
self.stats = stats_ops.fertile_stats_variable(
- params, '', self.get_tree_name('stats', tree_num))
+ params, tree_stat, self.get_tree_name('stats', tree_num))
self.tree = model_ops.tree_variable(
- params, '', self.stats, self.get_tree_name('tree', tree_num))
+ params, tree_config, self.stats, self.get_tree_name('tree', tree_num))
def get_tree_name(self, name, num):
return '{0}-{1}'.format(name, num)
-class ForestTrainingVariables(object):
+class ForestVariables(object):
"""A container for a forests training data, consisting of multiple trees.
- Instantiates a TreeTrainingVariables object for each tree. We override the
+ Instantiates a TreeVariables object for each tree. We override the
__getitem__ and __setitem__ function so that usage looks like this:
- forest_variables = ForestTrainingVariables(params)
+ forest_variables = ForestVariables(params)
... forest_variables.tree ...
"""
def __init__(self, params, device_assigner, training=True,
- tree_variables_class=TreeTrainingVariables):
+ tree_variables_class=TreeVariables, tree_configs=None, tree_stats=None):
self.variables = []
# Set up some scalar variables to run through the device assigner, then
# we can use those to colocate everything related to a tree.
@@ -347,7 +347,12 @@ class ForestTrainingVariables(object):
for i in range(params.num_trees):
with ops.device(self.device_dummies[i].device):
- self.variables.append(tree_variables_class(params, i, training))
+ kwargs = {}
+ if tree_configs is not None:
+ kwargs.update(dict(tree_config=tree_configs[i]))
+ if tree_stats is not None:
+ kwargs.update(dict(tree_stat=tree_stats[i]))
+ self.variables.append(tree_variables_class(params, i, training, **kwargs))
def __setitem__(self, t, val):
self.variables[t] = val
@@ -361,9 +366,11 @@ class RandomForestGraphs(object):
def __init__(self,
params,
+ tree_configs=None,
+ tree_stats=None,
device_assigner=None,
variables=None,
- tree_variables_class=TreeTrainingVariables,
+ tree_variables_class=TreeVariables,
tree_graphs=None,
training=True):
self.params = params
@@ -371,9 +378,9 @@ class RandomForestGraphs(object):
device_assigner or framework_variables.VariableDeviceChooser())
logging.info('Constructing forest with params = ')
logging.info(self.params.__dict__)
- self.variables = variables or ForestTrainingVariables(
+ self.variables = variables or ForestVariables(
self.params, device_assigner=self.device_assigner, training=training,
- tree_variables_class=tree_variables_class)
+ tree_variables_class=tree_variables_class, tree_configs=tree_configs, tree_stats=tree_stats)
tree_graph_class = tree_graphs or RandomTreeGraphs
self.trees = [
tree_graph_class(self.variables[i], self.params, i)