diff options
Diffstat (limited to 'tensorflow/python/estimator/canned/boosted_trees.py')
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees.py | 118 |
1 files changed, 63 insertions, 55 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index 8b423f76de..16928ca4b7 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -703,9 +703,30 @@ def _bt_model_fn( global_step = training_util.get_or_create_global_step() bucket_size_list, feature_ids_list = _group_features_by_num_buckets( sorted_feature_columns) + # Create Ensemble resources. + tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name) + + # Create logits. + if mode != model_fn.ModeKeys.TRAIN: + input_feature_list = _get_transformed_features(features, + sorted_feature_columns) + logits = boosted_trees_ops.predict( + # For non-TRAIN mode, ensemble doesn't change after initialization, + # so no local copy is needed; using tree_ensemble directly. + tree_ensemble_handle=tree_ensemble.resource_handle, + bucketized_features=input_feature_list, + logits_dimension=head.logits_dimension) + return head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=control_flow_ops.no_op, + logits=logits) + + # ============== Training graph ============== # Extract input features and set up cache for training. training_state_cache = None - if mode == model_fn.ModeKeys.TRAIN and train_in_memory: + if train_in_memory: # cache transformed features as well for in-memory training. batch_size = array_ops.shape(labels)[0] input_feature_list, input_cache_op = ( @@ -717,63 +738,51 @@ def _bt_model_fn( else: input_feature_list = _get_transformed_features(features, sorted_feature_columns) - if mode == model_fn.ModeKeys.TRAIN and example_id_column_name: + if example_id_column_name: example_ids = features[example_id_column_name] training_state_cache = _CacheTrainingStatesUsingHashTable( example_ids, head.logits_dimension) - # Create Ensemble resources. - tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name) # Variable that determines whether bias centering is needed. center_bias_var = variable_scope.variable( initial_value=center_bias, name='center_bias_needed', trainable=False) - # Create logits. - if mode != model_fn.ModeKeys.TRAIN: - logits = boosted_trees_ops.predict( - # For non-TRAIN mode, ensemble doesn't change after initialization, - # so no local copy is needed; using tree_ensemble directly. - tree_ensemble_handle=tree_ensemble.resource_handle, + if is_single_machine: + local_tree_ensemble = tree_ensemble + ensemble_reload = control_flow_ops.no_op() + else: + # Have a local copy of ensemble for the distributed setting. + with ops.device(worker_device): + local_tree_ensemble = boosted_trees_ops.TreeEnsemble( + name=name + '_local', is_local=True) + # TODO(soroush): Do partial updates if this becomes a bottleneck. + ensemble_reload = local_tree_ensemble.deserialize( + *tree_ensemble.serialize()) + + if training_state_cache: + cached_tree_ids, cached_node_ids, cached_logits = ( + training_state_cache.lookup()) + else: + # Always start from the beginning when no cache is set up. + batch_size = array_ops.shape(labels)[0] + cached_tree_ids, cached_node_ids, cached_logits = ( + array_ops.zeros([batch_size], dtype=dtypes.int32), + _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32), + array_ops.zeros( + [batch_size, head.logits_dimension], dtype=dtypes.float32)) + + with ops.control_dependencies([ensemble_reload]): + (stamp_token, num_trees, num_finalized_trees, num_attempted_layers, + last_layer_nodes_range) = local_tree_ensemble.get_states() + summary.scalar('ensemble/num_trees', num_trees) + summary.scalar('ensemble/num_finalized_trees', num_finalized_trees) + summary.scalar('ensemble/num_attempted_layers', num_attempted_layers) + + partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict( + tree_ensemble_handle=local_tree_ensemble.resource_handle, + cached_tree_ids=cached_tree_ids, + cached_node_ids=cached_node_ids, bucketized_features=input_feature_list, logits_dimension=head.logits_dimension) - else: - if is_single_machine: - local_tree_ensemble = tree_ensemble - ensemble_reload = control_flow_ops.no_op() - else: - # Have a local copy of ensemble for the distributed setting. - with ops.device(worker_device): - local_tree_ensemble = boosted_trees_ops.TreeEnsemble( - name=name + '_local', is_local=True) - # TODO(soroush): Do partial updates if this becomes a bottleneck. - ensemble_reload = local_tree_ensemble.deserialize( - *tree_ensemble.serialize()) - - if training_state_cache: - cached_tree_ids, cached_node_ids, cached_logits = ( - training_state_cache.lookup()) - else: - # Always start from the beginning when no cache is set up. - batch_size = array_ops.shape(labels)[0] - cached_tree_ids, cached_node_ids, cached_logits = ( - array_ops.zeros([batch_size], dtype=dtypes.int32), - _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32), - array_ops.zeros( - [batch_size, head.logits_dimension], dtype=dtypes.float32)) - - with ops.control_dependencies([ensemble_reload]): - (stamp_token, num_trees, num_finalized_trees, num_attempted_layers, - last_layer_nodes_range) = local_tree_ensemble.get_states() - summary.scalar('ensemble/num_trees', num_trees) - summary.scalar('ensemble/num_finalized_trees', num_finalized_trees) - summary.scalar('ensemble/num_attempted_layers', num_attempted_layers) - - partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict( - tree_ensemble_handle=local_tree_ensemble.resource_handle, - cached_tree_ids=cached_tree_ids, - cached_node_ids=cached_node_ids, - bucketized_features=input_feature_list, - logits_dimension=head.logits_dimension) - logits = cached_logits + partial_logits # Create training graph. @@ -846,12 +855,11 @@ def _bt_model_fn( labels=labels, train_op_fn=_train_op_fn, logits=logits) - if mode == model_fn.ModeKeys.TRAIN: - # Add an early stop hook. - estimator_spec = estimator_spec._replace( - training_hooks=estimator_spec.training_hooks + - (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers, - tree_hparams.n_trees, tree_hparams.max_depth),)) + # Add an early stop hook. + estimator_spec = estimator_spec._replace( + training_hooks=estimator_spec.training_hooks + + (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers, + tree_hparams.n_trees, tree_hparams.max_depth),)) return estimator_spec |