aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/python/tensor_forest.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-30 14:40:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-30 15:48:03 -0700
commitb70103502b41df370906e8988b6593e55caf69cf (patch)
tree3455ed439430bb6c0e739bb974a52a99a7bc6626 /tensorflow/contrib/tensor_forest/python/tensor_forest.py
parentd3067c338425bdf97fa782d834399b89bce18309 (diff)
Improvements to tensor_forest, including support for sparse and categorical inputs.
Add tf.learn.Estimator for random forests. Change: 126352221
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python/tensor_forest.py')
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py510
1 files changed, 300 insertions, 210 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index f48efaa5db..791954c51f 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -1,3 +1,4 @@
+# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,14 +21,22 @@ from __future__ import print_function
import math
import random
-import tensorflow as tf
-
+from tensorflow.contrib.tensor_forest.python import constants
from tensorflow.contrib.tensor_forest.python.ops import inference_ops
from tensorflow.contrib.tensor_forest.python.ops import training_ops
-
-# If tree[i][0] equals this value, then i is a leaf node.
-LEAF_NODE = -1
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.platform import tf_logging as logging
# A convenience class for holding random forest hyperparameters.
@@ -49,6 +58,7 @@ class ForestHParams(object):
max_depth=0, num_splits_to_consider=0,
feature_bagging_fraction=1.0,
max_fertile_nodes=0, split_after_samples=250,
+ min_split_samples=5,
valid_leaf_threshold=1, **kwargs):
self.num_trees = num_trees
self.max_nodes = max_nodes
@@ -58,6 +68,7 @@ class ForestHParams(object):
self.num_splits_to_consider = num_splits_to_consider
self.max_fertile_nodes = max_fertile_nodes
self.split_after_samples = split_after_samples
+ self.min_split_samples = min_split_samples
self.valid_leaf_threshold = valid_leaf_threshold
for name, value in kwargs.items():
@@ -72,11 +83,6 @@ class ForestHParams(object):
_ = getattr(self, 'num_classes')
_ = getattr(self, 'num_features')
- self.training_library_base_dir = getattr(
- self, 'training_library_base_dir', '')
- self.inference_library_base_dir = getattr(
- self, 'inference_library_base_dir', '')
-
self.bagged_num_features = int(self.feature_bagging_fraction *
self.num_features)
@@ -147,92 +153,86 @@ class TreeTrainingVariables(object):
"""
def __init__(self, params, tree_num, training):
- self.tree = tf.get_variable(
- name=self.get_tree_name('tree', tree_num), dtype=tf.int32,
- initializer=tf.constant(
- [[-1, -1]] + [[-2, -1]] * (params.max_nodes - 1)))
- self.tree_thresholds = tf.get_variable(
+ self.tree = variable_scope.get_variable(
+ name=self.get_tree_name('tree', tree_num), dtype=dtypes.int32,
+ shape=[params.max_nodes, 2],
+ initializer=init_ops.constant_initializer(-2))
+ self.tree_thresholds = variable_scope.get_variable(
name=self.get_tree_name('tree_thresholds', tree_num),
shape=[params.max_nodes],
- initializer=tf.constant_initializer(-1.0))
- self.tree_depths = tf.get_variable(
+ initializer=init_ops.constant_initializer(-1.0))
+ self.tree_depths = variable_scope.get_variable(
name=self.get_tree_name('tree_depths', tree_num),
shape=[params.max_nodes],
- dtype=tf.int32,
- initializer=tf.constant_initializer(1))
- self.end_of_tree = tf.get_variable(
+ dtype=dtypes.int32,
+ initializer=init_ops.constant_initializer(1))
+ self.end_of_tree = variable_scope.get_variable(
name=self.get_tree_name('end_of_tree', tree_num),
- dtype=tf.int32,
- initializer=tf.constant([1]))
+ dtype=dtypes.int32,
+ initializer=constant_op.constant([1]))
+ self.start_epoch = tf_variables.Variable(
+ [0] * (params.max_nodes), name='start_epoch')
if training:
- self.non_fertile_leaves = tf.get_variable(
- name=self.get_tree_name('non_fertile_leaves', tree_num),
- dtype=tf.int32,
- initializer=tf.constant([0]))
- self.non_fertile_leaf_scores = tf.get_variable(
- name=self.get_tree_name('non_fertile_leaf_scores', tree_num),
- initializer=tf.constant([1.0]))
-
- self.node_to_accumulator_map = tf.get_variable(
+ self.node_to_accumulator_map = variable_scope.get_variable(
name=self.get_tree_name('node_to_accumulator_map', tree_num),
shape=[params.max_nodes],
- dtype=tf.int32,
- initializer=tf.constant_initializer(-1))
+ dtype=dtypes.int32,
+ initializer=init_ops.constant_initializer(-1))
- self.candidate_split_features = tf.get_variable(
+ self.candidate_split_features = variable_scope.get_variable(
name=self.get_tree_name('candidate_split_features', tree_num),
shape=[params.max_fertile_nodes, params.num_splits_to_consider],
- dtype=tf.int32,
- initializer=tf.constant_initializer(-1))
- self.candidate_split_thresholds = tf.get_variable(
+ dtype=dtypes.int32,
+ initializer=init_ops.constant_initializer(-1))
+ self.candidate_split_thresholds = variable_scope.get_variable(
name=self.get_tree_name('candidate_split_thresholds', tree_num),
shape=[params.max_fertile_nodes, params.num_splits_to_consider],
- initializer=tf.constant_initializer(0.0))
+ initializer=init_ops.constant_initializer(0.0))
# Statistics shared by classification and regression.
- self.node_sums = tf.get_variable(
+ self.node_sums = variable_scope.get_variable(
name=self.get_tree_name('node_sums', tree_num),
shape=[params.max_nodes, params.num_output_columns],
- initializer=tf.constant_initializer(0.0))
+ initializer=init_ops.constant_initializer(0.0))
if training:
- self.candidate_split_sums = tf.get_variable(
+ self.candidate_split_sums = variable_scope.get_variable(
name=self.get_tree_name('candidate_split_sums', tree_num),
shape=[params.max_fertile_nodes, params.num_splits_to_consider,
params.num_output_columns],
- initializer=tf.constant_initializer(0.0))
- self.accumulator_sums = tf.get_variable(
+ initializer=init_ops.constant_initializer(0.0))
+ self.accumulator_sums = variable_scope.get_variable(
name=self.get_tree_name('accumulator_sums', tree_num),
shape=[params.max_fertile_nodes, params.num_output_columns],
- initializer=tf.constant_initializer(-1.0))
+ initializer=init_ops.constant_initializer(-1.0))
# Regression also tracks second order stats.
if params.regression:
- self.node_squares = tf.get_variable(
+ self.node_squares = variable_scope.get_variable(
name=self.get_tree_name('node_squares', tree_num),
shape=[params.max_nodes, params.num_output_columns],
- initializer=tf.constant_initializer(0.0))
+ initializer=init_ops.constant_initializer(0.0))
- self.candidate_split_squares = tf.get_variable(
+ self.candidate_split_squares = variable_scope.get_variable(
name=self.get_tree_name('candidate_split_squares', tree_num),
shape=[params.max_fertile_nodes, params.num_splits_to_consider,
params.num_output_columns],
- initializer=tf.constant_initializer(0.0))
+ initializer=init_ops.constant_initializer(0.0))
- self.accumulator_squares = tf.get_variable(
+ self.accumulator_squares = variable_scope.get_variable(
name=self.get_tree_name('accumulator_squares', tree_num),
shape=[params.max_fertile_nodes, params.num_output_columns],
- initializer=tf.constant_initializer(-1.0))
+ initializer=init_ops.constant_initializer(-1.0))
else:
- self.node_squares = tf.constant(
+ self.node_squares = constant_op.constant(
0.0, name=self.get_tree_name('node_squares', tree_num))
- self.candidate_split_squares = tf.constant(
+ self.candidate_split_squares = constant_op.constant(
0.0, name=self.get_tree_name('candidate_split_squares', tree_num))
- self.accumulator_squares = tf.constant(
+ self.accumulator_squares = constant_op.constant(
0.0, name=self.get_tree_name('accumulator_squares', tree_num))
def get_tree_name(self, name, num):
@@ -273,11 +273,11 @@ class ForestTrainingVariables(object):
"""
def __init__(self, params, device_assigner, training=True,
- tree_variable_class=TreeTrainingVariables):
+ tree_variables_class=TreeTrainingVariables):
self.variables = []
for i in range(params.num_trees):
- with tf.device(device_assigner.get_device(i)):
- self.variables.append(tree_variable_class(params, i, training))
+ with ops.device(device_assigner.get_device(i)):
+ self.variables.append(tree_variables_class(params, i, training))
def __setitem__(self, t, val):
self.variables[t] = val
@@ -299,7 +299,7 @@ class RandomForestDeviceAssigner(object):
def get_device(self, unused_tree_num):
if not self.cached:
- dummy = tf.constant(0)
+ dummy = constant_op.constant(0)
self.cached = dummy.device
return self.cached
@@ -308,43 +308,51 @@ class RandomForestDeviceAssigner(object):
class RandomForestGraphs(object):
"""Builds TF graphs for random forest training and inference."""
- def __init__(self, params, device_assigner=None, variables=None,
- tree_graphs=None,
+ def __init__(self, params, device_assigner=None,
+ variables=None, tree_variables_class=TreeTrainingVariables,
+ tree_graphs=None, training=True,
t_ops=training_ops,
i_ops=inference_ops):
self.params = params
self.device_assigner = device_assigner or RandomForestDeviceAssigner()
- tf.logging.info('Constructing forest with params = ')
- tf.logging.info(self.params.__dict__)
+ logging.info('Constructing forest with params = ')
+ logging.info(self.params.__dict__)
self.variables = variables or ForestTrainingVariables(
- self.params, device_assigner=self.device_assigner)
+ self.params, device_assigner=self.device_assigner, training=training,
+ tree_variables_class=tree_variables_class)
tree_graph_class = tree_graphs or RandomTreeGraphs
self.trees = [
tree_graph_class(
self.variables[i], self.params,
- t_ops.Load(self.params.training_library_base_dir),
- i_ops.Load(self.params.inference_library_base_dir), i)
+ t_ops.Load(), i_ops.Load(), i)
for i in range(self.params.num_trees)]
def _bag_features(self, tree_num, input_data):
- split_data = tf.split(1, self.params.num_features, input_data)
- return tf.concat(1, [split_data[ind]
- for ind in self.params.bagged_features[tree_num]])
+ split_data = array_ops.split(1, self.params.num_features, input_data)
+ return array_ops.concat(
+ 1, [split_data[ind] for ind in self.params.bagged_features[tree_num]])
- def training_graph(self, input_data, input_labels):
+ def training_graph(self, input_data, input_labels, data_spec=None,
+ epoch=None, **tree_kwargs):
"""Constructs a TF graph for training a random forest.
Args:
- input_data: A tensor or placeholder for input data.
+ input_data: A tensor or SparseTensor or placeholder for input data.
input_labels: A tensor or placeholder for labels associated with
input_data.
+ data_spec: A list of tf.dtype values specifying the original types of
+ each column.
+ epoch: A tensor or placeholder for the epoch the training data comes from.
+ **tree_kwargs: Keyword arguments passed to each tree's training_graph.
Returns:
The last op in the random forest training graph.
"""
+ data_spec = ([constants.DATA_FLOAT] * self.params.num_features
+ if data_spec is None else data_spec)
tree_graphs = []
for i in range(self.params.num_trees):
- with tf.device(self.device_assigner.get_device(i)):
+ with ops.device(self.device_assigner.get_device(i)):
seed = self.params.base_random_seed
if seed != 0:
seed += i
@@ -354,40 +362,54 @@ class RandomForestGraphs(object):
if self.params.bagging_fraction < 1.0:
# TODO(thomaswc): This does sampling without replacment. Consider
# also allowing sampling with replacement as an option.
- batch_size = tf.slice(tf.shape(input_data), [0], [1])
- r = tf.random_uniform(batch_size, seed=seed)
- mask = tf.less(r, tf.ones_like(r) * self.params.bagging_fraction)
- gather_indices = tf.squeeze(tf.where(mask), squeeze_dims=[1])
+ batch_size = array_ops.slice(array_ops.shape(input_data), [0], [1])
+ r = random_ops.random_uniform(batch_size, seed=seed)
+ mask = math_ops.less(
+ r, array_ops.ones_like(r) * self.params.bagging_fraction)
+ gather_indices = array_ops.squeeze(
+ array_ops.where(mask), squeeze_dims=[1])
# TODO(thomaswc): Calculate out-of-bag data and labels, and store
# them for use in calculating statistics later.
- tree_data = tf.gather(input_data, gather_indices)
- tree_labels = tf.gather(input_labels, gather_indices)
+ tree_data = array_ops.gather(input_data, gather_indices)
+ tree_labels = array_ops.gather(input_labels, gather_indices)
if self.params.bagged_features:
tree_data = self._bag_features(i, tree_data)
- tree_graphs.append(
- self.trees[i].training_graph(tree_data, tree_labels, seed))
- return tf.group(*tree_graphs)
+ initialization = self.trees[i].tree_initialization()
+
+ with ops.control_dependencies([initialization]):
+ tree_graphs.append(
+ self.trees[i].training_graph(
+ tree_data, tree_labels, seed, data_spec=data_spec,
+ epoch=([0] if epoch is None else epoch),
+ **tree_kwargs))
- def inference_graph(self, input_data):
+ return control_flow_ops.group(*tree_graphs)
+
+ def inference_graph(self, input_data, data_spec=None):
"""Constructs a TF graph for evaluating a random forest.
Args:
- input_data: A tensor or placeholder for input data.
+ input_data: A tensor or SparseTensor or placeholder for input data.
+ data_spec: A list of tf.dtype values specifying the original types of
+ each column.
Returns:
The last op in the random forest inference graph.
"""
+ data_spec = ([constants.DATA_FLOAT] * self.params.num_features
+ if data_spec is None else data_spec)
probabilities = []
for i in range(self.params.num_trees):
- with tf.device(self.device_assigner.get_device(i)):
+ with ops.device(self.device_assigner.get_device(i)):
tree_data = input_data
if self.params.bagged_features:
tree_data = self._bag_features(i, input_data)
- probabilities.append(self.trees[i].inference_graph(tree_data))
- with tf.device(self.device_assigner.get_device(0)):
- all_predict = tf.pack(probabilities)
- return tf.reduce_sum(all_predict, 0) / self.params.num_trees
+ probabilities.append(self.trees[i].inference_graph(tree_data,
+ data_spec))
+ with ops.device(self.device_assigner.get_device(0)):
+ all_predict = array_ops.pack(probabilities)
+ return math_ops.reduce_sum(all_predict, 0) / self.params.num_trees
def average_size(self):
"""Constructs a TF graph for evaluating the average size of a forest.
@@ -397,9 +419,16 @@ class RandomForestGraphs(object):
"""
sizes = []
for i in range(self.params.num_trees):
- with tf.device(self.device_assigner.get_device(i)):
+ with ops.device(self.device_assigner.get_device(i)):
sizes.append(self.trees[i].size())
- return tf.reduce_mean(tf.pack(sizes))
+ return math_ops.reduce_mean(array_ops.pack(sizes))
+
+ def training_loss(self):
+ return math_ops.neg(self.average_size())
+
+ # pylint: disable=unused-argument
+ def validation_loss(self, features, labels):
+ return math_ops.neg(self.average_size())
def average_impurity(self):
"""Constructs a TF graph for evaluating the leaf impurity of a forest.
@@ -409,14 +438,14 @@ class RandomForestGraphs(object):
"""
impurities = []
for i in range(self.params.num_trees):
- with tf.device(self.device_assigner.get_device(i)):
+ with ops.device(self.device_assigner.get_device(i)):
impurities.append(self.trees[i].average_impurity())
- return tf.reduce_mean(tf.pack(impurities))
+ return math_ops.reduce_mean(array_ops.pack(impurities))
def get_stats(self, session):
tree_stats = []
for i in range(self.params.num_trees):
- with tf.device(self.device_assigner.get_device(i)):
+ with ops.device(self.device_assigner.get_device(i)):
tree_stats.append(self.trees[i].get_stats(session))
return ForestStats(tree_stats, self.params)
@@ -431,6 +460,18 @@ class RandomTreeGraphs(object):
self.params = params
self.tree_num = tree_num
+ def tree_initialization(self):
+ def _init_tree():
+ return state_ops.scatter_update(self.variables.tree, [0], [[-1, -1]]).op
+
+ def _nothing():
+ return control_flow_ops.no_op()
+
+ return control_flow_ops.cond(
+ math_ops.equal(array_ops.squeeze(array_ops.slice(
+ self.variables.tree, [0, 0], [1, 1])), -2),
+ _init_tree, _nothing)
+
def _gini(self, class_counts):
"""Calculate the Gini impurity.
@@ -444,9 +485,9 @@ class RandomTreeGraphs(object):
Returns:
A 1-D tensor of the Gini impurities for each row in the input.
"""
- smoothed = 1.0 + tf.slice(class_counts, [0, 1], [-1, -1])
- sums = tf.reduce_sum(smoothed, 1)
- sum_squares = tf.reduce_sum(tf.square(smoothed), 1)
+ smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1])
+ sums = math_ops.reduce_sum(smoothed, 1)
+ sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1)
return 1.0 - sum_squares / (sums * sums)
@@ -463,9 +504,9 @@ class RandomTreeGraphs(object):
Returns:
A 1-D tensor of the Gini impurities for each row in the input.
"""
- smoothed = 1.0 + tf.slice(class_counts, [0, 1], [-1, -1])
- sums = tf.reduce_sum(smoothed, 1)
- sum_squares = tf.reduce_sum(tf.square(smoothed), 1)
+ smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1])
+ sums = math_ops.reduce_sum(smoothed, 1)
+ sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1)
return sums - sum_squares / sums
@@ -483,40 +524,58 @@ class RandomTreeGraphs(object):
Returns:
A 1-D tensor of the variances for each row in the input.
"""
- total_count = tf.slice(sums, [0, 0], [-1, 1])
+ total_count = array_ops.slice(sums, [0, 0], [-1, 1])
e_x = sums / total_count
e_x2 = squares / total_count
- return tf.reduce_sum(e_x2 - tf.square(e_x), 1)
+ return math_ops.reduce_sum(e_x2 - math_ops.square(e_x), 1)
+
+ def training_graph(self, input_data, input_labels, random_seed,
+ data_spec, epoch=None):
- def training_graph(self, input_data, input_labels, random_seed):
"""Constructs a TF graph for training a random tree.
Args:
- input_data: A tensor or placeholder for input data.
+ input_data: A tensor or SparseTensor or placeholder for input data.
input_labels: A tensor or placeholder for labels associated with
input_data.
random_seed: The random number generator seed to use for this tree. 0
means use the current time as the seed.
+ data_spec: A list of tf.dtype values specifying the original types of
+ each column.
+ epoch: A tensor or placeholder for the epoch the training data comes from.
Returns:
The last op in the random tree training graph.
"""
+ epoch = [0] if epoch is None else epoch
+
+ sparse_indices = []
+ sparse_values = []
+ sparse_shape = []
+ if isinstance(input_data, ops.SparseTensor):
+ sparse_indices = input_data.indices
+ sparse_values = input_data.values
+ sparse_shape = input_data.shape
+ input_data = []
+
# Count extremely random stats.
(node_sums, node_squares, splits_indices, splits_sums,
splits_squares, totals_indices, totals_sums,
totals_squares, input_leaves) = (
self.training_ops.count_extremely_random_stats(
- input_data, input_labels, self.variables.tree,
+ input_data, sparse_indices, sparse_values, sparse_shape,
+ data_spec, input_labels, self.variables.tree,
self.variables.tree_thresholds,
self.variables.node_to_accumulator_map,
self.variables.candidate_split_features,
self.variables.candidate_split_thresholds,
+ self.variables.start_epoch, epoch,
num_classes=self.params.num_output_columns,
regression=self.params.regression))
node_update_ops = []
node_update_ops.append(
- tf.assign_add(self.variables.node_sums, node_sums))
+ state_ops.assign_add(self.variables.node_sums, node_sums))
splits_update_ops = []
splits_update_ops.append(self.training_ops.scatter_add_ndim(
@@ -527,8 +586,8 @@ class RandomTreeGraphs(object):
totals_sums))
if self.params.regression:
- node_update_ops.append(tf.assign_add(self.variables.node_squares,
- node_squares))
+ node_update_ops.append(state_ops.assign_add(self.variables.node_squares,
+ node_squares))
splits_update_ops.append(self.training_ops.scatter_add_ndim(
self.variables.candidate_split_squares,
splits_indices, splits_squares))
@@ -539,63 +598,56 @@ class RandomTreeGraphs(object):
# Sample inputs.
update_indices, feature_updates, threshold_updates = (
self.training_ops.sample_inputs(
- input_data, self.variables.node_to_accumulator_map,
+ input_data, sparse_indices, sparse_values, sparse_shape,
+ self.variables.node_to_accumulator_map,
input_leaves, self.variables.candidate_split_features,
self.variables.candidate_split_thresholds,
split_initializations_per_input=(
self.params.split_initializations_per_input),
split_sampling_random_seed=random_seed))
- update_features_op = tf.scatter_update(
+ update_features_op = state_ops.scatter_update(
self.variables.candidate_split_features, update_indices,
feature_updates)
- update_thresholds_op = tf.scatter_update(
+ update_thresholds_op = state_ops.scatter_update(
self.variables.candidate_split_thresholds, update_indices,
threshold_updates)
# Calculate finished nodes.
- with tf.control_dependencies(splits_update_ops):
- children = tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1]),
- squeeze_dims=[1])
- is_leaf = tf.equal(LEAF_NODE, children)
- leaves = tf.to_int32(tf.squeeze(tf.where(is_leaf), squeeze_dims=[1]))
- finished = self.training_ops.finished_nodes(
+ with ops.control_dependencies(splits_update_ops):
+ children = array_ops.squeeze(array_ops.slice(
+ self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
+ is_leaf = math_ops.equal(constants.LEAF_NODE, children)
+ leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
+ squeeze_dims=[1]))
+ finished, stale = self.training_ops.finished_nodes(
leaves, self.variables.node_to_accumulator_map,
+ self.variables.candidate_split_sums,
+ self.variables.candidate_split_squares,
self.variables.accumulator_sums,
- num_split_after_samples=self.params.split_after_samples)
+ self.variables.accumulator_squares,
+ self.variables.start_epoch, epoch,
+ num_split_after_samples=self.params.split_after_samples,
+ min_split_samples=self.params.min_split_samples)
# Update leaf scores.
- # TODO(gilberth): Optimize this. It currently calculates counts for
- # every non-fertile leaf.
- with tf.control_dependencies(node_update_ops):
- def dont_update_leaf_scores():
- return self.variables.non_fertile_leaf_scores
-
- def update_leaf_scores_regression():
- sums = tf.gather(self.variables.node_sums,
- self.variables.non_fertile_leaves)
- squares = tf.gather(self.variables.node_squares,
- self.variables.non_fertile_leaves)
- new_scores = self._variance(sums, squares)
- return tf.assign(self.variables.non_fertile_leaf_scores, new_scores)
-
- def update_leaf_scores_classification():
- counts = tf.gather(self.variables.node_sums,
- self.variables.non_fertile_leaves)
- new_scores = self._weighted_gini(counts)
- return tf.assign(self.variables.non_fertile_leaf_scores, new_scores)
-
- # Because we can't have tf.self.variables of size 0, we have to put in a
- # garbage value of -1 in there. Here we check for that so we don't
- # try to index into node_per_class_weights in a tf.gather with a negative
- # number.
- update_nonfertile_leaves_scores_op = tf.cond(
- tf.less(self.variables.non_fertile_leaves[0], 0),
- dont_update_leaf_scores,
- update_leaf_scores_regression if self.params.regression else
- update_leaf_scores_classification)
+ non_fertile_leaves = array_ops.boolean_mask(
+ leaves, math_ops.less(array_ops.gather(
+ self.variables.node_to_accumulator_map, leaves), 0))
+
+ # TODO(gilberth): It should be possible to limit the number of non
+ # fertile leaves we calculate scores for, especially since we can only take
+ # at most array_ops.shape(finished)[0] of them.
+ with ops.control_dependencies(node_update_ops):
+ sums = array_ops.gather(self.variables.node_sums, non_fertile_leaves)
+ if self.params.regression:
+ squares = array_ops.gather(self.variables.node_squares,
+ non_fertile_leaves)
+ non_fertile_leaf_scores = self._variance(sums, squares)
+ else:
+ non_fertile_leaf_scores = self._weighted_gini(sums)
# Calculate best splits.
- with tf.control_dependencies(splits_update_ops):
+ with ops.control_dependencies(splits_update_ops):
split_indices = self.training_ops.best_splits(
finished, self.variables.node_to_accumulator_map,
self.variables.candidate_split_sums,
@@ -605,7 +657,7 @@ class RandomTreeGraphs(object):
regression=self.params.regression)
# Grow tree.
- with tf.control_dependencies([update_features_op, update_thresholds_op]):
+ with ops.control_dependencies([update_features_op, update_thresholds_op]):
(tree_update_indices, tree_children_updates,
tree_threshold_updates, tree_depth_updates, new_eot) = (
self.training_ops.grow_tree(
@@ -613,110 +665,138 @@ class RandomTreeGraphs(object):
self.variables.node_to_accumulator_map, finished, split_indices,
self.variables.candidate_split_features,
self.variables.candidate_split_thresholds))
- tree_update_op = tf.scatter_update(
+ tree_update_op = state_ops.scatter_update(
self.variables.tree, tree_update_indices, tree_children_updates)
- threhsolds_update_op = tf.scatter_update(
+ thresholds_update_op = state_ops.scatter_update(
self.variables.tree_thresholds, tree_update_indices,
tree_threshold_updates)
- depth_update_op = tf.scatter_update(
+ depth_update_op = state_ops.scatter_update(
self.variables.tree_depths, tree_update_indices, tree_depth_updates)
+ # TODO(thomaswc): Only update the epoch on the new leaves.
+ new_epoch_updates = epoch * array_ops.ones_like(tree_depth_updates)
+ epoch_update_op = state_ops.scatter_update(
+ self.variables.start_epoch, tree_update_indices,
+ new_epoch_updates)
# Update fertile slots.
- with tf.control_dependencies([update_nonfertile_leaves_scores_op,
- depth_update_op]):
- (node_map_updates, accumulators_cleared, accumulators_allocated,
- new_nonfertile_leaves, new_nonfertile_leaves_scores) = (
- self.training_ops.update_fertile_slots(
- finished, self.variables.non_fertile_leaves,
- self.variables.non_fertile_leaf_scores,
- self.variables.end_of_tree, self.variables.tree_depths,
- self.variables.accumulator_sums,
- self.variables.node_to_accumulator_map,
- max_depth=self.params.max_depth,
- regression=self.params.regression))
+ with ops.control_dependencies([depth_update_op]):
+ (node_map_updates, accumulators_cleared, accumulators_allocated) = (
+ self.training_ops.update_fertile_slots(
+ finished, non_fertile_leaves,
+ non_fertile_leaf_scores,
+ self.variables.end_of_tree, self.variables.tree_depths,
+ self.variables.accumulator_sums,
+ self.variables.node_to_accumulator_map,
+ stale,
+ max_depth=self.params.max_depth,
+ regression=self.params.regression))
# Ensure end_of_tree doesn't get updated until UpdateFertileSlots has
# used it to calculate new leaves.
- gated_new_eot, = tf.tuple([new_eot], control_inputs=[new_nonfertile_leaves])
- eot_update_op = tf.assign(self.variables.end_of_tree, gated_new_eot)
+ gated_new_eot, = control_flow_ops.tuple([new_eot],
+ control_inputs=[node_map_updates])
+ eot_update_op = state_ops.assign(self.variables.end_of_tree, gated_new_eot)
updates = []
updates.append(eot_update_op)
updates.append(tree_update_op)
- updates.append(threhsolds_update_op)
- updates.append(tf.assign(
- self.variables.non_fertile_leaves, new_nonfertile_leaves,
- validate_shape=False))
- updates.append(tf.assign(
- self.variables.non_fertile_leaf_scores,
- new_nonfertile_leaves_scores, validate_shape=False))
-
- updates.append(tf.scatter_update(
+ updates.append(thresholds_update_op)
+ updates.append(epoch_update_op)
+
+ updates.append(state_ops.scatter_update(
self.variables.node_to_accumulator_map,
- tf.squeeze(tf.slice(node_map_updates, [0, 0], [1, -1]),
- squeeze_dims=[0]),
- tf.squeeze(tf.slice(node_map_updates, [1, 0], [1, -1]),
- squeeze_dims=[0])))
+ array_ops.squeeze(array_ops.slice(node_map_updates, [0, 0], [1, -1]),
+ squeeze_dims=[0]),
+ array_ops.squeeze(array_ops.slice(node_map_updates, [1, 0], [1, -1]),
+ squeeze_dims=[0])))
- cleared_and_allocated_accumulators = tf.concat(
+ cleared_and_allocated_accumulators = array_ops.concat(
0, [accumulators_cleared, accumulators_allocated])
# Calculate values to put into scatter update for candidate counts.
# Candidate split counts are always reset back to 0 for both cleared
# and allocated accumulators. This means some accumulators might be doubly
# reset to 0 if the were released and not allocated, then later allocated.
- split_values = tf.tile(
- tf.expand_dims(tf.expand_dims(
- tf.zeros_like(cleared_and_allocated_accumulators, dtype=tf.float32),
- 1), 2),
+ split_values = array_ops.tile(
+ array_ops.expand_dims(array_ops.expand_dims(
+ array_ops.zeros_like(cleared_and_allocated_accumulators,
+ dtype=dtypes.float32), 1), 2),
[1, self.params.num_splits_to_consider, self.params.num_output_columns])
- updates.append(tf.scatter_update(
+ updates.append(state_ops.scatter_update(
self.variables.candidate_split_sums,
cleared_and_allocated_accumulators, split_values))
if self.params.regression:
- updates.append(tf.scatter_update(
+ updates.append(state_ops.scatter_update(
self.variables.candidate_split_squares,
cleared_and_allocated_accumulators, split_values))
# Calculate values to put into scatter update for total counts.
- total_cleared = tf.tile(
- tf.expand_dims(
- tf.neg(tf.ones_like(accumulators_cleared, dtype=tf.float32)), 1),
+ total_cleared = array_ops.tile(
+ array_ops.expand_dims(
+ math_ops.neg(array_ops.ones_like(accumulators_cleared,
+ dtype=dtypes.float32)), 1),
[1, self.params.num_output_columns])
- total_reset = tf.tile(
- tf.expand_dims(
- tf.zeros_like(accumulators_allocated, dtype=tf.float32), 1),
+ total_reset = array_ops.tile(
+ array_ops.expand_dims(
+ array_ops.zeros_like(accumulators_allocated,
+ dtype=dtypes.float32), 1),
[1, self.params.num_output_columns])
- accumulator_updates = tf.concat(0, [total_cleared, total_reset])
- updates.append(tf.scatter_update(
+ accumulator_updates = array_ops.concat(0, [total_cleared, total_reset])
+ updates.append(state_ops.scatter_update(
self.variables.accumulator_sums,
cleared_and_allocated_accumulators, accumulator_updates))
if self.params.regression:
- updates.append(tf.scatter_update(
+ updates.append(state_ops.scatter_update(
self.variables.accumulator_squares,
cleared_and_allocated_accumulators, accumulator_updates))
# Calculate values to put into scatter update for candidate splits.
- split_features_updates = tf.tile(
- tf.expand_dims(
- tf.neg(tf.ones_like(cleared_and_allocated_accumulators)), 1),
+ split_features_updates = array_ops.tile(
+ array_ops.expand_dims(
+ math_ops.neg(array_ops.ones_like(
+ cleared_and_allocated_accumulators)), 1),
[1, self.params.num_splits_to_consider])
- updates.append(tf.scatter_update(
+ updates.append(state_ops.scatter_update(
self.variables.candidate_split_features,
cleared_and_allocated_accumulators, split_features_updates))
- return tf.group(*updates)
+ updates += self.finish_iteration()
+
+ return control_flow_ops.group(*updates)
+
+ def finish_iteration(self):
+ """Perform any operations that should be done at the end of an iteration.
+
+ This is mostly useful for subclasses that need to reset variables after
+ an iteration, such as ones that are used to finish nodes.
+
+ Returns:
+ A list of operations.
+ """
+ return []
- def inference_graph(self, input_data):
+ def inference_graph(self, input_data, data_spec):
"""Constructs a TF graph for evaluating a random tree.
Args:
- input_data: A tensor or placeholder for input data.
+ input_data: A tensor or SparseTensor or placeholder for input data.
+ data_spec: A list of tf.dtype values specifying the original types of
+ each column.
Returns:
The last op in the random tree inference graph.
"""
+ sparse_indices = []
+ sparse_values = []
+ sparse_shape = []
+ if isinstance(input_data, ops.SparseTensor):
+ sparse_indices = input_data.indices
+ sparse_values = input_data.values
+ sparse_shape = input_data.shape
+ input_data = []
return self.inference_ops.tree_predictions(
- input_data, self.variables.tree, self.variables.tree_thresholds,
+ input_data, sparse_indices, sparse_values, sparse_shape, data_spec,
+ self.variables.tree,
+ self.variables.tree_thresholds,
self.variables.node_sums,
valid_leaf_threshold=self.params.valid_leaf_threshold)
@@ -729,13 +809,22 @@ class RandomTreeGraphs(object):
Returns:
The last op in the graph.
"""
- children = tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1]),
- squeeze_dims=[1])
- is_leaf = tf.equal(LEAF_NODE, children)
- leaves = tf.to_int32(tf.squeeze(tf.where(is_leaf), squeeze_dims=[1]))
- counts = tf.gather(self.variables.node_sums, leaves)
- impurity = self._weighted_gini(counts)
- return tf.reduce_sum(impurity) / tf.reduce_sum(counts + 1.0)
+ children = array_ops.squeeze(array_ops.slice(
+ self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
+ is_leaf = math_ops.equal(constants.LEAF_NODE, children)
+ leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
+ squeeze_dims=[1]))
+ counts = array_ops.gather(self.variables.node_sums, leaves)
+ gini = self._weighted_gini(counts)
+ # Guard against step 1, when there often are no leaves yet.
+ def impurity():
+ return gini
+ # Since average impurity can be used for loss, when there's no data just
+ # return a big number so that loss always decreases.
+ def big():
+ return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000.
+ return control_flow_ops.cond(math_ops.greater(
+ array_ops.shape(leaves)[0], 0), impurity, big)
def size(self):
"""Constructs a TF graph for evaluating the current number of nodes.
@@ -747,7 +836,8 @@ class RandomTreeGraphs(object):
def get_stats(self, session):
num_nodes = self.variables.end_of_tree.eval(session=session) - 1
- num_leaves = tf.where(
- tf.equal(tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1])),
- LEAF_NODE)).eval(session=session).shape[0]
+ num_leaves = array_ops.where(
+ math_ops.equal(array_ops.squeeze(array_ops.slice(
+ self.variables.tree, [0, 0], [-1, 1])), constants.LEAF_NODE)
+ ).eval(session=session).shape[0]
return TreeStats(num_nodes, num_leaves)