diff options
author | 2016-11-08 16:46:50 -0800 | |
---|---|---|
committer | 2016-11-08 17:01:31 -0800 | |
commit | 265e61d0d6ad5a003d3ed13ab15fe7a8155a5e45 (patch) | |
tree | 3188a359086fdf70a3e9c9e338b6ff1329f9197c | |
parent | fb60813c2f14249f316c3e535dcb994e75a6c73c (diff) |
Add documentation for TensorForestEstimator.
Change: 138582357
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/random_forest.py | 57 |
1 files changed, 56 insertions, 1 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py index c6d6a666eb..c2c41255c9 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py +++ b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py @@ -146,12 +146,67 @@ def get_model_fn(params, graph_builder_class, device_assigner, class TensorForestEstimator(evaluable.Evaluable, trainable.Trainable): - """An estimator that can train and evaluate a random forest.""" + """An estimator that can train and evaluate a random forest. + + Example: + + ```python + params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams( + num_classes=2, num_features=40, num_trees=10, max_nodes=1000) + + # Estimator using the default graph builder. + estimator = TensorForestEstimator(params, model_dir=model_dir) + + # Or estimator using TrainingLossForest as the graph builder. + estimator = TensorForestEstimator( + params, graph_builder_class=tensor_forest.TrainingLossForest, + model_dir=model_dir) + + # Input builders + def input_fn_train: # returns x, y + ... + def input_fn_eval: # returns x, y + ... + estimator.fit(input_fn=input_fn_train) + estimator.evaluate(input_fn=input_fn_eval) + estimator.predict(x=x) + ``` + """ def __init__(self, params, device_assigner=None, model_dir=None, graph_builder_class=tensor_forest.RandomForestGraphs, config=None, weights_name=None, keys_name=None, feature_engineering_fn=None, early_stopping_rounds=100): + + """Initializes a TensorForestEstimator instance. + + Args: + params: ForestHParams object that holds random forest hyperparameters. + These parameters will be passed into `model_fn`. + device_assigner: An `object` instance that controls how trees get + assigned to devices. If `None`, will use + `tensor_forest.RandomForestDeviceAssigner`. + model_dir: Directory to save model parameters, graph, etc. To continue + training a previously saved model, load checkpoints saved to this + directory into an estimator. + graph_builder_class: An `object` instance that defines how TF graphs for + random forest training and inference are built. By default will use + `tensor_forest.RandomForestGraphs`. + config: `RunConfig` object to configure the runtime settings. + weights_name: A string defining feature column name representing + weights. Will be multiplied by the loss of the example. Used to + downweight or boost examples during training. + keys_name: A string defining feature column name representing example + keys. Used by `predict_with_keys` method. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + early_stopping_rounds: Allows training to terminate early if the forest is + no longer growing. 100 by default. + + Returns: + A `TensorForestEstimator` instance. + """ self.params = params.fill() self.graph_builder_class = graph_builder_class self.early_stopping_rounds = early_stopping_rounds |