aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python/tensor_forest_test.py')
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest_test.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index bbe627b157..1c9c81827e 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -18,10 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from google.protobuf.json_format import ParseDict
+from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2 as _tree_proto
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resources
+from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -110,6 +114,47 @@ class TensorForestTest(test_util.TensorFlowTestCase):
self.assertTrue(isinstance(paths, ops.Tensor))
self.assertTrue(isinstance(var, ops.Tensor))
+ def testInfrenceFromRestoredModel(self):
+ input_data = [[-1., 0.], [-1., 2.], # node 1
+ [1., 0.], [1., -2.]] # node 2
+ expected_prediction = [[0.0, 1.0], [0.0, 1.0],
+ [0.0, 1.0], [0.0, 1.0]]
+ hparams = tensor_forest.ForestHParams(
+ num_classes=2,
+ num_features=2,
+ num_trees=1,
+ max_nodes=1000,
+ split_after_samples=25).fill()
+ tree_weight = {'decisionTree':
+ {'nodes':
+ [{'binaryNode':
+ {'rightChildId': 2,
+ 'leftChildId': 1,
+ 'inequalityLeftChildTest':
+ {'featureId': {'id': '0'},
+ 'threshold': {'floatValue': 0}}}},
+ {'leaf': {'vector':
+ {'value': [{'floatValue': 0.0},
+ {'floatValue': 1.0}]}},
+ 'nodeId': 1},
+ {'leaf': {'vector':
+ {'value': [{'floatValue': 0.0},
+ {'floatValue': 1.0}]}},
+ 'nodeId': 2}]}}
+ restored_tree_param = ParseDict(tree_weight,
+ _tree_proto.Model()).SerializeToString()
+ graph_builder = tensor_forest.RandomForestGraphs(hparams,
+ [restored_tree_param])
+ probs, paths, var = graph_builder.inference_graph(input_data)
+ self.assertTrue(isinstance(probs, ops.Tensor))
+ self.assertTrue(isinstance(paths, ops.Tensor))
+ self.assertTrue(isinstance(var, ops.Tensor))
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ resources.initialize_resources(resources.shared_resources()).run()
+ self.assertEquals(probs.eval().shape, (4, 2))
+ self.assertEquals(probs.eval().tolist(), expected_prediction)
+
def testTrainingConstructionClassificationSparse(self):
input_data = sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 3], [1, 0], [1, 7], [2, 1], [3, 9]],