diff options
author | 2018-09-06 10:02:24 -0700 | |
---|---|---|
committer | 2018-09-06 10:10:16 -0700 | |
commit | 43a3c393d7a329b7dc7aec02a7d46dc69e5a8ee1 (patch) | |
tree | 2a2710f3e09f5c9c037965ccfade1614bdcab40e /tensorflow/python/estimator | |
parent | d17016a8dfd9b9bd92a55fc1fddee4fd1c29bdbe (diff) |
Update docstring for BoostedTrees n_batches_per_layer.
PiperOrigin-RevId: 211824645
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees.py | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index d104c961d3..19f18015e4 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -1000,8 +1000,11 @@ class BoostedTreesClassifier(estimator.Estimator): bucketized_feature_2 = bucketized_column( numeric_column('feature_2'), BUCKET_BOUNDARIES_2) + # Need to see a large portion of the data before we can build a layer, for + # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE classifier = estimator.BoostedTreesClassifier( feature_columns=[bucketized_feature_1, bucketized_feature_2], + n_batches_per_layer=n_batches_per_layer, n_trees=100, ... <some other params> ) @@ -1024,7 +1027,8 @@ class BoostedTreesClassifier(estimator.Estimator): the model. All items in the set should be instances of classes derived from `FeatureColumn`. n_batches_per_layer: the number of batches to collect statistics per - layer. + layer. The total number of batches is total number of data divided by + batch size. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. @@ -1138,8 +1142,11 @@ class BoostedTreesRegressor(estimator.Estimator): bucketized_feature_2 = bucketized_column( numeric_column('feature_2'), BUCKET_BOUNDARIES_2) + # Need to see a large portion of the data before we can build a layer, for + # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE regressor = estimator.BoostedTreesRegressor( feature_columns=[bucketized_feature_1, bucketized_feature_2], + n_batches_per_layer=n_batches_per_layer, n_trees=100, ... <some other params> ) @@ -1162,7 +1169,8 @@ class BoostedTreesRegressor(estimator.Estimator): the model. All items in the set should be instances of classes derived from `FeatureColumn`. n_batches_per_layer: the number of batches to collect statistics per - layer. + layer. The total number of batches is total number of data divided by + batch size. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. |