aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-04-12 11:25:18 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-12 12:31:42 -0700
commitcc3267fefaef23683fa85057fc27c21f2b31c8bd (patch)
tree46998a7afeca29bde1b07004d0dda2705b60f2d5
parentcbef061ec6393e7018b03b38b05d186875bda512 (diff)
More tests and minor fixes for tf/contrib/tensor_forest.
Change: 119666432
-rw-r--r--tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc6
-rw-r--r--tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc6
-rw-r--r--tensorflow/contrib/tensor_forest/python/ops/inference_ops.py9
-rw-r--r--tensorflow/contrib/tensor_forest/python/ops/training_ops.py9
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py142
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest_test.py39
6 files changed, 168 insertions, 43 deletions
diff --git a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc
index ab5ac9c899..bd2cd59eea 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc
@@ -79,9 +79,9 @@ REGISTER_OP("CountExtremelyRandomStats")
gives the j-th feature of the i-th input.
input_labels: The training batch's labels; `input_labels[i]` is the class
of the i-th input.
- tree:= A 2-d int32 tensor. `tree[0][i]` gives the index of the left child
- of the i-th node, `tree[0][i] + 1` gives the index of the right child of
- the i-th node, and `tree[1][i]` gives the index of the feature used to
+ tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child
+ of the i-th node, `tree[i][0] + 1` gives the index of the right child of
+ the i-th node, and `tree[i][1]` gives the index of the feature used to
split the i-th node.
tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
node.
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc
index 3e84534795..37640c31b6 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc
@@ -44,9 +44,9 @@ REGISTER_OP("TreePredictions")
input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
gives the j-th feature of the i-th input.
- tree:= A 2-d int32 tensor. `tree[0][i]` gives the index of the left child
- of the i-th node, `tree[0][i] + 1` gives the index of the right child of
- the i-th node, and `tree[1][i]` gives the index of the feature used to
+ tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child
+ of the i-th node, `tree[i][0] + 1` gives the index of the right child of
+ the i-th node, and `tree[i][1]` gives the index of the feature used to
split the i-th node.
tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
node.
diff --git a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
index 7cad6a8d38..bcf6ca6b6e 100644
--- a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
+++ b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
@@ -25,6 +25,12 @@ import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('inference_library_base_dir', '',
+ 'Directory to look for inference library file.')
+
INFERENCE_OPS_FILE = '_inference_ops.so'
_inference_ops = None
@@ -54,7 +60,8 @@ def Load():
with _ops_lock:
global _inference_ops
if not _inference_ops:
- data_files_path = tf.resource_loader.get_data_files_path()
+ data_files_path = os.path.join(FLAGS.inference_library_base_dir,
+ tf.resource_loader.get_data_files_path())
tf.logging.info('data path: %s', data_files_path)
_inference_ops = tf.load_op_library(os.path.join(
data_files_path, INFERENCE_OPS_FILE))
diff --git a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
index 8ca2491d60..5cf5e4af90 100644
--- a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
+++ b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
@@ -25,6 +25,12 @@ import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('training_library_base_dir', '',
+ 'Directory to look for inference library file.')
+
TRAINING_OPS_FILE = '_training_ops.so'
_training_ops = None
@@ -101,7 +107,8 @@ def Load():
with _ops_lock:
global _training_ops
if not _training_ops:
- data_files_path = tf.resource_loader.get_data_files_path()
+ data_files_path = os.path.join(FLAGS.training_library_base_dir,
+ tf.resource_loader.get_data_files_path())
tf.logging.info('data path: %s', data_files_path)
_training_ops = tf.load_op_library(os.path.join(
data_files_path, TRAINING_OPS_FILE))
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index 6257d6481d..45e8cab485 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -37,6 +37,13 @@ flags.DEFINE_float(
'samples_to_decide', 25.0,
'Only decide on a split, or only fully use a leaf, after this many '
'training samples have been seen.')
+flags.DEFINE_float('bagging_fraction', 1.0,
+ 'Use this fraction of the input, randomly chosen, to train '
+ 'each tree in the forest.')
+flags.DEFINE_integer(
+ 'num_splits_to_consider', 0,
+ 'If non-zero, consider this many candidates for a splitting '
+ 'rule at a fertile node.')
# If tree[i][0] equals this value, then i is a leaf node.
LEAF_NODE = -1
@@ -69,6 +76,9 @@ class ForestHParams(object):
# Fail fast if num_classes isn't set.
_ = getattr(self, 'num_classes')
+ self.bagging_fraction = getattr(self, 'bagging_fraction',
+ FLAGS.bagging_fraction)
+
self.num_trees = getattr(self, 'num_trees', FLAGS.num_trees)
self.max_nodes = getattr(self, 'max_nodes', FLAGS.max_nodes)
@@ -79,7 +89,9 @@ class ForestHParams(object):
# The Random Forest literature recommends sqrt(# features) for
# classification problems, and p/3 for regression problems.
# TODO(thomaswc): Consider capping this for large number of features.
- if not getattr(self, 'num_splits_to_consider', None):
+ self.num_splits_to_consider = getattr(self, 'num_splits_to_consider',
+ FLAGS.num_splits_to_consider)
+ if not self.num_splits_to_consider:
self.num_splits_to_consider = max(10, int(
math.ceil(math.sqrt(self.num_features))))
@@ -94,8 +106,8 @@ class ForestHParams(object):
self.max_fertile_nodes = getattr(self, 'max_fertile_nodes', num_fertile)
# But it also never needs to be larger than the number of leaves,
# which is max_nodes / 2.
- self.max_fertile_nodes = min(self.max_nodes,
- int(math.ceil(self.max_fertile_nodes / 2.0)))
+ self.max_fertile_nodes = min(self.max_fertile_nodes,
+ int(math.ceil(self.max_nodes / 2.0)))
# split_after_samples and valid_leaf_threshold should be about the same.
# Therefore, if either is set, use it to set the other. Otherwise, fall
@@ -184,23 +196,6 @@ class TreeStats(object):
self.num_leaves = num_leaves
-def get_tree_stats(variables, unused_params, session):
- num_nodes = variables.end_of_tree.eval(session=session) - 1
- num_leaves = tf.where(
- tf.equal(tf.squeeze(tf.slice(variables.tree, [0, 0], [-1, 1])),
- LEAF_NODE)).eval(session=session).shape[0]
- return TreeStats(num_nodes, num_leaves)
-
-
-def get_forest_stats(variables, params, session):
-
- tree_stats = []
- for i in range(params.num_trees):
- tree_stats.append(get_tree_stats(variables[i], params, session))
-
- return ForestStats(tree_stats, params)
-
-
class ForestTrainingVariables(object):
"""A container for a forests training data, consisting of multiple trees.
@@ -212,9 +207,11 @@ class ForestTrainingVariables(object):
... forest_variables.tree ...
"""
- def __init__(self, params):
- self.variables = [TreeTrainingVariables(params)
- for _ in range(params.num_trees)]
+ def __init__(self, params, device_assigner):
+ self.variables = []
+ for i in range(params.num_trees):
+ with tf.device(device_assigner.get_device(i)):
+ self.variables.append(TreeTrainingVariables(params))
def __setitem__(self, t, val):
self.variables[t] = val
@@ -223,12 +220,35 @@ class ForestTrainingVariables(object):
return self.variables[t]
+class RandomForestDeviceAssigner(object):
+ """A device assigner that uses the default device.
+
+ Write subclasses that implement get_device for control over how trees
+ get assigned to devices. This assumes that whole trees are assigned
+ to a device.
+ """
+
+ def __init__(self):
+ self.cached = None
+
+ def get_device(self, unused_tree_num):
+ if not self.cached:
+ dummy = tf.constant(0)
+ self.cached = dummy.device
+
+ return self.cached
+
+
class RandomForestGraphs(object):
"""Builds TF graphs for random forest training and inference."""
- def __init__(self, params):
+ def __init__(self, params, device_assigner=None, variables=None):
self.params = params
- self.variables = ForestTrainingVariables(self.params)
+ self.device_assigner = device_assigner or RandomForestDeviceAssigner()
+ tf.logging.info('Constructing forest with params = ')
+ tf.logging.info(self.params.__dict__)
+ self.variables = variables or ForestTrainingVariables(
+ self.params, device_assigner=self.device_assigner)
self.trees = [RandomTreeGraphs(self.variables[i], self.params,
training_ops.Load(), inference_ops.Load())
for i in range(self.params.num_trees)]
@@ -246,12 +266,26 @@ class RandomForestGraphs(object):
"""
tree_graphs = []
for i in range(self.params.num_trees):
- tf.logging.info('Constructing tree %d', i)
- seed = self.params.base_random_seed
- if seed != 0:
- seed += i
- tree_graphs.append(self.trees[i].training_graph(
- input_data, input_labels, seed))
+ with tf.device(self.device_assigner.get_device(i)):
+ seed = self.params.base_random_seed
+ if seed != 0:
+ seed += i
+ # If using bagging, randomly select some of the input.
+ tree_data = input_data
+ tree_labels = input_labels
+ if self.params.bagging_fraction < 1.0:
+ # TODO(thomaswc): This does sampling without replacment. Consider
+ # also allowing sampling with replacement as an option.
+ batch_size = tf.slice(tf.shape(input_data), [0], [1])
+ r = tf.random_uniform(batch_size, seed=seed)
+ mask = tf.less(r, tf.ones_like(r) * self.params.bagging_fraction)
+ gather_indices = tf.squeeze(tf.where(mask), squeeze_dims=[1])
+ # TODO(thomaswc): Calculate out-of-bag data and labels, and store
+ # them for use in calculating statistics later.
+ tree_data = tf.gather(input_data, gather_indices)
+ tree_labels = tf.gather(input_labels, gather_indices)
+ tree_graphs.append(
+ self.trees[i].training_graph(tree_data, tree_labels, seed))
return tf.group(*tree_graphs)
def inference_graph(self, input_data):
@@ -265,9 +299,23 @@ class RandomForestGraphs(object):
"""
probabilities = []
for i in range(self.params.num_trees):
- probabilities.append(self.trees[i].inference_graph(input_data))
- all_predict = tf.pack(probabilities)
- return tf.reduce_sum(all_predict, 0) / self.params.num_trees
+ with tf.device(self.device_assigner.get_device(i)):
+ probabilities.append(self.trees[i].inference_graph(input_data))
+ with tf.device(self.device_assigner.get_device(0)):
+ all_predict = tf.pack(probabilities)
+ return tf.reduce_sum(all_predict, 0) / self.params.num_trees
+
+ def average_size(self):
+ """Constructs a TF graph for evaluating the average size of a forest.
+
+ Returns:
+ The average number of nodes over the trees.
+ """
+ sizes = []
+ for i in range(self.params.num_trees):
+ with tf.device(self.device_assigner.get_device(i)):
+ sizes.append(self.trees[i].size())
+ return tf.reduce_mean(tf.pack(sizes))
def average_impurity(self):
"""Constructs a TF graph for evaluating the leaf impurity of a forest.
@@ -277,9 +325,17 @@ class RandomForestGraphs(object):
"""
impurities = []
for i in range(self.params.num_trees):
- impurities.append(self.trees[i].average_impurity(self.variables[i]))
+ with tf.device(self.device_assigner.get_device(i)):
+ impurities.append(self.trees[i].average_impurity())
return tf.reduce_mean(tf.pack(impurities))
+ def get_stats(self, session):
+ tree_stats = []
+ for i in range(self.params.num_trees):
+ with tf.device(self.device_assigner.get_device(i)):
+ tree_stats.append(self.trees[i].get_stats(session))
+ return ForestStats(tree_stats, self.params)
+
class RandomTreeGraphs(object):
"""Builds TF graphs for random tree training and inference."""
@@ -394,6 +450,7 @@ class RandomTreeGraphs(object):
with tf.control_dependencies([node_update_op]):
def f1():
return self.variables.non_fertile_leaf_scores
+
def f2():
counts = tf.gather(self.variables.node_per_class_weights,
self.variables.non_fertile_leaves)
@@ -535,3 +592,18 @@ class RandomTreeGraphs(object):
counts = tf.gather(self.variables.node_per_class_weights, leaves)
impurity = self._weighted_gini(counts)
return tf.reduce_sum(impurity) / tf.reduce_sum(counts + 1.0)
+
+ def size(self):
+ """Constructs a TF graph for evaluating the current number of nodes.
+
+ Returns:
+ The current number of nodes in the tree.
+ """
+ return self.variables.end_of_tree - 1
+
+ def get_stats(self, session):
+ num_nodes = self.variables.end_of_tree.eval(session=session) - 1
+ num_leaves = tf.where(
+ tf.equal(tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1])),
+ LEAF_NODE)).eval(session=session).shape[0]
+ return TreeStats(num_nodes, num_leaves)
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index e4846cb047..a2cf187bdc 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -27,6 +27,37 @@ from tensorflow.python.platform import googletest
class TensorForestTest(test_util.TensorFlowTestCase):
+ def testForestHParams(self):
+ hparams = tensor_forest.ForestHParams(
+ num_classes=2, num_trees=100, max_nodes=1000,
+ num_features=60).fill()
+ self.assertEquals(2, hparams.num_classes)
+ # 2 * ceil(log_2(1000)) = 20
+ self.assertEquals(20, hparams.max_depth)
+ # sqrt(num_features) < 10, so num_splits_to_consider should be 10.
+ self.assertEquals(10, hparams.num_splits_to_consider)
+ # Don't have more fertile nodes than max # leaves, which is 500.
+ self.assertEquals(500, hparams.max_fertile_nodes)
+ # We didn't set either of these, so they should be equal
+ self.assertEquals(hparams.split_after_samples,
+ hparams.valid_leaf_threshold)
+ # split_after_samples is larger than 10
+ self.assertEquals(1, hparams.split_initializations_per_input)
+ self.assertEquals(0, hparams.base_random_seed)
+
+ def testForestHParamsBigTree(self):
+ hparams = tensor_forest.ForestHParams(
+ num_classes=2, num_trees=100, max_nodes=1000000,
+ split_after_samples=25,
+ num_features=1000).fill()
+ self.assertEquals(40, hparams.max_depth)
+ # sqrt(1000) = 31.63...
+ self.assertEquals(32, hparams.num_splits_to_consider)
+ # 1000000 / 32 = 31250
+ self.assertEquals(31250, hparams.max_fertile_nodes)
+ # floor(31.63 / 25) = 1
+ self.assertEquals(1, hparams.split_initializations_per_input)
+
def testTrainingConstruction(self):
input_data = [[-1., 0.], [-1., 2.], # node 1
[1., 0.], [1., -2.]] # node 2
@@ -50,6 +81,14 @@ class TensorForestTest(test_util.TensorFlowTestCase):
graph = graph_builder.inference_graph(input_data)
self.assertTrue(isinstance(graph, tf.Tensor))
+ def testImpurityConstruction(self):
+ params = tensor_forest.ForestHParams(
+ num_classes=4, num_features=2, num_trees=10, max_nodes=1000).fill()
+
+ graph_builder = tensor_forest.RandomForestGraphs(params)
+ graph = graph_builder.average_impurity()
+ self.assertTrue(isinstance(graph, tf.Tensor))
+
if __name__ == '__main__':
googletest.main()