aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-26 09:22:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 09:29:51 -0700
commitfa1ecc082519922827bad10f07df438c9453fedb (patch)
tree4be2ee195a66f3145843ce5433fa1bd9188661f1 /tensorflow/contrib/tensor_forest
parentd7de49e456fc84416fbf3a6de7ad1ed6c12d7a20 (diff)
Derive the number of trainers in tensorforest if run config is provided.
PiperOrigin-RevId: 214616123
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index 0042d37acd..d78d12d997 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -446,6 +446,12 @@ class TensorForestEstimator(estimator.Estimator):
Returns:
A `TensorForestEstimator` instance.
"""
+ # Override default number of trainers if config is provided.
+ if num_trainers == 1 and config is not None:
+ num_trainers = config.num_worker_replicas
+ if trainer_id == 0 and config is not None:
+ trainer_id = config.global_id_in_cluster
+
super(TensorForestEstimator, self).__init__(
model_fn=get_model_fn(
params.fill(),
@@ -564,6 +570,12 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
local_eval=False):
"""See TensorForestEstimator.__init__."""
model_fns = []
+ # Override default number of trainers if config is provided.
+ if num_trainers == 1 and config is not None:
+ num_trainers = config.num_worker_replicas
+ if trainer_id == 0 and config is not None:
+ trainer_id = config.global_id_in_cluster
+
for i in range(len(params_list)):
params = params_list[i].fill()
model_fns.append(
@@ -709,6 +721,11 @@ class CoreTensorForestEstimator(core_estimator.Estimator):
Returns:
A `TensorForestEstimator` instance.
"""
+ # Override default number of trainers if config is provided.
+ if num_trainers == 1 and config is not None:
+ num_trainers = config.num_worker_replicas
+ if trainer_id == 0 and config is not None:
+ trainer_id = config.global_id_in_cluster
super(CoreTensorForestEstimator, self).__init__(
model_fn=get_model_fn(