aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-08-21 14:48:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 15:00:13 -0700
commit0f02f05913e03889bbcb85e71a6d005a8519bfb9 (patch)
treec5b2bacb1b96d260b67cb56c208ce8f1b1025dae /tensorflow/contrib/boosted_trees
parent3f24f93c2a32b2eae8951e5b272c3b647c5b9611 (diff)
Merged commit includes the following changes:
209663919 by yifeif<yifeif@google.com>: Internal change. -- 209663914 by amitpatankar<amitpatankar@google.com>: Fix the topk_op_test for numpy>1.15. -- 209660476 by jdduke<jdduke@google.com>: Fix model lifetime for TensorFlow Lite C# bindings Ensure the model's existence for the duration of the interpreter, as per API requirements. -- 209655960 by scottzhu<scottzhu@google.com>: Unify RNN Cell interface between TF and Keras. -- 209655731 by A. Unique TensorFlower<gardener@tensorflow.org>: Added tests for PredictionOps and PartitionExamplesOps -- 209655291 by nolivia<nolivia@google.com>: adding rate class so that we can save global_step/sec using tf.contrib.summary. The function takes the rate in relation to any tensors provided that the numerator and denominator are broadcastable and have dtypes that can be cast to float64 -- 209654655 by kramerb<kramerb@google.com>: [XLA] Switch from tensorflow::gtl::InlinedVector to absl::InlinedVector This one comes with extra goodies like a move constructor. -- 209653851 by A. Unique TensorFlower<gardener@tensorflow.org>: Internal build specification change -- PiperOrigin-RevId: 209663919
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py185
1 files changed, 153 insertions, 32 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
index cf55759aaa..bef42fdf7f 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
@@ -96,6 +96,20 @@ def _set_float_split(split, feat_col, thresh, l_id, r_id, feature_dim_id=None):
split.dimension_id = feature_dim_id
+def _set_float_oblivious_split(split, feat_col, thresh):
+ """Helper method for building tree float splits.
+
+ Sets split feature column and threshold.
+
+ Args:
+ split: split node to update.
+ feat_col: feature column for the split.
+ thresh: threshold to split on forming rule x <= thresh.
+ """
+ split.feature_column = feat_col
+ split.threshold = thresh
+
+
def _set_categorical_id_split(split, feat_col, feat_id, l_id, r_id):
"""Helper method for building tree categorical id splits.
@@ -119,15 +133,17 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
def setUp(self):
"""Sets up the prediction tests.
- Create a batch of two examples having one dense float, two sparse float
+ Creates, a batch of two examples having three dense float, two sparse float
single valued, one sparse float multidimensional and one sparse int
features. The data looks like the following:
- | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 | SparseM
- | 0 | 7 | -3 | | 9,1 | __, 5.0
- | 1 | -2 | | 4 | | 3, ___
+ |Instance |Dense0 |Dense1 |Dense2 |SparseF0 |SparseF1 |SparseI0 |SparseM
+ | 0 | 7 | 1 | 2 | -3 | | 9,1 | __, 5.0
+ | 1 | -2 | 2 | 0.5 | | 4 | | 3, ___
"""
super(PredictionOpsTest, self).setUp()
- self._dense_float_tensor = np.array([[7.0], [-2.0]])
+ self._dense_float_tensor1 = np.array([[7.0], [-2.0]])
+ self._dense_float_tensor2 = np.array([[1.0], [2.0]])
+ self._dense_float_tensor3 = np.array([[2.0], [0.5]])
self._sparse_float_indices1 = np.array([[0, 0]])
self._sparse_float_values1 = np.array([-3.0])
self._sparse_float_shape1 = np.array([2, 1])
@@ -153,7 +169,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
reduce_dim=False):
return prediction_ops.gradient_trees_prediction(
tree_ensemble_handle,
- self._seed, [self._dense_float_tensor],
+ self._seed, [self._dense_float_tensor1],
[self._sparse_float_indices1, self._sparse_float_indices2],
[self._sparse_float_values1, self._sparse_float_values2],
[self._sparse_float_shape1, self._sparse_float_shape2],
@@ -165,6 +181,25 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
center_bias=center_bias,
reduce_dim=reduce_dim)
+ def _get_predictions_oblivious_case(self,
+ tree_ensemble_handle,
+ learner_config,
+ apply_dropout=False,
+ apply_averaging=False,
+ center_bias=False,
+ reduce_dim=False):
+ return prediction_ops.gradient_trees_prediction(
+ tree_ensemble_handle,
+ self._seed, [
+ self._dense_float_tensor1, self._dense_float_tensor2,
+ self._dense_float_tensor3
+ ], [], [], [], [], [], [],
+ learner_config=learner_config,
+ apply_dropout=apply_dropout,
+ apply_averaging=apply_averaging,
+ center_bias=center_bias,
+ reduce_dim=reduce_dim)
+
def testEmptyEnsemble(self):
with self.test_session():
# Empty tree ensenble.
@@ -295,6 +330,53 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
# Empty dropout.
self.assertAllEqual([[], []], dropout_info.eval())
+ def testObliviousEnsemble(self):
+ with self.test_session():
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ # Bias tree.
+ tree1 = tree_ensemble_config.trees.add()
+ tree_ensemble_config.tree_metadata.add().is_finalized = True
+ _append_to_leaf(tree1.nodes.add().leaf, 0, -0.4)
+
+ # Depth 3 tree.
+ tree2 = tree_ensemble_config.trees.add()
+ _set_float_oblivious_split(
+ tree2.nodes.add().oblivious_dense_float_binary_split, 0, 5.0)
+ _set_float_oblivious_split(
+ tree2.nodes.add().oblivious_dense_float_binary_split, 1, 3.0)
+ _set_float_oblivious_split(
+ tree2.nodes.add().oblivious_dense_float_binary_split, 2, 1.0)
+ for i in range(1, 9):
+ _append_to_leaf(tree2.nodes.add().leaf, 0, i / 10.0)
+
+ tree_ensemble_config.tree_weights.append(1.0)
+ tree_ensemble_config.tree_weights.append(1.0)
+
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="full_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare learner config.
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+
+ result, dropout_info = self._get_predictions_oblivious_case(
+ tree_ensemble_handle,
+ learner_config=learner_config.SerializeToString(),
+ reduce_dim=True)
+
+ # The first example will get bias -0.4 from first tree and 0.6 from
+ # the 5th leaf of the second tree corresponding to node_id = 8, hence a
+ # prediction of 0.2.
+ # The second example will get bias -0.4 and 0.1 from the 0th leaf of the
+ # second tree corresponding to node_id = 3, hence a prediction of -0.3
+ self.assertAllClose([[0.2], [-0.3]], result.eval())
+
+ # Empty dropout.
+ self.assertAllEqual([[], []], dropout_info.eval())
+
def testFullEnsembleWithMultidimensionalSparseSingleClass(self):
with self.test_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
@@ -358,7 +440,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
result, dropout_info = prediction_ops.gradient_trees_prediction(
tree_ensemble_handle,
- self._seed, [self._dense_float_tensor], [
+ self._seed, [self._dense_float_tensor1], [
self._sparse_float_indices1, self._sparse_float_indices2,
self._sparse_float_indices_m
], [
@@ -917,7 +999,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
# Different seed.
_, dropout_info_3 = prediction_ops.gradient_trees_prediction(
tree_ensemble_handle,
- 112314, [self._dense_float_tensor],
+ 112314, [self._dense_float_tensor1],
[self._sparse_float_indices1, self._sparse_float_indices2],
[self._sparse_float_values1, self._sparse_float_values2],
[self._sparse_float_shape1, self._sparse_float_shape2],
@@ -1204,15 +1286,18 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
def setUp(self):
"""Sets up the prediction tests.
- Create a batch of two examples having one dense float, two sparse float and
- one sparse int features.
+ Create a batch of two examples having three dense float, two sparse float
+ and one sparse int features.
The data looks like the following:
- | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 |
- | 0 | 7 | -3 | | 9,1 |
- | 1 | -2 | | 4 | |
+ |Instance |Dense0 |Dense1 |Dense2 |SparseF0 |SparseF1 |SparseI0 |
+ | 0 | 7 | 1 | 2 | -3 | | 9,1 |
+ | 1 | -2 | 2 | 0.5 | | 4 | |
+
"""
super(PartitionExamplesOpsTest, self).setUp()
- self._dense_float_tensor = np.array([[7.0], [-2.0]])
+ self._dense_float_tensor1 = np.array([[7.0], [-2.0]])
+ self._dense_float_tensor2 = np.array([[1.0], [2.0]])
+ self._dense_float_tensor3 = np.array([[2.0], [0.5]])
self._sparse_float_indices1 = np.array([[0, 0]])
self._sparse_float_values1 = np.array([-3.0])
self._sparse_float_shape1 = np.array([2, 1])
@@ -1234,12 +1319,12 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
resources.initialize_resources(resources.shared_resources()).run()
result = prediction_ops.gradient_trees_partition_examples(
- tree_ensemble_handle, [self._dense_float_tensor], [
- self._sparse_float_indices1, self._sparse_float_indices2
- ], [self._sparse_float_values1, self._sparse_float_values2],
- [self._sparse_float_shape1,
- self._sparse_float_shape2], [self._sparse_int_indices1],
- [self._sparse_int_values1], [self._sparse_int_shape1])
+ tree_ensemble_handle, [self._dense_float_tensor1],
+ [self._sparse_float_indices1, self._sparse_float_indices2],
+ [self._sparse_float_values1, self._sparse_float_values2],
+ [self._sparse_float_shape1, self._sparse_float_shape2],
+ [self._sparse_int_indices1], [self._sparse_int_values1],
+ [self._sparse_int_shape1])
self.assertAllEqual([0, 0], result.eval())
@@ -1269,12 +1354,12 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
resources.initialize_resources(resources.shared_resources()).run()
result = prediction_ops.gradient_trees_partition_examples(
- tree_ensemble_handle, [self._dense_float_tensor], [
- self._sparse_float_indices1, self._sparse_float_indices2
- ], [self._sparse_float_values1, self._sparse_float_values2],
- [self._sparse_float_shape1,
- self._sparse_float_shape2], [self._sparse_int_indices1],
- [self._sparse_int_values1], [self._sparse_int_shape1])
+ tree_ensemble_handle, [self._dense_float_tensor1],
+ [self._sparse_float_indices1, self._sparse_float_indices2],
+ [self._sparse_float_values1, self._sparse_float_values2],
+ [self._sparse_float_shape1, self._sparse_float_shape2],
+ [self._sparse_int_indices1], [self._sparse_int_values1],
+ [self._sparse_int_shape1])
self.assertAllEqual([5, 3], result.eval())
@@ -1304,15 +1389,51 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
resources.initialize_resources(resources.shared_resources()).run()
result = prediction_ops.gradient_trees_partition_examples(
- tree_ensemble_handle, [self._dense_float_tensor], [
- self._sparse_float_indices1, self._sparse_float_indices2
- ], [self._sparse_float_values1, self._sparse_float_values2],
- [self._sparse_float_shape1,
- self._sparse_float_shape2], [self._sparse_int_indices1],
- [self._sparse_int_values1], [self._sparse_int_shape1])
+ tree_ensemble_handle, [self._dense_float_tensor1],
+ [self._sparse_float_indices1, self._sparse_float_indices2],
+ [self._sparse_float_values1, self._sparse_float_values2],
+ [self._sparse_float_shape1, self._sparse_float_shape2],
+ [self._sparse_int_indices1], [self._sparse_int_values1],
+ [self._sparse_int_shape1])
self.assertAllEqual([0, 0], result.eval())
+ def testObliviousTreeNonFinalized(self):
+ with self.test_session():
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ # Depth 3 tree.
+ tree1 = tree_ensemble_config.trees.add()
+ _set_float_oblivious_split(
+ tree1.nodes.add().oblivious_dense_float_binary_split, 0, 5.0)
+ _set_float_oblivious_split(
+ tree1.nodes.add().oblivious_dense_float_binary_split, 1, 3.0)
+ _set_float_oblivious_split(
+ tree1.nodes.add().oblivious_dense_float_binary_split, 2, 1.0)
+ for i in range(1, 9):
+ _append_to_leaf(tree1.nodes.add().leaf, 0, i / 10.0)
+ tree_ensemble_config.tree_weights.append(1.0)
+ tree_ensemble_config.tree_metadata.add().is_finalized = False
+
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="full_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ result = prediction_ops.gradient_trees_partition_examples(
+ tree_ensemble_handle, [
+ self._dense_float_tensor1,
+ self._dense_float_tensor2,
+ self._dense_float_tensor3
+ ], [], [], [], [], [], [])
+
+ # The first example goes right, left, right in the tree and the second
+ # example goes lef, left, left. Since the depth of the tree is 3, the
+ # partition id's are as follows:
+ # First example: 3 + 5 = 8
+ # Second exampel: 3 + 0 = 3
+ self.assertAllEqual([8, 3], result.eval())
+
if __name__ == "__main__":
googletest.main()