aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/client/random_forest.py
blob: 0042d37acdb5bd56b736702c7aa125e1b6ee040a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A tf.learn implementation of online extremely random forests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.contrib import layers
from tensorflow.contrib.estimator.python.estimator import head as core_head_lib
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.estimator.export.export_output import PredictOutput
from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util

KEYS_NAME = 'keys'
LOSS_NAME = 'rf_training_loss'
TREE_PATHS_PREDICTION_KEY = 'tree_paths'
VARIANCE_PREDICTION_KEY = 'prediction_variance'
ALL_SERVING_KEY = 'tensorforest_all'
EPSILON = 0.000001


class ModelBuilderOutputType(object):
  MODEL_FN_OPS = 0
  ESTIMATOR_SPEC = 1


class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook):

  def __init__(self, op_dict):
    """Ops is a dict of {name: op} to run before the session is destroyed."""
    self._ops = op_dict

  def end(self, session):
    for name in sorted(self._ops.keys()):
      logging.info('{0}: {1}'.format(name, session.run(self._ops[name])))


class TensorForestLossHook(session_run_hook.SessionRunHook):
  """Monitor to request stop when loss stops decreasing."""

  def __init__(self,
               early_stopping_rounds,
               early_stopping_loss_threshold=None,
               loss_op=None):
    self.early_stopping_rounds = early_stopping_rounds
    self.early_stopping_loss_threshold = early_stopping_loss_threshold
    self.loss_op = loss_op
    self.min_loss = None
    self.last_step = -1
    # self.steps records the number of steps for which the loss has been
    # non-decreasing
    self.steps = 0

  def before_run(self, run_context):
    loss = (self.loss_op if self.loss_op is not None else
            run_context.session.graph.get_operation_by_name(
                LOSS_NAME).outputs[0])
    return session_run_hook.SessionRunArgs(
        {'global_step': training_util.get_global_step(),
         'current_loss': loss})

  def after_run(self, run_context, run_values):
    current_loss = run_values.results['current_loss']
    current_step = run_values.results['global_step']
    self.steps += 1
    # Guard against the global step going backwards, which might happen
    # if we recover from something.
    if self.last_step == -1 or self.last_step > current_step:
      logging.info('TensorForestLossHook resetting last_step.')
      self.last_step = current_step
      self.steps = 0
      self.min_loss = None
      return

    self.last_step = current_step
    if (self.min_loss is None or current_loss <
        (self.min_loss - self.min_loss * self.early_stopping_loss_threshold)):
      self.min_loss = current_loss
      self.steps = 0
    if self.steps > self.early_stopping_rounds:
      logging.info('TensorForestLossHook requesting stop.')
      run_context.request_stop()


