aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
blob: 48f12a64f94c7bd0531488ef537b199558e17e3e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Strategy to export custom proto formats."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import os

from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
from tensorflow.contrib.decision_trees.proto import generic_tree_model_extensions_pb2
from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2
from tensorflow.contrib.learn.python.learn import export_strategy
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import ops
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import loader as saved_model_loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.util import compat

_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d"


def make_custom_export_strategy(name,
                                convert_fn,
                                feature_columns,
                                export_input_fn,
                                use_core_columns=False):
  """Makes custom exporter of GTFlow tree format.

  Args:
    name: A string, for the name of the export strategy.
    convert_fn: A function that converts the tree proto to desired format and
      saves it to the desired location. Can be None to skip conversion.
    feature_columns: A list of feature columns.
    export_input_fn: A function that takes no arguments and returns an
      `InputFnOps`.
    use_core_columns: A boolean, whether core feature columns were used.

  Returns:
    An `ExportStrategy`.
  """
  base_strategy = saved_model_export_utils.make_export_strategy(
      serving_input_fn=export_input_fn, strip_default_attrs=True)
  input_fn = export_input_fn()
  (sorted_feature_names, dense_floats, sparse_float_indices, _, _,
   sparse_int_indices, _, _) = gbdt_batch.extract_features(
       input_fn.features, feature_columns, use_core_columns)

  def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None):
    """A wrapper to export to SavedModel, and convert it to other formats."""
    result_dir = base_strategy.export(estimator, export_dir,
                                      checkpoint_path,
                                      eval_result)
    with ops.Graph().as_default() as graph:
      with tf_session.Session(graph=graph) as sess:
        saved_model_loader.load(
            sess, [tag_constants.SERVING], result_dir)
        # Note: This is GTFlow internal API and might change.
        ensemble_model = graph.get_operation_by_name(
            "ensemble_model/TreeEnsembleSerialize")
        _, dfec_str = sess.run(ensemble_model.outputs)
        dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
        dtec.ParseFromString(dfec_str)
        # Export the result in the same folder as the saved model.
        if convert_fn:
          convert_fn(dtec, sorted_feature_names,
                     len(dense_floats),
                     len(sparse_float_indices),
                     len(sparse_int_indices), result_dir, eval_result)
        feature_importances = _get_feature_importances(
            dtec, sorted_feature_names,
            len(dense_floats),
            len(sparse_float_indices), len(sparse_int_indices))
        sorted_by_importance = sorted(
            feature_importances.items(), key=lambda x: -x[1])
        assets_dir = os.path.join(
            compat.as_bytes(result_dir), compat.as_bytes("assets.extra"))
        gfile.MakeDirs(assets_dir)
        with gfile.GFile(os.path.join(
            compat.as_bytes(assets_dir),
            compat.as_bytes("feature_importances")), "w") as f:
          f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
    return result_dir

  return export_strategy.ExportStrategy(
      name, export_fn, strip_default_attrs=True)


