aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/python/tensor_forest.py
blob: b9bcbb170b04fe953be2d2dd515b607127d3cae6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Extremely random forest graph builder. go/brain-tree."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import numbers
import random

from google.protobuf import text_format

from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2 as _tree_proto
from tensorflow.contrib.framework.python.ops import variables as framework_variables
from tensorflow.contrib.tensor_forest.proto import tensor_forest_params_pb2 as _params_proto
from tensorflow.contrib.tensor_forest.python.ops import data_ops
from tensorflow.contrib.tensor_forest.python.ops import model_ops
from tensorflow.contrib.tensor_forest.python.ops import stats_ops

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging


# Stores tuples of (leaf model type, stats model type)
CLASSIFICATION_LEAF_MODEL_TYPES = {
    'all_dense': (_params_proto.MODEL_DENSE_CLASSIFICATION,
                  _params_proto.STATS_DENSE_GINI),
    'all_sparse': (_params_proto.MODEL_SPARSE_CLASSIFICATION,
                   _params_proto.STATS_SPARSE_GINI),
    'sparse_then_dense':
        (_params_proto.MODEL_SPARSE_OR_DENSE_CLASSIFICATION,
         _params_proto.STATS_SPARSE_THEN_DENSE_GINI),
}
REGRESSION_MODEL_TYPE = (
    _params_proto.MODEL_REGRESSION,
    _params_proto.STATS_LEAST_SQUARES_REGRESSION,
    _params_proto.COLLECTION_BASIC)

FINISH_TYPES = {
    'basic': _params_proto.SPLIT_FINISH_BASIC,
    'hoeffding': _params_proto.SPLIT_FINISH_DOMINATE_HOEFFDING,
    'bootstrap': _params_proto.SPLIT_FINISH_DOMINATE_BOOTSTRAP
}
PRUNING_TYPES = {
    'none': _params_proto.SPLIT_PRUNE_NONE,
    'half': _params_proto.SPLIT_PRUNE_HALF,
    'quarter': _params_proto.SPLIT_PRUNE_QUARTER,
    '10_percent': _params_proto.SPLIT_PRUNE_10_PERCENT,
    'hoeffding': _params_proto.SPLIT_PRUNE_HOEFFDING,
}
SPLIT_TYPES = {
    'less_or_equal': _tree_proto.InequalityTest.LESS_OR_EQUAL,
    'less': _tree_proto.InequalityTest.LESS_THAN
}


def parse_number_or_string_to_proto(proto, param):
  if isinstance(param, numbers.Number):
    proto.constant_value = param
  else:  # assume it's a string
    if param.isdigit():
      proto.constant_value = int(param)
    else:
      text_format.Merge(param, proto)


def build_params_proto(params):
  """Build a TensorForestParams proto out of the V4ForestHParams object."""
  proto = _params_proto.TensorForestParams()
  proto.num_trees = params.num_trees
  proto.max_nodes = params.max_nodes
  proto.is_regression = params.regression
  proto.num_outputs = params.num_classes
  proto.num_features = params.num_features

  proto.leaf_type = params.leaf_model_type
  proto.stats_type = params.stats_model_type
  proto.collection_type = _params_proto.COLLECTION_BASIC
  proto.pruning_type.type = params.pruning_type
  proto.finish_type.type = params.finish_type

  proto.inequality_test_type = params.split_type

  proto.drop_final_class = False
  proto.collate_examples = params.collate_examples
  proto.checkpoint_stats = params.checkpoint_stats
  proto.use_running_stats_method = params.use_running_stats_method
  proto.initialize_average_splits = params.initialize_average_splits
  proto.inference_tree_paths = params.inference_tree_paths

  parse_number_or_string_to_proto(proto.pruning_type.prune_every_samples,
                                  params.prune_every_samples)
  parse_number_or_string_to_proto(proto.finish_type.check_every_steps,
                                  params.early_finish_check_every_samples)
  parse_number_or_string_to_proto(proto.split_after_samples,
                                  params.split_after_samples)
  parse_number_or_string_to_proto(proto.num_splits_to_consider,
                                  params.num_splits_to_consider)

  proto.dominate_fraction.constant_value = params.dominate_fraction

  if params.param_file:
    with open(params.param_file) as f:
      text_format.Merge(f.read(), proto)

  return proto


