aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2017-10-16 10:13:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-16 10:17:43 -0700
commit3b595a805bbcf4be24a2e01abe1b8031d82dc57b (patch)
tree3e75b5862126c2d5d7afec44cd690d40de31b386
parent19fd294eae4e8e22f6ab46b21cf41323750a1c69 (diff)
Support a configurable TPU job name
PiperOrigin-RevId: 172340173
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config.py16
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py45
2 files changed, 56 insertions, 5 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 0a3be8503a..79fd8b839b 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -27,7 +27,10 @@ from tensorflow.python.estimator import run_config as run_config_lib
class TPUConfig(
collections.namedtuple('TPUConfig', [
- 'iterations_per_loop', 'num_shards', 'per_host_input_for_training'
+ 'iterations_per_loop',
+ 'num_shards',
+ 'per_host_input_for_training',
+ 'tpu_job_name',
])):
"""TPU related configuration required by `TPUEstimator`.
@@ -46,12 +49,17 @@ class TPUConfig(
that this only works for single-host TPU training now (tracked in
b/67051042). For multi-host, please use Per-Core, i.e., `False` for
`per_host_input_for_training`.
+ tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
+ within TPUEstimator, however when using ClusterSpec propagation in more
+ esoteric cluster configurations, you may need to specify the job name as a
+ string.
"""
def __new__(cls,
iterations_per_loop=2,
num_shards=2,
- per_host_input_for_training=True):
+ per_host_input_for_training=True,
+ tpu_job_name=None):
# Check iterations_per_loop.
util_lib.check_positive_integer(iterations_per_loop,
@@ -59,12 +67,12 @@ class TPUConfig(
# Check num_shards.
util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')
-
return super(TPUConfig, cls).__new__(
cls,
iterations_per_loop=iterations_per_loop,
num_shards=num_shards,
- per_host_input_for_training=per_host_input_for_training)
+ per_host_input_for_training=per_host_input_for_training,
+ tpu_job_name=tpu_job_name)
class RunConfig(run_config_lib.RunConfig):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 43f9defd54..de6c8140c6 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -122,12 +122,55 @@ def _increase_eval_step_op(iterations_per_loop):
use_locking=True)
+_DEFAULT_JOB_NAME = 'tpu_worker'
+_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
+_LOCAL_MASTERS = ('', 'local')
+
+
def _tpu_job(run_config, mode):
+ """Returns the job name to use to place TPU computations on.
+
+ Args:
+ run_config: The tpu_config.RunConfig used for this custom estimator.
+ mode: A model_fn_lib.ModeKeys value.
+
+ Returns:
+ A string containing the job name, or None if no job should be specified.
+
+ Raises:
+ ValueError: If the user needs to specify a tpu_job_name, because we are
+ unable to infer the job name automatically, or if the user-specified job
+ names are inappropriate.
+ """
+ # If the user specifies the tpu_job_name, use that.
+ if run_config.tpu_config.tpu_job_name:
+ return run_config.tpu_config.tpu_job_name
+
# The tpu job is determined by the run_config. Right now, this method is
# required as tpu_config is not part of the RunConfig.
master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL
else run_config.master)
- return None if master in ['', 'local'] else 'tpu_worker'
+ if master in _LOCAL_MASTERS:
+ return None
+
+ if (not run_config.session_config or
+ not run_config.session_config.cluster_def.job):
+ return _DEFAULT_JOB_NAME
+ cluster_def = run_config.session_config.cluster_def
+ job_names = set([job.name for job in cluster_def.job])
+ if _DEFAULT_JOB_NAME in job_names:
+ # b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
+ raise ValueError('Currently, tpu_worker is not an allowed job name.')
+ if len(job_names) == 1:
+ return cluster_def.job[0].name
+ if len(job_names) == 2:
+ if _DEFAULT_COORDINATOR_JOB_NAME in job_names:
+ job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME)
+ return job_names.pop()
+ # TODO(b/67716447): Include more sophisticated heuristics.
+ raise ValueError(
+ 'Could not infer TPU job name. Please specify a tpu_job_name as part of '
+ 'your TPUConfig.')
def _is_running_on_cpu(use_tpu, mode, eval_batch_size):