aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/client/random_forest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensor_forest/client/random_forest.py')
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index 0042d37acd..6e3bfbb9bd 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -446,6 +446,10 @@ class TensorForestEstimator(estimator.Estimator):
Returns:
A `TensorForestEstimator` instance.
"""
+ # Override default number of trainers if config is provided.
+ if num_trainers == 1 and config is not None:
+ num_trainers = max(1, config.num_worker_replicas)
+
super(TensorForestEstimator, self).__init__(
model_fn=get_model_fn(
params.fill(),
@@ -564,6 +568,10 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
local_eval=False):
"""See TensorForestEstimator.__init__."""
model_fns = []
+ # Override default number of trainers if config is provided.
+ if num_trainers == 1 and config is not None:
+ num_trainers = max(1, config.num_worker_replicas)
+
for i in range(len(params_list)):
params = params_list[i].fill()
model_fns.append(
@@ -709,6 +717,11 @@ class CoreTensorForestEstimator(core_estimator.Estimator):
Returns:
A `TensorForestEstimator` instance.
"""
+ # Override default number of trainers if config is provided.
+ if num_trainers == 1 and config is not None:
+ num_trainers = max(1, config.num_worker_replicas)
+ if trainer_id == 0 and config is not None:
+ trainer_id = config.global_id_in_cluster
super(CoreTensorForestEstimator, self).__init__(
model_fn=get_model_fn(