# A convenience class for holding random forest hyperparameters.
#
# To just get some good default parameters, use:
#   hparams = ForestHParams(num_classes=2, num_features=40).fill()
#
# Note that num_classes can not be inferred and so must always be specified.
# Also, either num_splits_to_consider or num_features should be set.
#
# To override specific values, pass them to the constructor:
#   hparams = ForestHParams(num_classes=5, num_trees=10, num_features=5).fill()
#
# TODO(thomaswc): Inherit from tf.HParams when that is publicly available.
class ForestHParams(object):
  """A base class for holding hyperparameters and calculating good defaults."""

  def __init__(
      self,
      num_trees=100,
      max_nodes=10000,
      bagging_fraction=1.0,
      num_splits_to_consider=0,
      feature_bagging_fraction=1.0,
      max_fertile_nodes=0,  # deprecated, unused.
      split_after_samples=250,
      valid_leaf_threshold=1,
      dominate_method='bootstrap',
      dominate_fraction=0.99,
      model_name='all_dense',
      split_finish_name='basic',
      split_pruning_name='none',
      prune_every_samples=0,
      early_finish_check_every_samples=0,
      collate_examples=False,
      checkpoint_stats=False,
      use_running_stats_method=False,
      initialize_average_splits=False,
      inference_tree_paths=False,
      param_file=None,
      split_name='less_or_equal',
      **kwargs):
    self.num_trees = num_trees
    self.max_nodes = max_nodes
    self.bagging_fraction = bagging_fraction
    self.feature_bagging_fraction = feature_bagging_fraction
    self.num_splits_to_consider = num_splits_to_consider
    self.max_fertile_nodes = max_fertile_nodes
    self.split_after_samples = split_after_samples
    self.valid_leaf_threshold = valid_leaf_threshold
    self.dominate_method = dominate_method
    self.dominate_fraction = dominate_fraction
    self.model_name = model_name
    self.split_finish_name = split_finish_name
    self.split_pruning_name = split_pruning_name
    self.collate_examples = collate_examples
    self.checkpoint_stats = checkpoint_stats
    self.use_running_stats_method = use_running_stats_method
    self.initialize_average_splits = initialize_average_splits
    self.inference_tree_paths = inference_tree_paths
    self.param_file = param_file
    self.split_name = split_name
    self.early_finish_check_every_samples = early_finish_check_every_samples
    self.prune_every_samples = prune_every_samples

    for name, value in kwargs.items():
      setattr(self, name, value)

  def values(self):
    return self.__dict__

  def fill(self):
    """Intelligently sets any non-specific parameters."""
    # Fail fast if num_classes or num_features isn't set.
    _ = getattr(self, 'num_classes')
    _ = getattr(self, 'num_features')

    self.bagged_num_features = int(self.feature_bagging_fraction *
                                   self.num_features)

    self.bagged_features = None
    if self.feature_bagging_fraction < 1.0:
      self.bagged_features = [random.sample(
          range(self.num_features),
          self.bagged_num_features) for _ in range(self.num_trees)]

    self.regression = getattr(self, 'regression', False)

    # Num_outputs is the actual number of outputs (a single prediction for
    # classification, a N-dimensional point for regression).
    self.num_outputs = self.num_classes if self.regression else 1

    # Add an extra column to classes for storing counts, which is needed for
    # regression and avoids having to recompute sums for classification.
    self.num_output_columns = self.num_classes + 1

    # Our experiments have found that num_splits_to_consider = num_features
    # gives good accuracy.
    self.num_splits_to_consider = self.num_splits_to_consider or min(
        max(10, math.floor(math.sqrt(self.num_features))), 1000)

    # If base_random_seed is 0, the current time will be used to seed the
    # random number generators for each tree.  If non-zero, the i-th tree
    # will be seeded with base_random_seed + i.
    self.base_random_seed = getattr(self, 'base_random_seed', 0)

    # How to store leaf models.
    self.leaf_model_type = (
        REGRESSION_MODEL_TYPE[0] if self.regression else
        CLASSIFICATION_LEAF_MODEL_TYPES[self.model_name][0])

    # How to store stats objects.
    self.stats_model_type = (
        REGRESSION_MODEL_TYPE[1] if self.regression else
        CLASSIFICATION_LEAF_MODEL_TYPES[self.model_name][1])

    self.finish_type = (
        _params_proto.SPLIT_FINISH_BASIC if self.regression else
        FINISH_TYPES[self.split_finish_name])

    self.pruning_type = PRUNING_TYPES[self.split_pruning_name]

    if self.pruning_type == _params_proto.SPLIT_PRUNE_NONE:
      self.prune_every_samples = 0
    else:
      if (not self.prune_every_samples and
          not (isinstance(numbers.Number) or
               self.split_after_samples.isdigit())):
        logging.error(
            'Must specify prune_every_samples if using a depth-dependent '
            'split_after_samples')
      # Pruning half-way through split_after_samples seems like a decent
      # default, making it easy to select the number being pruned with
      # pruning_type while not paying the cost of pruning too often.  Note that
      # this only holds if not using a depth-dependent split_after_samples.
      self.prune_every_samples = (self.prune_every_samples or
                                  int(self.split_after_samples) / 2)

    if self.finish_type == _params_proto.SPLIT_FINISH_BASIC:
      self.early_finish_check_every_samples = 0
    else:
      if (not self.early_finish_check_every_samples and
          not (isinstance(numbers.Number) or
               self.split_after_samples.isdigit())):
        logging.error(
            'Must specify prune_every_samples if using a depth-dependent '
            'split_after_samples')
      # Checking for early finish every quarter through split_after_samples
      # seems like a decent default. We don't want to incur the checking cost
      # too often, but (at least for hoeffding) it's lower than the cost of
      # pruning so we can do it a little more frequently.
      self.early_finish_check_every_samples = (
          self.early_finish_check_every_samples or
          int(self.split_after_samples) / 4)

    self.split_type = SPLIT_TYPES[self.split_name]

    return self


