diff options
author | 2018-01-12 00:24:57 -0800 | |
---|---|---|
committer | 2018-01-12 00:28:21 -0800 | |
commit | 4149938e5e1aa72d794caa127fdcd741a6abf90d (patch) | |
tree | b65c301bc89ff514ac33cd09e1e8ebe81ca9fcb7 /tensorflow/contrib/boosted_trees/estimator_batch | |
parent | 769adf0c2fac4e3de40f42506b8cc85901f8569c (diff) |
Adjust feature name in `features` field of generic tree proto for Sparse Float Features
PiperOrigin-RevId: 181712201
Diffstat (limited to 'tensorflow/contrib/boosted_trees/estimator_batch')
-rw-r--r-- | tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py | 4 |
2 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 6ebc7d7911..31f5c44481 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -155,6 +155,9 @@ def convert_to_universal_format(dtec, sorted_feature_names, inequality_test.feature_id.id.value = ( _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % (sorted_feature_names[feature_id], split.dimension_id)) + model_and_features.features.pop(sorted_feature_names[feature_id]) + (model_and_features.features[inequality_test.feature_id.id.value] + .SetInParent()) inequality_test.type = ( generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) inequality_test.threshold.float_value = split.threshold @@ -169,6 +172,9 @@ def convert_to_universal_format(dtec, sorted_feature_names, inequality_test.feature_id.id.value = ( _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % (sorted_feature_names[feature_id], split.dimension_id)) + model_and_features.features.pop(sorted_feature_names[feature_id]) + (model_and_features.features[inequality_test.feature_id.id.value] + .SetInParent()) inequality_test.type = ( generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) inequality_test.threshold.float_value = split.threshold diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py index 492d9ca40c..67ec0e16bf 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py @@ -150,8 +150,8 @@ class ConvertModelTest(test_util.TensorFlowTestCase): dtec, feature_columns, 1, 2, 1) # Features a and a_m are sparse float features, a_m is multidimensional. expected_tree = """ - features { key: "feature_a" } - features { key: "feature_a_m" } + features { key: "feature_a_0" } + features { key: "feature_a_m_3" } features { key: "feature_b" } features { key: "feature_d" } model { |