aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/estimator_batch
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-12 00:24:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-12 00:28:21 -0800
commit4149938e5e1aa72d794caa127fdcd741a6abf90d (patch)
treeb65c301bc89ff514ac33cd09e1e8ebe81ca9fcb7 /tensorflow/contrib/boosted_trees/estimator_batch
parent769adf0c2fac4e3de40f42506b8cc85901f8569c (diff)
Adjust feature name in `features` field of generic tree proto for Sparse Float Features
PiperOrigin-RevId: 181712201
Diffstat (limited to 'tensorflow/contrib/boosted_trees/estimator_batch')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py6
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py4
2 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
index 6ebc7d7911..31f5c44481 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
@@ -155,6 +155,9 @@ def convert_to_universal_format(dtec, sorted_feature_names,
inequality_test.feature_id.id.value = (
_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE %
(sorted_feature_names[feature_id], split.dimension_id))
+ model_and_features.features.pop(sorted_feature_names[feature_id])
+ (model_and_features.features[inequality_test.feature_id.id.value]
+ .SetInParent())
inequality_test.type = (
generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
inequality_test.threshold.float_value = split.threshold
@@ -169,6 +172,9 @@ def convert_to_universal_format(dtec, sorted_feature_names,
inequality_test.feature_id.id.value = (
_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE %
(sorted_feature_names[feature_id], split.dimension_id))
+ model_and_features.features.pop(sorted_feature_names[feature_id])
+ (model_and_features.features[inequality_test.feature_id.id.value]
+ .SetInParent())
inequality_test.type = (
generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
inequality_test.threshold.float_value = split.threshold
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py
index 492d9ca40c..67ec0e16bf 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py
@@ -150,8 +150,8 @@ class ConvertModelTest(test_util.TensorFlowTestCase):
dtec, feature_columns, 1, 2, 1)
# Features a and a_m are sparse float features, a_m is multidimensional.
expected_tree = """
- features { key: "feature_a" }
- features { key: "feature_a_m" }
+ features { key: "feature_a_0" }
+ features { key: "feature_a_m_3" }
features { key: "feature_b" }
features { key: "feature_d" }
model {