diff options
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python/tensor_forest_test.py')
-rw-r--r-- | tensorflow/contrib/tensor_forest/python/tensor_forest_test.py | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py index bbe627b157..1c9c81827e 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py @@ -18,10 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from google.protobuf.json_format import ParseDict +from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2 as _tree_proto from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.ops import resources +from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -110,6 +114,47 @@ class TensorForestTest(test_util.TensorFlowTestCase): self.assertTrue(isinstance(paths, ops.Tensor)) self.assertTrue(isinstance(var, ops.Tensor)) + def testInfrenceFromRestoredModel(self): + input_data = [[-1., 0.], [-1., 2.], # node 1 + [1., 0.], [1., -2.]] # node 2 + expected_prediction = [[0.0, 1.0], [0.0, 1.0], + [0.0, 1.0], [0.0, 1.0]] + hparams = tensor_forest.ForestHParams( + num_classes=2, + num_features=2, + num_trees=1, + max_nodes=1000, + split_after_samples=25).fill() + tree_weight = {'decisionTree': + {'nodes': + [{'binaryNode': + {'rightChildId': 2, + 'leftChildId': 1, + 'inequalityLeftChildTest': + {'featureId': {'id': '0'}, + 'threshold': {'floatValue': 0}}}}, + {'leaf': {'vector': + {'value': [{'floatValue': 0.0}, + {'floatValue': 1.0}]}}, + 'nodeId': 1}, + {'leaf': {'vector': + {'value': [{'floatValue': 0.0}, + {'floatValue': 1.0}]}}, + 'nodeId': 2}]}} + restored_tree_param = ParseDict(tree_weight, + _tree_proto.Model()).SerializeToString() + graph_builder = tensor_forest.RandomForestGraphs(hparams, + [restored_tree_param]) + probs, paths, var = graph_builder.inference_graph(input_data) + self.assertTrue(isinstance(probs, ops.Tensor)) + self.assertTrue(isinstance(paths, ops.Tensor)) + self.assertTrue(isinstance(var, ops.Tensor)) + with self.test_session(): + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + self.assertEquals(probs.eval().shape, (4, 2)) + self.assertEquals(probs.eval().tolist(), expected_prediction) + def testTrainingConstructionClassificationSparse(self): input_data = sparse_tensor.SparseTensor( indices=[[0, 0], [0, 3], [1, 0], [1, 7], [2, 1], [3, 9]], |