diff options
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu_config.py')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_config.py | 14 |
1 files changed, 4 insertions, 10 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 916b9b3082..3965c087a1 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -45,7 +45,10 @@ class TPUConfig( is invoked once on each host. To be precise, with a global batch size `train_batch_size` in `TPUEstimator` constructor, the batch size for each shard is `train_batch_size` // #hosts. With Per-Core input pipeline - deployment, the shard batch size is `train_batch_size` // #cores. + deployment, the shard batch size is `train_batch_size` // #cores. Note + that this only works for single-host TPU training now (tracked in + b/67051042). For multi-host, please use Per-Core, i.e., `False` for + `per_host_input_for_training`. tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred within TPUEstimator, however when using ClusterSpec propagation in more esoteric cluster configurations, you may need to specify the job name as a @@ -106,12 +109,3 @@ class RunConfig(run_config_lib.RunConfig): @property def tpu_config(self): return self._tpu_config - - def replace(self, **kwargs): - if 'tpu_config' not in kwargs: - return super(RunConfig, self).replace(**kwargs) - - tpu_config = kwargs.pop('tpu_config') - new_instance = super(RunConfig, self).replace(**kwargs) - new_instance._tpu_config = tpu_config # pylint: disable=protected-access - return new_instance |