diff options
author | 2017-10-10 18:40:50 -0700 | |
---|---|---|
committer | 2017-10-10 18:45:14 -0700 | |
commit | 9a7e849472c954470de889cc8873223e4db1e4df (patch) | |
tree | 3e37c08c3de7fbd2435c4372b73b7a229e53e80d | |
parent | d4d5e1510f2404ff1dafaa83171b0dcaec5fdfeb (diff) |
* Passing `training_features` (without weight column) instead of `features` into GradientBoostedDecisionTreeModel.
* Export GTFlow model into generic format with features defined in proto.
PiperOrigin-RevId: 171766066
-rw-r--r-- | tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py | 9 | ||||
-rw-r--r-- | tensorflow/contrib/boosted_trees/estimator_batch/model.py | 2 |
2 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 7773125c16..a800c3ddc7 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -96,7 +96,8 @@ def make_custom_export_strategy(name, def convert_to_universal_format(dtec, sorted_feature_names, num_dense, num_sparse_float, - num_sparse_int): + num_sparse_int, + feature_name_to_proto=None): """Convert GTFlow trees to universal format.""" del num_sparse_int # unused. model_and_features = generic_tree_model_pb2.ModelAndFeatures() @@ -104,7 +105,11 @@ def convert_to_universal_format(dtec, sorted_feature_names, # feature is processed before it's fed to the model (e.g. bucketing # information). As of now, this serves as a list of features the model uses. for feature_name in sorted_feature_names: - model_and_features.features[feature_name].SetInParent() + if not feature_name_to_proto: + model_and_features.features[feature_name].SetInParent() + else: + model_and_features.features[feature_name].CopyFrom( + feature_name_to_proto[feature_name]) model = model_and_features.model model.ensemble.summation_combination_technique.SetInParent() for tree_idx in range(len(dtec.trees)): diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 8cda5c8f2b..c6455a7ea3 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -93,7 +93,7 @@ def model_builder(features, labels, mode, params, config): learner_config=learner_config, feature_columns=feature_columns, logits_dimension=head.logits_dimension, - features=features) + features=training_features) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) logits = predictions_dict["predictions"] |