aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest
diff options
context:
space:
mode:
authorGravatar Peng Yu <peng.yu@shopify.com>2018-04-12 15:38:48 -0400
committerGravatar Peng Yu <peng.yu@shopify.com>2018-05-21 13:34:57 -0400
commit2d2f0455d215c7a3f4353ab879cbb513b7946a78 (patch)
tree8897477459c071f51978b69775ad62a136af5389 /tensorflow/contrib/tensor_forest
parent2041cb05f8b03ac94daf9dcfd1387d7954862557 (diff)
address comments
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest_test.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index b6e70d457e..cf50ba25f6 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -114,7 +114,7 @@ class TensorForestTest(test_util.TensorFlowTestCase):
self.assertTrue(isinstance(paths, ops.Tensor))
self.assertTrue(isinstance(var, ops.Tensor))
- def testInfrenceWithPreTrainedParams(self):
+ def testInfrenceFromRestoredModel(self):
input_data = [[-1., 0.], [-1., 2.], # node 1
[1., 0.], [1., -2.]] # node 2
expected_prediction = [[0.0, 1.0], [0.0, 1.0],
@@ -126,8 +126,8 @@ class TensorForestTest(test_util.TensorFlowTestCase):
max_nodes=1000,
split_after_samples=25).fill()
tree_weight = {'decisionTree': {'nodes': [{'binaryNode': {'rightChildId': 2, 'leftChildId': 1, 'inequalityLeftChildTest': {'featureId': {'id': '0'}, 'threshold': {'floatValue': 0}}}}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 1}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 2}]}}
- tree_param = ParseDict(tree_weight, _tree_proto.Model()).SerializeToString()
- graph_builder = tensor_forest.RandomForestGraphs(hparams, [tree_param])
+ restored_tree_param = ParseDict(tree_weight, _tree_proto.Model()).SerializeToString()
+ graph_builder = tensor_forest.RandomForestGraphs(hparams, [restored_tree_param])
probs, paths, var = graph_builder.inference_graph(input_data)
self.assertTrue(isinstance(probs, ops.Tensor))
self.assertTrue(isinstance(paths, ops.Tensor))