aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/canned/boosted_trees.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/canned/boosted_trees.py')
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py118
1 files changed, 63 insertions, 55 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 8b423f76de..16928ca4b7 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -703,9 +703,30 @@ def _bt_model_fn(
global_step = training_util.get_or_create_global_step()
bucket_size_list, feature_ids_list = _group_features_by_num_buckets(
sorted_feature_columns)
+ # Create Ensemble resources.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+
+ # Create logits.
+ if mode != model_fn.ModeKeys.TRAIN:
+ input_feature_list = _get_transformed_features(features,
+ sorted_feature_columns)
+ logits = boosted_trees_ops.predict(
+ # For non-TRAIN mode, ensemble doesn't change after initialization,
+ # so no local copy is needed; using tree_ensemble directly.
+ tree_ensemble_handle=tree_ensemble.resource_handle,
+ bucketized_features=input_feature_list,
+ logits_dimension=head.logits_dimension)
+ return head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=control_flow_ops.no_op,
+ logits=logits)
+
+ # ============== Training graph ==============
# Extract input features and set up cache for training.
training_state_cache = None
- if mode == model_fn.ModeKeys.TRAIN and train_in_memory:
+ if train_in_memory:
# cache transformed features as well for in-memory training.
batch_size = array_ops.shape(labels)[0]
input_feature_list, input_cache_op = (
@@ -717,63 +738,51 @@ def _bt_model_fn(
else:
input_feature_list = _get_transformed_features(features,
sorted_feature_columns)
- if mode == model_fn.ModeKeys.TRAIN and example_id_column_name:
+ if example_id_column_name:
example_ids = features[example_id_column_name]
training_state_cache = _CacheTrainingStatesUsingHashTable(
example_ids, head.logits_dimension)
- # Create Ensemble resources.
- tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
# Variable that determines whether bias centering is needed.
center_bias_var = variable_scope.variable(
initial_value=center_bias, name='center_bias_needed', trainable=False)
- # Create logits.
- if mode != model_fn.ModeKeys.TRAIN:
- logits = boosted_trees_ops.predict(
- # For non-TRAIN mode, ensemble doesn't change after initialization,
- # so no local copy is needed; using tree_ensemble directly.
- tree_ensemble_handle=tree_ensemble.resource_handle,
+ if is_single_machine:
+ local_tree_ensemble = tree_ensemble
+ ensemble_reload = control_flow_ops.no_op()
+ else:
+ # Have a local copy of ensemble for the distributed setting.
+ with ops.device(worker_device):
+ local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ name=name + '_local', is_local=True)
+ # TODO(soroush): Do partial updates if this becomes a bottleneck.
+ ensemble_reload = local_tree_ensemble.deserialize(
+ *tree_ensemble.serialize())
+
+ if training_state_cache:
+ cached_tree_ids, cached_node_ids, cached_logits = (
+ training_state_cache.lookup())
+ else:
+ # Always start from the beginning when no cache is set up.
+ batch_size = array_ops.shape(labels)[0]
+ cached_tree_ids, cached_node_ids, cached_logits = (
+ array_ops.zeros([batch_size], dtype=dtypes.int32),
+ _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
+ array_ops.zeros(
+ [batch_size, head.logits_dimension], dtype=dtypes.float32))
+
+ with ops.control_dependencies([ensemble_reload]):
+ (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
+ last_layer_nodes_range) = local_tree_ensemble.get_states()
+ summary.scalar('ensemble/num_trees', num_trees)
+ summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
+ summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
+
+ partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
+ tree_ensemble_handle=local_tree_ensemble.resource_handle,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
bucketized_features=input_feature_list,
logits_dimension=head.logits_dimension)
- else:
- if is_single_machine:
- local_tree_ensemble = tree_ensemble
- ensemble_reload = control_flow_ops.no_op()
- else:
- # Have a local copy of ensemble for the distributed setting.
- with ops.device(worker_device):
- local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
- name=name + '_local', is_local=True)
- # TODO(soroush): Do partial updates if this becomes a bottleneck.
- ensemble_reload = local_tree_ensemble.deserialize(
- *tree_ensemble.serialize())
-
- if training_state_cache:
- cached_tree_ids, cached_node_ids, cached_logits = (
- training_state_cache.lookup())
- else:
- # Always start from the beginning when no cache is set up.
- batch_size = array_ops.shape(labels)[0]
- cached_tree_ids, cached_node_ids, cached_logits = (
- array_ops.zeros([batch_size], dtype=dtypes.int32),
- _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
- array_ops.zeros(
- [batch_size, head.logits_dimension], dtype=dtypes.float32))
-
- with ops.control_dependencies([ensemble_reload]):
- (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
- last_layer_nodes_range) = local_tree_ensemble.get_states()
- summary.scalar('ensemble/num_trees', num_trees)
- summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
- summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
-
- partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
- tree_ensemble_handle=local_tree_ensemble.resource_handle,
- cached_tree_ids=cached_tree_ids,
- cached_node_ids=cached_node_ids,
- bucketized_features=input_feature_list,
- logits_dimension=head.logits_dimension)
-
logits = cached_logits + partial_logits
# Create training graph.
@@ -846,12 +855,11 @@ def _bt_model_fn(
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
- if mode == model_fn.ModeKeys.TRAIN:
- # Add an early stop hook.
- estimator_spec = estimator_spec._replace(
- training_hooks=estimator_spec.training_hooks +
- (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
- tree_hparams.n_trees, tree_hparams.max_depth),))
+ # Add an early stop hook.
+ estimator_spec = estimator_spec._replace(
+ training_hooks=estimator_spec.training_hooks +
+ (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
+ tree_hparams.n_trees, tree_hparams.max_depth),))
return estimator_spec