diff options
Diffstat (limited to 'tensorflow/contrib/tensor_forest/client/random_forest.py')
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/random_forest.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index 0042d37acd..6e3bfbb9bd 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -446,6 +446,10 @@ class TensorForestEstimator(estimator.Estimator): Returns: A `TensorForestEstimator` instance. """ + # Override default number of trainers if config is provided. + if num_trainers == 1 and config is not None: + num_trainers = max(1, config.num_worker_replicas) + super(TensorForestEstimator, self).__init__( model_fn=get_model_fn( params.fill(), @@ -564,6 +568,10 @@ class MultiForestMultiHeadEstimator(estimator.Estimator): local_eval=False): """See TensorForestEstimator.__init__.""" model_fns = [] + # Override default number of trainers if config is provided. + if num_trainers == 1 and config is not None: + num_trainers = max(1, config.num_worker_replicas) + for i in range(len(params_list)): params = params_list[i].fill() model_fns.append( @@ -709,6 +717,11 @@ class CoreTensorForestEstimator(core_estimator.Estimator): Returns: A `TensorForestEstimator` instance. """ + # Override default number of trainers if config is provided. + if num_trainers == 1 and config is not None: + num_trainers = max(1, config.num_worker_replicas) + if trainer_id == 0 and config is not None: + trainer_id = config.global_id_in_cluster super(CoreTensorForestEstimator, self).__init__( model_fn=get_model_fn( |