aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/canned/boosted_trees_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/canned/boosted_trees_test.py')
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py675
1 files changed, 664 insertions, 11 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 33e9e69b04..f807641057 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -554,14 +554,6 @@ class ModelFnTests(test_util.TensorFlowTestCase):
feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
BUCKET_BOUNDARIES) for i in range(NUM_FEATURES)
}
- self._tree_hparams = boosted_trees._TreeHParams( # pylint:disable=protected-access
- n_trees=2,
- max_depth=2,
- learning_rate=0.1,
- l1=0.,
- l2=0.01,
- tree_complexity=0.,
- min_node_weight=0.)
def _get_expected_ensembles_for_classification(self):
first_round = """
@@ -790,6 +782,245 @@ class ModelFnTests(test_util.TensorFlowTestCase):
"""
return (first_round, second_round, third_round)
+ def _get_expected_ensembles_for_classification_with_bias(self):
+ first_round = """
+ trees {
+ nodes {
+ leaf {
+ scalar: -0.405086
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ }
+ """
+ second_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.407711
+ original_leaf {
+ scalar: -0.405086
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.556054
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.301233
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ last_layer_node_start: 1
+ last_layer_node_end: 3
+ }
+ """
+ third_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.407711
+ original_leaf {
+ scalar: -0.405086
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 3
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ original_leaf {
+ scalar: -0.556054
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 0
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 0.09876
+ original_leaf {
+ scalar: -0.301233
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.698072
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.556054
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.106016
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.27349
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 2
+ is_finalized: true
+ }
+ tree_metadata {
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ last_layer_node_end: 1
+ }
+ """
+ forth_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.4077113
+ original_leaf {
+ scalar: -0.405086
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ threshold: 3
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ original_leaf {
+ scalar: -0.556054
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ threshold: 0
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 0.09876
+ original_leaf {
+ scalar: -0.301233
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.698072
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.556054
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.106016
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.27349
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.289927
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.134588
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.083838
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 2
+ is_finalized: true
+ }
+ tree_metadata {
+ num_layers_grown: 1
+ }
+ growing_metadata {
+ num_trees_attempted: 2
+ num_layers_attempted: 3
+ last_layer_node_start: 1
+ last_layer_node_end: 3
+ }
+ """
+ return (first_round, second_round, third_round, forth_round)
+
def _get_expected_ensembles_for_regression(self):
first_round = """
trees {
@@ -1017,17 +1248,275 @@ class ModelFnTests(test_util.TensorFlowTestCase):
"""
return (first_round, second_round, third_round)
- def _get_train_op_and_ensemble(self, head, config, is_classification,
- train_in_memory):
+ def _get_expected_ensembles_for_regression_with_bias(self):
+ first_round = """
+ trees {
+ nodes {
+ leaf {
+ scalar: 1.799974
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ }
+ """
+ second_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.190442
+ original_leaf {
+ scalar: 1.799974
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.862786
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.706149
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ last_layer_node_start: 1
+ last_layer_node_end: 3
+ }
+ """
+ third_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.190442
+ original_leaf {
+ scalar: 1.799974
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 1
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.683594
+ original_leaf {
+ scalar: 1.862786
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 0
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 0.322693
+ original_leaf {
+ scalar: 1.706149
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 2.024487
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.710319
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.559208
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.686037
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 2
+ is_finalized: true
+ }
+ tree_metadata {
+ num_layers_grown: 0
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ last_layer_node_start: 0
+ last_layer_node_end: 1
+ }
+ """
+ forth_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.190442
+ original_leaf {
+ scalar: 1.799974
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ threshold: 1
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.683594
+ original_leaf {
+ scalar: 1.8627863
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 0.322693
+ original_leaf {
+ scalar: 1.706149
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 2.024487
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.710319
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.5592078
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.686037
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.972589
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.137592
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.034926
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 2
+ is_finalized: true
+ }
+ tree_metadata {
+ num_layers_grown: 1
+ }
+ growing_metadata {
+ num_trees_attempted: 2
+ num_layers_attempted: 3
+ last_layer_node_start: 1
+ last_layer_node_end: 3
+ }
+ """
+ return (first_round, second_round, third_round, forth_round)
+
+ def _get_train_op_and_ensemble(self,
+ head,
+ config,
+ is_classification,
+ train_in_memory,
+ center_bias=False):
"""Calls bt_model_fn() and returns the train_op and ensemble_serialzed."""
features, labels = _make_train_input_fn(is_classification)()
+
+ tree_hparams = boosted_trees._TreeHParams( # pylint:disable=protected-access
+ n_trees=2,
+ max_depth=2,
+ learning_rate=0.1,
+ l1=0.,
+ l2=0.01,
+ tree_complexity=0.,
+ min_node_weight=0.,
+ center_bias=center_bias)
+
estimator_spec = boosted_trees._bt_model_fn( # pylint:disable=protected-access
features=features,
labels=labels,
mode=model_fn.ModeKeys.TRAIN,
head=head,
feature_columns=self._feature_columns,
- tree_hparams=self._tree_hparams,
+ tree_hparams=tree_hparams,
example_id_column_name=EXAMPLE_ID_COLUMN,
n_batches_per_layer=1,
config=config,
@@ -1076,6 +1565,49 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ensemble_proto.ParseFromString(serialized)
self.assertProtoEquals(expected_third, ensemble_proto)
+ def testTrainClassifierWithCenterBiasInMemory(self):
+ ops.reset_default_graph()
+
+ # When bias centering is on, we expect the very first node to have the
+ expected_first, expected_second, expected_third, expected_forth = (
+ self._get_expected_ensembles_for_classification_with_bias())
+
+ with self.test_session() as sess:
+ with sess.graph.as_default():
+ train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+ boosted_trees._create_classification_head(n_classes=2),
+ run_config.RunConfig(),
+ is_classification=True,
+ train_in_memory=True,
+ center_bias=True)
+
+ # 4 iterations to center bias.
+ for _ in range(4):
+ _, serialized = sess.run([train_op, ensemble_serialized])
+
+ # Validate the trained ensemble.
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_first, ensemble_proto)
+
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_second, ensemble_proto)
+
+ # Third round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_third, ensemble_proto)
+
+ # Forth round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+
+ self.assertProtoEquals(expected_forth, ensemble_proto)
+
def testTrainClassifierNonInMemory(self):
ops.reset_default_graph()
expected_first, expected_second, expected_third = (
@@ -1106,6 +1638,47 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ensemble_proto.ParseFromString(serialized)
self.assertProtoEquals(expected_third, ensemble_proto)
+ def testTrainClassifierWithCenterBiasNonInMemory(self):
+ ops.reset_default_graph()
+
+ # When bias centering is on, we expect the very first node to have the
+ expected_first, expected_second, expected_third, expected_forth = (
+ self._get_expected_ensembles_for_classification_with_bias())
+
+ with self.test_session() as sess:
+ with sess.graph.as_default():
+ train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+ boosted_trees._create_classification_head(n_classes=2),
+ run_config.RunConfig(),
+ is_classification=True,
+ train_in_memory=False,
+ center_bias=True)
+ # 4 iterations to center bias.
+ for _ in range(4):
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ # Validate the trained ensemble.
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_first, ensemble_proto)
+
+ # Run one more time and validate the trained ensemble.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_second, ensemble_proto)
+
+ # Third round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_third, ensemble_proto)
+
+ # Forth round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_forth, ensemble_proto)
+
def testTrainRegressorInMemory(self):
ops.reset_default_graph()
expected_first, expected_second, expected_third = (
@@ -1136,6 +1709,46 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ensemble_proto.ParseFromString(serialized)
self.assertProtoEquals(expected_third, ensemble_proto)
+ def testTrainRegressorInMemoryWithCenterBias(self):
+ ops.reset_default_graph()
+ expected_first, expected_second, expected_third, expected_forth = (
+ self._get_expected_ensembles_for_regression_with_bias())
+ with self.test_session() as sess:
+ # Train with train_in_memory mode.
+ with sess.graph.as_default():
+ train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+ boosted_trees._create_regression_head(label_dimension=1),
+ run_config.RunConfig(),
+ is_classification=False,
+ train_in_memory=True,
+ center_bias=True)
+ # 3 iterations to center bias.
+ for _ in range(3):
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ # Validate the trained ensemble.
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+
+ self.assertProtoEquals(expected_first, ensemble_proto)
+
+ # Run one more time and validate the trained ensemble.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_second, ensemble_proto)
+
+ # Third round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_third, ensemble_proto)
+
+ # Forth round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_forth, ensemble_proto)
+
def testTrainRegressorNonInMemory(self):
ops.reset_default_graph()
expected_first, expected_second, expected_third = (
@@ -1166,6 +1779,46 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ensemble_proto.ParseFromString(serialized)
self.assertProtoEquals(expected_third, ensemble_proto)
+ def testTrainRegressorNotInMemoryWithCenterBias(self):
+ ops.reset_default_graph()
+ expected_first, expected_second, expected_third, expected_forth = (
+ self._get_expected_ensembles_for_regression_with_bias())
+ with self.test_session() as sess:
+ # Train with train_in_memory mode.
+ with sess.graph.as_default():
+ train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+ boosted_trees._create_regression_head(label_dimension=1),
+ run_config.RunConfig(),
+ is_classification=False,
+ train_in_memory=False,
+ center_bias=True)
+ # 3 iterations to center the bias (because we are using regularization).
+ for _ in range(3):
+ _, serialized = sess.run([train_op, ensemble_serialized])
+
+ # Validate the trained ensemble.
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_first, ensemble_proto)
+
+ # Run one more time and validate the trained ensemble.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_second, ensemble_proto)
+
+ # Third round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_third, ensemble_proto)
+
+ # Forth round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_forth, ensemble_proto)
+
if __name__ == '__main__':
googletest.main()