aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/estimator_batch
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-10 18:40:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-10 18:45:14 -0700
commit9a7e849472c954470de889cc8873223e4db1e4df (patch)
tree3e37c08c3de7fbd2435c4372b73b7a229e53e80d /tensorflow/contrib/boosted_trees/estimator_batch
parentd4d5e1510f2404ff1dafaa83171b0dcaec5fdfeb (diff)
* Passing `training_features` (without weight column) instead of `features` into GradientBoostedDecisionTreeModel.
* Export GTFlow model into generic format with features defined in proto. PiperOrigin-RevId: 171766066
Diffstat (limited to 'tensorflow/contrib/boosted_trees/estimator_batch')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py9
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py2
2 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
index 7773125c16..a800c3ddc7 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
@@ -96,7 +96,8 @@ def make_custom_export_strategy(name,
def convert_to_universal_format(dtec, sorted_feature_names,
num_dense, num_sparse_float,
- num_sparse_int):
+ num_sparse_int,
+ feature_name_to_proto=None):
"""Convert GTFlow trees to universal format."""
del num_sparse_int # unused.
model_and_features = generic_tree_model_pb2.ModelAndFeatures()
@@ -104,7 +105,11 @@ def convert_to_universal_format(dtec, sorted_feature_names,
# feature is processed before it's fed to the model (e.g. bucketing
# information). As of now, this serves as a list of features the model uses.
for feature_name in sorted_feature_names:
- model_and_features.features[feature_name].SetInParent()
+ if not feature_name_to_proto:
+ model_and_features.features[feature_name].SetInParent()
+ else:
+ model_and_features.features[feature_name].CopyFrom(
+ feature_name_to_proto[feature_name])
model = model_and_features.model
model.ensemble.summation_combination_technique.SetInParent()
for tree_idx in range(len(dtec.trees)):
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 8cda5c8f2b..c6455a7ea3 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -93,7 +93,7 @@ def model_builder(features, labels, mode, params, config):
learner_config=learner_config,
feature_columns=feature_columns,
logits_dimension=head.logits_dimension,
- features=features)
+ features=training_features)
with ops.name_scope("gbdt", "gbdt_optimizer"):
predictions_dict = gbdt_model.predict(mode)
logits = predictions_dict["predictions"]