aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/client/random_forest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensor_forest/client/random_forest.py')
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py328
1 files changed, 271 insertions, 57 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index 35e8c92aba..8fa0b3ada9 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -22,10 +22,12 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
-
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.python import tensor_forest
-
+from tensorflow.python.estimator import estimator as core_estimator
+from tensorflow.python.estimator.canned import head as core_head_lib
+from tensorflow.python.estimator.export.export_output import PredictOutput
+from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
@@ -34,12 +36,12 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
-
KEYS_NAME = 'keys'
LOSS_NAME = 'rf_training_loss'
TREE_PATHS_PREDICTION_KEY = 'tree_paths'
@@ -48,6 +50,11 @@ ALL_SERVING_KEY = 'tensorforest_all'
EPSILON = 0.000001
+class ModelBuilderOutputType(object):
+ MODEL_FN_OPS = 0
+ ESTIMATOR_SPEC = 1
+
+
class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook):
def __init__(self, op_dict):
@@ -106,20 +113,34 @@ class TensorForestLossHook(session_run_hook.SessionRunHook):
run_context.request_stop()
-def get_default_head(params, weights_name, name=None):
- if params.regression:
- return head_lib.regression_head(
- weight_column_name=weights_name,
- label_dimension=params.num_outputs,
- enable_centered_bias=False,
- head_name=name)
+def _get_default_head(params, weights_name, output_type, name=None):
+ """Creates a default head based on a type of a problem."""
+ if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
+ if params.regression:
+ return head_lib.regression_head(
+ weight_column_name=weights_name,
+ label_dimension=params.num_outputs,
+ enable_centered_bias=False,
+ head_name=name)
+ else:
+ return head_lib.multi_class_head(
+ params.num_classes,
+ weight_column_name=weights_name,
+ enable_centered_bias=False,
+ head_name=name)
else:
- return head_lib.multi_class_head(
- params.num_classes,
- weight_column_name=weights_name,
- enable_centered_bias=False,
- head_name=name)
-
+ if params.regression:
+ return core_head_lib._regression_head( # pylint:disable=protected-access
+ weight_column=weights_name,
+ label_dimension=params.num_outputs,
+ name=name,
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ else:
+ return core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access
+ n_classes=params.num_classes,
+ weight_column=weights_name,
+ name=name,
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
def get_model_fn(params,
graph_builder_class,
@@ -135,19 +156,27 @@ def get_model_fn(params,
report_feature_importances=False,
local_eval=False,
head_scope=None,
- include_all_in_serving=False):
+ include_all_in_serving=False,
+ output_type=ModelBuilderOutputType.MODEL_FN_OPS):
"""Return a model function given a way to construct a graph builder."""
if model_head is None:
- model_head = get_default_head(params, weights_name)
+ model_head = _get_default_head(params, weights_name, output_type)
def _model_fn(features, labels, mode):
"""Function that returns predictions, training loss, and training op."""
+
if (isinstance(features, ops.Tensor) or
isinstance(features, sparse_tensor.SparseTensor)):
features = {'features': features}
if feature_columns:
features = features.copy()
- features.update(layers.transform_features(features, feature_columns))
+
+ if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
+ features.update(layers.transform_features(features, feature_columns))
+ else:
+ for fc in feature_columns:
+ tensor = fc_core._transform_features(features, [fc])[fc] # pylint: disable=protected-access
+ features[fc.name] = tensor
weights = None
if weights_name and weights_name in features:
@@ -201,52 +230,95 @@ def get_model_fn(params,
def _train_fn(unused_loss):
return training_graph
- model_ops = model_head.create_model_fn_ops(
- features=features,
- labels=labels,
- mode=mode,
- train_op_fn=_train_fn,
- logits=logits,
- scope=head_scope)
# Ops are run in lexigraphical order of their keys. Run the resource
# clean-up op last.
all_handles = graph_builder.get_all_resource_handles()
ops_at_end = {
- '9: clean up resources': control_flow_ops.group(
- *[resource_variable_ops.destroy_resource_op(handle)
- for handle in all_handles])}
+ '9: clean up resources':
+ control_flow_ops.group(*[
+ resource_variable_ops.destroy_resource_op(handle)
+ for handle in all_handles
+ ])
+ }
if report_feature_importances:
ops_at_end['1: feature_importances'] = (
graph_builder.feature_importances())
- training_hooks.append(TensorForestRunOpAtEndHook(ops_at_end))
-
- if early_stopping_rounds:
- training_hooks.append(
- TensorForestLossHook(
- early_stopping_rounds,
- early_stopping_loss_threshold=early_stopping_loss_threshold,
- loss_op=model_ops.loss))
-
- model_ops.training_hooks.extend(training_hooks)
-
- if keys is not None:
- model_ops.predictions[keys_name] = keys
-
- if params.inference_tree_paths:
- model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
-
- model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
- if include_all_in_serving:
- # In order to serve the variance we need to add the prediction dict
- # to output_alternatives dict.
- if not model_ops.output_alternatives:
- model_ops.output_alternatives = {}
- model_ops.output_alternatives[ALL_SERVING_KEY] = (
- constants.ProblemType.UNSPECIFIED, model_ops.predictions)
- return model_ops
+ training_hooks = [TensorForestRunOpAtEndHook(ops_at_end)]
+
+ if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
+ model_ops = model_head.create_model_fn_ops(
+ features=features,
+ labels=labels,
+ mode=mode,
+ train_op_fn=_train_fn,
+ logits=logits,
+ scope=head_scope)
+
+ if early_stopping_rounds:
+ training_hooks.append(
+ TensorForestLossHook(
+ early_stopping_rounds,
+ early_stopping_loss_threshold=early_stopping_loss_threshold,
+ loss_op=model_ops.loss))
+
+ model_ops.training_hooks.extend(training_hooks)
+
+ if keys is not None:
+ model_ops.predictions[keys_name] = keys
+
+ if params.inference_tree_paths:
+ model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
+
+ model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
+
+ if include_all_in_serving:
+ # In order to serve the variance we need to add the prediction dict
+ # to output_alternatives dict.
+ if not model_ops.output_alternatives:
+ model_ops.output_alternatives = {}
+ model_ops.output_alternatives[ALL_SERVING_KEY] = (
+ constants.ProblemType.UNSPECIFIED, model_ops.predictions)
+
+ return model_ops
+
+ else:
+ # Estimator spec
+ estimator_spec = model_head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_fn,
+ logits=logits)
+
+ if early_stopping_rounds:
+ training_hooks.append(
+ TensorForestLossHook(
+ early_stopping_rounds,
+ early_stopping_loss_threshold=early_stopping_loss_threshold,
+ loss_op=estimator_spec.loss))
+
+ estimator_spec = estimator_spec._replace(
+ training_hooks=training_hooks + list(estimator_spec.training_hooks))
+ if keys is not None:
+ estimator_spec.predictions[keys_name] = keys
+ if params.inference_tree_paths:
+ estimator_spec.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
+ estimator_spec.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
+
+ if include_all_in_serving:
+ outputs = estimator_spec.export_outputs
+ if not outputs:
+ outputs = {}
+ outputs = {ALL_SERVING_KEY: PredictOutput(estimator_spec.predictions)}
+ print(estimator_spec.export_outputs)
+ # In order to serve the variance we need to add the prediction dict
+ # to output_alternatives dict.
+ estimator_spec = estimator_spec._replace(export_outputs=outputs)
+
+ return estimator_spec
return _model_fn
@@ -493,8 +565,11 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
params,
graph_builder_class,
device_assigner,
- model_head=get_default_head(
- params, weight_column, name='head{0}'.format(i)),
+ model_head=_get_default_head(
+ params,
+ weight_column,
+ name='head{0}'.format(i),
+ output_type=ModelBuilderOutputType.MODEL_FN_OPS),
weights_name=weight_column,
keys_name=keys_column,
early_stopping_rounds=early_stopping_rounds,
@@ -509,3 +584,142 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
+
+
+class CoreTensorForestEstimator(core_estimator.Estimator):
+ """A CORE estimator that can train and evaluate a random forest.
+
+ Example:
+
+ ```python
+ params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
+ num_classes=2, num_features=40, num_trees=10, max_nodes=1000)
+
+ # Estimator using the default graph builder.
+ estimator = CoreTensorForestEstimator(params, model_dir=model_dir)
+
+ # Or estimator using TrainingLossForest as the graph builder.
+ estimator = CoreTensorForestEstimator(
+ params, graph_builder_class=tensor_forest.TrainingLossForest,
+ model_dir=model_dir)
+
+ # Input builders
+ def input_fn_train: # returns x, y
+ ...
+ def input_fn_eval: # returns x, y
+ ...
+ estimator.train(input_fn=input_fn_train)
+ estimator.evaluate(input_fn=input_fn_eval)
+
+ # Predict returns an iterable of dicts.
+ results = list(estimator.predict(x=x))
+ prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]
+ prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME]
+ ```
+ """
+
+ def __init__(self,
+ params,
+ device_assigner=None,
+ model_dir=None,
+ feature_columns=None,
+ graph_builder_class=tensor_forest.RandomForestGraphs,
+ config=None,
+ weight_column=None,
+ keys_column=None,
+ feature_engineering_fn=None,
+ early_stopping_rounds=100,
+ early_stopping_loss_threshold=0.001,
+ num_trainers=1,
+ trainer_id=0,
+ report_feature_importances=False,
+ local_eval=False,
+ version=None,
+ head=None,
+ include_all_in_serving=False):
+ """Initializes a TensorForestEstimator instance.
+
+ Args:
+ params: ForestHParams object that holds random forest hyperparameters.
+ These parameters will be passed into `model_fn`.
+ device_assigner: An `object` instance that controls how trees get
+ assigned to devices. If `None`, will use
+ `tensor_forest.RandomForestDeviceAssigner`.
+ model_dir: Directory to save model parameters, graph, etc. To continue
+ training a previously saved model, load checkpoints saved to this
+ directory into an estimator.
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `_FeatureColumn`.
+ graph_builder_class: An `object` instance that defines how TF graphs for
+ random forest training and inference are built. By default will use
+ `tensor_forest.RandomForestGraphs`. Can be overridden by version
+ kwarg.
+ config: `RunConfig` object to configure the runtime settings.
+ weight_column: A string defining feature column name representing
+ weights. Will be multiplied by the loss of the example. Used to
+ downweight or boost examples during training.
+ keys_column: A string naming one of the features to strip out and
+ pass through into the inference/eval results dict. Useful for
+ associating specific examples with their prediction.
+ feature_engineering_fn: Feature engineering function. Takes features and
+ labels which are the output of `input_fn` and returns features and
+ labels which will be fed into the model.
+ early_stopping_rounds: Allows training to terminate early if the forest is
+ no longer growing. 100 by default. Set to a Falsy value to disable
+ the default training hook.
+ early_stopping_loss_threshold: Percentage (as fraction) that loss must
+ improve by within early_stopping_rounds steps, otherwise training will
+ terminate.
+ num_trainers: Number of training jobs, which will partition trees
+ among them.
+ trainer_id: Which trainer this instance is.
+ report_feature_importances: If True, print out feature importances
+ during evaluation.
+ local_eval: If True, don't use a device assigner for eval. This is to
+ support some common setups where eval is done on a single machine, even
+ though training might be distributed.
+ version: Unused.
+ head: A heads_lib.Head object that calculates losses and such. If None,
+ one will be automatically created based on params.
+ include_all_in_serving: if True, allow preparation of the complete
+ prediction dict including the variance to be exported for serving with
+ the Servo lib; and it also requires calling export_savedmodel with
+ default_output_alternative_key=ALL_SERVING_KEY, i.e.
+ estimator.export_savedmodel(export_dir_base=your_export_dir,
+ serving_input_fn=your_export_input_fn,
+ default_output_alternative_key=ALL_SERVING_KEY)
+ if False, resort to default behavior, i.e. export scores and
+ probabilities but no variances. In this case
+ default_output_alternative_key should be None while calling
+ export_savedmodel().
+ Note, that due to backward compatibility we cannot always set
+ include_all_in_serving to True because in this case calling
+ export_saved_model() without
+ default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the
+ saved_model_export_utils.get_output_alternatives() would raise
+ ValueError.
+
+ Returns:
+ A `TensorForestEstimator` instance.
+ """
+
+ super(CoreTensorForestEstimator, self).__init__(
+ model_fn=get_model_fn(
+ params.fill(),
+ graph_builder_class,
+ device_assigner,
+ feature_columns=feature_columns,
+ model_head=head,
+ weights_name=weight_column,
+ keys_name=keys_column,
+ early_stopping_rounds=early_stopping_rounds,
+ early_stopping_loss_threshold=early_stopping_loss_threshold,
+ num_trainers=num_trainers,
+ trainer_id=trainer_id,
+ report_feature_importances=report_feature_importances,
+ local_eval=local_eval,
+ include_all_in_serving=include_all_in_serving,
+ output_type=ModelBuilderOutputType.ESTIMATOR_SPEC),
+ model_dir=model_dir,
+ config=config)