aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest
diff options
context:
space:
mode:
authorGravatar Peng Yu <peng.yu@shopify.com>2018-05-18 17:05:45 -0400
committerGravatar Peng Yu <peng.yu@shopify.com>2018-05-21 13:34:57 -0400
commitcea25b276d3a975a1ae5152aa4809d2c73a5f9d4 (patch)
treef3f31a2a7856c4648a04b064ec0da9120567ce24 /tensorflow/contrib/tensor_forest
parent2d2f0455d215c7a3f4353ab879cbb513b7946a78 (diff)
address lint
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py13
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest_test.py17
2 files changed, 24 insertions, 6 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index 0feca52b0f..ba1755eddd 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -335,7 +335,8 @@ class ForestVariables(object):
"""
def __init__(self, params, device_assigner, training=True,
- tree_variables_class=TreeVariables, tree_configs=None, tree_stats=None):
+ tree_variables_class=TreeVariables,
+ tree_configs=None, tree_stats=None):
self.variables = []
# Set up some scalar variables to run through the device assigner, then
# we can use those to colocate everything related to a tree.
@@ -349,10 +350,11 @@ class ForestVariables(object):
with ops.device(self.device_dummies[i].device):
kwargs = {}
if tree_configs is not None:
- kwargs.update(dict(tree_config=tree_configs[i]))
+ kwargs.update(dict(tree_config=tree_configs[i]))
if tree_stats is not None:
- kwargs.update(dict(tree_stat=tree_stats[i]))
- self.variables.append(tree_variables_class(params, i, training, **kwargs))
+ kwargs.update(dict(tree_stat=tree_stats[i]))
+ self.variables.append(tree_variables_class(
+ params, i, training, **kwargs))
def __setitem__(self, t, val):
self.variables[t] = val
@@ -380,7 +382,8 @@ class RandomForestGraphs(object):
logging.info(self.params.__dict__)
self.variables = variables or ForestVariables(
self.params, device_assigner=self.device_assigner, training=training,
- tree_variables_class=tree_variables_class, tree_configs=tree_configs, tree_stats=tree_stats)
+ tree_variables_class=tree_variables_class,
+ tree_configs=tree_configs, tree_stats=tree_stats)
tree_graph_class = tree_graphs or RandomTreeGraphs
self.trees = [
tree_graph_class(self.variables[i], self.params, i)
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index cf50ba25f6..7c5883d447 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -125,7 +125,22 @@ class TensorForestTest(test_util.TensorFlowTestCase):
num_trees=1,
max_nodes=1000,
split_after_samples=25).fill()
- tree_weight = {'decisionTree': {'nodes': [{'binaryNode': {'rightChildId': 2, 'leftChildId': 1, 'inequalityLeftChildTest': {'featureId': {'id': '0'}, 'threshold': {'floatValue': 0}}}}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 1}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 2}]}}
+ tree_weight = {'decisionTree':
+ {'nodes':
+ [{'binaryNode':
+ {'rightChildId': 2,
+ 'leftChildId': 1,
+ 'inequalityLeftChildTest':
+ {'featureId': {'id': '0'},
+ 'threshold': {'floatValue': 0}}}},
+ {'leaf': {'vector':
+ {'value': [{'floatValue': 0.0},
+ {'floatValue': 1.0}]}},
+ 'nodeId': 1},
+ {'leaf': {'vector':
+ {'value': [{'floatValue': 0.0},
+ {'floatValue': 1.0}]}},
+ 'nodeId': 2}]}}
restored_tree_param = ParseDict(tree_weight, _tree_proto.Model()).SerializeToString()
graph_builder = tensor_forest.RandomForestGraphs(hparams, [restored_tree_param])
probs, paths, var = graph_builder.inference_graph(input_data)