diff options
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py')
-rw-r--r-- | tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py | 70 |
1 files changed, 62 insertions, 8 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py index e61085657a..aaead5610f 100644 --- a/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py +++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import tensorflow # pylint: disable=unused-import +from tensorflow.contrib.tensor_forest.python import constants from tensorflow.contrib.tensor_forest.python.ops import inference_ops from tensorflow.python.framework import test_util @@ -29,6 +30,7 @@ class TreePredictionsTest(test_util.TensorFlowTestCase): def setUp(self): self.ops = inference_ops.Load() + self.data_spec = [constants.DATA_FLOAT] * 2 def testSimple(self): input_data = [[-1., 0.], [-1., 2.], # node 1 @@ -41,13 +43,65 @@ class TreePredictionsTest(test_util.TensorFlowTestCase): with self.test_session(): predictions = self.ops.tree_predictions( - input_data, tree, tree_thresholds, node_pcw, - valid_leaf_threshold=1) + input_data, [], [], [], self.data_spec, tree, tree_thresholds, + node_pcw, valid_leaf_threshold=1) self.assertAllClose([[0.1, 0.1, 0.8], [0.1, 0.1, 0.8], [0.5, 0.25, 0.25], [0.5, 0.25, 0.25]], predictions.eval()) + def testSparseInput(self): + sparse_shape = [3, 10] + sparse_indices = [[0, 0], [0, 4], [0, 9], + [1, 0], [1, 7], + [2, 0]] + sparse_values = [3.0, -1.0, 0.5, + 1.5, 6.0, + -2.0] + sparse_data_spec = [constants.DATA_FLOAT] * 10 + + tree = [[1, 0], [-1, 0], [-1, 0]] + tree_thresholds = [0., 0., 0.] + node_pcw = [[1.0, 0.3, 0.4, 0.3], [1.0, 0.1, 0.1, 0.8], + [1.0, 0.5, 0.25, 0.25]] + + with self.test_session(): + predictions = self.ops.tree_predictions( + [], sparse_indices, sparse_values, sparse_shape, sparse_data_spec, + tree, tree_thresholds, node_pcw, + valid_leaf_threshold=1) + + self.assertAllClose([[0.5, 0.25, 0.25], + [0.5, 0.25, 0.25], + [0.1, 0.1, 0.8]], + predictions.eval()) + + def testSparseInputDefaultIsZero(self): + sparse_shape = [3, 10] + sparse_indices = [[0, 0], [0, 4], [0, 9], + [1, 0], [1, 7], + [2, 0]] + sparse_values = [3.0, -1.0, 0.5, + 1.5, 6.0, + -2.0] + sparse_data_spec = [constants.DATA_FLOAT] * 10 + + tree = [[1, 7], [-1, 0], [-1, 0]] + tree_thresholds = [3.0, 0., 0.] + node_pcw = [[1.0, 0.3, 0.4, 0.3], [1.0, 0.1, 0.1, 0.8], + [1.0, 0.5, 0.25, 0.25]] + + with self.test_session(): + predictions = self.ops.tree_predictions( + [], sparse_indices, sparse_values, sparse_shape, sparse_data_spec, + tree, tree_thresholds, node_pcw, + valid_leaf_threshold=1) + + self.assertAllClose([[0.1, 0.1, 0.8], + [0.5, 0.25, 0.25], + [0.1, 0.1, 0.8]], + predictions.eval()) + def testBackoffToParent(self): input_data = [[-1., 0.], [-1., 2.], # node 1 [1., 0.], [1., -2.]] # node 2 @@ -59,8 +113,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase): with self.test_session(): predictions = self.ops.tree_predictions( - input_data, tree, tree_thresholds, node_pcw, - valid_leaf_threshold=10) + input_data, [], [], [], self.data_spec, tree, tree_thresholds, + node_pcw, valid_leaf_threshold=10) # Node 2 has enough data, but Node 1 needs to combine with the parent # counts. @@ -78,8 +132,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase): with self.test_session(): predictions = self.ops.tree_predictions( - input_data, tree, tree_thresholds, node_pcw, - valid_leaf_threshold=10) + input_data, [], [], [], self.data_spec, tree, tree_thresholds, + node_pcw, valid_leaf_threshold=10) self.assertEquals((0, 3), predictions.eval().shape) @@ -97,8 +151,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase): 'Number of nodes should be the same in tree, tree_thresholds ' 'and node_pcw.'): predictions = self.ops.tree_predictions( - input_data, tree, tree_thresholds, node_pcw, - valid_leaf_threshold=10) + input_data, [], [], [], self.data_spec, tree, tree_thresholds, + node_pcw, valid_leaf_threshold=10) self.assertEquals((0, 3), predictions.eval().shape) |