aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/estimator_batch/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/estimator_batch/model.py')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py94
1 files changed, 70 insertions, 24 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 2fbe72951a..04b46c3483 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -58,7 +58,13 @@ def model_builder(features,
* weight_column_name: The name of weight column.
* center_bias: Whether a separate tree should be created for first fitting
the bias.
+ * override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
config: `RunConfig` of the estimator.
+ output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
+ (new interface).
Returns:
A `ModelFnOps` object.
@@ -74,6 +80,7 @@ def model_builder(features,
use_core_libs = params["use_core_libs"]
logits_modifier_function = params["logits_modifier_function"]
output_leaf_index = params["output_leaf_index"]
+ override_global_step_value = params.get("override_global_step_value", None)
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -126,14 +133,16 @@ def model_builder(features,
create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
+ training_hooks = []
if num_trees:
if center_bias:
num_trees += 1
+
finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
- training_hooks = [
+ training_hooks.append(
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees)
- ]
+ finalized_trees,
+ override_global_step_value))
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
if use_core_libs and callable(create_estimator_spec_op):
@@ -175,7 +184,12 @@ def model_builder(features,
return model_fn_ops
-def ranking_model_builder(features, labels, mode, params, config):
+def ranking_model_builder(features,
+ labels,
+ mode,
+ params,
+ config,
+ output_type=ModelBuilderOutputType.MODEL_FN_OPS):
"""Multi-machine batch gradient descent tree model for ranking.
Args:
@@ -198,7 +212,14 @@ def ranking_model_builder(features, labels, mode, params, config):
for left and right part of the training pairs for ranking. For example,
for an Example with features "a.f1" and "b.f1", the keys would be
("a", "b").
+ * override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
config: `RunConfig` of the estimator.
+ output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
+ (new interface).
+
Returns:
A `ModelFnOps` object.
@@ -215,6 +236,7 @@ def ranking_model_builder(features, labels, mode, params, config):
logits_modifier_function = params["logits_modifier_function"]
output_leaf_index = params["output_leaf_index"]
ranking_model_pair_keys = params["ranking_model_pair_keys"]
+ override_global_step_value = params.get("override_global_step_value", None)
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -326,31 +348,55 @@ def ranking_model_builder(features, labels, mode, params, config):
return update_op
create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
- if use_core_libs and callable(create_estimator_spec_op):
- model_fn_ops = head.create_estimator_spec(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_train_op_fn,
- logits=logits)
- model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops)
- else:
- model_fn_ops = head.create_model_fn_ops(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_train_op_fn,
- logits=logits)
- if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
- model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
- gbdt_batch.LEAF_INDEX]
+ training_hooks = []
if num_trees:
if center_bias:
num_trees += 1
+
finalized_trees, attempted_trees = (
gbdt_model_main.get_number_of_trees_tensor())
- model_fn_ops.training_hooks.append(
+ training_hooks.append(
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees))
+ finalized_trees,
+ override_global_step_value))
+
+ if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
+ if use_core_libs and callable(create_estimator_spec_op):
+ model_fn_ops = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+ model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(
+ model_fn_ops)
+ else:
+ model_fn_ops = head.create_model_fn_ops(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+
+ if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
+ model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
+ gbdt_batch.LEAF_INDEX]
+
+ model_fn_ops.training_hooks.extend(training_hooks)
+ return model_fn_ops
+
+ elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC:
+ assert callable(create_estimator_spec_op)
+ estimator_spec = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+
+ estimator_spec = estimator_spec._replace(
+ training_hooks=training_hooks + list(estimator_spec.training_hooks))
+ return estimator_spec
+
return model_fn_ops