aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-20 09:11:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-20 10:33:26 -0700
commit51224864c05ac503ee927a09ffe42b4f1a879771 (patch)
tree037914491cf78432d043cad51a7be8db406a8183
parent796e0aa4832a990b8b36a026b6a45b6c5320005e (diff)
Include option to run TensorForest eval locally even if training is distributed, which is a common setup.
Change: 150640056
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py21
1 files changed, 17 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index e4995c60aa..f55602d8b8 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -103,7 +103,8 @@ def get_model_fn(params,
early_stopping_rounds=100,
num_trainers=1,
trainer_id=0,
- report_feature_importances=False):
+ report_feature_importances=False,
+ local_eval=False):
"""Return a model function given a way to construct a graph builder."""
def _model_fn(features, labels, mode):
"""Function that returns predictions, training loss, and training op."""
@@ -111,7 +112,14 @@ def get_model_fn(params,
if weights_name and weights_name in features:
weights = features.pop(weights_name)
- graph_builder = graph_builder_class(params, device_assigner=device_assigner)
+ # If we're doing eval, optionally ignore device_assigner.
+ dev_assn = device_assigner
+ if (local_eval and (mode == model_fn_lib.ModeKeys.EVAL or
+ mode == model_fn_lib.ModeKeys.INFER)):
+ dev_assn = None
+
+ graph_builder = graph_builder_class(params,
+ device_assigner=dev_assn)
inference = {}
if (mode == model_fn_lib.ModeKeys.EVAL or
mode == model_fn_lib.ModeKeys.INFER):
@@ -200,7 +208,8 @@ class TensorForestEstimator(estimator.Estimator):
feature_engineering_fn=None,
early_stopping_rounds=100,
num_trainers=1, trainer_id=0,
- report_feature_importances=False):
+ report_feature_importances=False,
+ local_eval=False):
"""Initializes a TensorForestEstimator instance.
Args:
@@ -230,6 +239,9 @@ class TensorForestEstimator(estimator.Estimator):
trainer_id: Which trainer this instance is.
report_feature_importances: If True, print out feature importances
during evaluation.
+ local_eval: If True, don't use a device assigner for eval. This is to
+ support some common setups where eval is done on a single machine, even
+ though training might be distributed.
Returns:
A `TensorForestEstimator` instance.
@@ -243,7 +255,8 @@ class TensorForestEstimator(estimator.Estimator):
early_stopping_rounds=early_stopping_rounds,
num_trainers=num_trainers,
trainer_id=trainer_id,
- report_feature_importances=report_feature_importances),
+ report_feature_importances=report_feature_importances,
+ local_eval=local_eval),
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)