def _get_default_head(params, weights_name, output_type, name=None):
  """Creates a default head based on a type of a problem."""
  if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
    if params.regression:
      return head_lib.regression_head(
          weight_column_name=weights_name,
          label_dimension=params.num_outputs,
          enable_centered_bias=False,
          head_name=name)
    else:
      return head_lib.multi_class_head(
          params.num_classes,
          weight_column_name=weights_name,
          enable_centered_bias=False,
          head_name=name)
  else:
    if params.regression:
      return core_head_lib.regression_head(
          weight_column=weights_name,
          label_dimension=params.num_outputs,
          name=name,
          loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
    else:
      if params.num_classes == 2:
        return core_head_lib.binary_classification_head(
            weight_column=weights_name,
            name=name,
            loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
      else:
        return core_head_lib.multi_class_head(
            n_classes=params.num_classes,
            weight_column=weights_name,
            name=name,
            loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)

def get_model_fn(params,
                 graph_builder_class,
                 device_assigner,
                 feature_columns=None,
                 weights_name=None,
                 model_head=None,
                 keys_name=None,
                 early_stopping_rounds=100,
                 early_stopping_loss_threshold=0.001,
                 num_trainers=1,
                 trainer_id=0,
                 report_feature_importances=False,
                 local_eval=False,
                 head_scope=None,
                 include_all_in_serving=False,
                 output_type=ModelBuilderOutputType.MODEL_FN_OPS):
  """Return a model function given a way to construct a graph builder."""
  if model_head is None:
    model_head = _get_default_head(params, weights_name, output_type)

  def _model_fn(features, labels, mode):
    """Function that returns predictions, training loss, and training op."""

    if (isinstance(features, ops.Tensor) or
        isinstance(features, sparse_tensor.SparseTensor)):
      features = {'features': features}
    if feature_columns:
      features = features.copy()

      if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
        features.update(layers.transform_features(features, feature_columns))
      else:
        for fc in feature_columns:
          tensor = fc_core._transform_features(features, [fc])[fc]  # pylint: disable=protected-access
          features[fc.name] = tensor

    weights = None
    if weights_name and weights_name in features:
      weights = features.pop(weights_name)

    keys = None
    if keys_name and keys_name in features:
      keys = features.pop(keys_name)

    # If we're doing eval, optionally ignore device_assigner.
    # Also ignore device assigner if we're exporting (mode == INFER)
    dev_assn = device_assigner
    if (mode == model_fn_lib.ModeKeys.INFER or
        (local_eval and mode == model_fn_lib.ModeKeys.EVAL)):
      dev_assn = None

    graph_builder = graph_builder_class(params,
                                        device_assigner=dev_assn)

    logits, tree_paths, regression_variance = graph_builder.inference_graph(
        features)

    summary.scalar('average_tree_size', graph_builder.average_size())
    # For binary classification problems, convert probabilities to logits.
    # Includes hack to get around the fact that a probability might be 0 or 1.
    if not params.regression and params.num_classes == 2:
      class_1_probs = array_ops.slice(logits, [0, 1], [-1, 1])
      logits = math_ops.log(
          math_ops.maximum(class_1_probs / math_ops.maximum(
              1.0 - class_1_probs, EPSILON), EPSILON))

    # labels might be None if we're doing prediction (which brings up the
    # question of why we force everything to adhere to a single model_fn).
    training_graph = None
    training_hooks = []
    if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN:
      with ops.control_dependencies([logits.op]):
        training_graph = control_flow_ops.group(
            graph_builder.training_graph(
                features, labels, input_weights=weights,
                num_trainers=num_trainers,
                trainer_id=trainer_id),
            state_ops.assign_add(training_util.get_global_step(), 1))

    # Put weights back in
    if weights is not None:
      features[weights_name] = weights

    # TensorForest's training graph isn't calculated directly from the loss
    # like many other models.
    def _train_fn(unused_loss):
      return training_graph


    # Ops are run in lexigraphical order of their keys. Run the resource
    # clean-up op last.
    all_handles = graph_builder.get_all_resource_handles()
    ops_at_end = {
        '9: clean up resources':
            control_flow_ops.group(*[
                resource_variable_ops.destroy_resource_op(handle)
                for handle in all_handles
            ])
    }

    if report_feature_importances:
      ops_at_end['1: feature_importances'] = (
          graph_builder.feature_importances())

    training_hooks = [TensorForestRunOpAtEndHook(ops_at_end)]

    if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
      model_ops = model_head.create_model_fn_ops(
          features=features,
          labels=labels,
          mode=mode,
          train_op_fn=_train_fn,
          logits=logits,
          scope=head_scope)

      if early_stopping_rounds:
        training_hooks.append(
            TensorForestLossHook(
                early_stopping_rounds,
                early_stopping_loss_threshold=early_stopping_loss_threshold,
                loss_op=model_ops.loss))

      model_ops.training_hooks.extend(training_hooks)

      if keys is not None:
        model_ops.predictions[keys_name] = keys

      if params.inference_tree_paths:
        model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths

      model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance

      if include_all_in_serving:
        # In order to serve the variance we need to add the prediction dict
        # to output_alternatives dict.
        if not model_ops.output_alternatives:
          model_ops.output_alternatives = {}
        model_ops.output_alternatives[ALL_SERVING_KEY] = (
            constants.ProblemType.UNSPECIFIED, model_ops.predictions)

      return model_ops

    else:
      # Estimator spec
      estimator_spec = model_head.create_estimator_spec(
          features=features,
          mode=mode,
          labels=labels,
          train_op_fn=_train_fn,
          logits=logits)

      if early_stopping_rounds:
        training_hooks.append(
            TensorForestLossHook(
                early_stopping_rounds,
                early_stopping_loss_threshold=early_stopping_loss_threshold,
                loss_op=estimator_spec.loss))

      estimator_spec = estimator_spec._replace(
          training_hooks=training_hooks + list(estimator_spec.training_hooks))
      if keys is not None:
        estimator_spec.predictions[keys_name] = keys
      if params.inference_tree_paths:
        estimator_spec.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
      estimator_spec.predictions[VARIANCE_PREDICTION_KEY] = regression_variance

      if include_all_in_serving:
        outputs = estimator_spec.export_outputs
        if not outputs:
          outputs = {}
        outputs = {ALL_SERVING_KEY: PredictOutput(estimator_spec.predictions)}
        print(estimator_spec.export_outputs)
        # In order to serve the variance we need to add the prediction dict
        # to output_alternatives dict.
        estimator_spec = estimator_spec._replace(export_outputs=outputs)

      return estimator_spec

  return _model_fn


class TensorForestEstimator(estimator.Estimator):
  """An estimator that can train and evaluate a random forest.

  Example:

  ```python
  params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
      num_classes=2, num_features=40, num_trees=10, max_nodes=1000)

  # Estimator using the default graph builder.
  estimator = TensorForestEstimator(params, model_dir=model_dir)

  # Or estimator using TrainingLossForest as the graph builder.
  estimator = TensorForestEstimator(
      params, graph_builder_class=tensor_forest.TrainingLossForest,
      model_dir=model_dir)

  # Input builders
  def input_fn_train: # returns x, y
    ...
  def input_fn_eval: # returns x, y
    ...
  estimator.fit(input_fn=input_fn_train)
  estimator.evaluate(input_fn=input_fn_eval)

  # Predict returns an iterable of dicts.
  results = list(estimator.predict(x=x))
  prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]
  prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME]
  ```
  """

  def __init__(self,
               params,
               device_assigner=None,
               model_dir=None,
               feature_columns=None,
               graph_builder_class=tensor_forest.RandomForestGraphs,
               config=None,
               weight_column=None,
               keys_column=None,
               feature_engineering_fn=None,
               early_stopping_rounds=100,
               early_stopping_loss_threshold=0.001,
               num_trainers=1,
               trainer_id=0,
               report_feature_importances=False,
               local_eval=False,
               version=None,
               head=None,
               include_all_in_serving=False):
    """Initializes a TensorForestEstimator instance.

    Args:
      params: ForestHParams object that holds random forest hyperparameters.
        These parameters will be passed into `model_fn`.
      device_assigner: An `object` instance that controls how trees get
        assigned to devices. If `None`, will use
        `tensor_forest.RandomForestDeviceAssigner`.
      model_dir: Directory to save model parameters, graph, etc. To continue
        training a previously saved model, load checkpoints saved to this
        directory into an estimator.
      feature_columns: An iterable containing all the feature columns used by
        the model. All items in the set should be instances of classes derived
        from `_FeatureColumn`.
      graph_builder_class: An `object` instance that defines how TF graphs for
        random forest training and inference are built. By default will use
        `tensor_forest.RandomForestGraphs`. Can be overridden by version
        kwarg.
      config: `RunConfig` object to configure the runtime settings.
      weight_column: A string defining feature column name representing
        weights. Will be multiplied by the loss of the example. Used to
        downweight or boost examples during training.
      keys_column: A string naming one of the features to strip out and
        pass through into the inference/eval results dict.  Useful for
        associating specific examples with their prediction.
      feature_engineering_fn: Feature engineering function. Takes features and
        labels which are the output of `input_fn` and returns features and
        labels which will be fed into the model.
      early_stopping_rounds: Allows training to terminate early if the forest is
        no longer growing. 100 by default.  Set to a Falsy value to disable
        the default training hook.
      early_stopping_loss_threshold: Percentage (as fraction) that loss must
        improve by within early_stopping_rounds steps, otherwise training will
        terminate.
      num_trainers: Number of training jobs, which will partition trees
        among them.
      trainer_id: Which trainer this instance is.
      report_feature_importances: If True, print out feature importances
        during evaluation.
      local_eval: If True, don't use a device assigner for eval. This is to
        support some common setups where eval is done on a single machine, even
        though training might be distributed.
      version: Unused.
      head: A heads_lib.Head object that calculates losses and such. If None,
        one will be automatically created based on params.
      include_all_in_serving: if True, allow preparation of the complete
        prediction dict including the variance to be exported for serving with
        the Servo lib; and it also requires calling export_savedmodel with
        default_output_alternative_key=ALL_SERVING_KEY, i.e.
        estimator.export_savedmodel(export_dir_base=your_export_dir,
          serving_input_fn=your_export_input_fn,
          default_output_alternative_key=ALL_SERVING_KEY)
        if False, resort to default behavior, i.e. export scores and
          probabilities but no variances. In this case
          default_output_alternative_key should be None while calling
          export_savedmodel().
        Note, that due to backward compatibility we cannot always set
        include_all_in_serving to True because in this case calling
        export_saved_model() without
        default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the
        saved_model_export_utils.get_output_alternatives() would raise
        ValueError.

    Returns:
      A `TensorForestEstimator` instance.
    """
    super(TensorForestEstimator, self).__init__(
        model_fn=get_model_fn(
            params.fill(),
            graph_builder_class,
            device_assigner,
            feature_columns=feature_columns,
            model_head=head,
            weights_name=weight_column,
            keys_name=keys_column,
            early_stopping_rounds=early_stopping_rounds,
            early_stopping_loss_threshold=early_stopping_loss_threshold,
            num_trainers=num_trainers,
            trainer_id=trainer_id,
            report_feature_importances=report_feature_importances,
            local_eval=local_eval,
            include_all_in_serving=include_all_in_serving,
        ),
        model_dir=model_dir,
        config=config,
        feature_engineering_fn=feature_engineering_fn)


def get_combined_model_fn(model_fns):
  """Get a combined model function given a list of other model fns.

  The model function returned will call the individual model functions and
  combine them appropriately.  For:

  training ops: tf.group them.
  loss: average them.
  predictions: concat probabilities such that predictions[*][0-C1] are the
    probabilities for output 1 (where C1 is the number of classes in output 1),
    predictions[*][C1-(C1+C2)] are the probabilities for output 2 (where C2
    is the number of classes in output 2), etc.  Also stack predictions such
    that predictions[i][j] is the class prediction for example i and output j.

  This assumes that labels are 2-dimensional, with labels[i][j] being the
  label for example i and output j, where forest j is trained using only
  output j.

  Args:
    model_fns: A list of model functions obtained from get_model_fn.

  Returns:
    A ModelFnOps instance.
  """
  def _model_fn(features, labels, mode):
    """Function that returns predictions, training loss, and training op."""
    model_fn_ops = []
    for i in range(len(model_fns)):
      with variable_scope.variable_scope('label_{0}'.format(i)):
        sliced_labels = array_ops.slice(labels, [0, i], [-1, 1])
        model_fn_ops.append(
            model_fns[i](features, sliced_labels, mode))
    training_hooks = []
    for mops in model_fn_ops:
      training_hooks += mops.training_hooks
    predictions = {}
    if (mode == model_fn_lib.ModeKeys.EVAL or
        mode == model_fn_lib.ModeKeys.INFER):
      # Flatten the probabilities into one dimension.
      predictions[eval_metrics.INFERENCE_PROB_NAME] = array_ops.concat(
          [mops.predictions[eval_metrics.INFERENCE_PROB_NAME]
           for mops in model_fn_ops], axis=1)
      predictions[eval_metrics.INFERENCE_PRED_NAME] = array_ops.stack(
          [mops.predictions[eval_metrics.INFERENCE_PRED_NAME]
           for mops in model_fn_ops], axis=1)
    loss = None
    if (mode == model_fn_lib.ModeKeys.EVAL or
        mode == model_fn_lib.ModeKeys.TRAIN):
      loss = math_ops.reduce_sum(
          array_ops.stack(
              [mops.loss for mops in model_fn_ops])) / len(model_fn_ops)

    train_op = None
    if mode == model_fn_lib.ModeKeys.TRAIN:
      train_op = control_flow_ops.group(
          *[mops.train_op for mops in model_fn_ops])
    return model_fn_lib.ModelFnOps(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        training_hooks=training_hooks,
        scaffold=None,
        output_alternatives=None)

  return _model_fn


class MultiForestMultiHeadEstimator(estimator.Estimator):
  """An estimator that can train a forest for a multi-headed problems.

  This class essentially trains separate forests (each with their own
  ForestHParams) for each output.

  For multi-headed regression, a single-headed TensorForestEstimator can
  be used to train a single model that predicts all outputs.  This class can
  be used to train separate forests for each output.
  """

  def __init__(self,
               params_list,
               device_assigner=None,
               model_dir=None,
               feature_columns=None,
               graph_builder_class=tensor_forest.RandomForestGraphs,
               config=None,
               weight_column=None,
               keys_column=None,
               feature_engineering_fn=None,
               early_stopping_rounds=100,
               num_trainers=1,
               trainer_id=0,
               report_feature_importances=False,
               local_eval=False):
    """See TensorForestEstimator.__init__."""
    model_fns = []
    for i in range(len(params_list)):
      params = params_list[i].fill()
      model_fns.append(
          get_model_fn(
              params,
              graph_builder_class,
              device_assigner,
              model_head=_get_default_head(
                  params,
                  weight_column,
                  name='head{0}'.format(i),
                  output_type=ModelBuilderOutputType.MODEL_FN_OPS),
              weights_name=weight_column,
              keys_name=keys_column,
              early_stopping_rounds=early_stopping_rounds,
              num_trainers=num_trainers,
              trainer_id=trainer_id,
              report_feature_importances=report_feature_importances,
              local_eval=local_eval,
              head_scope='output{0}'.format(i)))

    super(MultiForestMultiHeadEstimator, self).__init__(
        model_fn=get_combined_model_fn(model_fns),
        model_dir=model_dir,
        config=config,
        feature_engineering_fn=feature_engineering_fn)


class CoreTensorForestEstimator(core_estimator.Estimator):
  """A CORE estimator that can train and evaluate a random forest.

  Example:

  ```python
  params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
      num_classes=2, num_features=40, num_trees=10, max_nodes=1000)

  # Estimator using the default graph builder.
  estimator = CoreTensorForestEstimator(params, model_dir=model_dir)

  # Or estimator using TrainingLossForest as the graph builder.
  estimator = CoreTensorForestEstimator(
      params, graph_builder_class=tensor_forest.TrainingLossForest,
      model_dir=model_dir)

  # Input builders
  def input_fn_train: # returns x, y
    ...
  def input_fn_eval: # returns x, y
    ...
  estimator.train(input_fn=input_fn_train)
  estimator.evaluate(input_fn=input_fn_eval)

  # Predict returns an iterable of dicts.
  results = list(estimator.predict(x=x))
  prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]
  prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME]
  ```
  """

  def __init__(self,
               params,
               device_assigner=None,
               model_dir=None,
               feature_columns=None,
               graph_builder_class=tensor_forest.RandomForestGraphs,
               config=None,
               weight_column=None,
               keys_column=None,
               feature_engineering_fn=None,
               early_stopping_rounds=100,
               early_stopping_loss_threshold=0.001,
               num_trainers=1,
               trainer_id=0,
               report_feature_importances=False,
               local_eval=False,
               version=None,
               head=None,
               include_all_in_serving=False):
    """Initializes a TensorForestEstimator instance.

    Args:
      params: ForestHParams object that holds random forest hyperparameters.
        These parameters will be passed into `model_fn`.
      device_assigner: An `object` instance that controls how trees get
        assigned to devices. If `None`, will use
        `tensor_forest.RandomForestDeviceAssigner`.
      model_dir: Directory to save model parameters, graph, etc. To continue
        training a previously saved model, load checkpoints saved to this
        directory into an estimator.
      feature_columns: An iterable containing all the feature columns used by
        the model. All items in the set should be instances of classes derived
        from `_FeatureColumn`.
      graph_builder_class: An `object` instance that defines how TF graphs for
        random forest training and inference are built. By default will use
        `tensor_forest.RandomForestGraphs`. Can be overridden by version
        kwarg.
      config: `RunConfig` object to configure the runtime settings.
      weight_column: A string defining feature column name representing
        weights. Will be multiplied by the loss of the example. Used to
        downweight or boost examples during training.
      keys_column: A string naming one of the features to strip out and
        pass through into the inference/eval results dict.  Useful for
        associating specific examples with their prediction.
      feature_engineering_fn: Feature engineering function. Takes features and
        labels which are the output of `input_fn` and returns features and
        labels which will be fed into the model.
      early_stopping_rounds: Allows training to terminate early if the forest is
        no longer growing. 100 by default.  Set to a Falsy value to disable
        the default training hook.
      early_stopping_loss_threshold: Percentage (as fraction) that loss must
        improve by within early_stopping_rounds steps, otherwise training will
        terminate.
      num_trainers: Number of training jobs, which will partition trees
        among them.
      trainer_id: Which trainer this instance is.
      report_feature_importances: If True, print out feature importances
        during evaluation.
      local_eval: If True, don't use a device assigner for eval. This is to
        support some common setups where eval is done on a single machine, even
        though training might be distributed.
      version: Unused.
      head: A heads_lib.Head object that calculates losses and such. If None,
        one will be automatically created based on params.
      include_all_in_serving: if True, allow preparation of the complete
        prediction dict including the variance to be exported for serving with
        the Servo lib; and it also requires calling export_savedmodel with
        default_output_alternative_key=ALL_SERVING_KEY, i.e.
        estimator.export_savedmodel(export_dir_base=your_export_dir,
          serving_input_fn=your_export_input_fn,
          default_output_alternative_key=ALL_SERVING_KEY)
        if False, resort to default behavior, i.e. export scores and
          probabilities but no variances. In this case
          default_output_alternative_key should be None while calling
          export_savedmodel().
        Note, that due to backward compatibility we cannot always set
        include_all_in_serving to True because in this case calling
        export_saved_model() without
        default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the
        saved_model_export_utils.get_output_alternatives() would raise
        ValueError.

    Returns:
      A `TensorForestEstimator` instance.
    """

    super(CoreTensorForestEstimator, self).__init__(
        model_fn=get_model_fn(
            params.fill(),
            graph_builder_class,
            device_assigner,
            feature_columns=feature_columns,
            model_head=head,
            weights_name=weight_column,
            keys_name=keys_column,
            early_stopping_rounds=early_stopping_rounds,
            early_stopping_loss_threshold=early_stopping_loss_threshold,
            num_trainers=num_trainers,
            trainer_id=trainer_id,
            report_feature_importances=report_feature_importances,
            local_eval=local_eval,
            include_all_in_serving=include_all_in_serving,
            output_type=ModelBuilderOutputType.ESTIMATOR_SPEC),
        model_dir=model_dir,
        config=config)