diff options
author | 2018-09-26 09:22:46 -0700 | |
---|---|---|
committer | 2018-09-26 09:29:51 -0700 | |
commit | fa1ecc082519922827bad10f07df438c9453fedb (patch) | |
tree | 4be2ee195a66f3145843ce5433fa1bd9188661f1 /tensorflow/contrib/tensor_forest | |
parent | d7de49e456fc84416fbf3a6de7ad1ed6c12d7a20 (diff) |
Derive the number of trainers in tensorforest if run config is provided.
PiperOrigin-RevId: 214616123
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/random_forest.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index 0042d37acd..d78d12d997 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -446,6 +446,12 @@ class TensorForestEstimator(estimator.Estimator): Returns: A `TensorForestEstimator` instance. """ + # Override default number of trainers if config is provided. + if num_trainers == 1 and config is not None: + num_trainers = config.num_worker_replicas + if trainer_id == 0 and config is not None: + trainer_id = config.global_id_in_cluster + super(TensorForestEstimator, self).__init__( model_fn=get_model_fn( params.fill(), @@ -564,6 +570,12 @@ class MultiForestMultiHeadEstimator(estimator.Estimator): local_eval=False): """See TensorForestEstimator.__init__.""" model_fns = [] + # Override default number of trainers if config is provided. + if num_trainers == 1 and config is not None: + num_trainers = config.num_worker_replicas + if trainer_id == 0 and config is not None: + trainer_id = config.global_id_in_cluster + for i in range(len(params_list)): params = params_list[i].fill() model_fns.append( @@ -709,6 +721,11 @@ class CoreTensorForestEstimator(core_estimator.Estimator): Returns: A `TensorForestEstimator` instance. """ + # Override default number of trainers if config is provided. + if num_trainers == 1 and config is not None: + num_trainers = config.num_worker_replicas + if trainer_id == 0 and config is not None: + trainer_id = config.global_id_in_cluster super(CoreTensorForestEstimator, self).__init__( model_fn=get_model_fn( |