def get_epoch_variable():
  """Returns the epoch variable, or [0] if not defined."""
  # Grab epoch variable defined in
  # //third_party/tensorflow/python/training/input.py::limit_epochs
  for v in tf_variables.local_variables():
    if 'limit_epochs/epoch' in v.op.name:
      return array_ops.reshape(v, [1])
  # TODO(thomaswc): Access epoch from the data feeder.
  return [0]


# A simple container to hold the training variables for a single tree.
class TreeTrainingVariables(object):
  """Stores tf.Variables for training a single random tree.

  Uses tf.get_variable to get tree-specific names so that this can be used
  with a tf.learn-style implementation (one that trains a model, saves it,
  then relies on restoring that model to evaluate).
  """

  def __init__(self, params, tree_num, training):
    if (not hasattr(params, 'params_proto') or
        not isinstance(params.params_proto,
                       _params_proto.TensorForestParams)):
      params.params_proto = build_params_proto(params)

    params.serialized_params_proto = params.params_proto.SerializeToString()
    self.stats = None
    if training:
      # TODO(gilberth): Manually shard this to be able to fit it on
      # multiple machines.
      self.stats = stats_ops.fertile_stats_variable(
          params, '', self.get_tree_name('stats', tree_num))
    self.tree = model_ops.tree_variable(
        params, '', self.stats, self.get_tree_name('tree', tree_num))

  def get_tree_name(self, name, num):
    return '{0}-{1}'.format(name, num)


