diff options
author | 2016-06-30 14:40:02 -0800 | |
---|---|---|
committer | 2016-06-30 15:48:03 -0700 | |
commit | b70103502b41df370906e8988b6593e55caf69cf (patch) | |
tree | 3455ed439430bb6c0e739bb974a52a99a7bc6626 /tensorflow/contrib/tensor_forest/python/tensor_forest.py | |
parent | d3067c338425bdf97fa782d834399b89bce18309 (diff) |
Improvements to tensor_forest, including support for sparse and categorical inputs.
Add tf.learn.Estimator for random forests.
Change: 126352221
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python/tensor_forest.py')
-rw-r--r-- | tensorflow/contrib/tensor_forest/python/tensor_forest.py | 510 |
1 files changed, 300 insertions, 210 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index f48efaa5db..791954c51f 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -1,3 +1,4 @@ +# pylint: disable=g-bad-file-header # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,14 +21,22 @@ from __future__ import print_function import math import random -import tensorflow as tf - +from tensorflow.contrib.tensor_forest.python import constants from tensorflow.contrib.tensor_forest.python.ops import inference_ops from tensorflow.contrib.tensor_forest.python.ops import training_ops - -# If tree[i][0] equals this value, then i is a leaf node. -LEAF_NODE = -1 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import tf_logging as logging # A convenience class for holding random forest hyperparameters. @@ -49,6 +58,7 @@ class ForestHParams(object): max_depth=0, num_splits_to_consider=0, feature_bagging_fraction=1.0, max_fertile_nodes=0, split_after_samples=250, + min_split_samples=5, valid_leaf_threshold=1, **kwargs): self.num_trees = num_trees self.max_nodes = max_nodes @@ -58,6 +68,7 @@ class ForestHParams(object): self.num_splits_to_consider = num_splits_to_consider self.max_fertile_nodes = max_fertile_nodes self.split_after_samples = split_after_samples + self.min_split_samples = min_split_samples self.valid_leaf_threshold = valid_leaf_threshold for name, value in kwargs.items(): @@ -72,11 +83,6 @@ class ForestHParams(object): _ = getattr(self, 'num_classes') _ = getattr(self, 'num_features') - self.training_library_base_dir = getattr( - self, 'training_library_base_dir', '') - self.inference_library_base_dir = getattr( - self, 'inference_library_base_dir', '') - self.bagged_num_features = int(self.feature_bagging_fraction * self.num_features) @@ -147,92 +153,86 @@ class TreeTrainingVariables(object): """ def __init__(self, params, tree_num, training): - self.tree = tf.get_variable( - name=self.get_tree_name('tree', tree_num), dtype=tf.int32, - initializer=tf.constant( - [[-1, -1]] + [[-2, -1]] * (params.max_nodes - 1))) - self.tree_thresholds = tf.get_variable( + self.tree = variable_scope.get_variable( + name=self.get_tree_name('tree', tree_num), dtype=dtypes.int32, + shape=[params.max_nodes, 2], + initializer=init_ops.constant_initializer(-2)) + self.tree_thresholds = variable_scope.get_variable( name=self.get_tree_name('tree_thresholds', tree_num), shape=[params.max_nodes], - initializer=tf.constant_initializer(-1.0)) - self.tree_depths = tf.get_variable( + initializer=init_ops.constant_initializer(-1.0)) + self.tree_depths = variable_scope.get_variable( name=self.get_tree_name('tree_depths', tree_num), shape=[params.max_nodes], - dtype=tf.int32, - initializer=tf.constant_initializer(1)) - self.end_of_tree = tf.get_variable( + dtype=dtypes.int32, + initializer=init_ops.constant_initializer(1)) + self.end_of_tree = variable_scope.get_variable( name=self.get_tree_name('end_of_tree', tree_num), - dtype=tf.int32, - initializer=tf.constant([1])) + dtype=dtypes.int32, + initializer=constant_op.constant([1])) + self.start_epoch = tf_variables.Variable( + [0] * (params.max_nodes), name='start_epoch') if training: - self.non_fertile_leaves = tf.get_variable( - name=self.get_tree_name('non_fertile_leaves', tree_num), - dtype=tf.int32, - initializer=tf.constant([0])) - self.non_fertile_leaf_scores = tf.get_variable( - name=self.get_tree_name('non_fertile_leaf_scores', tree_num), - initializer=tf.constant([1.0])) - - self.node_to_accumulator_map = tf.get_variable( + self.node_to_accumulator_map = variable_scope.get_variable( name=self.get_tree_name('node_to_accumulator_map', tree_num), shape=[params.max_nodes], - dtype=tf.int32, - initializer=tf.constant_initializer(-1)) + dtype=dtypes.int32, + initializer=init_ops.constant_initializer(-1)) - self.candidate_split_features = tf.get_variable( + self.candidate_split_features = variable_scope.get_variable( name=self.get_tree_name('candidate_split_features', tree_num), shape=[params.max_fertile_nodes, params.num_splits_to_consider], - dtype=tf.int32, - initializer=tf.constant_initializer(-1)) - self.candidate_split_thresholds = tf.get_variable( + dtype=dtypes.int32, + initializer=init_ops.constant_initializer(-1)) + self.candidate_split_thresholds = variable_scope.get_variable( name=self.get_tree_name('candidate_split_thresholds', tree_num), shape=[params.max_fertile_nodes, params.num_splits_to_consider], - initializer=tf.constant_initializer(0.0)) + initializer=init_ops.constant_initializer(0.0)) # Statistics shared by classification and regression. - self.node_sums = tf.get_variable( + self.node_sums = variable_scope.get_variable( name=self.get_tree_name('node_sums', tree_num), shape=[params.max_nodes, params.num_output_columns], - initializer=tf.constant_initializer(0.0)) + initializer=init_ops.constant_initializer(0.0)) if training: - self.candidate_split_sums = tf.get_variable( + self.candidate_split_sums = variable_scope.get_variable( name=self.get_tree_name('candidate_split_sums', tree_num), shape=[params.max_fertile_nodes, params.num_splits_to_consider, params.num_output_columns], - initializer=tf.constant_initializer(0.0)) - self.accumulator_sums = tf.get_variable( + initializer=init_ops.constant_initializer(0.0)) + self.accumulator_sums = variable_scope.get_variable( name=self.get_tree_name('accumulator_sums', tree_num), shape=[params.max_fertile_nodes, params.num_output_columns], - initializer=tf.constant_initializer(-1.0)) + initializer=init_ops.constant_initializer(-1.0)) # Regression also tracks second order stats. if params.regression: - self.node_squares = tf.get_variable( + self.node_squares = variable_scope.get_variable( name=self.get_tree_name('node_squares', tree_num), shape=[params.max_nodes, params.num_output_columns], - initializer=tf.constant_initializer(0.0)) + initializer=init_ops.constant_initializer(0.0)) - self.candidate_split_squares = tf.get_variable( + self.candidate_split_squares = variable_scope.get_variable( name=self.get_tree_name('candidate_split_squares', tree_num), shape=[params.max_fertile_nodes, params.num_splits_to_consider, params.num_output_columns], - initializer=tf.constant_initializer(0.0)) + initializer=init_ops.constant_initializer(0.0)) - self.accumulator_squares = tf.get_variable( + self.accumulator_squares = variable_scope.get_variable( name=self.get_tree_name('accumulator_squares', tree_num), shape=[params.max_fertile_nodes, params.num_output_columns], - initializer=tf.constant_initializer(-1.0)) + initializer=init_ops.constant_initializer(-1.0)) else: - self.node_squares = tf.constant( + self.node_squares = constant_op.constant( 0.0, name=self.get_tree_name('node_squares', tree_num)) - self.candidate_split_squares = tf.constant( + self.candidate_split_squares = constant_op.constant( 0.0, name=self.get_tree_name('candidate_split_squares', tree_num)) - self.accumulator_squares = tf.constant( + self.accumulator_squares = constant_op.constant( 0.0, name=self.get_tree_name('accumulator_squares', tree_num)) def get_tree_name(self, name, num): @@ -273,11 +273,11 @@ class ForestTrainingVariables(object): """ def __init__(self, params, device_assigner, training=True, - tree_variable_class=TreeTrainingVariables): + tree_variables_class=TreeTrainingVariables): self.variables = [] for i in range(params.num_trees): - with tf.device(device_assigner.get_device(i)): - self.variables.append(tree_variable_class(params, i, training)) + with ops.device(device_assigner.get_device(i)): + self.variables.append(tree_variables_class(params, i, training)) def __setitem__(self, t, val): self.variables[t] = val @@ -299,7 +299,7 @@ class RandomForestDeviceAssigner(object): def get_device(self, unused_tree_num): if not self.cached: - dummy = tf.constant(0) + dummy = constant_op.constant(0) self.cached = dummy.device return self.cached @@ -308,43 +308,51 @@ class RandomForestDeviceAssigner(object): class RandomForestGraphs(object): """Builds TF graphs for random forest training and inference.""" - def __init__(self, params, device_assigner=None, variables=None, - tree_graphs=None, + def __init__(self, params, device_assigner=None, + variables=None, tree_variables_class=TreeTrainingVariables, + tree_graphs=None, training=True, t_ops=training_ops, i_ops=inference_ops): self.params = params self.device_assigner = device_assigner or RandomForestDeviceAssigner() - tf.logging.info('Constructing forest with params = ') - tf.logging.info(self.params.__dict__) + logging.info('Constructing forest with params = ') + logging.info(self.params.__dict__) self.variables = variables or ForestTrainingVariables( - self.params, device_assigner=self.device_assigner) + self.params, device_assigner=self.device_assigner, training=training, + tree_variables_class=tree_variables_class) tree_graph_class = tree_graphs or RandomTreeGraphs self.trees = [ tree_graph_class( self.variables[i], self.params, - t_ops.Load(self.params.training_library_base_dir), - i_ops.Load(self.params.inference_library_base_dir), i) + t_ops.Load(), i_ops.Load(), i) for i in range(self.params.num_trees)] def _bag_features(self, tree_num, input_data): - split_data = tf.split(1, self.params.num_features, input_data) - return tf.concat(1, [split_data[ind] - for ind in self.params.bagged_features[tree_num]]) + split_data = array_ops.split(1, self.params.num_features, input_data) + return array_ops.concat( + 1, [split_data[ind] for ind in self.params.bagged_features[tree_num]]) - def training_graph(self, input_data, input_labels): + def training_graph(self, input_data, input_labels, data_spec=None, + epoch=None, **tree_kwargs): """Constructs a TF graph for training a random forest. Args: - input_data: A tensor or placeholder for input data. + input_data: A tensor or SparseTensor or placeholder for input data. input_labels: A tensor or placeholder for labels associated with input_data. + data_spec: A list of tf.dtype values specifying the original types of + each column. + epoch: A tensor or placeholder for the epoch the training data comes from. + **tree_kwargs: Keyword arguments passed to each tree's training_graph. Returns: The last op in the random forest training graph. """ + data_spec = ([constants.DATA_FLOAT] * self.params.num_features + if data_spec is None else data_spec) tree_graphs = [] for i in range(self.params.num_trees): - with tf.device(self.device_assigner.get_device(i)): + with ops.device(self.device_assigner.get_device(i)): seed = self.params.base_random_seed if seed != 0: seed += i @@ -354,40 +362,54 @@ class RandomForestGraphs(object): if self.params.bagging_fraction < 1.0: # TODO(thomaswc): This does sampling without replacment. Consider # also allowing sampling with replacement as an option. - batch_size = tf.slice(tf.shape(input_data), [0], [1]) - r = tf.random_uniform(batch_size, seed=seed) - mask = tf.less(r, tf.ones_like(r) * self.params.bagging_fraction) - gather_indices = tf.squeeze(tf.where(mask), squeeze_dims=[1]) + batch_size = array_ops.slice(array_ops.shape(input_data), [0], [1]) + r = random_ops.random_uniform(batch_size, seed=seed) + mask = math_ops.less( + r, array_ops.ones_like(r) * self.params.bagging_fraction) + gather_indices = array_ops.squeeze( + array_ops.where(mask), squeeze_dims=[1]) # TODO(thomaswc): Calculate out-of-bag data and labels, and store # them for use in calculating statistics later. - tree_data = tf.gather(input_data, gather_indices) - tree_labels = tf.gather(input_labels, gather_indices) + tree_data = array_ops.gather(input_data, gather_indices) + tree_labels = array_ops.gather(input_labels, gather_indices) if self.params.bagged_features: tree_data = self._bag_features(i, tree_data) - tree_graphs.append( - self.trees[i].training_graph(tree_data, tree_labels, seed)) - return tf.group(*tree_graphs) + initialization = self.trees[i].tree_initialization() + + with ops.control_dependencies([initialization]): + tree_graphs.append( + self.trees[i].training_graph( + tree_data, tree_labels, seed, data_spec=data_spec, + epoch=([0] if epoch is None else epoch), + **tree_kwargs)) - def inference_graph(self, input_data): + return control_flow_ops.group(*tree_graphs) + + def inference_graph(self, input_data, data_spec=None): """Constructs a TF graph for evaluating a random forest. Args: - input_data: A tensor or placeholder for input data. + input_data: A tensor or SparseTensor or placeholder for input data. + data_spec: A list of tf.dtype values specifying the original types of + each column. Returns: The last op in the random forest inference graph. """ + data_spec = ([constants.DATA_FLOAT] * self.params.num_features + if data_spec is None else data_spec) probabilities = [] for i in range(self.params.num_trees): - with tf.device(self.device_assigner.get_device(i)): + with ops.device(self.device_assigner.get_device(i)): tree_data = input_data if self.params.bagged_features: tree_data = self._bag_features(i, input_data) - probabilities.append(self.trees[i].inference_graph(tree_data)) - with tf.device(self.device_assigner.get_device(0)): - all_predict = tf.pack(probabilities) - return tf.reduce_sum(all_predict, 0) / self.params.num_trees + probabilities.append(self.trees[i].inference_graph(tree_data, + data_spec)) + with ops.device(self.device_assigner.get_device(0)): + all_predict = array_ops.pack(probabilities) + return math_ops.reduce_sum(all_predict, 0) / self.params.num_trees def average_size(self): """Constructs a TF graph for evaluating the average size of a forest. @@ -397,9 +419,16 @@ class RandomForestGraphs(object): """ sizes = [] for i in range(self.params.num_trees): - with tf.device(self.device_assigner.get_device(i)): + with ops.device(self.device_assigner.get_device(i)): sizes.append(self.trees[i].size()) - return tf.reduce_mean(tf.pack(sizes)) + return math_ops.reduce_mean(array_ops.pack(sizes)) + + def training_loss(self): + return math_ops.neg(self.average_size()) + + # pylint: disable=unused-argument + def validation_loss(self, features, labels): + return math_ops.neg(self.average_size()) def average_impurity(self): """Constructs a TF graph for evaluating the leaf impurity of a forest. @@ -409,14 +438,14 @@ class RandomForestGraphs(object): """ impurities = [] for i in range(self.params.num_trees): - with tf.device(self.device_assigner.get_device(i)): + with ops.device(self.device_assigner.get_device(i)): impurities.append(self.trees[i].average_impurity()) - return tf.reduce_mean(tf.pack(impurities)) + return math_ops.reduce_mean(array_ops.pack(impurities)) def get_stats(self, session): tree_stats = [] for i in range(self.params.num_trees): - with tf.device(self.device_assigner.get_device(i)): + with ops.device(self.device_assigner.get_device(i)): tree_stats.append(self.trees[i].get_stats(session)) return ForestStats(tree_stats, self.params) @@ -431,6 +460,18 @@ class RandomTreeGraphs(object): self.params = params self.tree_num = tree_num + def tree_initialization(self): + def _init_tree(): + return state_ops.scatter_update(self.variables.tree, [0], [[-1, -1]]).op + + def _nothing(): + return control_flow_ops.no_op() + + return control_flow_ops.cond( + math_ops.equal(array_ops.squeeze(array_ops.slice( + self.variables.tree, [0, 0], [1, 1])), -2), + _init_tree, _nothing) + def _gini(self, class_counts): """Calculate the Gini impurity. @@ -444,9 +485,9 @@ class RandomTreeGraphs(object): Returns: A 1-D tensor of the Gini impurities for each row in the input. """ - smoothed = 1.0 + tf.slice(class_counts, [0, 1], [-1, -1]) - sums = tf.reduce_sum(smoothed, 1) - sum_squares = tf.reduce_sum(tf.square(smoothed), 1) + smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1]) + sums = math_ops.reduce_sum(smoothed, 1) + sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1) return 1.0 - sum_squares / (sums * sums) @@ -463,9 +504,9 @@ class RandomTreeGraphs(object): Returns: A 1-D tensor of the Gini impurities for each row in the input. """ - smoothed = 1.0 + tf.slice(class_counts, [0, 1], [-1, -1]) - sums = tf.reduce_sum(smoothed, 1) - sum_squares = tf.reduce_sum(tf.square(smoothed), 1) + smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1]) + sums = math_ops.reduce_sum(smoothed, 1) + sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1) return sums - sum_squares / sums @@ -483,40 +524,58 @@ class RandomTreeGraphs(object): Returns: A 1-D tensor of the variances for each row in the input. """ - total_count = tf.slice(sums, [0, 0], [-1, 1]) + total_count = array_ops.slice(sums, [0, 0], [-1, 1]) e_x = sums / total_count e_x2 = squares / total_count - return tf.reduce_sum(e_x2 - tf.square(e_x), 1) + return math_ops.reduce_sum(e_x2 - math_ops.square(e_x), 1) + + def training_graph(self, input_data, input_labels, random_seed, + data_spec, epoch=None): - def training_graph(self, input_data, input_labels, random_seed): """Constructs a TF graph for training a random tree. Args: - input_data: A tensor or placeholder for input data. + input_data: A tensor or SparseTensor or placeholder for input data. input_labels: A tensor or placeholder for labels associated with input_data. random_seed: The random number generator seed to use for this tree. 0 means use the current time as the seed. + data_spec: A list of tf.dtype values specifying the original types of + each column. + epoch: A tensor or placeholder for the epoch the training data comes from. Returns: The last op in the random tree training graph. """ + epoch = [0] if epoch is None else epoch + + sparse_indices = [] + sparse_values = [] + sparse_shape = [] + if isinstance(input_data, ops.SparseTensor): + sparse_indices = input_data.indices + sparse_values = input_data.values + sparse_shape = input_data.shape + input_data = [] + # Count extremely random stats. (node_sums, node_squares, splits_indices, splits_sums, splits_squares, totals_indices, totals_sums, totals_squares, input_leaves) = ( self.training_ops.count_extremely_random_stats( - input_data, input_labels, self.variables.tree, + input_data, sparse_indices, sparse_values, sparse_shape, + data_spec, input_labels, self.variables.tree, self.variables.tree_thresholds, self.variables.node_to_accumulator_map, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, + self.variables.start_epoch, epoch, num_classes=self.params.num_output_columns, regression=self.params.regression)) node_update_ops = [] node_update_ops.append( - tf.assign_add(self.variables.node_sums, node_sums)) + state_ops.assign_add(self.variables.node_sums, node_sums)) splits_update_ops = [] splits_update_ops.append(self.training_ops.scatter_add_ndim( @@ -527,8 +586,8 @@ class RandomTreeGraphs(object): totals_sums)) if self.params.regression: - node_update_ops.append(tf.assign_add(self.variables.node_squares, - node_squares)) + node_update_ops.append(state_ops.assign_add(self.variables.node_squares, + node_squares)) splits_update_ops.append(self.training_ops.scatter_add_ndim( self.variables.candidate_split_squares, splits_indices, splits_squares)) @@ -539,63 +598,56 @@ class RandomTreeGraphs(object): # Sample inputs. update_indices, feature_updates, threshold_updates = ( self.training_ops.sample_inputs( - input_data, self.variables.node_to_accumulator_map, + input_data, sparse_indices, sparse_values, sparse_shape, + self.variables.node_to_accumulator_map, input_leaves, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, split_initializations_per_input=( self.params.split_initializations_per_input), split_sampling_random_seed=random_seed)) - update_features_op = tf.scatter_update( + update_features_op = state_ops.scatter_update( self.variables.candidate_split_features, update_indices, feature_updates) - update_thresholds_op = tf.scatter_update( + update_thresholds_op = state_ops.scatter_update( self.variables.candidate_split_thresholds, update_indices, threshold_updates) # Calculate finished nodes. - with tf.control_dependencies(splits_update_ops): - children = tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1]), - squeeze_dims=[1]) - is_leaf = tf.equal(LEAF_NODE, children) - leaves = tf.to_int32(tf.squeeze(tf.where(is_leaf), squeeze_dims=[1])) - finished = self.training_ops.finished_nodes( + with ops.control_dependencies(splits_update_ops): + children = array_ops.squeeze(array_ops.slice( + self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1]) + is_leaf = math_ops.equal(constants.LEAF_NODE, children) + leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf), + squeeze_dims=[1])) + finished, stale = self.training_ops.finished_nodes( leaves, self.variables.node_to_accumulator_map, + self.variables.candidate_split_sums, + self.variables.candidate_split_squares, self.variables.accumulator_sums, - num_split_after_samples=self.params.split_after_samples) + self.variables.accumulator_squares, + self.variables.start_epoch, epoch, + num_split_after_samples=self.params.split_after_samples, + min_split_samples=self.params.min_split_samples) # Update leaf scores. - # TODO(gilberth): Optimize this. It currently calculates counts for - # every non-fertile leaf. - with tf.control_dependencies(node_update_ops): - def dont_update_leaf_scores(): - return self.variables.non_fertile_leaf_scores - - def update_leaf_scores_regression(): - sums = tf.gather(self.variables.node_sums, - self.variables.non_fertile_leaves) - squares = tf.gather(self.variables.node_squares, - self.variables.non_fertile_leaves) - new_scores = self._variance(sums, squares) - return tf.assign(self.variables.non_fertile_leaf_scores, new_scores) - - def update_leaf_scores_classification(): - counts = tf.gather(self.variables.node_sums, - self.variables.non_fertile_leaves) - new_scores = self._weighted_gini(counts) - return tf.assign(self.variables.non_fertile_leaf_scores, new_scores) - - # Because we can't have tf.self.variables of size 0, we have to put in a - # garbage value of -1 in there. Here we check for that so we don't - # try to index into node_per_class_weights in a tf.gather with a negative - # number. - update_nonfertile_leaves_scores_op = tf.cond( - tf.less(self.variables.non_fertile_leaves[0], 0), - dont_update_leaf_scores, - update_leaf_scores_regression if self.params.regression else - update_leaf_scores_classification) + non_fertile_leaves = array_ops.boolean_mask( + leaves, math_ops.less(array_ops.gather( + self.variables.node_to_accumulator_map, leaves), 0)) + + # TODO(gilberth): It should be possible to limit the number of non + # fertile leaves we calculate scores for, especially since we can only take + # at most array_ops.shape(finished)[0] of them. + with ops.control_dependencies(node_update_ops): + sums = array_ops.gather(self.variables.node_sums, non_fertile_leaves) + if self.params.regression: + squares = array_ops.gather(self.variables.node_squares, + non_fertile_leaves) + non_fertile_leaf_scores = self._variance(sums, squares) + else: + non_fertile_leaf_scores = self._weighted_gini(sums) # Calculate best splits. - with tf.control_dependencies(splits_update_ops): + with ops.control_dependencies(splits_update_ops): split_indices = self.training_ops.best_splits( finished, self.variables.node_to_accumulator_map, self.variables.candidate_split_sums, @@ -605,7 +657,7 @@ class RandomTreeGraphs(object): regression=self.params.regression) # Grow tree. - with tf.control_dependencies([update_features_op, update_thresholds_op]): + with ops.control_dependencies([update_features_op, update_thresholds_op]): (tree_update_indices, tree_children_updates, tree_threshold_updates, tree_depth_updates, new_eot) = ( self.training_ops.grow_tree( @@ -613,110 +665,138 @@ class RandomTreeGraphs(object): self.variables.node_to_accumulator_map, finished, split_indices, self.variables.candidate_split_features, self.variables.candidate_split_thresholds)) - tree_update_op = tf.scatter_update( + tree_update_op = state_ops.scatter_update( self.variables.tree, tree_update_indices, tree_children_updates) - threhsolds_update_op = tf.scatter_update( + thresholds_update_op = state_ops.scatter_update( self.variables.tree_thresholds, tree_update_indices, tree_threshold_updates) - depth_update_op = tf.scatter_update( + depth_update_op = state_ops.scatter_update( self.variables.tree_depths, tree_update_indices, tree_depth_updates) + # TODO(thomaswc): Only update the epoch on the new leaves. + new_epoch_updates = epoch * array_ops.ones_like(tree_depth_updates) + epoch_update_op = state_ops.scatter_update( + self.variables.start_epoch, tree_update_indices, + new_epoch_updates) # Update fertile slots. - with tf.control_dependencies([update_nonfertile_leaves_scores_op, - depth_update_op]): - (node_map_updates, accumulators_cleared, accumulators_allocated, - new_nonfertile_leaves, new_nonfertile_leaves_scores) = ( - self.training_ops.update_fertile_slots( - finished, self.variables.non_fertile_leaves, - self.variables.non_fertile_leaf_scores, - self.variables.end_of_tree, self.variables.tree_depths, - self.variables.accumulator_sums, - self.variables.node_to_accumulator_map, - max_depth=self.params.max_depth, - regression=self.params.regression)) + with ops.control_dependencies([depth_update_op]): + (node_map_updates, accumulators_cleared, accumulators_allocated) = ( + self.training_ops.update_fertile_slots( + finished, non_fertile_leaves, + non_fertile_leaf_scores, + self.variables.end_of_tree, self.variables.tree_depths, + self.variables.accumulator_sums, + self.variables.node_to_accumulator_map, + stale, + max_depth=self.params.max_depth, + regression=self.params.regression)) # Ensure end_of_tree doesn't get updated until UpdateFertileSlots has # used it to calculate new leaves. - gated_new_eot, = tf.tuple([new_eot], control_inputs=[new_nonfertile_leaves]) - eot_update_op = tf.assign(self.variables.end_of_tree, gated_new_eot) + gated_new_eot, = control_flow_ops.tuple([new_eot], + control_inputs=[node_map_updates]) + eot_update_op = state_ops.assign(self.variables.end_of_tree, gated_new_eot) updates = [] updates.append(eot_update_op) updates.append(tree_update_op) - updates.append(threhsolds_update_op) - updates.append(tf.assign( - self.variables.non_fertile_leaves, new_nonfertile_leaves, - validate_shape=False)) - updates.append(tf.assign( - self.variables.non_fertile_leaf_scores, - new_nonfertile_leaves_scores, validate_shape=False)) - - updates.append(tf.scatter_update( + updates.append(thresholds_update_op) + updates.append(epoch_update_op) + + updates.append(state_ops.scatter_update( self.variables.node_to_accumulator_map, - tf.squeeze(tf.slice(node_map_updates, [0, 0], [1, -1]), - squeeze_dims=[0]), - tf.squeeze(tf.slice(node_map_updates, [1, 0], [1, -1]), - squeeze_dims=[0]))) + array_ops.squeeze(array_ops.slice(node_map_updates, [0, 0], [1, -1]), + squeeze_dims=[0]), + array_ops.squeeze(array_ops.slice(node_map_updates, [1, 0], [1, -1]), + squeeze_dims=[0]))) - cleared_and_allocated_accumulators = tf.concat( + cleared_and_allocated_accumulators = array_ops.concat( 0, [accumulators_cleared, accumulators_allocated]) # Calculate values to put into scatter update for candidate counts. # Candidate split counts are always reset back to 0 for both cleared # and allocated accumulators. This means some accumulators might be doubly # reset to 0 if the were released and not allocated, then later allocated. - split_values = tf.tile( - tf.expand_dims(tf.expand_dims( - tf.zeros_like(cleared_and_allocated_accumulators, dtype=tf.float32), - 1), 2), + split_values = array_ops.tile( + array_ops.expand_dims(array_ops.expand_dims( + array_ops.zeros_like(cleared_and_allocated_accumulators, + dtype=dtypes.float32), 1), 2), [1, self.params.num_splits_to_consider, self.params.num_output_columns]) - updates.append(tf.scatter_update( + updates.append(state_ops.scatter_update( self.variables.candidate_split_sums, cleared_and_allocated_accumulators, split_values)) if self.params.regression: - updates.append(tf.scatter_update( + updates.append(state_ops.scatter_update( self.variables.candidate_split_squares, cleared_and_allocated_accumulators, split_values)) # Calculate values to put into scatter update for total counts. - total_cleared = tf.tile( - tf.expand_dims( - tf.neg(tf.ones_like(accumulators_cleared, dtype=tf.float32)), 1), + total_cleared = array_ops.tile( + array_ops.expand_dims( + math_ops.neg(array_ops.ones_like(accumulators_cleared, + dtype=dtypes.float32)), 1), [1, self.params.num_output_columns]) - total_reset = tf.tile( - tf.expand_dims( - tf.zeros_like(accumulators_allocated, dtype=tf.float32), 1), + total_reset = array_ops.tile( + array_ops.expand_dims( + array_ops.zeros_like(accumulators_allocated, + dtype=dtypes.float32), 1), [1, self.params.num_output_columns]) - accumulator_updates = tf.concat(0, [total_cleared, total_reset]) - updates.append(tf.scatter_update( + accumulator_updates = array_ops.concat(0, [total_cleared, total_reset]) + updates.append(state_ops.scatter_update( self.variables.accumulator_sums, cleared_and_allocated_accumulators, accumulator_updates)) if self.params.regression: - updates.append(tf.scatter_update( + updates.append(state_ops.scatter_update( self.variables.accumulator_squares, cleared_and_allocated_accumulators, accumulator_updates)) # Calculate values to put into scatter update for candidate splits. - split_features_updates = tf.tile( - tf.expand_dims( - tf.neg(tf.ones_like(cleared_and_allocated_accumulators)), 1), + split_features_updates = array_ops.tile( + array_ops.expand_dims( + math_ops.neg(array_ops.ones_like( + cleared_and_allocated_accumulators)), 1), [1, self.params.num_splits_to_consider]) - updates.append(tf.scatter_update( + updates.append(state_ops.scatter_update( self.variables.candidate_split_features, cleared_and_allocated_accumulators, split_features_updates)) - return tf.group(*updates) + updates += self.finish_iteration() + + return control_flow_ops.group(*updates) + + def finish_iteration(self): + """Perform any operations that should be done at the end of an iteration. + + This is mostly useful for subclasses that need to reset variables after + an iteration, such as ones that are used to finish nodes. + + Returns: + A list of operations. + """ + return [] - def inference_graph(self, input_data): + def inference_graph(self, input_data, data_spec): """Constructs a TF graph for evaluating a random tree. Args: - input_data: A tensor or placeholder for input data. + input_data: A tensor or SparseTensor or placeholder for input data. + data_spec: A list of tf.dtype values specifying the original types of + each column. Returns: The last op in the random tree inference graph. """ + sparse_indices = [] + sparse_values = [] + sparse_shape = [] + if isinstance(input_data, ops.SparseTensor): + sparse_indices = input_data.indices + sparse_values = input_data.values + sparse_shape = input_data.shape + input_data = [] return self.inference_ops.tree_predictions( - input_data, self.variables.tree, self.variables.tree_thresholds, + input_data, sparse_indices, sparse_values, sparse_shape, data_spec, + self.variables.tree, + self.variables.tree_thresholds, self.variables.node_sums, valid_leaf_threshold=self.params.valid_leaf_threshold) @@ -729,13 +809,22 @@ class RandomTreeGraphs(object): Returns: The last op in the graph. """ - children = tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1]), - squeeze_dims=[1]) - is_leaf = tf.equal(LEAF_NODE, children) - leaves = tf.to_int32(tf.squeeze(tf.where(is_leaf), squeeze_dims=[1])) - counts = tf.gather(self.variables.node_sums, leaves) - impurity = self._weighted_gini(counts) - return tf.reduce_sum(impurity) / tf.reduce_sum(counts + 1.0) + children = array_ops.squeeze(array_ops.slice( + self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1]) + is_leaf = math_ops.equal(constants.LEAF_NODE, children) + leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf), + squeeze_dims=[1])) + counts = array_ops.gather(self.variables.node_sums, leaves) + gini = self._weighted_gini(counts) + # Guard against step 1, when there often are no leaves yet. + def impurity(): + return gini + # Since average impurity can be used for loss, when there's no data just + # return a big number so that loss always decreases. + def big(): + return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000. + return control_flow_ops.cond(math_ops.greater( + array_ops.shape(leaves)[0], 0), impurity, big) def size(self): """Constructs a TF graph for evaluating the current number of nodes. @@ -747,7 +836,8 @@ class RandomTreeGraphs(object): def get_stats(self, session): num_nodes = self.variables.end_of_tree.eval(session=session) - 1 - num_leaves = tf.where( - tf.equal(tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1])), - LEAF_NODE)).eval(session=session).shape[0] + num_leaves = array_ops.where( + math_ops.equal(array_ops.squeeze(array_ops.slice( + self.variables.tree, [0, 0], [-1, 1])), constants.LEAF_NODE) + ).eval(session=session).shape[0] return TreeStats(num_nodes, num_leaves) |