aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-14 09:21:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-14 09:25:27 -0700
commitb03f732b3f8bd5a038d378c7f2044e8e61d5fdb0 (patch)
treeaf6f06b32f0632bc83393f44fce5d4accdf7ba85 /tensorflow/contrib/tensor_forest
parentd7f93284c826451dfd0c28108674a453cb629c09 (diff)
Providing a core estimator interface over a contrib tensorforest.
PiperOrigin-RevId: 208658097
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py328
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest_test.py303
2 files changed, 565 insertions, 66 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index 35e8c92aba..8fa0b3ada9 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -22,10 +22,12 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
-
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.python import tensor_forest
-
+from tensorflow.python.estimator import estimator as core_estimator
+from tensorflow.python.estimator.canned import head as core_head_lib
+from tensorflow.python.estimator.export.export_output import PredictOutput
+from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
@@ -34,12 +36,12 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
-
KEYS_NAME = 'keys'
LOSS_NAME = 'rf_training_loss'
TREE_PATHS_PREDICTION_KEY = 'tree_paths'
@@ -48,6 +50,11 @@ ALL_SERVING_KEY = 'tensorforest_all'
EPSILON = 0.000001
+class ModelBuilderOutputType(object):
+ MODEL_FN_OPS = 0
+ ESTIMATOR_SPEC = 1
+
+
class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook):
def __init__(self, op_dict):
@@ -106,20 +113,34 @@ class TensorForestLossHook(session_run_hook.SessionRunHook):
run_context.request_stop()
-def get_default_head(params, weights_name, name=None):
- if params.regression:
- return head_lib.regression_head(
- weight_column_name=weights_name,
- label_dimension=params.num_outputs,
- enable_centered_bias=False,
- head_name=name)
+def _get_default_head(params, weights_name, output_type, name=None):
+ """Creates a default head based on a type of a problem."""
+ if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
+ if params.regression:
+ return head_lib.regression_head(
+ weight_column_name=weights_name,
+ label_dimension=params.num_outputs,
+ enable_centered_bias=False,
+ head_name=name)
+ else:
+ return head_lib.multi_class_head(
+ params.num_classes,
+ weight_column_name=weights_name,
+ enable_centered_bias=False,
+ head_name=name)
else:
- return head_lib.multi_class_head(
- params.num_classes,
- weight_column_name=weights_name,
- enable_centered_bias=False,
- head_name=name)
-
+ if params.regression:
+ return core_head_lib._regression_head( # pylint:disable=protected-access
+ weight_column=weights_name,
+ label_dimension=params.num_outputs,
+ name=name,
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ else:
+ return core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access
+ n_classes=params.num_classes,
+ weight_column=weights_name,
+ name=name,
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
def get_model_fn(params,
graph_builder_class,
@@ -135,19 +156,27 @@ def get_model_fn(params,
report_feature_importances=False,
local_eval=False,
head_scope=None,
- include_all_in_serving=False):
+ include_all_in_serving=False,
+ output_type=ModelBuilderOutputType.MODEL_FN_OPS):
"""Return a model function given a way to construct a graph builder."""
if model_head is None:
- model_head = get_default_head(params, weights_name)
+ model_head = _get_default_head(params, weights_name, output_type)
def _model_fn(features, labels, mode):
"""Function that returns predictions, training loss, and training op."""
+
if (isinstance(features, ops.Tensor) or
isinstance(features, sparse_tensor.SparseTensor)):
features = {'features': features}
if feature_columns:
features = features.copy()
- features.update(layers.transform_features(features, feature_columns))
+
+ if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
+ features.update(layers.transform_features(features, feature_columns))
+ else:
+ for fc in feature_columns:
+ tensor = fc_core._transform_features(features, [fc])[fc] # pylint: disable=protected-access
+ features[fc.name] = tensor
weights = None
if weights_name and weights_name in features:
@@ -201,52 +230,95 @@ def get_model_fn(params,
def _train_fn(unused_loss):
return training_graph
- model_ops = model_head.create_model_fn_ops(
- features=features,
- labels=labels,
- mode=mode,
- train_op_fn=_train_fn,
- logits=logits,
- scope=head_scope)
# Ops are run in lexigraphical order of their keys. Run the resource
# clean-up op last.
all_handles = graph_builder.get_all_resource_handles()
ops_at_end = {
- '9: clean up resources': control_flow_ops.group(
- *[resource_variable_ops.destroy_resource_op(handle)
- for handle in all_handles])}
+ '9: clean up resources':
+ control_flow_ops.group(*[
+ resource_variable_ops.destroy_resource_op(handle)
+ for handle in all_handles
+ ])
+ }
if report_feature_importances:
ops_at_end['1: feature_importances'] = (
graph_builder.feature_importances())
- training_hooks.append(TensorForestRunOpAtEndHook(ops_at_end))
-
- if early_stopping_rounds:
- training_hooks.append(
- TensorForestLossHook(
- early_stopping_rounds,
- early_stopping_loss_threshold=early_stopping_loss_threshold,
- loss_op=model_ops.loss))
-
- model_ops.training_hooks.extend(training_hooks)
-
- if keys is not None:
- model_ops.predictions[keys_name] = keys
-
- if params.inference_tree_paths:
- model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
-
- model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
- if include_all_in_serving:
- # In order to serve the variance we need to add the prediction dict
- # to output_alternatives dict.
- if not model_ops.output_alternatives:
- model_ops.output_alternatives = {}
- model_ops.output_alternatives[ALL_SERVING_KEY] = (
- constants.ProblemType.UNSPECIFIED, model_ops.predictions)
- return model_ops
+ training_hooks = [TensorForestRunOpAtEndHook(ops_at_end)]
+
+ if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
+ model_ops = model_head.create_model_fn_ops(
+ features=features,
+ labels=labels,
+ mode=mode,
+ train_op_fn=_train_fn,
+ logits=logits,
+ scope=head_scope)
+
+ if early_stopping_rounds:
+ training_hooks.append(
+ TensorForestLossHook(
+ early_stopping_rounds,
+ early_stopping_loss_threshold=early_stopping_loss_threshold,
+ loss_op=model_ops.loss))
+
+ model_ops.training_hooks.extend(training_hooks)
+
+ if keys is not None:
+ model_ops.predictions[keys_name] = keys
+
+ if params.inference_tree_paths:
+ model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
+
+ model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
+
+ if include_all_in_serving:
+ # In order to serve the variance we need to add the prediction dict
+ # to output_alternatives dict.
+ if not model_ops.output_alternatives:
+ model_ops.output_alternatives = {}
+ model_ops.output_alternatives[ALL_SERVING_KEY] = (
+ constants.ProblemType.UNSPECIFIED, model_ops.predictions)
+
+ return model_ops
+
+ else:
+ # Estimator spec
+ estimator_spec = model_head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_fn,
+ logits=logits)
+
+ if early_stopping_rounds:
+ training_hooks.append(
+ TensorForestLossHook(
+ early_stopping_rounds,
+ early_stopping_loss_threshold=early_stopping_loss_threshold,
+ loss_op=estimator_spec.loss))
+
+ estimator_spec = estimator_spec._replace(
+ training_hooks=training_hooks + list(estimator_spec.training_hooks))
+ if keys is not None:
+ estimator_spec.predictions[keys_name] = keys
+ if params.inference_tree_paths:
+ estimator_spec.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
+ estimator_spec.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
+
+ if include_all_in_serving:
+ outputs = estimator_spec.export_outputs
+ if not outputs:
+ outputs = {}
+ outputs = {ALL_SERVING_KEY: PredictOutput(estimator_spec.predictions)}
+ print(estimator_spec.export_outputs)
+ # In order to serve the variance we need to add the prediction dict
+ # to output_alternatives dict.
+ estimator_spec = estimator_spec._replace(export_outputs=outputs)
+
+ return estimator_spec
return _model_fn
@@ -493,8 +565,11 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
params,
graph_builder_class,
device_assigner,
- model_head=get_default_head(
- params, weight_column, name='head{0}'.format(i)),
+ model_head=_get_default_head(
+ params,
+ weight_column,
+ name='head{0}'.format(i),
+ output_type=ModelBuilderOutputType.MODEL_FN_OPS),
weights_name=weight_column,
keys_name=keys_column,
early_stopping_rounds=early_stopping_rounds,
@@ -509,3 +584,142 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
+
+
+class CoreTensorForestEstimator(core_estimator.Estimator):
+ """A CORE estimator that can train and evaluate a random forest.
+
+ Example:
+
+ ```python
+ params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
+ num_classes=2, num_features=40, num_trees=10, max_nodes=1000)
+
+ # Estimator using the default graph builder.
+ estimator = CoreTensorForestEstimator(params, model_dir=model_dir)
+
+ # Or estimator using TrainingLossForest as the graph builder.
+ estimator = CoreTensorForestEstimator(
+ params, graph_builder_class=tensor_forest.TrainingLossForest,
+ model_dir=model_dir)
+
+ # Input builders
+ def input_fn_train: # returns x, y
+ ...
+ def input_fn_eval: # returns x, y
+ ...
+ estimator.train(input_fn=input_fn_train)
+ estimator.evaluate(input_fn=input_fn_eval)
+
+ # Predict returns an iterable of dicts.
+ results = list(estimator.predict(x=x))
+ prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]
+ prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME]
+ ```
+ """
+
+ def __init__(self,
+ params,
+ device_assigner=None,
+ model_dir=None,
+ feature_columns=None,
+ graph_builder_class=tensor_forest.RandomForestGraphs,
+ config=None,
+ weight_column=None,
+ keys_column=None,
+ feature_engineering_fn=None,
+ early_stopping_rounds=100,
+ early_stopping_loss_threshold=0.001,
+ num_trainers=1,
+ trainer_id=0,
+ report_feature_importances=False,
+ local_eval=False,
+ version=None,
+ head=None,
+ include_all_in_serving=False):
+ """Initializes a TensorForestEstimator instance.
+
+ Args:
+ params: ForestHParams object that holds random forest hyperparameters.
+ These parameters will be passed into `model_fn`.
+ device_assigner: An `object` instance that controls how trees get
+ assigned to devices. If `None`, will use
+ `tensor_forest.RandomForestDeviceAssigner`.
+ model_dir: Directory to save model parameters, graph, etc. To continue
+ training a previously saved model, load checkpoints saved to this
+ directory into an estimator.
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `_FeatureColumn`.
+ graph_builder_class: An `object` instance that defines how TF graphs for
+ random forest training and inference are built. By default will use
+ `tensor_forest.RandomForestGraphs`. Can be overridden by version
+ kwarg.
+ config: `RunConfig` object to configure the runtime settings.
+ weight_column: A string defining feature column name representing
+ weights. Will be multiplied by the loss of the example. Used to
+ downweight or boost examples during training.
+ keys_column: A string naming one of the features to strip out and
+ pass through into the inference/eval results dict. Useful for
+ associating specific examples with their prediction.
+ feature_engineering_fn: Feature engineering function. Takes features and
+ labels which are the output of `input_fn` and returns features and
+ labels which will be fed into the model.
+ early_stopping_rounds: Allows training to terminate early if the forest is
+ no longer growing. 100 by default. Set to a Falsy value to disable
+ the default training hook.
+ early_stopping_loss_threshold: Percentage (as fraction) that loss must
+ improve by within early_stopping_rounds steps, otherwise training will
+ terminate.
+ num_trainers: Number of training jobs, which will partition trees
+ among them.
+ trainer_id: Which trainer this instance is.
+ report_feature_importances: If True, print out feature importances
+ during evaluation.
+ local_eval: If True, don't use a device assigner for eval. This is to
+ support some common setups where eval is done on a single machine, even
+ though training might be distributed.
+ version: Unused.
+ head: A heads_lib.Head object that calculates losses and such. If None,
+ one will be automatically created based on params.
+ include_all_in_serving: if True, allow preparation of the complete
+ prediction dict including the variance to be exported for serving with
+ the Servo lib; and it also requires calling export_savedmodel with
+ default_output_alternative_key=ALL_SERVING_KEY, i.e.
+ estimator.export_savedmodel(export_dir_base=your_export_dir,
+ serving_input_fn=your_export_input_fn,
+ default_output_alternative_key=ALL_SERVING_KEY)
+ if False, resort to default behavior, i.e. export scores and
+ probabilities but no variances. In this case
+ default_output_alternative_key should be None while calling
+ export_savedmodel().
+ Note, that due to backward compatibility we cannot always set
+ include_all_in_serving to True because in this case calling
+ export_saved_model() without
+ default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the
+ saved_model_export_utils.get_output_alternatives() would raise
+ ValueError.
+
+ Returns:
+ A `TensorForestEstimator` instance.
+ """
+
+ super(CoreTensorForestEstimator, self).__init__(
+ model_fn=get_model_fn(
+ params.fill(),
+ graph_builder_class,
+ device_assigner,
+ feature_columns=feature_columns,
+ model_head=head,
+ weights_name=weight_column,
+ keys_name=keys_column,
+ early_stopping_rounds=early_stopping_rounds,
+ early_stopping_loss_threshold=early_stopping_loss_threshold,
+ num_trainers=num_trainers,
+ trainer_id=trainer_id,
+ report_feature_importances=report_feature_importances,
+ local_eval=local_eval,
+ include_all_in_serving=include_all_in_serving,
+ output_type=ModelBuilderOutputType.ESTIMATOR_SPEC),
+ model_dir=model_dir,
+ config=config)
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest_test.py b/tensorflow/contrib/tensor_forest/client/random_forest_test.py
index ac42364d25..e951592f85 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest_test.py
@@ -23,7 +23,39 @@ import numpy as np
from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.contrib.tensor_forest.client import random_forest
from tensorflow.contrib.tensor_forest.python import tensor_forest
+from tensorflow.python.estimator.canned import head as head_lib
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column_lib as core_feature_column
+from tensorflow.python.framework import ops
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_utils
+
+
+def _get_classification_input_fns():
+ iris = base.load_iris()
+ data = iris.data.astype(np.float32)
+ labels = iris.target.astype(np.int32)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False)
+
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False)
+ return train_input_fn, predict_input_fn
+
+
+def _get_regression_input_fns():
+ boston = base.load_boston()
+ data = boston.data.astype(np.float32)
+ labels = boston.target.astype(np.int32)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x=data, y=labels, batch_size=506, num_epochs=None, shuffle=False)
+
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False)
+ return train_input_fn, predict_input_fn
class TensorForestTrainerTests(test.TestCase):
@@ -39,32 +71,285 @@ class TensorForestTrainerTests(test.TestCase):
inference_tree_paths=True)
classifier = random_forest.TensorForestEstimator(hparams.fill())
+ input_fn, predict_input_fn = _get_classification_input_fns()
+ classifier.fit(input_fn=input_fn, steps=100)
+ res = classifier.evaluate(input_fn=input_fn, steps=10)
+
+ self.assertEqual(1.0, res['accuracy'])
+ self.assertAllClose(0.55144483, res['loss'])
+
+ predictions = list(classifier.predict(input_fn=predict_input_fn))
+ self.assertAllClose([[0.576117, 0.211942, 0.211942]],
+ [pred['probabilities'] for pred in predictions])
+
+ def testRegression(self):
+ """Tests regression using matrix data as input."""
+
+ hparams = tensor_forest.ForestHParams(
+ num_trees=5,
+ max_nodes=1000,
+ num_classes=1,
+ num_features=13,
+ regression=True,
+ split_after_samples=20)
+
+ regressor = random_forest.TensorForestEstimator(hparams.fill())
+
+ input_fn, predict_input_fn = _get_regression_input_fns()
+
+ regressor.fit(input_fn=input_fn, steps=100)
+ res = regressor.evaluate(input_fn=input_fn, steps=10)
+ self.assertGreaterEqual(0.1, res['loss'])
+
+ predictions = list(regressor.predict(input_fn=predict_input_fn))
+ self.assertAllClose([24.], [pred['scores'] for pred in predictions], atol=1)
+
+ def testAdditionalOutputs(self):
+ """Tests multi-class classification using matrix data as input."""
+ hparams = tensor_forest.ForestHParams(
+ num_trees=1,
+ max_nodes=100,
+ num_classes=3,
+ num_features=4,
+ split_after_samples=20,
+ inference_tree_paths=True)
+ classifier = random_forest.TensorForestEstimator(
+ hparams.fill(), keys_column='keys', include_all_in_serving=True)
+
iris = base.load_iris()
data = iris.data.astype(np.float32)
labels = iris.target.astype(np.int32)
- classifier.fit(x=data, y=labels, steps=100, batch_size=50)
- classifier.evaluate(x=data, y=labels, steps=10)
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'x': data,
+ 'keys': np.arange(len(iris.data)).reshape(150, 1)
+ },
+ y=labels,
+ batch_size=10,
+ num_epochs=1,
+ shuffle=False)
- def testRegression(self):
+ classifier.fit(input_fn=input_fn, steps=100)
+ predictions = list(classifier.predict(input_fn=input_fn))
+ # Check that there is a key column, tree paths and var.
+ for pred in predictions:
+ self.assertTrue('keys' in pred)
+ self.assertTrue('tree_paths' in pred)
+ self.assertTrue('prediction_variance' in pred)
+
+ def _assert_checkpoint(self, model_dir, global_step):
+ reader = checkpoint_utils.load_checkpoint(model_dir)
+ self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
+
+ def testEarlyStopping(self):
"""Tests multi-class classification using matrix data as input."""
+ hparams = tensor_forest.ForestHParams(
+ num_trees=100,
+ max_nodes=10000,
+ num_classes=3,
+ num_features=4,
+ split_after_samples=20,
+ inference_tree_paths=True)
+ classifier = random_forest.TensorForestEstimator(
+ hparams.fill(),
+ # Set a crazy threshold - 30% loss change.
+ early_stopping_loss_threshold=0.3,
+ early_stopping_rounds=2)
+
+ input_fn, _ = _get_classification_input_fns()
+ classifier.fit(input_fn=input_fn, steps=100)
+
+ # We stopped early.
+ self._assert_checkpoint(classifier.model_dir, global_step=5)
+
+
+class CoreTensorForestTests(test.TestCase):
+
+ def testTrainEvaluateInferDoesNotThrowErrorForClassifier(self):
+ head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
hparams = tensor_forest.ForestHParams(
num_trees=3,
max_nodes=1000,
+ num_classes=3,
+ num_features=4,
+ split_after_samples=20,
+ inference_tree_paths=True)
+
+ est = random_forest.CoreTensorForestEstimator(hparams.fill(), head=head_fn)
+
+ input_fn, predict_input_fn = _get_classification_input_fns()
+
+ est.train(input_fn=input_fn, steps=100)
+ res = est.evaluate(input_fn=input_fn, steps=1)
+
+ self.assertEqual(1.0, res['accuracy'])
+ self.assertAllClose(0.55144483, res['loss'])
+
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose([[0.576117, 0.211942, 0.211942]],
+ [pred['probabilities'] for pred in predictions])
+
+ def testRegression(self):
+ """Tests regression using matrix data as input."""
+ head_fn = head_lib._regression_head(
+ label_dimension=1,
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ hparams = tensor_forest.ForestHParams(
+ num_trees=5,
+ max_nodes=1000,
num_classes=1,
num_features=13,
regression=True,
split_after_samples=20)
- regressor = random_forest.TensorForestEstimator(hparams.fill())
+ regressor = random_forest.CoreTensorForestEstimator(
+ hparams.fill(), head=head_fn)
+
+ input_fn, predict_input_fn = _get_regression_input_fns()
+
+ regressor.train(input_fn=input_fn, steps=100)
+ res = regressor.evaluate(input_fn=input_fn, steps=10)
+ self.assertGreaterEqual(0.1, res['loss'])
+
+ predictions = list(regressor.predict(input_fn=predict_input_fn))
+ self.assertAllClose(
+ [[24.]], [pred['predictions'] for pred in predictions], atol=1)
+
+ def testWithFeatureColumns(self):
+ head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ hparams = tensor_forest.ForestHParams(
+ num_trees=3,
+ max_nodes=1000,
+ num_classes=3,
+ num_features=4,
+ split_after_samples=20,
+ inference_tree_paths=True)
+
+ est = random_forest.CoreTensorForestEstimator(
+ hparams.fill(),
+ head=head_fn,
+ feature_columns=[core_feature_column.numeric_column('x')])
+
+ iris = base.load_iris()
+ data = {'x': iris.data.astype(np.float32)}
+ labels = iris.target.astype(np.int32)
+
+ input_fn = numpy_io.numpy_input_fn(
+ x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False)
+
+ est.train(input_fn=input_fn, steps=100)
+ res = est.evaluate(input_fn=input_fn, steps=1)
+
+ self.assertEqual(1.0, res['accuracy'])
+ self.assertAllClose(0.55144483, res['loss'])
+
+ def testAutofillsClassificationHead(self):
+ hparams = tensor_forest.ForestHParams(
+ num_trees=3,
+ max_nodes=1000,
+ num_classes=3,
+ num_features=4,
+ split_after_samples=20,
+ inference_tree_paths=True)
+
+ est = random_forest.CoreTensorForestEstimator(hparams.fill())
+
+ input_fn, _ = _get_classification_input_fns()
+
+ est.train(input_fn=input_fn, steps=100)
+ res = est.evaluate(input_fn=input_fn, steps=1)
+
+ self.assertEqual(1.0, res['accuracy'])
+ self.assertAllClose(0.55144483, res['loss'])
+
+ def testAutofillsRegressionHead(self):
+ hparams = tensor_forest.ForestHParams(
+ num_trees=5,
+ max_nodes=1000,
+ num_classes=1,
+ num_features=13,
+ regression=True,
+ split_after_samples=20)
+
+ regressor = random_forest.CoreTensorForestEstimator(hparams.fill())
+
+ input_fn, predict_input_fn = _get_regression_input_fns()
+
+ regressor.train(input_fn=input_fn, steps=100)
+ res = regressor.evaluate(input_fn=input_fn, steps=10)
+ self.assertGreaterEqual(0.1, res['loss'])
+
+ predictions = list(regressor.predict(input_fn=predict_input_fn))
+ self.assertAllClose(
+ [[24.]], [pred['predictions'] for pred in predictions], atol=1)
+
+ def testAdditionalOutputs(self):
+ """Tests multi-class classification using matrix data as input."""
+ hparams = tensor_forest.ForestHParams(
+ num_trees=1,
+ max_nodes=100,
+ num_classes=3,
+ num_features=4,
+ split_after_samples=20,
+ inference_tree_paths=True)
+ classifier = random_forest.CoreTensorForestEstimator(
+ hparams.fill(), keys_column='keys', include_all_in_serving=True)
+
+ iris = base.load_iris()
+ data = iris.data.astype(np.float32)
+ labels = iris.target.astype(np.int32)
+
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'x': data,
+ 'keys': np.arange(len(iris.data)).reshape(150, 1)
+ },
+ y=labels,
+ batch_size=10,
+ num_epochs=1,
+ shuffle=False)
+
+ classifier.train(input_fn=input_fn, steps=100)
+ predictions = list(classifier.predict(input_fn=input_fn))
+ # Check that there is a key column, tree paths and var.
+ for pred in predictions:
+ self.assertTrue('keys' in pred)
+ self.assertTrue('tree_paths' in pred)
+ self.assertTrue('prediction_variance' in pred)
+
+ def _assert_checkpoint(self, model_dir, global_step):
+ reader = checkpoint_utils.load_checkpoint(model_dir)
+ self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
+
+ def testEarlyStopping(self):
+ head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ hparams = tensor_forest.ForestHParams(
+ num_trees=3,
+ max_nodes=1000,
+ num_classes=3,
+ num_features=4,
+ split_after_samples=20,
+ inference_tree_paths=True)
- boston = base.load_boston()
- data = boston.data.astype(np.float32)
- labels = boston.target.astype(np.int32)
+ est = random_forest.CoreTensorForestEstimator(
+ hparams.fill(),
+ head=head_fn,
+ # Set a crazy threshold - 30% loss change.
+ early_stopping_loss_threshold=0.3,
+ early_stopping_rounds=2)
- regressor.fit(x=data, y=labels, steps=100, batch_size=50)
- regressor.evaluate(x=data, y=labels, steps=10)
+ input_fn, _ = _get_classification_input_fns()
+ est.train(input_fn=input_fn, steps=100)
+ # We stopped early.
+ self._assert_checkpoint(est.model_dir, global_step=5)
if __name__ == "__main__":