aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py')
-rw-r--r--tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py70
1 files changed, 62 insertions, 8 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py
index e61085657a..aaead5610f 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import tensorflow # pylint: disable=unused-import
+from tensorflow.contrib.tensor_forest.python import constants
from tensorflow.contrib.tensor_forest.python.ops import inference_ops
from tensorflow.python.framework import test_util
@@ -29,6 +30,7 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
def setUp(self):
self.ops = inference_ops.Load()
+ self.data_spec = [constants.DATA_FLOAT] * 2
def testSimple(self):
input_data = [[-1., 0.], [-1., 2.], # node 1
@@ -41,13 +43,65 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
with self.test_session():
predictions = self.ops.tree_predictions(
- input_data, tree, tree_thresholds, node_pcw,
- valid_leaf_threshold=1)
+ input_data, [], [], [], self.data_spec, tree, tree_thresholds,
+ node_pcw, valid_leaf_threshold=1)
self.assertAllClose([[0.1, 0.1, 0.8], [0.1, 0.1, 0.8],
[0.5, 0.25, 0.25], [0.5, 0.25, 0.25]],
predictions.eval())
+ def testSparseInput(self):
+ sparse_shape = [3, 10]
+ sparse_indices = [[0, 0], [0, 4], [0, 9],
+ [1, 0], [1, 7],
+ [2, 0]]
+ sparse_values = [3.0, -1.0, 0.5,
+ 1.5, 6.0,
+ -2.0]
+ sparse_data_spec = [constants.DATA_FLOAT] * 10
+
+ tree = [[1, 0], [-1, 0], [-1, 0]]
+ tree_thresholds = [0., 0., 0.]
+ node_pcw = [[1.0, 0.3, 0.4, 0.3], [1.0, 0.1, 0.1, 0.8],
+ [1.0, 0.5, 0.25, 0.25]]
+
+ with self.test_session():
+ predictions = self.ops.tree_predictions(
+ [], sparse_indices, sparse_values, sparse_shape, sparse_data_spec,
+ tree, tree_thresholds, node_pcw,
+ valid_leaf_threshold=1)
+
+ self.assertAllClose([[0.5, 0.25, 0.25],
+ [0.5, 0.25, 0.25],
+ [0.1, 0.1, 0.8]],
+ predictions.eval())
+
+ def testSparseInputDefaultIsZero(self):
+ sparse_shape = [3, 10]
+ sparse_indices = [[0, 0], [0, 4], [0, 9],
+ [1, 0], [1, 7],
+ [2, 0]]
+ sparse_values = [3.0, -1.0, 0.5,
+ 1.5, 6.0,
+ -2.0]
+ sparse_data_spec = [constants.DATA_FLOAT] * 10
+
+ tree = [[1, 7], [-1, 0], [-1, 0]]
+ tree_thresholds = [3.0, 0., 0.]
+ node_pcw = [[1.0, 0.3, 0.4, 0.3], [1.0, 0.1, 0.1, 0.8],
+ [1.0, 0.5, 0.25, 0.25]]
+
+ with self.test_session():
+ predictions = self.ops.tree_predictions(
+ [], sparse_indices, sparse_values, sparse_shape, sparse_data_spec,
+ tree, tree_thresholds, node_pcw,
+ valid_leaf_threshold=1)
+
+ self.assertAllClose([[0.1, 0.1, 0.8],
+ [0.5, 0.25, 0.25],
+ [0.1, 0.1, 0.8]],
+ predictions.eval())
+
def testBackoffToParent(self):
input_data = [[-1., 0.], [-1., 2.], # node 1
[1., 0.], [1., -2.]] # node 2
@@ -59,8 +113,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
with self.test_session():
predictions = self.ops.tree_predictions(
- input_data, tree, tree_thresholds, node_pcw,
- valid_leaf_threshold=10)
+ input_data, [], [], [], self.data_spec, tree, tree_thresholds,
+ node_pcw, valid_leaf_threshold=10)
# Node 2 has enough data, but Node 1 needs to combine with the parent
# counts.
@@ -78,8 +132,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
with self.test_session():
predictions = self.ops.tree_predictions(
- input_data, tree, tree_thresholds, node_pcw,
- valid_leaf_threshold=10)
+ input_data, [], [], [], self.data_spec, tree, tree_thresholds,
+ node_pcw, valid_leaf_threshold=10)
self.assertEquals((0, 3), predictions.eval().shape)
@@ -97,8 +151,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
'Number of nodes should be the same in tree, tree_thresholds '
'and node_pcw.'):
predictions = self.ops.tree_predictions(
- input_data, tree, tree_thresholds, node_pcw,
- valid_leaf_threshold=10)
+ input_data, [], [], [], self.data_spec, tree, tree_thresholds,
+ node_pcw, valid_leaf_threshold=10)
self.assertEquals((0, 3), predictions.eval().shape)