diff options
author | Peng Yu <peng.yu@shopify.com> | 2018-04-12 15:38:48 -0400 |
---|---|---|
committer | Peng Yu <peng.yu@shopify.com> | 2018-05-21 13:34:57 -0400 |
commit | 2d2f0455d215c7a3f4353ab879cbb513b7946a78 (patch) | |
tree | 8897477459c071f51978b69775ad62a136af5389 /tensorflow/contrib/tensor_forest/python | |
parent | 2041cb05f8b03ac94daf9dcfd1387d7954862557 (diff) |
address comments
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python')
-rw-r--r-- | tensorflow/contrib/tensor_forest/python/tensor_forest_test.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py index b6e70d457e..cf50ba25f6 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py @@ -114,7 +114,7 @@ class TensorForestTest(test_util.TensorFlowTestCase): self.assertTrue(isinstance(paths, ops.Tensor)) self.assertTrue(isinstance(var, ops.Tensor)) - def testInfrenceWithPreTrainedParams(self): + def testInfrenceFromRestoredModel(self): input_data = [[-1., 0.], [-1., 2.], # node 1 [1., 0.], [1., -2.]] # node 2 expected_prediction = [[0.0, 1.0], [0.0, 1.0], @@ -126,8 +126,8 @@ class TensorForestTest(test_util.TensorFlowTestCase): max_nodes=1000, split_after_samples=25).fill() tree_weight = {'decisionTree': {'nodes': [{'binaryNode': {'rightChildId': 2, 'leftChildId': 1, 'inequalityLeftChildTest': {'featureId': {'id': '0'}, 'threshold': {'floatValue': 0}}}}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 1}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 2}]}} - tree_param = ParseDict(tree_weight, _tree_proto.Model()).SerializeToString() - graph_builder = tensor_forest.RandomForestGraphs(hparams, [tree_param]) + restored_tree_param = ParseDict(tree_weight, _tree_proto.Model()).SerializeToString() + graph_builder = tensor_forest.RandomForestGraphs(hparams, [restored_tree_param]) probs, paths, var = graph_builder.inference_graph(input_data) self.assertTrue(isinstance(probs, ops.Tensor)) self.assertTrue(isinstance(paths, ops.Tensor)) |