aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-26 12:08:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 12:14:07 -0700
commit9276b19b468b82b7457cf256352e7eac9d90d68e (patch)
tree692049efd316069a0fc1fabdec2ead49792fa9a0 /tensorflow/contrib/tensor_forest
parente8c18aa0947d253d861f56c99788a8ab94f28164 (diff)
Account for old run config, more robust num trainers
PiperOrigin-RevId: 214646114
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py10
1 files changed, 3 insertions, 7 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index d78d12d997..6e3bfbb9bd 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -448,9 +448,7 @@ class TensorForestEstimator(estimator.Estimator):
"""
# Override default number of trainers if config is provided.
if num_trainers == 1 and config is not None:
- num_trainers = config.num_worker_replicas
- if trainer_id == 0 and config is not None:
- trainer_id = config.global_id_in_cluster
+ num_trainers = max(1, config.num_worker_replicas)
super(TensorForestEstimator, self).__init__(
model_fn=get_model_fn(
@@ -572,9 +570,7 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
model_fns = []
# Override default number of trainers if config is provided.
if num_trainers == 1 and config is not None:
- num_trainers = config.num_worker_replicas
- if trainer_id == 0 and config is not None:
- trainer_id = config.global_id_in_cluster
+ num_trainers = max(1, config.num_worker_replicas)
for i in range(len(params_list)):
params = params_list[i].fill()
@@ -723,7 +719,7 @@ class CoreTensorForestEstimator(core_estimator.Estimator):
"""
# Override default number of trainers if config is provided.
if num_trainers == 1 and config is not None:
- num_trainers = config.num_worker_replicas
+ num_trainers = max(1, config.num_worker_replicas)
if trainer_id == 0 and config is not None:
trainer_id = config.global_id_in_cluster