diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py | 44 |
1 files changed, 43 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py index 13b804875e..d55240297a 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py @@ -139,6 +139,49 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): self.assertEqual(new_stamp, 1) self.assertProtoEquals(expected_result, tree_ensemble) + def testBiasCenteringOnEmptyEnsemble(self): + """Test growing with bias centering on an empty ensemble.""" + with self.test_session() as session: + # Create empty ensemble. + tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble') + tree_ensemble_handle = tree_ensemble.resource_handle + resources.initialize_resources(resources.shared_resources()).run() + + gradients = np.array([[5.]], dtype=np.float32) + hessians = np.array([[24.]], dtype=np.float32) + + # Grow tree ensemble. + grow_op = boosted_trees_ops.center_bias( + tree_ensemble_handle, + mean_gradients=gradients, + mean_hessians=hessians, + l1=0.0, + l2=1.0 + ) + session.run(grow_op) + + new_stamp, serialized = session.run(tree_ensemble.serialize()) + + tree_ensemble = boosted_trees_pb2.TreeEnsemble() + tree_ensemble.ParseFromString(serialized) + + expected_result = """ + trees { + nodes { + leaf { + scalar: -0.2 + } + } + } + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 0 + is_finalized: false + } + """ + self.assertEqual(new_stamp, 1) + self.assertProtoEquals(expected_result, tree_ensemble) + def testGrowExistingEnsembleTreeNotFinalized(self): """Test growing an existing ensemble with the last tree not finalized.""" with self.test_session() as session: @@ -666,7 +709,6 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): num_layers_attempted: 1 last_layer_node_start: 1 last_layer_node_end: 3 - } """, tree_ensemble_config) |