aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-28 07:21:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-28 07:25:44 -0700
commit94934d94407962f992dfbc22007bbaaadbaf63c2 (patch)
treed88ccd2f8685108cb21f922249d1f8025cb71b79 /tensorflow
parent42fb4382b9ef7b7b64e97ed85c51a1dfdb4071ec (diff)
Optionally output a new TreePath proto during TensorForest inference for ultimate interpretability.
PiperOrigin-RevId: 163466324
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py6
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest_test.py3
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/model_ops.cc34
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc5
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h2
-rw-r--r--tensorflow/contrib/tensor_forest/ops/model_ops.cc3
-rw-r--r--tensorflow/contrib/tensor_forest/proto/fertile_stats.proto7
-rw-r--r--tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto1
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py67
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest_test.py10
11 files changed, 98 insertions, 46 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index 21fdff23b0..f7f92a15a8 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -43,6 +43,7 @@ from tensorflow.python.training import session_run_hook
KEYS_NAME = 'keys'
LOSS_NAME = 'rf_training_loss'
+TREE_PATHS_PREDICTION_KEY = 'tree_paths'
EPSILON = 0.000001
@@ -194,7 +195,7 @@ def get_model_fn(params,
graph_builder = graph_builder_class(params,
device_assigner=dev_assn)
- logits = graph_builder.inference_graph(features)
+ logits, tree_paths = graph_builder.inference_graph(features)
summary.scalar('average_tree_size', graph_builder.average_size())
# For binary classification problems, convert probabilities to logits.
@@ -261,6 +262,9 @@ def get_model_fn(params,
if keys is not None:
model_ops.predictions[keys_name] = keys
+ if params.inference_tree_paths:
+ model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
+
return model_ops
return _model_fn
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest_test.py b/tensorflow/contrib/tensor_forest/client/random_forest_test.py
index 52e41a6fe8..ac42364d25 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest_test.py
@@ -35,7 +35,8 @@ class TensorForestTrainerTests(test.TestCase):
max_nodes=1000,
num_classes=3,
num_features=4,
- split_after_samples=20)
+ split_after_samples=20,
+ inference_tree_paths=True)
classifier = random_forest.TensorForestEstimator(hparams.fill())
iris = base.load_iris()
diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
index 221f8d969b..3d9de006b4 100644
--- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
@@ -147,9 +147,12 @@ class TreeSizeOp : public OpKernel {
void TraverseTree(const DecisionTreeResource* tree_resource,
const std::unique_ptr<TensorDataSet>& data, int32 start,
int32 end,
- const std::function<void(int32, int32)>& set_leaf_id) {
+ const std::function<void(int32, int32)>& set_leaf_id,
+ std::vector<TreePath>* tree_paths) {
for (int i = start; i < end; ++i) {
- const int32 id = tree_resource->TraverseTree(data, i, nullptr);
+ const int32 id = tree_resource->TraverseTree(
+ data, i, nullptr,
+ (tree_paths == nullptr) ? nullptr : &(*tree_paths)[i]);
set_leaf_id(i, id);
}
}
@@ -199,21 +202,40 @@ class TreePredictionsV4Op : public OpKernel {
&output_predictions));
TTypes<float, 2>::Tensor out = output_predictions->tensor<float, 2>();
+ std::vector<TreePath> tree_paths(
+ param_proto_.inference_tree_paths() ? num_data : 0);
+
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
const int64 costPerTraverse = 500;
- auto traverse = [this, &out, decision_tree_resource, num_data](int64 start,
- int64 end) {
+ auto traverse = [this, &out, decision_tree_resource, num_data, &tree_paths](
+ int64 start, int64 end) {
CHECK(start <= end);
CHECK(end <= num_data);
TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start),
static_cast<int32>(end),
std::bind(&TreePredictionsV4Op::set_output_value, this,
std::placeholders::_1, std::placeholders::_2,
- decision_tree_resource, &out));
+ decision_tree_resource, &out),
+ param_proto_.inference_tree_paths() ? &tree_paths : nullptr);
};
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
traverse);
+
+ Tensor* output_tree_paths = nullptr;
+ TensorShape output_paths_shape;
+ output_paths_shape.AddDim(param_proto_.inference_tree_paths() ? num_data
+ : 0);
+ OP_REQUIRES_OK(context, context->allocate_output(1, output_paths_shape,
+ &output_tree_paths));
+ auto out_paths = output_tree_paths->unaligned_flat<string>();
+
+ // TODO(gilberth): If this slows down inference too much, consider having
+ // a filter that only serializes paths for the predicted label that we're
+ // interested in.
+ for (int i = 0; i < tree_paths.size(); ++i) {
+ out_paths(i) = tree_paths[i].SerializeAsString();
+ }
}
void set_output_value(int32 i, int32 id,
@@ -293,7 +315,7 @@ class TraverseTreeV4Op : public OpKernel {
CHECK(start <= end);
CHECK(end <= num_data);
TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start),
- static_cast<int32>(end), set_leaf_ids);
+ static_cast<int32>(end), set_leaf_ids, nullptr);
};
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
traverse);
diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc
index 0fdab8e6e0..1ac1d27761 100644
--- a/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc
@@ -55,13 +55,13 @@ TEST(ModelOpsTest, TreePredictionsV4_ShapeFn) {
.Finalize(&op.node_def));
// num_points = 2, sparse shape not known
- INFER_OK(op, "?;[2,3];?;?;?", "[d1_0,?]");
+ INFER_OK(op, "?;[2,3];?;?;?", "[d1_0,?];[?]");
// num_points = 2, sparse and dense shape rank known and > 1
- INFER_OK(op, "?;[2,3];?;?;[10,11]", "[d1_0,?]");
+ INFER_OK(op, "?;[2,3];?;?;[10,11]", "[d1_0,?];[?]");
// num_points = 2, sparse shape rank known and > 1
- INFER_OK(op, "?;?;?;?;[10,11]", "[?,?]");
+ INFER_OK(op, "?;?;?;?;[10,11]", "[?,?];[?]");
}
TEST(ModelOpsTest, TraverseTreeV4_ShapeFn) {
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc
index 881e4339a7..952b34b353 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc
@@ -23,12 +23,15 @@ using decision_trees::TreeNode;
int32 DecisionTreeResource::TraverseTree(
const std::unique_ptr<TensorDataSet>& input_data, int example,
- int32* leaf_depth) const {
+ int32* leaf_depth, TreePath* path) const {
const DecisionTree& tree = decision_tree_->decision_tree();
int32 current_id = 0;
int32 depth = 0;
while (true) {
const TreeNode& current = tree.nodes(current_id);
+ if (path != nullptr) {
+ *path->add_nodes_visited() = current;
+ }
if (current.has_leaf()) {
if (leaf_depth != nullptr) {
*leaf_depth = depth;
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
index 438d3d817c..87c5290604 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
@@ -71,7 +71,7 @@ class DecisionTreeResource : public ResourceBase {
// Return the TreeNode for the leaf that the example ends up at according
// to decsion_tree_. Also fill in that leaf's depth if it isn't nullptr.
int32 TraverseTree(const std::unique_ptr<TensorDataSet>& input_data,
- int example, int32* depth) const;
+ int example, int32* depth, TreePath* path) const;
// Split the given node_id, turning it from a Leaf to a BinaryNode and
// setting it's split to the given best. Add new children ids to
diff --git a/tensorflow/contrib/tensor_forest/ops/model_ops.cc b/tensorflow/contrib/tensor_forest/ops/model_ops.cc
index 3dca6913f6..3099cccdf8 100644
--- a/tensorflow/contrib/tensor_forest/ops/model_ops.cc
+++ b/tensorflow/contrib/tensor_forest/ops/model_ops.cc
@@ -90,6 +90,7 @@ REGISTER_OP("TreePredictionsV4")
.Input("sparse_input_values: float")
.Input("sparse_input_shape: int64")
.Output("predictions: float")
+ .Output("tree_paths: string")
.SetShapeFn([](InferenceContext* c) {
DimensionHandle num_points = c->UnknownDim();
@@ -99,6 +100,7 @@ REGISTER_OP("TreePredictionsV4")
}
c->set_output(0, c->Matrix(num_points, c->UnknownDim()));
+ c->set_output(1, c->Vector(c->UnknownDim()));
return Status::OK();
})
.Doc(R"doc(
@@ -112,6 +114,7 @@ sparse_input_indices: The indices tensor from the SparseTensor input.
sparse_input_values: The values tensor from the SparseTensor input.
sparse_input_shape: The shape tensor from the SparseTensor input.
predictions: `predictions[i][j]` is the probability that input i is class j.
+tree_paths: `tree_paths[i]` is a serialized TreePath proto for example i.
)doc");
REGISTER_OP("TraverseTreeV4")
diff --git a/tensorflow/contrib/tensor_forest/proto/fertile_stats.proto b/tensorflow/contrib/tensor_forest/proto/fertile_stats.proto
index 0ded04ad75..d568fa3081 100644
--- a/tensorflow/contrib/tensor_forest/proto/fertile_stats.proto
+++ b/tensorflow/contrib/tensor_forest/proto/fertile_stats.proto
@@ -90,3 +90,10 @@ message SplitCandidate {
// Fields used when training with a graph runner.
string unique_id = 6;
}
+
+// Proto used for tracking tree paths during inference time.
+message TreePath {
+ // Nodes are listed in order that they were traversed. i.e. nodes_visited[0]
+ // is the tree's root node.
+ repeated decision_trees.TreeNode nodes_visited = 1;
+}
diff --git a/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto b/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto
index 58c5b9bbe7..29d115ab69 100644
--- a/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto
+++ b/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto
@@ -131,6 +131,7 @@ message TensorForestParams {
bool checkpoint_stats = 11;
bool use_running_stats_method = 20;
bool initialize_average_splits = 22;
+ bool inference_tree_paths = 23;
// Number of classes (classification) or targets (regression)
int32 num_outputs = 12;
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index b1a8357048..d606ea5770 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -105,6 +105,7 @@ def build_params_proto(params):
proto.checkpoint_stats = params.checkpoint_stats
proto.use_running_stats_method = params.use_running_stats_method
proto.initialize_average_splits = params.initialize_average_splits
+ proto.inference_tree_paths = params.inference_tree_paths
parse_number_or_string_to_proto(proto.pruning_type.prune_every_samples,
params.prune_every_samples)
@@ -139,29 +140,31 @@ def build_params_proto(params):
class ForestHParams(object):
"""A base class for holding hyperparameters and calculating good defaults."""
- def __init__(self,
- num_trees=100,
- max_nodes=10000,
- bagging_fraction=1.0,
- num_splits_to_consider=0,
- feature_bagging_fraction=1.0,
- max_fertile_nodes=0, # deprecated, unused.
- split_after_samples=250,
- valid_leaf_threshold=1,
- dominate_method='bootstrap',
- dominate_fraction=0.99,
- model_name='all_dense',
- split_finish_name='basic',
- split_pruning_name='none',
- prune_every_samples=0,
- early_finish_check_every_samples=0,
- collate_examples=False,
- checkpoint_stats=False,
- use_running_stats_method=False,
- initialize_average_splits=False,
- param_file=None,
- split_name='less_or_equal',
- **kwargs):
+ def __init__(
+ self,
+ num_trees=100,
+ max_nodes=10000,
+ bagging_fraction=1.0,
+ num_splits_to_consider=0,
+ feature_bagging_fraction=1.0,
+ max_fertile_nodes=0, # deprecated, unused.
+ split_after_samples=250,
+ valid_leaf_threshold=1,
+ dominate_method='bootstrap',
+ dominate_fraction=0.99,
+ model_name='all_dense',
+ split_finish_name='basic',
+ split_pruning_name='none',
+ prune_every_samples=0,
+ early_finish_check_every_samples=0,
+ collate_examples=False,
+ checkpoint_stats=False,
+ use_running_stats_method=False,
+ initialize_average_splits=False,
+ inference_tree_paths=False,
+ param_file=None,
+ split_name='less_or_equal',
+ **kwargs):
self.num_trees = num_trees
self.max_nodes = max_nodes
self.bagging_fraction = bagging_fraction
@@ -179,6 +182,7 @@ class ForestHParams(object):
self.checkpoint_stats = checkpoint_stats
self.use_running_stats_method = use_running_stats_method
self.initialize_average_splits = initialize_average_splits
+ self.inference_tree_paths = inference_tree_paths
self.param_file = param_file
self.split_name = split_name
self.early_finish_check_every_samples = early_finish_check_every_samples
@@ -470,7 +474,7 @@ class RandomForestGraphs(object):
**inference_args: Keyword arguments to pass through to each tree.
Returns:
- The last op in the random forest inference graph.
+ A tuple of (probabilities, tree_paths).
Raises:
NotImplementedError: If trying to use feature bagging with sparse
@@ -480,6 +484,7 @@ class RandomForestGraphs(object):
data_ops.ParseDataTensorOrDict(input_data))
probabilities = []
+ paths = []
for i in range(self.params.num_trees):
with ops.device(self.variables.device_dummies[i].device):
tree_data = processed_dense_features
@@ -488,16 +493,20 @@ class RandomForestGraphs(object):
raise NotImplementedError(
'Feature bagging not supported with sparse features.')
tree_data = self._bag_features(i, tree_data)
- probabilities.append(self.trees[i].inference_graph(
+ probs, path = self.trees[i].inference_graph(
tree_data,
data_spec,
sparse_features=processed_sparse_features,
- **inference_args))
+ **inference_args)
+ probabilities.append(probs)
+ paths.append(path)
with ops.device(self.variables.device_dummies[0].device):
all_predict = array_ops.stack(probabilities)
return math_ops.div(
- math_ops.reduce_sum(all_predict, 0), self.params.num_trees,
- name='probabilities')
+ math_ops.reduce_sum(all_predict, 0),
+ self.params.num_trees,
+ name='probabilities'), array_ops.stack(
+ paths, axis=1)
def average_size(self):
"""Constructs a TF graph for evaluating the average size of a forest.
@@ -635,7 +644,7 @@ class RandomTreeGraphs(object):
sparse_features: A tf.SparseTensor for sparse input data.
Returns:
- The last op in the random tree inference graph.
+ A tuple of (probabilities, tree_paths).
"""
sparse_indices = []
sparse_values = []
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index ddbe30426d..025ad9132d 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -105,8 +105,9 @@ class TensorForestTest(test_util.TensorFlowTestCase):
split_after_samples=25).fill()
graph_builder = tensor_forest.RandomForestGraphs(params)
- graph = graph_builder.inference_graph(input_data)
- self.assertTrue(isinstance(graph, ops.Tensor))
+ probs, paths = graph_builder.inference_graph(input_data)
+ self.assertTrue(isinstance(probs, ops.Tensor))
+ self.assertTrue(isinstance(paths, ops.Tensor))
def testTrainingConstructionClassificationSparse(self):
input_data = sparse_tensor.SparseTensor(
@@ -146,8 +147,9 @@ class TensorForestTest(test_util.TensorFlowTestCase):
split_after_samples=25).fill()
graph_builder = tensor_forest.RandomForestGraphs(params)
- graph = graph_builder.inference_graph(input_data)
- self.assertTrue(isinstance(graph, ops.Tensor))
+ probs, paths = graph_builder.inference_graph(input_data)
+ self.assertTrue(isinstance(probs, ops.Tensor))
+ self.assertTrue(isinstance(paths, ops.Tensor))
if __name__ == "__main__":