diff options
Diffstat (limited to 'tensorflow/python/estimator/canned/boosted_trees_test.py')
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees_test.py | 675 |
1 files changed, 664 insertions, 11 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py index 33e9e69b04..f807641057 100644 --- a/tensorflow/python/estimator/canned/boosted_trees_test.py +++ b/tensorflow/python/estimator/canned/boosted_trees_test.py @@ -554,14 +554,6 @@ class ModelFnTests(test_util.TensorFlowTestCase): feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), BUCKET_BOUNDARIES) for i in range(NUM_FEATURES) } - self._tree_hparams = boosted_trees._TreeHParams( # pylint:disable=protected-access - n_trees=2, - max_depth=2, - learning_rate=0.1, - l1=0., - l2=0.01, - tree_complexity=0., - min_node_weight=0.) def _get_expected_ensembles_for_classification(self): first_round = """ @@ -790,6 +782,245 @@ class ModelFnTests(test_util.TensorFlowTestCase): """ return (first_round, second_round, third_round) + def _get_expected_ensembles_for_classification_with_bias(self): + first_round = """ + trees { + nodes { + leaf { + scalar: -0.405086 + } + } + } + tree_weights: 1.0 + tree_metadata { + } + """ + second_round = """ + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.407711 + original_leaf { + scalar: -0.405086 + } + } + } + nodes { + leaf { + scalar: -0.556054 + } + } + nodes { + leaf { + scalar: -0.301233 + } + } + } + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 1 + is_finalized: false + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 1 + last_layer_node_start: 1 + last_layer_node_end: 3 + } + """ + third_round = """ + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.407711 + original_leaf { + scalar: -0.405086 + } + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 3 + left_id: 3 + right_id: 4 + } + metadata { + original_leaf { + scalar: -0.556054 + } + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 0 + left_id: 5 + right_id: 6 + } + metadata { + gain: 0.09876 + original_leaf { + scalar: -0.301233 + } + } + } + nodes { + leaf { + scalar: -0.698072 + } + } + nodes { + leaf { + scalar: -0.556054 + } + } + nodes { + leaf { + scalar: -0.106016 + } + } + nodes { + leaf { + scalar: -0.27349 + } + } + } + trees { + nodes { + leaf { + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 2 + is_finalized: true + } + tree_metadata { + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 2 + last_layer_node_end: 1 + } + """ + forth_round = """ + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.4077113 + original_leaf { + scalar: -0.405086 + } + } + } + nodes { + bucketized_split { + threshold: 3 + left_id: 3 + right_id: 4 + } + metadata { + original_leaf { + scalar: -0.556054 + } + } + } + nodes { + bucketized_split { + threshold: 0 + left_id: 5 + right_id: 6 + } + metadata { + gain: 0.09876 + original_leaf { + scalar: -0.301233 + } + } + } + nodes { + leaf { + scalar: -0.698072 + } + } + nodes { + leaf { + scalar: -0.556054 + } + } + nodes { + leaf { + scalar: -0.106016 + } + } + nodes { + leaf { + scalar: -0.27349 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.289927 + } + } + nodes { + leaf { + scalar: -0.134588 + } + } + nodes { + leaf { + scalar: 0.083838 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 2 + is_finalized: true + } + tree_metadata { + num_layers_grown: 1 + } + growing_metadata { + num_trees_attempted: 2 + num_layers_attempted: 3 + last_layer_node_start: 1 + last_layer_node_end: 3 + } + """ + return (first_round, second_round, third_round, forth_round) + def _get_expected_ensembles_for_regression(self): first_round = """ trees { @@ -1017,17 +1248,275 @@ class ModelFnTests(test_util.TensorFlowTestCase): """ return (first_round, second_round, third_round) - def _get_train_op_and_ensemble(self, head, config, is_classification, - train_in_memory): + def _get_expected_ensembles_for_regression_with_bias(self): + first_round = """ + trees { + nodes { + leaf { + scalar: 1.799974 + } + } + } + tree_weights: 1.0 + tree_metadata { + } + """ + second_round = """ + trees { + nodes { + bucketized_split { + feature_id: 1 + threshold: 1 + left_id: 1 + right_id: 2 + } + metadata { + gain: 1.190442 + original_leaf { + scalar: 1.799974 + } + } + } + nodes { + leaf { + scalar: 1.862786 + } + } + nodes { + leaf { + scalar: 1.706149 + } + } + } + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 1 + is_finalized: false + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 1 + last_layer_node_start: 1 + last_layer_node_end: 3 + } + """ + third_round = """ + trees { + nodes { + bucketized_split { + feature_id: 1 + threshold: 1 + left_id: 1 + right_id: 2 + } + metadata { + gain: 1.190442 + original_leaf { + scalar: 1.799974 + } + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 1 + left_id: 3 + right_id: 4 + } + metadata { + gain: 2.683594 + original_leaf { + scalar: 1.862786 + } + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 0 + left_id: 5 + right_id: 6 + } + metadata { + gain: 0.322693 + original_leaf { + scalar: 1.706149 + } + } + } + nodes { + leaf { + scalar: 2.024487 + } + } + nodes { + leaf { + scalar: 1.710319 + } + } + nodes { + leaf { + scalar: 1.559208 + } + } + nodes { + leaf { + scalar: 1.686037 + } + } + } + trees { + nodes { + leaf { + scalar: 0.0 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 2 + is_finalized: true + } + tree_metadata { + num_layers_grown: 0 + is_finalized: false + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 2 + last_layer_node_start: 0 + last_layer_node_end: 1 + } + """ + forth_round = """ + trees { + nodes { + bucketized_split { + feature_id: 1 + threshold: 1 + left_id: 1 + right_id: 2 + } + metadata { + gain: 1.190442 + original_leaf { + scalar: 1.799974 + } + } + } + nodes { + bucketized_split { + threshold: 1 + left_id: 3 + right_id: 4 + } + metadata { + gain: 2.683594 + original_leaf { + scalar: 1.8627863 + } + } + } + nodes { + bucketized_split { + left_id: 5 + right_id: 6 + } + metadata { + gain: 0.322693 + original_leaf { + scalar: 1.706149 + } + } + } + nodes { + leaf { + scalar: 2.024487 + } + } + nodes { + leaf { + scalar: 1.710319 + } + } + nodes { + leaf { + scalar: 1.5592078 + } + } + nodes { + leaf { + scalar: 1.686037 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 1 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.972589 + } + } + nodes { + leaf { + scalar: -0.137592 + } + } + nodes { + leaf { + scalar: 0.034926 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 2 + is_finalized: true + } + tree_metadata { + num_layers_grown: 1 + } + growing_metadata { + num_trees_attempted: 2 + num_layers_attempted: 3 + last_layer_node_start: 1 + last_layer_node_end: 3 + } + """ + return (first_round, second_round, third_round, forth_round) + + def _get_train_op_and_ensemble(self, + head, + config, + is_classification, + train_in_memory, + center_bias=False): """Calls bt_model_fn() and returns the train_op and ensemble_serialzed.""" features, labels = _make_train_input_fn(is_classification)() + + tree_hparams = boosted_trees._TreeHParams( # pylint:disable=protected-access + n_trees=2, + max_depth=2, + learning_rate=0.1, + l1=0., + l2=0.01, + tree_complexity=0., + min_node_weight=0., + center_bias=center_bias) + estimator_spec = boosted_trees._bt_model_fn( # pylint:disable=protected-access features=features, labels=labels, mode=model_fn.ModeKeys.TRAIN, head=head, feature_columns=self._feature_columns, - tree_hparams=self._tree_hparams, + tree_hparams=tree_hparams, example_id_column_name=EXAMPLE_ID_COLUMN, n_batches_per_layer=1, config=config, @@ -1076,6 +1565,49 @@ class ModelFnTests(test_util.TensorFlowTestCase): ensemble_proto.ParseFromString(serialized) self.assertProtoEquals(expected_third, ensemble_proto) + def testTrainClassifierWithCenterBiasInMemory(self): + ops.reset_default_graph() + + # When bias centering is on, we expect the very first node to have the + expected_first, expected_second, expected_third, expected_forth = ( + self._get_expected_ensembles_for_classification_with_bias()) + + with self.test_session() as sess: + with sess.graph.as_default(): + train_op, ensemble_serialized = self._get_train_op_and_ensemble( + boosted_trees._create_classification_head(n_classes=2), + run_config.RunConfig(), + is_classification=True, + train_in_memory=True, + center_bias=True) + + # 4 iterations to center bias. + for _ in range(4): + _, serialized = sess.run([train_op, ensemble_serialized]) + + # Validate the trained ensemble. + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_first, ensemble_proto) + + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_second, ensemble_proto) + + # Third round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_third, ensemble_proto) + + # Forth round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + + self.assertProtoEquals(expected_forth, ensemble_proto) + def testTrainClassifierNonInMemory(self): ops.reset_default_graph() expected_first, expected_second, expected_third = ( @@ -1106,6 +1638,47 @@ class ModelFnTests(test_util.TensorFlowTestCase): ensemble_proto.ParseFromString(serialized) self.assertProtoEquals(expected_third, ensemble_proto) + def testTrainClassifierWithCenterBiasNonInMemory(self): + ops.reset_default_graph() + + # When bias centering is on, we expect the very first node to have the + expected_first, expected_second, expected_third, expected_forth = ( + self._get_expected_ensembles_for_classification_with_bias()) + + with self.test_session() as sess: + with sess.graph.as_default(): + train_op, ensemble_serialized = self._get_train_op_and_ensemble( + boosted_trees._create_classification_head(n_classes=2), + run_config.RunConfig(), + is_classification=True, + train_in_memory=False, + center_bias=True) + # 4 iterations to center bias. + for _ in range(4): + _, serialized = sess.run([train_op, ensemble_serialized]) + # Validate the trained ensemble. + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_first, ensemble_proto) + + # Run one more time and validate the trained ensemble. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_second, ensemble_proto) + + # Third round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_third, ensemble_proto) + + # Forth round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_forth, ensemble_proto) + def testTrainRegressorInMemory(self): ops.reset_default_graph() expected_first, expected_second, expected_third = ( @@ -1136,6 +1709,46 @@ class ModelFnTests(test_util.TensorFlowTestCase): ensemble_proto.ParseFromString(serialized) self.assertProtoEquals(expected_third, ensemble_proto) + def testTrainRegressorInMemoryWithCenterBias(self): + ops.reset_default_graph() + expected_first, expected_second, expected_third, expected_forth = ( + self._get_expected_ensembles_for_regression_with_bias()) + with self.test_session() as sess: + # Train with train_in_memory mode. + with sess.graph.as_default(): + train_op, ensemble_serialized = self._get_train_op_and_ensemble( + boosted_trees._create_regression_head(label_dimension=1), + run_config.RunConfig(), + is_classification=False, + train_in_memory=True, + center_bias=True) + # 3 iterations to center bias. + for _ in range(3): + _, serialized = sess.run([train_op, ensemble_serialized]) + # Validate the trained ensemble. + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + + self.assertProtoEquals(expected_first, ensemble_proto) + + # Run one more time and validate the trained ensemble. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_second, ensemble_proto) + + # Third round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_third, ensemble_proto) + + # Forth round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_forth, ensemble_proto) + def testTrainRegressorNonInMemory(self): ops.reset_default_graph() expected_first, expected_second, expected_third = ( @@ -1166,6 +1779,46 @@ class ModelFnTests(test_util.TensorFlowTestCase): ensemble_proto.ParseFromString(serialized) self.assertProtoEquals(expected_third, ensemble_proto) + def testTrainRegressorNotInMemoryWithCenterBias(self): + ops.reset_default_graph() + expected_first, expected_second, expected_third, expected_forth = ( + self._get_expected_ensembles_for_regression_with_bias()) + with self.test_session() as sess: + # Train with train_in_memory mode. + with sess.graph.as_default(): + train_op, ensemble_serialized = self._get_train_op_and_ensemble( + boosted_trees._create_regression_head(label_dimension=1), + run_config.RunConfig(), + is_classification=False, + train_in_memory=False, + center_bias=True) + # 3 iterations to center the bias (because we are using regularization). + for _ in range(3): + _, serialized = sess.run([train_op, ensemble_serialized]) + + # Validate the trained ensemble. + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_first, ensemble_proto) + + # Run one more time and validate the trained ensemble. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_second, ensemble_proto) + + # Third round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_third, ensemble_proto) + + # Forth round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_forth, ensemble_proto) + if __name__ == '__main__': googletest.main() |