aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-08 16:46:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-08 17:01:31 -0800
commit265e61d0d6ad5a003d3ed13ab15fe7a8155a5e45 (patch)
tree3188a359086fdf70a3e9c9e338b6ff1329f9197c
parentfb60813c2f14249f316c3e535dcb994e75a6c73c (diff)
Add documentation for TensorForestEstimator.
Change: 138582357
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/random_forest.py57
1 files changed, 56 insertions, 1 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
index c6d6a666eb..c2c41255c9 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
@@ -146,12 +146,67 @@ def get_model_fn(params, graph_builder_class, device_assigner,
class TensorForestEstimator(evaluable.Evaluable, trainable.Trainable):
- """An estimator that can train and evaluate a random forest."""
+ """An estimator that can train and evaluate a random forest.
+
+ Example:
+
+ ```python
+ params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
+ num_classes=2, num_features=40, num_trees=10, max_nodes=1000)
+
+ # Estimator using the default graph builder.
+ estimator = TensorForestEstimator(params, model_dir=model_dir)
+
+ # Or estimator using TrainingLossForest as the graph builder.
+ estimator = TensorForestEstimator(
+ params, graph_builder_class=tensor_forest.TrainingLossForest,
+ model_dir=model_dir)
+
+ # Input builders
+ def input_fn_train: # returns x, y
+ ...
+ def input_fn_eval: # returns x, y
+ ...
+ estimator.fit(input_fn=input_fn_train)
+ estimator.evaluate(input_fn=input_fn_eval)
+ estimator.predict(x=x)
+ ```
+ """
def __init__(self, params, device_assigner=None, model_dir=None,
graph_builder_class=tensor_forest.RandomForestGraphs,
config=None, weights_name=None, keys_name=None,
feature_engineering_fn=None, early_stopping_rounds=100):
+
+ """Initializes a TensorForestEstimator instance.
+
+ Args:
+ params: ForestHParams object that holds random forest hyperparameters.
+ These parameters will be passed into `model_fn`.
+ device_assigner: An `object` instance that controls how trees get
+ assigned to devices. If `None`, will use
+ `tensor_forest.RandomForestDeviceAssigner`.
+ model_dir: Directory to save model parameters, graph, etc. To continue
+ training a previously saved model, load checkpoints saved to this
+ directory into an estimator.
+ graph_builder_class: An `object` instance that defines how TF graphs for
+ random forest training and inference are built. By default will use
+ `tensor_forest.RandomForestGraphs`.
+ config: `RunConfig` object to configure the runtime settings.
+ weights_name: A string defining feature column name representing
+ weights. Will be multiplied by the loss of the example. Used to
+ downweight or boost examples during training.
+ keys_name: A string defining feature column name representing example
+ keys. Used by `predict_with_keys` method.
+ feature_engineering_fn: Feature engineering function. Takes features and
+ labels which are the output of `input_fn` and returns features and
+ labels which will be fed into the model.
+ early_stopping_rounds: Allows training to terminate early if the forest is
+ no longer growing. 100 by default.
+
+ Returns:
+ A `TensorForestEstimator` instance.
+ """
self.params = params.fill()
self.graph_builder_class = graph_builder_class
self.early_stopping_rounds = early_stopping_rounds