diff options
author | 2018-07-23 14:51:53 -0700 | |
---|---|---|
committer | 2018-07-23 15:10:56 -0700 | |
commit | 32d121be1b105ef44fd5d4b421b78eb74dc94870 (patch) | |
tree | fe08311d30eb6a4229e1eeaa5dff2a30f1e3c47f /tensorflow/contrib/boosted_trees | |
parent | 3baa7b63edf7890b5489cf2085a79598f13af2c6 (diff) |
Adding core interface to a contrib version
PiperOrigin-RevId: 205728990
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
3 files changed, 124 insertions, 22 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 59a78515c6..38fa8c3834 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -22,6 +22,7 @@ from tensorflow.contrib.boosted_trees.estimator_batch import model from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.ops import math_ops @@ -354,3 +355,45 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) + + +class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): + """An estimator using gradient boosted decision trees.""" + + def __init__(self, + learner_config, + examples_per_layer, + head, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=True, + output_leaf_index=False): + + def _model_fn(features, labels, mode, config): + return model.model_builder( + features=features, + labels=labels, + mode=mode, + config=config, + params={ + 'head': head, + 'feature_columns': feature_columns, + 'learner_config': learner_config, + 'num_trees': num_trees, + 'weight_column_name': weight_column_name, + 'examples_per_layer': examples_per_layer, + 'center_bias': center_bias, + 'logits_modifier_function': logits_modifier_function, + 'use_core_libs': True, + 'output_leaf_index': output_leaf_index, + }, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) + + super(CoreGradientBoostedDecisionTreeEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index 2c2dcb039d..f787d3cdb8 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -182,7 +182,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): config = run_config.RunConfig() head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( - loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) model = estimator.GradientBoostedDecisionTreeRanker( head=head_fn, @@ -203,5 +203,32 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): model.predict(input_fn=_infer_ranking_train_input_fn) +class CoreGradientBoostedDecisionTreeEstimator(test_util.TensorFlowTestCase): + + def testTrainEvaluateInferDoesNotThrowError(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + est = estimator.CoreGradientBoostedDecisionTreeEstimator( + head=head_fn, + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + # Train for a few steps. + est.train(input_fn=_train_input_fn, steps=1000) + est.evaluate(input_fn=_eval_input_fn, steps=1) + est.predict(input_fn=_eval_input_fn) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 0e8a56e6e9..2fbe72951a 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -29,7 +29,17 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import state_ops from tensorflow.python.training import training_util -def model_builder(features, labels, mode, params, config): +class ModelBuilderOutputType(object): + MODEL_FN_OPS = 0 + ESTIMATOR_SPEC = 1 + + +def model_builder(features, + labels, + mode, + params, + config, + output_type=ModelBuilderOutputType.MODEL_FN_OPS): """Multi-machine batch gradient descent tree model. Args: @@ -115,31 +125,53 @@ def model_builder(features, labels, mode, params, config): return update_op create_estimator_spec_op = getattr(head, "create_estimator_spec", None) - if use_core_libs and callable(create_estimator_spec_op): - model_fn_ops = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_train_op_fn, - logits=logits) - model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops) - else: - model_fn_ops = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_train_op_fn, - logits=logits) - if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: - model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ - gbdt_batch.LEAF_INDEX] + if num_trees: if center_bias: num_trees += 1 finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() - model_fn_ops.training_hooks.append( + training_hooks = [ trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, - finalized_trees)) + finalized_trees) + ] + + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + if use_core_libs and callable(create_estimator_spec_op): + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( + model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: + model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ + gbdt_batch.LEAF_INDEX] + + model_fn_ops.training_hooks.extend(training_hooks) + return model_fn_ops + elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC: + assert callable(create_estimator_spec_op) + estimator_spec = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + estimator_spec = estimator_spec._replace( + training_hooks=training_hooks + list(estimator_spec.training_hooks)) + return estimator_spec + return model_fn_ops |