diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-26 12:08:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 12:14:07 -0700 |
commit | 9276b19b468b82b7457cf256352e7eac9d90d68e (patch) | |
tree | 692049efd316069a0fc1fabdec2ead49792fa9a0 /tensorflow/contrib/tensor_forest | |
parent | e8c18aa0947d253d861f56c99788a8ab94f28164 (diff) |
Account for old run config, more robust num trainers
PiperOrigin-RevId: 214646114
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/random_forest.py | 10 |
1 files changed, 3 insertions, 7 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index d78d12d997..6e3bfbb9bd 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -448,9 +448,7 @@ class TensorForestEstimator(estimator.Estimator): """ # Override default number of trainers if config is provided. if num_trainers == 1 and config is not None: - num_trainers = config.num_worker_replicas - if trainer_id == 0 and config is not None: - trainer_id = config.global_id_in_cluster + num_trainers = max(1, config.num_worker_replicas) super(TensorForestEstimator, self).__init__( model_fn=get_model_fn( @@ -572,9 +570,7 @@ class MultiForestMultiHeadEstimator(estimator.Estimator): model_fns = [] # Override default number of trainers if config is provided. if num_trainers == 1 and config is not None: - num_trainers = config.num_worker_replicas - if trainer_id == 0 and config is not None: - trainer_id = config.global_id_in_cluster + num_trainers = max(1, config.num_worker_replicas) for i in range(len(params_list)): params = params_list[i].fill() @@ -723,7 +719,7 @@ class CoreTensorForestEstimator(core_estimator.Estimator): """ # Override default number of trainers if config is provided. if num_trainers == 1 and config is not None: - num_trainers = config.num_worker_replicas + num_trainers = max(1, config.num_worker_replicas) if trainer_id == 0 and config is not None: trainer_id = config.global_id_in_cluster |