def convert_to_universal_format(dtec, sorted_feature_names,
                                num_dense, num_sparse_float,
                                num_sparse_int,
                                feature_name_to_proto=None):
  """Convert GTFlow trees to universal format."""
  del num_sparse_int  # unused.
  model_and_features = generic_tree_model_pb2.ModelAndFeatures()
  # TODO(jonasz): Feature descriptions should contain information about how each
  # feature is processed before it's fed to the model (e.g. bucketing
  # information). As of now, this serves as a list of features the model uses.
  for feature_name in sorted_feature_names:
    if not feature_name_to_proto:
      model_and_features.features[feature_name].SetInParent()
    else:
      model_and_features.features[feature_name].CopyFrom(
          feature_name_to_proto[feature_name])
  model = model_and_features.model
  model.ensemble.summation_combination_technique.SetInParent()
  for tree_idx in range(len(dtec.trees)):
    gtflow_tree = dtec.trees[tree_idx]
    tree_weight = dtec.tree_weights[tree_idx]
    member = model.ensemble.members.add()
    member.submodel_id.value = tree_idx
    tree = member.submodel.decision_tree
    for node_idx in range(len(gtflow_tree.nodes)):
      gtflow_node = gtflow_tree.nodes[node_idx]
      node = tree.nodes.add()
      node_type = gtflow_node.WhichOneof("node")
      node.node_id.value = node_idx
      if node_type == "leaf":
        leaf = gtflow_node.leaf
        if leaf.HasField("vector"):
          for weight in leaf.vector.value:
            new_value = node.leaf.vector.value.add()
            new_value.float_value = weight * tree_weight
        else:
          for index, weight in zip(
              leaf.sparse_vector.index, leaf.sparse_vector.value):
            new_value = node.leaf.sparse_vector.sparse_value[index]
            new_value.float_value = weight * tree_weight
      else:
        node = node.binary_node
        # Binary nodes here.
        if node_type == "dense_float_binary_split":
          split = gtflow_node.dense_float_binary_split
          feature_id = split.feature_column
          inequality_test = node.inequality_left_child_test
          inequality_test.feature_id.id.value = sorted_feature_names[feature_id]
          inequality_test.type = (
              generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
          inequality_test.threshold.float_value = split.threshold
        elif node_type == "sparse_float_binary_split_default_left":
          split = gtflow_node.sparse_float_binary_split_default_left.split
          node.default_direction = (generic_tree_model_pb2.BinaryNode.LEFT)
          feature_id = split.feature_column + num_dense
          inequality_test = node.inequality_left_child_test
          inequality_test.feature_id.id.value = (
              _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE %
              (sorted_feature_names[feature_id], split.dimension_id))
          model_and_features.features.pop(sorted_feature_names[feature_id])
          (model_and_features.features[inequality_test.feature_id.id.value]
           .SetInParent())
          inequality_test.type = (
              generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
          inequality_test.threshold.float_value = split.threshold
        elif node_type == "sparse_float_binary_split_default_right":
          split = gtflow_node.sparse_float_binary_split_default_right.split
          node.default_direction = (
              generic_tree_model_pb2.BinaryNode.RIGHT)
          # TODO(nponomareva): adjust this id assignement when we allow multi-
          # column sparse tensors.
          feature_id = split.feature_column + num_dense
          inequality_test = node.inequality_left_child_test
          inequality_test.feature_id.id.value = (
              _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE %
              (sorted_feature_names[feature_id], split.dimension_id))
          model_and_features.features.pop(sorted_feature_names[feature_id])
          (model_and_features.features[inequality_test.feature_id.id.value]
           .SetInParent())
          inequality_test.type = (
              generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
          inequality_test.threshold.float_value = split.threshold
        elif node_type == "categorical_id_binary_split":
          split = gtflow_node.categorical_id_binary_split
          node.default_direction = generic_tree_model_pb2.BinaryNode.RIGHT
          feature_id = split.feature_column + num_dense + num_sparse_float
          categorical_test = (
              generic_tree_model_extensions_pb2.MatchingValuesTest())
          categorical_test.feature_id.id.value = sorted_feature_names[
              feature_id]
          matching_id = categorical_test.value.add()
          matching_id.int64_value = split.feature_id
          node.custom_left_child_test.Pack(categorical_test)
        else:
          raise ValueError("Unexpected node type %s" % node_type)
        node.left_child_id.value = split.left_id
        node.right_child_id.value = split.right_id
  return model_and_features


def _get_feature_importances(dtec, feature_names, num_dense_floats,
                             num_sparse_float, num_sparse_int):
  """Export the feature importance per feature column."""
  del num_sparse_int    # Unused.
  sums = collections.defaultdict(lambda: 0)
  for tree_idx in range(len(dtec.trees)):
    tree = dtec.trees[tree_idx]
    for tree_node in tree.nodes:
      node_type = tree_node.WhichOneof("node")
      if node_type == "dense_float_binary_split":
        split = tree_node.dense_float_binary_split
        split_column = feature_names[split.feature_column]
      elif node_type == "sparse_float_binary_split_default_left":
        split = tree_node.sparse_float_binary_split_default_left.split
        split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % (
            feature_names[split.feature_column + num_dense_floats],
            split.dimension_id)
      elif node_type == "sparse_float_binary_split_default_right":
        split = tree_node.sparse_float_binary_split_default_right.split
        split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % (
            feature_names[split.feature_column + num_dense_floats],
            split.dimension_id)
      elif node_type == "categorical_id_binary_split":
        split = tree_node.categorical_id_binary_split
        split_column = feature_names[split.feature_column + num_dense_floats +
                                     num_sparse_float]
      elif node_type == "categorical_id_set_membership_binary_split":
        split = tree_node.categorical_id_set_membership_binary_split
        split_column = feature_names[split.feature_column + num_dense_floats +
                                     num_sparse_float]
      elif node_type == "leaf":
        assert tree_node.node_metadata.gain == 0
        continue
      else:
        raise ValueError("Unexpected split type %s" % node_type)
      # Apply shrinkage factor. It is important since it is not always uniform
      # across different trees.
      sums[split_column] += (
          tree_node.node_metadata.gain * dtec.tree_weights[tree_idx])
  return dict(sums)