diff options
author | Brennan Saeta <saeta@google.com> | 2017-10-16 10:13:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-16 10:17:43 -0700 |
commit | 3b595a805bbcf4be24a2e01abe1b8031d82dc57b (patch) | |
tree | 3e75b5862126c2d5d7afec44cd690d40de31b386 | |
parent | 19fd294eae4e8e22f6ab46b21cf41323750a1c69 (diff) |
Support a configurable TPU job name
PiperOrigin-RevId: 172340173
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_config.py | 16 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 45 |
2 files changed, 56 insertions, 5 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 0a3be8503a..79fd8b839b 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -27,7 +27,10 @@ from tensorflow.python.estimator import run_config as run_config_lib class TPUConfig( collections.namedtuple('TPUConfig', [ - 'iterations_per_loop', 'num_shards', 'per_host_input_for_training' + 'iterations_per_loop', + 'num_shards', + 'per_host_input_for_training', + 'tpu_job_name', ])): """TPU related configuration required by `TPUEstimator`. @@ -46,12 +49,17 @@ class TPUConfig( that this only works for single-host TPU training now (tracked in b/67051042). For multi-host, please use Per-Core, i.e., `False` for `per_host_input_for_training`. + tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred + within TPUEstimator, however when using ClusterSpec propagation in more + esoteric cluster configurations, you may need to specify the job name as a + string. """ def __new__(cls, iterations_per_loop=2, num_shards=2, - per_host_input_for_training=True): + per_host_input_for_training=True, + tpu_job_name=None): # Check iterations_per_loop. util_lib.check_positive_integer(iterations_per_loop, @@ -59,12 +67,12 @@ class TPUConfig( # Check num_shards. util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') - return super(TPUConfig, cls).__new__( cls, iterations_per_loop=iterations_per_loop, num_shards=num_shards, - per_host_input_for_training=per_host_input_for_training) + per_host_input_for_training=per_host_input_for_training, + tpu_job_name=tpu_job_name) class RunConfig(run_config_lib.RunConfig): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 43f9defd54..de6c8140c6 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -122,12 +122,55 @@ def _increase_eval_step_op(iterations_per_loop): use_locking=True) +_DEFAULT_JOB_NAME = 'tpu_worker' +_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' +_LOCAL_MASTERS = ('', 'local') + + def _tpu_job(run_config, mode): + """Returns the job name to use to place TPU computations on. + + Args: + run_config: The tpu_config.RunConfig used for this custom estimator. + mode: A model_fn_lib.ModeKeys value. + + Returns: + A string containing the job name, or None if no job should be specified. + + Raises: + ValueError: If the user needs to specify a tpu_job_name, because we are + unable to infer the job name automatically, or if the user-specified job + names are inappropriate. + """ + # If the user specifies the tpu_job_name, use that. + if run_config.tpu_config.tpu_job_name: + return run_config.tpu_config.tpu_job_name + # The tpu job is determined by the run_config. Right now, this method is # required as tpu_config is not part of the RunConfig. master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL else run_config.master) - return None if master in ['', 'local'] else 'tpu_worker' + if master in _LOCAL_MASTERS: + return None + + if (not run_config.session_config or + not run_config.session_config.cluster_def.job): + return _DEFAULT_JOB_NAME + cluster_def = run_config.session_config.cluster_def + job_names = set([job.name for job in cluster_def.job]) + if _DEFAULT_JOB_NAME in job_names: + # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. + raise ValueError('Currently, tpu_worker is not an allowed job name.') + if len(job_names) == 1: + return cluster_def.job[0].name + if len(job_names) == 2: + if _DEFAULT_COORDINATOR_JOB_NAME in job_names: + job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) + return job_names.pop() + # TODO(b/67716447): Include more sophisticated heuristics. + raise ValueError( + 'Could not infer TPU job name. Please specify a tpu_job_name as part of ' + 'your TPUConfig.') def _is_running_on_cpu(use_tpu, mode, eval_batch_size): |