class ForestTrainingVariables(object):
  """A container for a forests training data, consisting of multiple trees.

  Instantiates a TreeTrainingVariables object for each tree. We override the
  __getitem__ and __setitem__ function so that usage looks like this:

    forest_variables = ForestTrainingVariables(params)

    ... forest_variables.tree ...
  """

  def __init__(self, params, device_assigner, training=True,
               tree_variables_class=TreeTrainingVariables):
    self.variables = []
    # Set up some scalar variables to run through the device assigner, then
    # we can use those to colocate everything related to a tree.
    self.device_dummies = []
    with ops.device(device_assigner):
      for i in range(params.num_trees):
        self.device_dummies.append(variable_scope.get_variable(
            name='device_dummy_%d' % i, shape=0))

    for i in range(params.num_trees):
      with ops.device(self.device_dummies[i].device):
        self.variables.append(tree_variables_class(params, i, training))

  def __setitem__(self, t, val):
    self.variables[t] = val

  def __getitem__(self, t):
    return self.variables[t]


class RandomForestGraphs(object):
  """Builds TF graphs for random forest training and inference."""

  def __init__(self,
               params,
               device_assigner=None,
               variables=None,
               tree_variables_class=TreeTrainingVariables,
               tree_graphs=None,
               training=True):
    self.params = params
    self.device_assigner = (
        device_assigner or framework_variables.VariableDeviceChooser())
    logging.info('Constructing forest with params = ')
    logging.info(self.params.__dict__)
    self.variables = variables or ForestTrainingVariables(
        self.params, device_assigner=self.device_assigner, training=training,
        tree_variables_class=tree_variables_class)
    tree_graph_class = tree_graphs or RandomTreeGraphs
    self.trees = [
        tree_graph_class(self.variables[i], self.params, i)
        for i in range(self.params.num_trees)
    ]

  def _bag_features(self, tree_num, input_data):
    split_data = array_ops.split(
        value=input_data, num_or_size_splits=self.params.num_features, axis=1)
    return array_ops.concat(
        [split_data[ind] for ind in self.params.bagged_features[tree_num]], 1)

  def get_all_resource_handles(self):
    return ([self.variables[i].tree for i in range(len(self.trees))] +
            [self.variables[i].stats for i in range(len(self.trees))])

  def training_graph(self,
                     input_data,
                     input_labels,
                     num_trainers=1,
                     trainer_id=0,
                     **tree_kwargs):
    """Constructs a TF graph for training a random forest.

    Args:
      input_data: A tensor or dict of string->Tensor for input data.
      input_labels: A tensor or placeholder for labels associated with
        input_data.
      num_trainers: Number of parallel trainers to split trees among.
      trainer_id: Which trainer this instance is.
      **tree_kwargs: Keyword arguments passed to each tree's training_graph.

    Returns:
      The last op in the random forest training graph.

    Raises:
      NotImplementedError: If trying to use bagging with sparse features.
    """
    processed_dense_features, processed_sparse_features, data_spec = (
        data_ops.ParseDataTensorOrDict(input_data))

    if input_labels is not None:
      labels = data_ops.ParseLabelTensorOrDict(input_labels)

    data_spec = data_spec or self.get_default_data_spec(input_data)

    tree_graphs = []
    trees_per_trainer = self.params.num_trees / num_trainers
    tree_start = int(trainer_id * trees_per_trainer)
    tree_end = int((trainer_id + 1) * trees_per_trainer)
    for i in range(tree_start, tree_end):
      with ops.device(self.variables.device_dummies[i].device):
        seed = self.params.base_random_seed
        if seed != 0:
          seed += i
        # If using bagging, randomly select some of the input.
        tree_data = processed_dense_features
        tree_labels = labels
        if self.params.bagging_fraction < 1.0:
          # TODO(gilberth): Support bagging for sparse features.
          if processed_sparse_features is not None:
            raise NotImplementedError(
                'Bagging not supported with sparse features.')
          # TODO(thomaswc): This does sampling without replacement.  Consider
          # also allowing sampling with replacement as an option.
          batch_size = array_ops.strided_slice(
              array_ops.shape(processed_dense_features), [0], [1])
          r = random_ops.random_uniform(batch_size, seed=seed)
          mask = math_ops.less(
              r, array_ops.ones_like(r) * self.params.bagging_fraction)
          gather_indices = array_ops.squeeze(
              array_ops.where(mask), squeeze_dims=[1])
          # TODO(thomaswc): Calculate out-of-bag data and labels, and store
          # them for use in calculating statistics later.
          tree_data = array_ops.gather(processed_dense_features, gather_indices)
          tree_labels = array_ops.gather(labels, gather_indices)
        if self.params.bagged_features:
          if processed_sparse_features is not None:
            raise NotImplementedError(
                'Feature bagging not supported with sparse features.')
          tree_data = self._bag_features(i, tree_data)

        tree_graphs.append(self.trees[i].training_graph(
            tree_data,
            tree_labels,
            seed,
            data_spec=data_spec,
            sparse_features=processed_sparse_features,
            **tree_kwargs))

    return control_flow_ops.group(*tree_graphs, name='train')

  def inference_graph(self, input_data, **inference_args):
    """Constructs a TF graph for evaluating a random forest.

    Args:
      input_data: A tensor or dict of string->Tensor for the input data.
                  This input_data must generate the same spec as the
                  input_data used in training_graph:  the dict must have
                  the same keys, for example, and all tensors must have
                  the same size in their first dimension.
      **inference_args: Keyword arguments to pass through to each tree.

    Returns:
      A tuple of (probabilities, tree_paths, variance).

    Raises:
      NotImplementedError: If trying to use feature bagging with sparse
        features.
    """
    processed_dense_features, processed_sparse_features, data_spec = (
        data_ops.ParseDataTensorOrDict(input_data))

    probabilities = []
    paths = []
    for i in range(self.params.num_trees):
      with ops.device(self.variables.device_dummies[i].device):
        tree_data = processed_dense_features
        if self.params.bagged_features:
          if processed_sparse_features is not None:
            raise NotImplementedError(
                'Feature bagging not supported with sparse features.')
          tree_data = self._bag_features(i, tree_data)
        probs, path = self.trees[i].inference_graph(
            tree_data,
            data_spec,
            sparse_features=processed_sparse_features,
            **inference_args)
        probabilities.append(probs)
        paths.append(path)
    with ops.device(self.variables.device_dummies[0].device):
      # shape of all_predict should be [batch_size, num_trees, num_outputs]
      all_predict = array_ops.stack(probabilities, axis=1)
      average_values = math_ops.div(
          math_ops.reduce_sum(all_predict, 1),
          self.params.num_trees,
          name='probabilities')
      tree_paths = array_ops.stack(paths, axis=1)

      expected_squares = math_ops.div(
          math_ops.reduce_sum(all_predict * all_predict, 1),
          self.params.num_trees)
      regression_variance = math_ops.maximum(
          0., expected_squares - average_values * average_values)
      return average_values, tree_paths, regression_variance

  def average_size(self):
    """Constructs a TF graph for evaluating the average size of a forest.

    Returns:
      The average number of nodes over the trees.
    """
    sizes = []
    for i in range(self.params.num_trees):
      with ops.device(self.variables.device_dummies[i].device):
        sizes.append(self.trees[i].size())
    return math_ops.reduce_mean(math_ops.to_float(array_ops.stack(sizes)))

  # pylint: disable=unused-argument
  def training_loss(self, features, labels, name='training_loss'):
    return math_ops.negative(self.average_size(), name=name)

  # pylint: disable=unused-argument
  def validation_loss(self, features, labels):
    return math_ops.negative(self.average_size())

  def average_impurity(self):
    """Constructs a TF graph for evaluating the leaf impurity of a forest.

    Returns:
      The last op in the graph.
    """
    impurities = []
    for i in range(self.params.num_trees):
      with ops.device(self.variables.device_dummies[i].device):
        impurities.append(self.trees[i].average_impurity())
    return math_ops.reduce_mean(array_ops.stack(impurities))

  def feature_importances(self):
    tree_counts = [self.trees[i].feature_usage_counts()
                   for i in range(self.params.num_trees)]
    total_counts = math_ops.reduce_sum(array_ops.stack(tree_counts, 0), 0)
    return total_counts / math_ops.reduce_sum(total_counts)


