aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/tpu_config.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu_config.py')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config.py14
1 files changed, 4 insertions, 10 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 916b9b3082..3965c087a1 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -45,7 +45,10 @@ class TPUConfig(
is invoked once on each host. To be precise, with a global batch size
`train_batch_size` in `TPUEstimator` constructor, the batch size for each
shard is `train_batch_size` // #hosts. With Per-Core input pipeline
- deployment, the shard batch size is `train_batch_size` // #cores.
+ deployment, the shard batch size is `train_batch_size` // #cores. Note
+ that this only works for single-host TPU training now (tracked in
+ b/67051042). For multi-host, please use Per-Core, i.e., `False` for
+ `per_host_input_for_training`.
tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
within TPUEstimator, however when using ClusterSpec propagation in more
esoteric cluster configurations, you may need to specify the job name as a
@@ -106,12 +109,3 @@ class RunConfig(run_config_lib.RunConfig):
@property
def tpu_config(self):
return self._tpu_config
-
- def replace(self, **kwargs):
- if 'tpu_config' not in kwargs:
- return super(RunConfig, self).replace(**kwargs)
-
- tpu_config = kwargs.pop('tpu_config')
- new_instance = super(RunConfig, self).replace(**kwargs)
- new_instance._tpu_config = tpu_config # pylint: disable=protected-access
- return new_instance