diff options
author | 2017-06-22 06:45:39 -0700 | |
---|---|---|
committer | 2017-06-22 06:49:09 -0700 | |
commit | 77f27c95fb565a9feafe8211d3de3b40e86c4402 (patch) | |
tree | 6cf0ef8bd2bc1105794311ad90a685b053cddbf0 | |
parent | d19f8b1e35409ef89463c8b84bef97cf53f43859 (diff) |
Add a multi-head TensorForest estimator.
PiperOrigin-RevId: 159820487
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/random_forest.py | 171 |
1 files changed, 155 insertions, 16 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index ef2f0337ac..af9f56ab24 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -29,12 +29,13 @@ from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook @@ -74,8 +75,9 @@ class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook): class TensorForestLossHook(session_run_hook.SessionRunHook): """Monitor to request stop when loss stops decreasing.""" - def __init__(self, early_stopping_rounds): + def __init__(self, early_stopping_rounds, loss_op=None): self.early_stopping_rounds = early_stopping_rounds + self.loss_op = loss_op self.min_loss = None self.last_step = -1 # self.steps records the number of steps for which the loss has been @@ -83,10 +85,12 @@ class TensorForestLossHook(session_run_hook.SessionRunHook): self.steps = 0 def before_run(self, run_context): + loss = (self.loss_op if self.loss_op is not None else + run_context.session.graph.get_operation_by_name( + LOSS_NAME).outputs[0]) return session_run_hook.SessionRunArgs( {'global_step': contrib_framework.get_global_step(), - 'current_loss': run_context.session.graph.get_operation_by_name( - LOSS_NAME).outputs[0]}) + 'current_loss': loss}) def after_run(self, run_context, run_values): current_loss = run_values.results['current_loss'] @@ -135,7 +139,6 @@ def get_model_fn(params, num_trainers=1, trainer_id=0, report_feature_importances=False, - model_dir=None, local_eval=False): """Return a model function given a way to construct a graph builder.""" def _model_fn(features, labels, mode): @@ -198,15 +201,6 @@ def get_model_fn(params, trainer_id=trainer_id), state_ops.assign_add(contrib_framework.get_global_step(), 1)) loss_deps.append(training_graph) - if hasattr(graph_builder, 'finalize_training'): - finalize_listener = EveryCheckpointPreSaveListener( - graph_builder.finalize_training()) - scaffold = monitored_session.Scaffold() - training_hooks.append( - basic_session_run_hooks.CheckpointSaverHook( - model_dir, save_secs=600, save_steps=None, - scaffold=scaffold, - listeners=[finalize_listener])) training_loss = None if (mode == model_fn_lib.ModeKeys.EVAL or @@ -220,7 +214,8 @@ def get_model_fn(params, features[weights_name] = weights if early_stopping_rounds: - training_hooks.append(TensorForestLossHook(early_stopping_rounds)) + training_hooks.append(TensorForestLossHook(early_stopping_rounds, + loss_op=training_loss)) if report_feature_importances: training_hooks.append(TensorForestRunOpAtEndHook( @@ -328,9 +323,153 @@ class TensorForestEstimator(estimator.Estimator): num_trainers=num_trainers, trainer_id=trainer_id, report_feature_importances=report_feature_importances, - model_dir=model_dir, local_eval=local_eval), model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) + +def get_combined_model_fn(model_fns): + """Get a combined model function given a list of other model fns. + + The model function returned will call the individual model functions and + combine them appropriately. For: + + training ops: tf.group them. + loss: average them. + predictions: concat probabilities such that predictions[*][0-C1] are the + probablities for output 1 (where C1 is the number of classes in output 1), + predictions[*][C1-(C1+C2)] are the probabilities for output 2 (where C2 + is the number of classes in output 2), etc. Also stack predictions such + that predictions[i][j] is the class prediction for example i and output j. + + This assumes that labels are 2-dimensional, with labels[i][j] being the + label for example i and output j, where forest j is trained using only + output j. + + Args: + model_fns: A list of model functions obtained from get_model_fn. + + Returns: + A ModelFnOps instance. + """ + def _model_fn(features, labels, mode): + """Function that returns predictions, training loss, and training op.""" + model_fn_ops = [] + for i in range(len(model_fns)): + with variable_scope.variable_scope('label_{0}'.format(i)): + sliced_labels = array_ops.slice(labels, [0, i], [-1, 1]) + model_fn_ops.append( + model_fns[i](features, sliced_labels, mode)) + training_hooks = [] + for mops in model_fn_ops: + training_hooks += mops.training_hooks + predictions = {} + if (mode == model_fn_lib.ModeKeys.EVAL or + mode == model_fn_lib.ModeKeys.INFER): + # Flatten the probabilities into one dimension. + predictions[eval_metrics.INFERENCE_PROB_NAME] = array_ops.concat( + [mops.predictions[eval_metrics.INFERENCE_PROB_NAME] + for mops in model_fn_ops], axis=1) + predictions[eval_metrics.INFERENCE_PRED_NAME] = array_ops.stack( + [mops.predictions[eval_metrics.INFERENCE_PRED_NAME] + for mops in model_fn_ops], axis=1) + loss = None + if (mode == model_fn_lib.ModeKeys.EVAL or + mode == model_fn_lib.ModeKeys.TRAIN): + loss = math_ops.reduce_sum( + array_ops.stack( + [mops.loss for mops in model_fn_ops])) / len(model_fn_ops) + + train_op = None + if mode == model_fn_lib.ModeKeys.TRAIN: + train_op = control_flow_ops.group( + *[mops.train_op for mops in model_fn_ops]) + return model_fn_lib.ModelFnOps( + mode=mode, + predictions=predictions, + loss=loss, + train_op=train_op, + training_hooks=training_hooks, + scaffold=None, + output_alternatives=None) + + return _model_fn + + +class MultiForestMultiHeadEstimator(estimator.Estimator): + """An estimator that can train a forest for a multi-headed problems. + + This class essentially trains separate forests (each with their own + ForestHParams) for each output. + + For multi-headed regression, a single-headed TensorForestEstimator can + be used to train a single model that predicts all outputs. This class can + be used to train separate forests for each output. + """ + + def __init__(self, params_list, device_assigner=None, model_dir=None, + graph_builder_class=tensor_forest.RandomForestGraphs, + config=None, weights_name=None, keys_name=None, + feature_engineering_fn=None, + early_stopping_rounds=100, + num_trainers=1, trainer_id=0, + report_feature_importances=False, + local_eval=False): + """Initializes a TensorForestEstimator instance. + + Args: + params_list: A list of ForestHParams objects for each head, given in order + of outputs in the label tensor to be trained on. + device_assigner: An `object` instance that controls how trees get + assigned to devices. If `None`, will use + `tensor_forest.RandomForestDeviceAssigner`. + model_dir: Directory to save model parameters, graph, etc. To continue + training a previously saved model, load checkpoints saved to this + directory into an estimator. + graph_builder_class: An `object` instance that defines how TF graphs for + random forest training and inference are built. By default will use + `tensor_forest.RandomForestGraphs`. + config: `RunConfig` object to configure the runtime settings. + weights_name: A string defining feature column name representing + weights. Will be multiplied by the loss of the example. Used to + downweight or boost examples during training. + keys_name: A string naming one of the features to strip out and + pass through into the inference/eval results dict. Useful for + associating specific examples with their prediction. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + early_stopping_rounds: Allows training to terminate early if the forest is + no longer growing. 100 by default. Set to a Falsy value to disable + the default training hook. + num_trainers: Number of training jobs, which will partition trees + among them. + trainer_id: Which trainer this instance is. + report_feature_importances: If True, print out feature importances + during evaluation. + local_eval: If True, don't use a device assigner for eval. This is to + support some common setups where eval is done on a single machine, even + though training might be distributed. + + Returns: + A `TensorForestEstimator` instance. + """ + model_fns = [get_model_fn( + params.fill(), + graph_builder_class, + device_assigner, + weights_name=weights_name, + keys_name=keys_name, + early_stopping_rounds=early_stopping_rounds, + num_trainers=num_trainers, + trainer_id=trainer_id, + report_feature_importances=report_feature_importances, + local_eval=local_eval) for params in params_list] + + super(MultiForestMultiHeadEstimator, self).__init__( + model_fn=get_combined_model_fn(model_fns), + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) + |