diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-07-28 07:21:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-28 07:25:44 -0700 |
commit | 94934d94407962f992dfbc22007bbaaadbaf63c2 (patch) | |
tree | d88ccd2f8685108cb21f922249d1f8025cb71b79 /tensorflow | |
parent | 42fb4382b9ef7b7b64e97ed85c51a1dfdb4071ec (diff) |
Optionally output a new TreePath proto during TensorForest inference for ultimate interpretability.
PiperOrigin-RevId: 163466324
Diffstat (limited to 'tensorflow')
11 files changed, 98 insertions, 46 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index 21fdff23b0..f7f92a15a8 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -43,6 +43,7 @@ from tensorflow.python.training import session_run_hook KEYS_NAME = 'keys' LOSS_NAME = 'rf_training_loss' +TREE_PATHS_PREDICTION_KEY = 'tree_paths' EPSILON = 0.000001 @@ -194,7 +195,7 @@ def get_model_fn(params, graph_builder = graph_builder_class(params, device_assigner=dev_assn) - logits = graph_builder.inference_graph(features) + logits, tree_paths = graph_builder.inference_graph(features) summary.scalar('average_tree_size', graph_builder.average_size()) # For binary classification problems, convert probabilities to logits. @@ -261,6 +262,9 @@ def get_model_fn(params, if keys is not None: model_ops.predictions[keys_name] = keys + if params.inference_tree_paths: + model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths + return model_ops return _model_fn diff --git a/tensorflow/contrib/tensor_forest/client/random_forest_test.py b/tensorflow/contrib/tensor_forest/client/random_forest_test.py index 52e41a6fe8..ac42364d25 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest_test.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest_test.py @@ -35,7 +35,8 @@ class TensorForestTrainerTests(test.TestCase): max_nodes=1000, num_classes=3, num_features=4, - split_after_samples=20) + split_after_samples=20, + inference_tree_paths=True) classifier = random_forest.TensorForestEstimator(hparams.fill()) iris = base.load_iris() diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc index 221f8d969b..3d9de006b4 100644 --- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc @@ -147,9 +147,12 @@ class TreeSizeOp : public OpKernel { void TraverseTree(const DecisionTreeResource* tree_resource, const std::unique_ptr<TensorDataSet>& data, int32 start, int32 end, - const std::function<void(int32, int32)>& set_leaf_id) { + const std::function<void(int32, int32)>& set_leaf_id, + std::vector<TreePath>* tree_paths) { for (int i = start; i < end; ++i) { - const int32 id = tree_resource->TraverseTree(data, i, nullptr); + const int32 id = tree_resource->TraverseTree( + data, i, nullptr, + (tree_paths == nullptr) ? nullptr : &(*tree_paths)[i]); set_leaf_id(i, id); } } @@ -199,21 +202,40 @@ class TreePredictionsV4Op : public OpKernel { &output_predictions)); TTypes<float, 2>::Tensor out = output_predictions->tensor<float, 2>(); + std::vector<TreePath> tree_paths( + param_proto_.inference_tree_paths() ? num_data : 0); + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; const int64 costPerTraverse = 500; - auto traverse = [this, &out, decision_tree_resource, num_data](int64 start, - int64 end) { + auto traverse = [this, &out, decision_tree_resource, num_data, &tree_paths]( + int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start), static_cast<int32>(end), std::bind(&TreePredictionsV4Op::set_output_value, this, std::placeholders::_1, std::placeholders::_2, - decision_tree_resource, &out)); + decision_tree_resource, &out), + param_proto_.inference_tree_paths() ? &tree_paths : nullptr); }; Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, traverse); + + Tensor* output_tree_paths = nullptr; + TensorShape output_paths_shape; + output_paths_shape.AddDim(param_proto_.inference_tree_paths() ? num_data + : 0); + OP_REQUIRES_OK(context, context->allocate_output(1, output_paths_shape, + &output_tree_paths)); + auto out_paths = output_tree_paths->unaligned_flat<string>(); + + // TODO(gilberth): If this slows down inference too much, consider having + // a filter that only serializes paths for the predicted label that we're + // interested in. + for (int i = 0; i < tree_paths.size(); ++i) { + out_paths(i) = tree_paths[i].SerializeAsString(); + } } void set_output_value(int32 i, int32 id, @@ -293,7 +315,7 @@ class TraverseTreeV4Op : public OpKernel { CHECK(start <= end); CHECK(end <= num_data); TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start), - static_cast<int32>(end), set_leaf_ids); + static_cast<int32>(end), set_leaf_ids, nullptr); }; Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, traverse); diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc index 0fdab8e6e0..1ac1d27761 100644 --- a/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc @@ -55,13 +55,13 @@ TEST(ModelOpsTest, TreePredictionsV4_ShapeFn) { .Finalize(&op.node_def)); // num_points = 2, sparse shape not known - INFER_OK(op, "?;[2,3];?;?;?", "[d1_0,?]"); + INFER_OK(op, "?;[2,3];?;?;?", "[d1_0,?];[?]"); // num_points = 2, sparse and dense shape rank known and > 1 - INFER_OK(op, "?;[2,3];?;?;[10,11]", "[d1_0,?]"); + INFER_OK(op, "?;[2,3];?;?;[10,11]", "[d1_0,?];[?]"); // num_points = 2, sparse shape rank known and > 1 - INFER_OK(op, "?;?;?;?;[10,11]", "[?,?]"); + INFER_OK(op, "?;?;?;?;[10,11]", "[?,?];[?]"); } TEST(ModelOpsTest, TraverseTreeV4_ShapeFn) { diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc index 881e4339a7..952b34b353 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc @@ -23,12 +23,15 @@ using decision_trees::TreeNode; int32 DecisionTreeResource::TraverseTree( const std::unique_ptr<TensorDataSet>& input_data, int example, - int32* leaf_depth) const { + int32* leaf_depth, TreePath* path) const { const DecisionTree& tree = decision_tree_->decision_tree(); int32 current_id = 0; int32 depth = 0; while (true) { const TreeNode& current = tree.nodes(current_id); + if (path != nullptr) { + *path->add_nodes_visited() = current; + } if (current.has_leaf()) { if (leaf_depth != nullptr) { *leaf_depth = depth; diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h index 438d3d817c..87c5290604 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h @@ -71,7 +71,7 @@ class DecisionTreeResource : public ResourceBase { // Return the TreeNode for the leaf that the example ends up at according // to decsion_tree_. Also fill in that leaf's depth if it isn't nullptr. int32 TraverseTree(const std::unique_ptr<TensorDataSet>& input_data, - int example, int32* depth) const; + int example, int32* depth, TreePath* path) const; // Split the given node_id, turning it from a Leaf to a BinaryNode and // setting it's split to the given best. Add new children ids to diff --git a/tensorflow/contrib/tensor_forest/ops/model_ops.cc b/tensorflow/contrib/tensor_forest/ops/model_ops.cc index 3dca6913f6..3099cccdf8 100644 --- a/tensorflow/contrib/tensor_forest/ops/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/ops/model_ops.cc @@ -90,6 +90,7 @@ REGISTER_OP("TreePredictionsV4") .Input("sparse_input_values: float") .Input("sparse_input_shape: int64") .Output("predictions: float") + .Output("tree_paths: string") .SetShapeFn([](InferenceContext* c) { DimensionHandle num_points = c->UnknownDim(); @@ -99,6 +100,7 @@ REGISTER_OP("TreePredictionsV4") } c->set_output(0, c->Matrix(num_points, c->UnknownDim())); + c->set_output(1, c->Vector(c->UnknownDim())); return Status::OK(); }) .Doc(R"doc( @@ -112,6 +114,7 @@ sparse_input_indices: The indices tensor from the SparseTensor input. sparse_input_values: The values tensor from the SparseTensor input. sparse_input_shape: The shape tensor from the SparseTensor input. predictions: `predictions[i][j]` is the probability that input i is class j. +tree_paths: `tree_paths[i]` is a serialized TreePath proto for example i. )doc"); REGISTER_OP("TraverseTreeV4") diff --git a/tensorflow/contrib/tensor_forest/proto/fertile_stats.proto b/tensorflow/contrib/tensor_forest/proto/fertile_stats.proto index 0ded04ad75..d568fa3081 100644 --- a/tensorflow/contrib/tensor_forest/proto/fertile_stats.proto +++ b/tensorflow/contrib/tensor_forest/proto/fertile_stats.proto @@ -90,3 +90,10 @@ message SplitCandidate { // Fields used when training with a graph runner. string unique_id = 6; } + +// Proto used for tracking tree paths during inference time. +message TreePath { + // Nodes are listed in order that they were traversed. i.e. nodes_visited[0] + // is the tree's root node. + repeated decision_trees.TreeNode nodes_visited = 1; +} diff --git a/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto b/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto index 58c5b9bbe7..29d115ab69 100644 --- a/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto +++ b/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto @@ -131,6 +131,7 @@ message TensorForestParams { bool checkpoint_stats = 11; bool use_running_stats_method = 20; bool initialize_average_splits = 22; + bool inference_tree_paths = 23; // Number of classes (classification) or targets (regression) int32 num_outputs = 12; diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index b1a8357048..d606ea5770 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -105,6 +105,7 @@ def build_params_proto(params): proto.checkpoint_stats = params.checkpoint_stats proto.use_running_stats_method = params.use_running_stats_method proto.initialize_average_splits = params.initialize_average_splits + proto.inference_tree_paths = params.inference_tree_paths parse_number_or_string_to_proto(proto.pruning_type.prune_every_samples, params.prune_every_samples) @@ -139,29 +140,31 @@ def build_params_proto(params): class ForestHParams(object): """A base class for holding hyperparameters and calculating good defaults.""" - def __init__(self, - num_trees=100, - max_nodes=10000, - bagging_fraction=1.0, - num_splits_to_consider=0, - feature_bagging_fraction=1.0, - max_fertile_nodes=0, # deprecated, unused. - split_after_samples=250, - valid_leaf_threshold=1, - dominate_method='bootstrap', - dominate_fraction=0.99, - model_name='all_dense', - split_finish_name='basic', - split_pruning_name='none', - prune_every_samples=0, - early_finish_check_every_samples=0, - collate_examples=False, - checkpoint_stats=False, - use_running_stats_method=False, - initialize_average_splits=False, - param_file=None, - split_name='less_or_equal', - **kwargs): + def __init__( + self, + num_trees=100, + max_nodes=10000, + bagging_fraction=1.0, + num_splits_to_consider=0, + feature_bagging_fraction=1.0, + max_fertile_nodes=0, # deprecated, unused. + split_after_samples=250, + valid_leaf_threshold=1, + dominate_method='bootstrap', + dominate_fraction=0.99, + model_name='all_dense', + split_finish_name='basic', + split_pruning_name='none', + prune_every_samples=0, + early_finish_check_every_samples=0, + collate_examples=False, + checkpoint_stats=False, + use_running_stats_method=False, + initialize_average_splits=False, + inference_tree_paths=False, + param_file=None, + split_name='less_or_equal', + **kwargs): self.num_trees = num_trees self.max_nodes = max_nodes self.bagging_fraction = bagging_fraction @@ -179,6 +182,7 @@ class ForestHParams(object): self.checkpoint_stats = checkpoint_stats self.use_running_stats_method = use_running_stats_method self.initialize_average_splits = initialize_average_splits + self.inference_tree_paths = inference_tree_paths self.param_file = param_file self.split_name = split_name self.early_finish_check_every_samples = early_finish_check_every_samples @@ -470,7 +474,7 @@ class RandomForestGraphs(object): **inference_args: Keyword arguments to pass through to each tree. Returns: - The last op in the random forest inference graph. + A tuple of (probabilities, tree_paths). Raises: NotImplementedError: If trying to use feature bagging with sparse @@ -480,6 +484,7 @@ class RandomForestGraphs(object): data_ops.ParseDataTensorOrDict(input_data)) probabilities = [] + paths = [] for i in range(self.params.num_trees): with ops.device(self.variables.device_dummies[i].device): tree_data = processed_dense_features @@ -488,16 +493,20 @@ class RandomForestGraphs(object): raise NotImplementedError( 'Feature bagging not supported with sparse features.') tree_data = self._bag_features(i, tree_data) - probabilities.append(self.trees[i].inference_graph( + probs, path = self.trees[i].inference_graph( tree_data, data_spec, sparse_features=processed_sparse_features, - **inference_args)) + **inference_args) + probabilities.append(probs) + paths.append(path) with ops.device(self.variables.device_dummies[0].device): all_predict = array_ops.stack(probabilities) return math_ops.div( - math_ops.reduce_sum(all_predict, 0), self.params.num_trees, - name='probabilities') + math_ops.reduce_sum(all_predict, 0), + self.params.num_trees, + name='probabilities'), array_ops.stack( + paths, axis=1) def average_size(self): """Constructs a TF graph for evaluating the average size of a forest. @@ -635,7 +644,7 @@ class RandomTreeGraphs(object): sparse_features: A tf.SparseTensor for sparse input data. Returns: - The last op in the random tree inference graph. + A tuple of (probabilities, tree_paths). """ sparse_indices = [] sparse_values = [] diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py index ddbe30426d..025ad9132d 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py @@ -105,8 +105,9 @@ class TensorForestTest(test_util.TensorFlowTestCase): split_after_samples=25).fill() graph_builder = tensor_forest.RandomForestGraphs(params) - graph = graph_builder.inference_graph(input_data) - self.assertTrue(isinstance(graph, ops.Tensor)) + probs, paths = graph_builder.inference_graph(input_data) + self.assertTrue(isinstance(probs, ops.Tensor)) + self.assertTrue(isinstance(paths, ops.Tensor)) def testTrainingConstructionClassificationSparse(self): input_data = sparse_tensor.SparseTensor( @@ -146,8 +147,9 @@ class TensorForestTest(test_util.TensorFlowTestCase): split_after_samples=25).fill() graph_builder = tensor_forest.RandomForestGraphs(params) - graph = graph_builder.inference_graph(input_data) - self.assertTrue(isinstance(graph, ops.Tensor)) + probs, paths = graph_builder.inference_graph(input_data) + self.assertTrue(isinstance(probs, ops.Tensor)) + self.assertTrue(isinstance(paths, ops.Tensor)) if __name__ == "__main__": |