diff options
Diffstat (limited to 'tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py')
-rw-r--r-- | tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py | 48 |
1 files changed, 38 insertions, 10 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index dbfa69edcb..194a5c8754 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -86,7 +86,8 @@ def _dnn_tree_combined_model_fn( tree_center_bias=False, dnn_to_tree_distillation_param=None, use_core_versions=False, - output_type=model.ModelBuilderOutputType.MODEL_FN_OPS): + output_type=model.ModelBuilderOutputType.MODEL_FN_OPS, + override_global_step_value=None): """DNN and GBDT combined model_fn. Args: @@ -135,6 +136,12 @@ def _dnn_tree_combined_model_fn( will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. Returns: A `ModelFnOps` object. @@ -350,7 +357,8 @@ def _dnn_tree_combined_model_fn( trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train, tree_train_op), trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, - finalized_trees) + finalized_trees, + override_global_step_value) ]) return model_fn_ops @@ -378,7 +386,8 @@ def _dnn_tree_combined_model_fn( trainer_hooks.SwitchTrainOp(dnn_spec.train_op, dnn_steps_to_train, tree_spec.train_op), trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, - finalized_trees) + finalized_trees, + override_global_step_value) ] fusion_spec = fusion_spec._replace(training_hooks=training_hooks + list(fusion_spec.training_hooks)) @@ -411,7 +420,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): tree_feature_columns=None, tree_center_bias=False, dnn_to_tree_distillation_param=None, - use_core_versions=False): + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedClassifier instance. Args: @@ -467,6 +477,10 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ head = head_lib.multi_class_head( n_classes=n_classes, @@ -497,7 +511,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, - use_core_versions=use_core_versions) + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedClassifier, self).__init__( model_fn=_model_fn, @@ -531,7 +546,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): tree_feature_columns=None, tree_center_bias=False, dnn_to_tree_distillation_param=None, - use_core_versions=False): + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedRegressor instance. Args: @@ -587,6 +603,10 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ head = head_lib.regression_head( label_name=label_name, @@ -622,7 +642,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, - use_core_versions=use_core_versions) + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedRegressor, self).__init__( model_fn=_model_fn, @@ -657,7 +678,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): tree_feature_columns=None, tree_center_bias=False, dnn_to_tree_distillation_param=None, - use_core_versions=False): + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedEstimator instance. Args: @@ -708,6 +730,10 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ def _model_fn(features, labels, mode, config): @@ -732,7 +758,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, - use_core_versions=use_core_versions) + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedEstimator, self).__init__( model_fn=_model_fn, @@ -832,7 +859,8 @@ class CoreDNNBoostedTreeCombinedEstimator(core_estimator.Estimator): tree_center_bias=tree_center_bias, dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC, - use_core_versions=True) + use_core_versions=True, + override_global_step_value=None) super(CoreDNNBoostedTreeCombinedEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) |