aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py9
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py2
2 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
index 7773125c16..a800c3ddc7 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
@@ -96,7 +96,8 @@ def make_custom_export_strategy(name,
def convert_to_universal_format(dtec, sorted_feature_names,
num_dense, num_sparse_float,
- num_sparse_int):
+ num_sparse_int,
+ feature_name_to_proto=None):
"""Convert GTFlow trees to universal format."""
del num_sparse_int # unused.
model_and_features = generic_tree_model_pb2.ModelAndFeatures()
@@ -104,7 +105,11 @@ def convert_to_universal_format(dtec, sorted_feature_names,
# feature is processed before it's fed to the model (e.g. bucketing
# information). As of now, this serves as a list of features the model uses.
for feature_name in sorted_feature_names:
- model_and_features.features[feature_name].SetInParent()
+ if not feature_name_to_proto:
+ model_and_features.features[feature_name].SetInParent()
+ else:
+ model_and_features.features[feature_name].CopyFrom(
+ feature_name_to_proto[feature_name])
model = model_and_features.model
model.ensemble.summation_combination_technique.SetInParent()
for tree_idx in range(len(dtec.trees)):
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 8cda5c8f2b..c6455a7ea3 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -93,7 +93,7 @@ def model_builder(features, labels, mode, params, config):
learner_config=learner_config,
feature_columns=feature_columns,
logits_dimension=head.logits_dimension,
- features=features)
+ features=training_features)
with ops.name_scope("gbdt", "gbdt_optimizer"):
predictions_dict = gbdt_model.predict(mode)
logits = predictions_dict["predictions"]