From cc3267fefaef23683fa85057fc27c21f2b31c8bd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 12 Apr 2016 11:25:18 -0800 Subject: More tests and minor fixes for tf/contrib/tensor_forest. Change: 119666432 --- .../core/ops/count_extremely_random_stats_op.cc | 6 +- .../tensor_forest/core/ops/tree_predictions_op.cc | 6 +- .../tensor_forest/python/ops/inference_ops.py | 9 +- .../tensor_forest/python/ops/training_ops.py | 9 +- .../contrib/tensor_forest/python/tensor_forest.py | 142 ++++++++++++++++----- .../tensor_forest/python/tensor_forest_test.py | 39 ++++++ 6 files changed, 168 insertions(+), 43 deletions(-) diff --git a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc index ab5ac9c899..bd2cd59eea 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc @@ -79,9 +79,9 @@ REGISTER_OP("CountExtremelyRandomStats") gives the j-th feature of the i-th input. input_labels: The training batch's labels; `input_labels[i]` is the class of the i-th input. - tree:= A 2-d int32 tensor. `tree[0][i]` gives the index of the left child - of the i-th node, `tree[0][i] + 1` gives the index of the right child of - the i-th node, and `tree[1][i]` gives the index of the feature used to + tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child + of the i-th node, `tree[i][0] + 1` gives the index of the right child of + the i-th node, and `tree[i][1]` gives the index of the feature used to split the i-th node. tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th node. diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc index 3e84534795..37640c31b6 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc @@ -44,9 +44,9 @@ REGISTER_OP("TreePredictions") input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` gives the j-th feature of the i-th input. - tree:= A 2-d int32 tensor. `tree[0][i]` gives the index of the left child - of the i-th node, `tree[0][i] + 1` gives the index of the right child of - the i-th node, and `tree[1][i]` gives the index of the feature used to + tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child + of the i-th node, `tree[i][0] + 1` gives the index of the right child of + the i-th node, and `tree[i][1]` gives the index of the feature used to split the i-th node. tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th node. diff --git a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py index 7cad6a8d38..bcf6ca6b6e 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py @@ -25,6 +25,12 @@ import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +flags = tf.app.flags +FLAGS = flags.FLAGS + +flags.DEFINE_string('inference_library_base_dir', '', + 'Directory to look for inference library file.') + INFERENCE_OPS_FILE = '_inference_ops.so' _inference_ops = None @@ -54,7 +60,8 @@ def Load(): with _ops_lock: global _inference_ops if not _inference_ops: - data_files_path = tf.resource_loader.get_data_files_path() + data_files_path = os.path.join(FLAGS.inference_library_base_dir, + tf.resource_loader.get_data_files_path()) tf.logging.info('data path: %s', data_files_path) _inference_ops = tf.load_op_library(os.path.join( data_files_path, INFERENCE_OPS_FILE)) diff --git a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py index 8ca2491d60..5cf5e4af90 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py @@ -25,6 +25,12 @@ import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +flags = tf.app.flags +FLAGS = flags.FLAGS + +flags.DEFINE_string('training_library_base_dir', '', + 'Directory to look for inference library file.') + TRAINING_OPS_FILE = '_training_ops.so' _training_ops = None @@ -101,7 +107,8 @@ def Load(): with _ops_lock: global _training_ops if not _training_ops: - data_files_path = tf.resource_loader.get_data_files_path() + data_files_path = os.path.join(FLAGS.training_library_base_dir, + tf.resource_loader.get_data_files_path()) tf.logging.info('data path: %s', data_files_path) _training_ops = tf.load_op_library(os.path.join( data_files_path, TRAINING_OPS_FILE)) diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index 6257d6481d..45e8cab485 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -37,6 +37,13 @@ flags.DEFINE_float( 'samples_to_decide', 25.0, 'Only decide on a split, or only fully use a leaf, after this many ' 'training samples have been seen.') +flags.DEFINE_float('bagging_fraction', 1.0, + 'Use this fraction of the input, randomly chosen, to train ' + 'each tree in the forest.') +flags.DEFINE_integer( + 'num_splits_to_consider', 0, + 'If non-zero, consider this many candidates for a splitting ' + 'rule at a fertile node.') # If tree[i][0] equals this value, then i is a leaf node. LEAF_NODE = -1 @@ -69,6 +76,9 @@ class ForestHParams(object): # Fail fast if num_classes isn't set. _ = getattr(self, 'num_classes') + self.bagging_fraction = getattr(self, 'bagging_fraction', + FLAGS.bagging_fraction) + self.num_trees = getattr(self, 'num_trees', FLAGS.num_trees) self.max_nodes = getattr(self, 'max_nodes', FLAGS.max_nodes) @@ -79,7 +89,9 @@ class ForestHParams(object): # The Random Forest literature recommends sqrt(# features) for # classification problems, and p/3 for regression problems. # TODO(thomaswc): Consider capping this for large number of features. - if not getattr(self, 'num_splits_to_consider', None): + self.num_splits_to_consider = getattr(self, 'num_splits_to_consider', + FLAGS.num_splits_to_consider) + if not self.num_splits_to_consider: self.num_splits_to_consider = max(10, int( math.ceil(math.sqrt(self.num_features)))) @@ -94,8 +106,8 @@ class ForestHParams(object): self.max_fertile_nodes = getattr(self, 'max_fertile_nodes', num_fertile) # But it also never needs to be larger than the number of leaves, # which is max_nodes / 2. - self.max_fertile_nodes = min(self.max_nodes, - int(math.ceil(self.max_fertile_nodes / 2.0))) + self.max_fertile_nodes = min(self.max_fertile_nodes, + int(math.ceil(self.max_nodes / 2.0))) # split_after_samples and valid_leaf_threshold should be about the same. # Therefore, if either is set, use it to set the other. Otherwise, fall @@ -184,23 +196,6 @@ class TreeStats(object): self.num_leaves = num_leaves -def get_tree_stats(variables, unused_params, session): - num_nodes = variables.end_of_tree.eval(session=session) - 1 - num_leaves = tf.where( - tf.equal(tf.squeeze(tf.slice(variables.tree, [0, 0], [-1, 1])), - LEAF_NODE)).eval(session=session).shape[0] - return TreeStats(num_nodes, num_leaves) - - -def get_forest_stats(variables, params, session): - - tree_stats = [] - for i in range(params.num_trees): - tree_stats.append(get_tree_stats(variables[i], params, session)) - - return ForestStats(tree_stats, params) - - class ForestTrainingVariables(object): """A container for a forests training data, consisting of multiple trees. @@ -212,9 +207,11 @@ class ForestTrainingVariables(object): ... forest_variables.tree ... """ - def __init__(self, params): - self.variables = [TreeTrainingVariables(params) - for _ in range(params.num_trees)] + def __init__(self, params, device_assigner): + self.variables = [] + for i in range(params.num_trees): + with tf.device(device_assigner.get_device(i)): + self.variables.append(TreeTrainingVariables(params)) def __setitem__(self, t, val): self.variables[t] = val @@ -223,12 +220,35 @@ class ForestTrainingVariables(object): return self.variables[t] +class RandomForestDeviceAssigner(object): + """A device assigner that uses the default device. + + Write subclasses that implement get_device for control over how trees + get assigned to devices. This assumes that whole trees are assigned + to a device. + """ + + def __init__(self): + self.cached = None + + def get_device(self, unused_tree_num): + if not self.cached: + dummy = tf.constant(0) + self.cached = dummy.device + + return self.cached + + class RandomForestGraphs(object): """Builds TF graphs for random forest training and inference.""" - def __init__(self, params): + def __init__(self, params, device_assigner=None, variables=None): self.params = params - self.variables = ForestTrainingVariables(self.params) + self.device_assigner = device_assigner or RandomForestDeviceAssigner() + tf.logging.info('Constructing forest with params = ') + tf.logging.info(self.params.__dict__) + self.variables = variables or ForestTrainingVariables( + self.params, device_assigner=self.device_assigner) self.trees = [RandomTreeGraphs(self.variables[i], self.params, training_ops.Load(), inference_ops.Load()) for i in range(self.params.num_trees)] @@ -246,12 +266,26 @@ class RandomForestGraphs(object): """ tree_graphs = [] for i in range(self.params.num_trees): - tf.logging.info('Constructing tree %d', i) - seed = self.params.base_random_seed - if seed != 0: - seed += i - tree_graphs.append(self.trees[i].training_graph( - input_data, input_labels, seed)) + with tf.device(self.device_assigner.get_device(i)): + seed = self.params.base_random_seed + if seed != 0: + seed += i + # If using bagging, randomly select some of the input. + tree_data = input_data + tree_labels = input_labels + if self.params.bagging_fraction < 1.0: + # TODO(thomaswc): This does sampling without replacment. Consider + # also allowing sampling with replacement as an option. + batch_size = tf.slice(tf.shape(input_data), [0], [1]) + r = tf.random_uniform(batch_size, seed=seed) + mask = tf.less(r, tf.ones_like(r) * self.params.bagging_fraction) + gather_indices = tf.squeeze(tf.where(mask), squeeze_dims=[1]) + # TODO(thomaswc): Calculate out-of-bag data and labels, and store + # them for use in calculating statistics later. + tree_data = tf.gather(input_data, gather_indices) + tree_labels = tf.gather(input_labels, gather_indices) + tree_graphs.append( + self.trees[i].training_graph(tree_data, tree_labels, seed)) return tf.group(*tree_graphs) def inference_graph(self, input_data): @@ -265,9 +299,23 @@ class RandomForestGraphs(object): """ probabilities = [] for i in range(self.params.num_trees): - probabilities.append(self.trees[i].inference_graph(input_data)) - all_predict = tf.pack(probabilities) - return tf.reduce_sum(all_predict, 0) / self.params.num_trees + with tf.device(self.device_assigner.get_device(i)): + probabilities.append(self.trees[i].inference_graph(input_data)) + with tf.device(self.device_assigner.get_device(0)): + all_predict = tf.pack(probabilities) + return tf.reduce_sum(all_predict, 0) / self.params.num_trees + + def average_size(self): + """Constructs a TF graph for evaluating the average size of a forest. + + Returns: + The average number of nodes over the trees. + """ + sizes = [] + for i in range(self.params.num_trees): + with tf.device(self.device_assigner.get_device(i)): + sizes.append(self.trees[i].size()) + return tf.reduce_mean(tf.pack(sizes)) def average_impurity(self): """Constructs a TF graph for evaluating the leaf impurity of a forest. @@ -277,9 +325,17 @@ class RandomForestGraphs(object): """ impurities = [] for i in range(self.params.num_trees): - impurities.append(self.trees[i].average_impurity(self.variables[i])) + with tf.device(self.device_assigner.get_device(i)): + impurities.append(self.trees[i].average_impurity()) return tf.reduce_mean(tf.pack(impurities)) + def get_stats(self, session): + tree_stats = [] + for i in range(self.params.num_trees): + with tf.device(self.device_assigner.get_device(i)): + tree_stats.append(self.trees[i].get_stats(session)) + return ForestStats(tree_stats, self.params) + class RandomTreeGraphs(object): """Builds TF graphs for random tree training and inference.""" @@ -394,6 +450,7 @@ class RandomTreeGraphs(object): with tf.control_dependencies([node_update_op]): def f1(): return self.variables.non_fertile_leaf_scores + def f2(): counts = tf.gather(self.variables.node_per_class_weights, self.variables.non_fertile_leaves) @@ -535,3 +592,18 @@ class RandomTreeGraphs(object): counts = tf.gather(self.variables.node_per_class_weights, leaves) impurity = self._weighted_gini(counts) return tf.reduce_sum(impurity) / tf.reduce_sum(counts + 1.0) + + def size(self): + """Constructs a TF graph for evaluating the current number of nodes. + + Returns: + The current number of nodes in the tree. + """ + return self.variables.end_of_tree - 1 + + def get_stats(self, session): + num_nodes = self.variables.end_of_tree.eval(session=session) - 1 + num_leaves = tf.where( + tf.equal(tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1])), + LEAF_NODE)).eval(session=session).shape[0] + return TreeStats(num_nodes, num_leaves) diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py index e4846cb047..a2cf187bdc 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py @@ -27,6 +27,37 @@ from tensorflow.python.platform import googletest class TensorForestTest(test_util.TensorFlowTestCase): + def testForestHParams(self): + hparams = tensor_forest.ForestHParams( + num_classes=2, num_trees=100, max_nodes=1000, + num_features=60).fill() + self.assertEquals(2, hparams.num_classes) + # 2 * ceil(log_2(1000)) = 20 + self.assertEquals(20, hparams.max_depth) + # sqrt(num_features) < 10, so num_splits_to_consider should be 10. + self.assertEquals(10, hparams.num_splits_to_consider) + # Don't have more fertile nodes than max # leaves, which is 500. + self.assertEquals(500, hparams.max_fertile_nodes) + # We didn't set either of these, so they should be equal + self.assertEquals(hparams.split_after_samples, + hparams.valid_leaf_threshold) + # split_after_samples is larger than 10 + self.assertEquals(1, hparams.split_initializations_per_input) + self.assertEquals(0, hparams.base_random_seed) + + def testForestHParamsBigTree(self): + hparams = tensor_forest.ForestHParams( + num_classes=2, num_trees=100, max_nodes=1000000, + split_after_samples=25, + num_features=1000).fill() + self.assertEquals(40, hparams.max_depth) + # sqrt(1000) = 31.63... + self.assertEquals(32, hparams.num_splits_to_consider) + # 1000000 / 32 = 31250 + self.assertEquals(31250, hparams.max_fertile_nodes) + # floor(31.63 / 25) = 1 + self.assertEquals(1, hparams.split_initializations_per_input) + def testTrainingConstruction(self): input_data = [[-1., 0.], [-1., 2.], # node 1 [1., 0.], [1., -2.]] # node 2 @@ -50,6 +81,14 @@ class TensorForestTest(test_util.TensorFlowTestCase): graph = graph_builder.inference_graph(input_data) self.assertTrue(isinstance(graph, tf.Tensor)) + def testImpurityConstruction(self): + params = tensor_forest.ForestHParams( + num_classes=4, num_features=2, num_trees=10, max_nodes=1000).fill() + + graph_builder = tensor_forest.RandomForestGraphs(params) + graph = graph_builder.average_impurity() + self.assertTrue(isinstance(graph, tf.Tensor)) + if __name__ == '__main__': googletest.main() -- cgit v1.2.3