aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar Zhenyu Tan <tanzheny@google.com>2018-09-06 10:02:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 10:10:16 -0700
commit43a3c393d7a329b7dc7aec02a7d46dc69e5a8ee1 (patch)
tree2a2710f3e09f5c9c037965ccfade1614bdcab40e /tensorflow/python/estimator
parentd17016a8dfd9b9bd92a55fc1fddee4fd1c29bdbe (diff)
Update docstring for BoostedTrees n_batches_per_layer.
PiperOrigin-RevId: 211824645
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index d104c961d3..19f18015e4 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -1000,8 +1000,11 @@ class BoostedTreesClassifier(estimator.Estimator):
bucketized_feature_2 = bucketized_column(
numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
+ # Need to see a large portion of the data before we can build a layer, for
+ # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE
classifier = estimator.BoostedTreesClassifier(
feature_columns=[bucketized_feature_1, bucketized_feature_2],
+ n_batches_per_layer=n_batches_per_layer,
n_trees=100,
... <some other params>
)
@@ -1024,7 +1027,8 @@ class BoostedTreesClassifier(estimator.Estimator):
the model. All items in the set should be instances of classes derived
from `FeatureColumn`.
n_batches_per_layer: the number of batches to collect statistics per
- layer.
+ layer. The total number of batches is total number of data divided by
+ batch size.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model.
@@ -1138,8 +1142,11 @@ class BoostedTreesRegressor(estimator.Estimator):
bucketized_feature_2 = bucketized_column(
numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
+ # Need to see a large portion of the data before we can build a layer, for
+ # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE
regressor = estimator.BoostedTreesRegressor(
feature_columns=[bucketized_feature_1, bucketized_feature_2],
+ n_batches_per_layer=n_batches_per_layer,
n_trees=100,
... <some other params>
)
@@ -1162,7 +1169,8 @@ class BoostedTreesRegressor(estimator.Estimator):
the model. All items in the set should be instances of classes derived
from `FeatureColumn`.
n_batches_per_layer: the number of batches to collect statistics per
- layer.
+ layer. The total number of batches is total number of data divided by
+ batch size.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model.