aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-22 06:45:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-22 06:49:09 -0700
commit77f27c95fb565a9feafe8211d3de3b40e86c4402 (patch)
tree6cf0ef8bd2bc1105794311ad90a685b053cddbf0
parentd19f8b1e35409ef89463c8b84bef97cf53f43859 (diff)
Add a multi-head TensorForest estimator.
PiperOrigin-RevId: 159820487
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py171
1 files changed, 155 insertions, 16 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index ef2f0337ac..af9f56ab24 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -29,12 +29,13 @@ from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
-from tensorflow.python.training import monitored_session
from tensorflow.python.training import session_run_hook
@@ -74,8 +75,9 @@ class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook):
class TensorForestLossHook(session_run_hook.SessionRunHook):
"""Monitor to request stop when loss stops decreasing."""
- def __init__(self, early_stopping_rounds):
+ def __init__(self, early_stopping_rounds, loss_op=None):
self.early_stopping_rounds = early_stopping_rounds
+ self.loss_op = loss_op
self.min_loss = None
self.last_step = -1
# self.steps records the number of steps for which the loss has been
@@ -83,10 +85,12 @@ class TensorForestLossHook(session_run_hook.SessionRunHook):
self.steps = 0
def before_run(self, run_context):
+ loss = (self.loss_op if self.loss_op is not None else
+ run_context.session.graph.get_operation_by_name(
+ LOSS_NAME).outputs[0])
return session_run_hook.SessionRunArgs(
{'global_step': contrib_framework.get_global_step(),
- 'current_loss': run_context.session.graph.get_operation_by_name(
- LOSS_NAME).outputs[0]})
+ 'current_loss': loss})
def after_run(self, run_context, run_values):
current_loss = run_values.results['current_loss']
@@ -135,7 +139,6 @@ def get_model_fn(params,
num_trainers=1,
trainer_id=0,
report_feature_importances=False,
- model_dir=None,
local_eval=False):
"""Return a model function given a way to construct a graph builder."""
def _model_fn(features, labels, mode):
@@ -198,15 +201,6 @@ def get_model_fn(params,
trainer_id=trainer_id),
state_ops.assign_add(contrib_framework.get_global_step(), 1))
loss_deps.append(training_graph)
- if hasattr(graph_builder, 'finalize_training'):
- finalize_listener = EveryCheckpointPreSaveListener(
- graph_builder.finalize_training())
- scaffold = monitored_session.Scaffold()
- training_hooks.append(
- basic_session_run_hooks.CheckpointSaverHook(
- model_dir, save_secs=600, save_steps=None,
- scaffold=scaffold,
- listeners=[finalize_listener]))
training_loss = None
if (mode == model_fn_lib.ModeKeys.EVAL or
@@ -220,7 +214,8 @@ def get_model_fn(params,
features[weights_name] = weights
if early_stopping_rounds:
- training_hooks.append(TensorForestLossHook(early_stopping_rounds))
+ training_hooks.append(TensorForestLossHook(early_stopping_rounds,
+ loss_op=training_loss))
if report_feature_importances:
training_hooks.append(TensorForestRunOpAtEndHook(
@@ -328,9 +323,153 @@ class TensorForestEstimator(estimator.Estimator):
num_trainers=num_trainers,
trainer_id=trainer_id,
report_feature_importances=report_feature_importances,
- model_dir=model_dir,
local_eval=local_eval),
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
+
+def get_combined_model_fn(model_fns):
+ """Get a combined model function given a list of other model fns.
+
+ The model function returned will call the individual model functions and
+ combine them appropriately. For:
+
+ training ops: tf.group them.
+ loss: average them.
+ predictions: concat probabilities such that predictions[*][0-C1] are the
+ probablities for output 1 (where C1 is the number of classes in output 1),
+ predictions[*][C1-(C1+C2)] are the probabilities for output 2 (where C2
+ is the number of classes in output 2), etc. Also stack predictions such
+ that predictions[i][j] is the class prediction for example i and output j.
+
+ This assumes that labels are 2-dimensional, with labels[i][j] being the
+ label for example i and output j, where forest j is trained using only
+ output j.
+
+ Args:
+ model_fns: A list of model functions obtained from get_model_fn.
+
+ Returns:
+ A ModelFnOps instance.
+ """
+ def _model_fn(features, labels, mode):
+ """Function that returns predictions, training loss, and training op."""
+ model_fn_ops = []
+ for i in range(len(model_fns)):
+ with variable_scope.variable_scope('label_{0}'.format(i)):
+ sliced_labels = array_ops.slice(labels, [0, i], [-1, 1])
+ model_fn_ops.append(
+ model_fns[i](features, sliced_labels, mode))
+ training_hooks = []
+ for mops in model_fn_ops:
+ training_hooks += mops.training_hooks
+ predictions = {}
+ if (mode == model_fn_lib.ModeKeys.EVAL or
+ mode == model_fn_lib.ModeKeys.INFER):
+ # Flatten the probabilities into one dimension.
+ predictions[eval_metrics.INFERENCE_PROB_NAME] = array_ops.concat(
+ [mops.predictions[eval_metrics.INFERENCE_PROB_NAME]
+ for mops in model_fn_ops], axis=1)
+ predictions[eval_metrics.INFERENCE_PRED_NAME] = array_ops.stack(
+ [mops.predictions[eval_metrics.INFERENCE_PRED_NAME]
+ for mops in model_fn_ops], axis=1)
+ loss = None
+ if (mode == model_fn_lib.ModeKeys.EVAL or
+ mode == model_fn_lib.ModeKeys.TRAIN):
+ loss = math_ops.reduce_sum(
+ array_ops.stack(
+ [mops.loss for mops in model_fn_ops])) / len(model_fn_ops)
+
+ train_op = None
+ if mode == model_fn_lib.ModeKeys.TRAIN:
+ train_op = control_flow_ops.group(
+ *[mops.train_op for mops in model_fn_ops])
+ return model_fn_lib.ModelFnOps(
+ mode=mode,
+ predictions=predictions,
+ loss=loss,
+ train_op=train_op,
+ training_hooks=training_hooks,
+ scaffold=None,
+ output_alternatives=None)
+
+ return _model_fn
+
+
+class MultiForestMultiHeadEstimator(estimator.Estimator):
+ """An estimator that can train a forest for a multi-headed problems.
+
+ This class essentially trains separate forests (each with their own
+ ForestHParams) for each output.
+
+ For multi-headed regression, a single-headed TensorForestEstimator can
+ be used to train a single model that predicts all outputs. This class can
+ be used to train separate forests for each output.
+ """
+
+ def __init__(self, params_list, device_assigner=None, model_dir=None,
+ graph_builder_class=tensor_forest.RandomForestGraphs,
+ config=None, weights_name=None, keys_name=None,
+ feature_engineering_fn=None,
+ early_stopping_rounds=100,
+ num_trainers=1, trainer_id=0,
+ report_feature_importances=False,
+ local_eval=False):
+ """Initializes a TensorForestEstimator instance.
+
+ Args:
+ params_list: A list of ForestHParams objects for each head, given in order
+ of outputs in the label tensor to be trained on.
+ device_assigner: An `object` instance that controls how trees get
+ assigned to devices. If `None`, will use
+ `tensor_forest.RandomForestDeviceAssigner`.
+ model_dir: Directory to save model parameters, graph, etc. To continue
+ training a previously saved model, load checkpoints saved to this
+ directory into an estimator.
+ graph_builder_class: An `object` instance that defines how TF graphs for
+ random forest training and inference are built. By default will use
+ `tensor_forest.RandomForestGraphs`.
+ config: `RunConfig` object to configure the runtime settings.
+ weights_name: A string defining feature column name representing
+ weights. Will be multiplied by the loss of the example. Used to
+ downweight or boost examples during training.
+ keys_name: A string naming one of the features to strip out and
+ pass through into the inference/eval results dict. Useful for
+ associating specific examples with their prediction.
+ feature_engineering_fn: Feature engineering function. Takes features and
+ labels which are the output of `input_fn` and returns features and
+ labels which will be fed into the model.
+ early_stopping_rounds: Allows training to terminate early if the forest is
+ no longer growing. 100 by default. Set to a Falsy value to disable
+ the default training hook.
+ num_trainers: Number of training jobs, which will partition trees
+ among them.
+ trainer_id: Which trainer this instance is.
+ report_feature_importances: If True, print out feature importances
+ during evaluation.
+ local_eval: If True, don't use a device assigner for eval. This is to
+ support some common setups where eval is done on a single machine, even
+ though training might be distributed.
+
+ Returns:
+ A `TensorForestEstimator` instance.
+ """
+ model_fns = [get_model_fn(
+ params.fill(),
+ graph_builder_class,
+ device_assigner,
+ weights_name=weights_name,
+ keys_name=keys_name,
+ early_stopping_rounds=early_stopping_rounds,
+ num_trainers=num_trainers,
+ trainer_id=trainer_id,
+ report_feature_importances=report_feature_importances,
+ local_eval=local_eval) for params in params_list]
+
+ super(MultiForestMultiHeadEstimator, self).__init__(
+ model_fn=get_combined_model_fn(model_fns),
+ model_dir=model_dir,
+ config=config,
+ feature_engineering_fn=feature_engineering_fn)
+