class RandomTreeGraphs(object):
  """Builds TF graphs for random tree training and inference."""

  def __init__(self, variables, params, tree_num):
    self.variables = variables
    self.params = params
    self.tree_num = tree_num

  def training_graph(self,
                     input_data,
                     input_labels,
                     random_seed,
                     data_spec,
                     sparse_features=None,
                     input_weights=None):

    """Constructs a TF graph for training a random tree.

    Args:
      input_data: A tensor or placeholder for input data.
      input_labels: A tensor or placeholder for labels associated with
        input_data.
      random_seed: The random number generator seed to use for this tree.  0
        means use the current time as the seed.
      data_spec: A data_ops.TensorForestDataSpec object specifying the
        original feature/columns of the data.
      sparse_features: A tf.SparseTensor for sparse input data.
      input_weights: A float tensor or placeholder holding per-input weights,
        or None if all inputs are to be weighted equally.

    Returns:
      The last op in the random tree training graph.
    """
    # TODO(gilberth): Use this.
    unused_epoch = math_ops.to_int32(get_epoch_variable())

    if input_weights is None:
      input_weights = []

    sparse_indices = []
    sparse_values = []
    sparse_shape = []
    if sparse_features is not None:
      sparse_indices = sparse_features.indices
      sparse_values = sparse_features.values
      sparse_shape = sparse_features.dense_shape

    if input_data is None:
      input_data = []

    leaf_ids = model_ops.traverse_tree_v4(
        self.variables.tree,
        input_data,
        sparse_indices,
        sparse_values,
        sparse_shape,
        input_spec=data_spec.SerializeToString(),
        params=self.params.serialized_params_proto)

    update_model = model_ops.update_model_v4(
        self.variables.tree,
        leaf_ids,
        input_labels,
        input_weights,
        params=self.params.serialized_params_proto)

    finished_nodes = stats_ops.process_input_v4(
        self.variables.tree,
        self.variables.stats,
        input_data,
        sparse_indices,
        sparse_values,
        sparse_shape,
        input_labels,
        input_weights,
        leaf_ids,
        input_spec=data_spec.SerializeToString(),
        random_seed=random_seed,
        params=self.params.serialized_params_proto)

    with ops.control_dependencies([update_model]):
      return stats_ops.grow_tree_v4(
          self.variables.tree,
          self.variables.stats,
          finished_nodes,
          params=self.params.serialized_params_proto)

  def inference_graph(self, input_data, data_spec, sparse_features=None):
    """Constructs a TF graph for evaluating a random tree.

    Args:
      input_data: A tensor or placeholder for input data.
      data_spec: A TensorForestDataSpec proto specifying the original
        input columns.
      sparse_features: A tf.SparseTensor for sparse input data.

    Returns:
      A tuple of (probabilities, tree_paths).
    """
    sparse_indices = []
    sparse_values = []
    sparse_shape = []
    if sparse_features is not None:
      sparse_indices = sparse_features.indices
      sparse_values = sparse_features.values
      sparse_shape = sparse_features.dense_shape
    if input_data is None:
      input_data = []

    return model_ops.tree_predictions_v4(
        self.variables.tree,
        input_data,
        sparse_indices,
        sparse_values,
        sparse_shape,
        input_spec=data_spec.SerializeToString(),
        params=self.params.serialized_params_proto)

  def size(self):
    """Constructs a TF graph for evaluating the current number of nodes.

    Returns:
      The current number of nodes in the tree.
    """
    return model_ops.tree_size(self.variables.tree)

  def feature_usage_counts(self):
    return model_ops.feature_usage_counts(
        self.variables.tree, params=self.params.serialized_params_proto)