aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py44
1 files changed, 43 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
index 13b804875e..d55240297a 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
@@ -139,6 +139,49 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
self.assertEqual(new_stamp, 1)
self.assertProtoEquals(expected_result, tree_ensemble)
+ def testBiasCenteringOnEmptyEnsemble(self):
+ """Test growing with bias centering on an empty ensemble."""
+ with self.test_session() as session:
+ # Create empty ensemble.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ gradients = np.array([[5.]], dtype=np.float32)
+ hessians = np.array([[24.]], dtype=np.float32)
+
+ # Grow tree ensemble.
+ grow_op = boosted_trees_ops.center_bias(
+ tree_ensemble_handle,
+ mean_gradients=gradients,
+ mean_hessians=hessians,
+ l1=0.0,
+ l2=1.0
+ )
+ session.run(grow_op)
+
+ new_stamp, serialized = session.run(tree_ensemble.serialize())
+
+ tree_ensemble = boosted_trees_pb2.TreeEnsemble()
+ tree_ensemble.ParseFromString(serialized)
+
+ expected_result = """
+ trees {
+ nodes {
+ leaf {
+ scalar: -0.2
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 0
+ is_finalized: false
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertProtoEquals(expected_result, tree_ensemble)
+
def testGrowExistingEnsembleTreeNotFinalized(self):
"""Test growing an existing ensemble with the last tree not finalized."""
with self.test_session() as session:
@@ -666,7 +709,6 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
num_layers_attempted: 1
last_layer_node_start: 1
last_layer_node_end: 3
-
}
""", tree_ensemble_config)