diff options
Diffstat (limited to 'tensorflow/contrib/boosted_trees/estimator_batch/model.py')
-rw-r--r-- | tensorflow/contrib/boosted_trees/estimator_batch/model.py | 94 |
1 files changed, 70 insertions, 24 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 2fbe72951a..04b46c3483 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -58,7 +58,13 @@ def model_builder(features, * weight_column_name: The name of weight column. * center_bias: Whether a separate tree should be created for first fitting the bias. + * override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. config: `RunConfig` of the estimator. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). Returns: A `ModelFnOps` object. @@ -74,6 +80,7 @@ def model_builder(features, use_core_libs = params["use_core_libs"] logits_modifier_function = params["logits_modifier_function"] output_leaf_index = params["output_leaf_index"] + override_global_step_value = params.get("override_global_step_value", None) if features is None: raise ValueError("At least one feature must be specified.") @@ -126,14 +133,16 @@ def model_builder(features, create_estimator_spec_op = getattr(head, "create_estimator_spec", None) + training_hooks = [] if num_trees: if center_bias: num_trees += 1 + finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() - training_hooks = [ + training_hooks.append( trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, - finalized_trees) - ] + finalized_trees, + override_global_step_value)) if output_type == ModelBuilderOutputType.MODEL_FN_OPS: if use_core_libs and callable(create_estimator_spec_op): @@ -175,7 +184,12 @@ def model_builder(features, return model_fn_ops -def ranking_model_builder(features, labels, mode, params, config): +def ranking_model_builder(features, + labels, + mode, + params, + config, + output_type=ModelBuilderOutputType.MODEL_FN_OPS): """Multi-machine batch gradient descent tree model for ranking. Args: @@ -198,7 +212,14 @@ def ranking_model_builder(features, labels, mode, params, config): for left and right part of the training pairs for ranking. For example, for an Example with features "a.f1" and "b.f1", the keys would be ("a", "b"). + * override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. config: `RunConfig` of the estimator. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). + Returns: A `ModelFnOps` object. @@ -215,6 +236,7 @@ def ranking_model_builder(features, labels, mode, params, config): logits_modifier_function = params["logits_modifier_function"] output_leaf_index = params["output_leaf_index"] ranking_model_pair_keys = params["ranking_model_pair_keys"] + override_global_step_value = params.get("override_global_step_value", None) if features is None: raise ValueError("At least one feature must be specified.") @@ -326,31 +348,55 @@ def ranking_model_builder(features, labels, mode, params, config): return update_op create_estimator_spec_op = getattr(head, "create_estimator_spec", None) - if use_core_libs and callable(create_estimator_spec_op): - model_fn_ops = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_train_op_fn, - logits=logits) - model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops) - else: - model_fn_ops = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_train_op_fn, - logits=logits) - if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: - model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ - gbdt_batch.LEAF_INDEX] + training_hooks = [] if num_trees: if center_bias: num_trees += 1 + finalized_trees, attempted_trees = ( gbdt_model_main.get_number_of_trees_tensor()) - model_fn_ops.training_hooks.append( + training_hooks.append( trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, - finalized_trees)) + finalized_trees, + override_global_step_value)) + + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + if use_core_libs and callable(create_estimator_spec_op): + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( + model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: + model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ + gbdt_batch.LEAF_INDEX] + + model_fn_ops.training_hooks.extend(training_hooks) + return model_fn_ops + + elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC: + assert callable(create_estimator_spec_op) + estimator_spec = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + estimator_spec = estimator_spec._replace( + training_hooks=training_hooks + list(estimator_spec.training_hooks)) + return estimator_spec + return model_fn_ops |