aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py48
1 files changed, 38 insertions, 10 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
index dbfa69edcb..194a5c8754 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
@@ -86,7 +86,8 @@ def _dnn_tree_combined_model_fn(
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
use_core_versions=False,
- output_type=model.ModelBuilderOutputType.MODEL_FN_OPS):
+ output_type=model.ModelBuilderOutputType.MODEL_FN_OPS,
+ override_global_step_value=None):
"""DNN and GBDT combined model_fn.
Args:
@@ -135,6 +136,12 @@ def _dnn_tree_combined_model_fn(
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
+ (new interface).
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
Returns:
A `ModelFnOps` object.
@@ -350,7 +357,8 @@ def _dnn_tree_combined_model_fn(
trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train,
tree_train_op),
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees)
+ finalized_trees,
+ override_global_step_value)
])
return model_fn_ops
@@ -378,7 +386,8 @@ def _dnn_tree_combined_model_fn(
trainer_hooks.SwitchTrainOp(dnn_spec.train_op, dnn_steps_to_train,
tree_spec.train_op),
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees)
+ finalized_trees,
+ override_global_step_value)
]
fusion_spec = fusion_spec._replace(training_hooks=training_hooks +
list(fusion_spec.training_hooks))
@@ -411,7 +420,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
tree_feature_columns=None,
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
- use_core_versions=False):
+ use_core_versions=False,
+ override_global_step_value=None):
"""Initializes a DNNBoostedTreeCombinedClassifier instance.
Args:
@@ -467,6 +477,10 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
"""
head = head_lib.multi_class_head(
n_classes=n_classes,
@@ -497,7 +511,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
- use_core_versions=use_core_versions)
+ use_core_versions=use_core_versions,
+ override_global_step_value=override_global_step_value)
super(DNNBoostedTreeCombinedClassifier, self).__init__(
model_fn=_model_fn,
@@ -531,7 +546,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
tree_feature_columns=None,
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
- use_core_versions=False):
+ use_core_versions=False,
+ override_global_step_value=None):
"""Initializes a DNNBoostedTreeCombinedRegressor instance.
Args:
@@ -587,6 +603,10 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
"""
head = head_lib.regression_head(
label_name=label_name,
@@ -622,7 +642,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
- use_core_versions=use_core_versions)
+ use_core_versions=use_core_versions,
+ override_global_step_value=override_global_step_value)
super(DNNBoostedTreeCombinedRegressor, self).__init__(
model_fn=_model_fn,
@@ -657,7 +678,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
tree_feature_columns=None,
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
- use_core_versions=False):
+ use_core_versions=False,
+ override_global_step_value=None):
"""Initializes a DNNBoostedTreeCombinedEstimator instance.
Args:
@@ -708,6 +730,10 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
"""
def _model_fn(features, labels, mode, config):
@@ -732,7 +758,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
- use_core_versions=use_core_versions)
+ use_core_versions=use_core_versions,
+ override_global_step_value=override_global_step_value)
super(DNNBoostedTreeCombinedEstimator, self).__init__(
model_fn=_model_fn,
@@ -832,7 +859,8 @@ class CoreDNNBoostedTreeCombinedEstimator(core_estimator.Estimator):
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC,
- use_core_versions=True)
+ use_core_versions=True,
+ override_global_step_value=None)
super(CoreDNNBoostedTreeCombinedEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)