diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-20 09:11:36 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-20 10:33:26 -0700 |
commit | 51224864c05ac503ee927a09ffe42b4f1a879771 (patch) | |
tree | 037914491cf78432d043cad51a7be8db406a8183 | |
parent | 796e0aa4832a990b8b36a026b6a45b6c5320005e (diff) |
Include option to run TensorForest eval locally even if training is distributed, which is a common setup.
Change: 150640056
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/random_forest.py | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index e4995c60aa..f55602d8b8 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -103,7 +103,8 @@ def get_model_fn(params, early_stopping_rounds=100, num_trainers=1, trainer_id=0, - report_feature_importances=False): + report_feature_importances=False, + local_eval=False): """Return a model function given a way to construct a graph builder.""" def _model_fn(features, labels, mode): """Function that returns predictions, training loss, and training op.""" @@ -111,7 +112,14 @@ def get_model_fn(params, if weights_name and weights_name in features: weights = features.pop(weights_name) - graph_builder = graph_builder_class(params, device_assigner=device_assigner) + # If we're doing eval, optionally ignore device_assigner. + dev_assn = device_assigner + if (local_eval and (mode == model_fn_lib.ModeKeys.EVAL or + mode == model_fn_lib.ModeKeys.INFER)): + dev_assn = None + + graph_builder = graph_builder_class(params, + device_assigner=dev_assn) inference = {} if (mode == model_fn_lib.ModeKeys.EVAL or mode == model_fn_lib.ModeKeys.INFER): @@ -200,7 +208,8 @@ class TensorForestEstimator(estimator.Estimator): feature_engineering_fn=None, early_stopping_rounds=100, num_trainers=1, trainer_id=0, - report_feature_importances=False): + report_feature_importances=False, + local_eval=False): """Initializes a TensorForestEstimator instance. Args: @@ -230,6 +239,9 @@ class TensorForestEstimator(estimator.Estimator): trainer_id: Which trainer this instance is. report_feature_importances: If True, print out feature importances during evaluation. + local_eval: If True, don't use a device assigner for eval. This is to + support some common setups where eval is done on a single machine, even + though training might be distributed. Returns: A `TensorForestEstimator` instance. @@ -243,7 +255,8 @@ class TensorForestEstimator(estimator.Estimator): early_stopping_rounds=early_stopping_rounds, num_trainers=num_trainers, trainer_id=trainer_id, - report_feature_importances=report_feature_importances), + report_feature_importances=report_feature_importances, + local_eval=local_eval), model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) |