diff options
author | Jianwei Xie <xiejw@google.com> | 2018-08-28 16:54:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 17:02:39 -0700 |
commit | 4dcd00066fba2bd7c504c1bc35738f804de9df67 (patch) | |
tree | ed074b01f355df668f6e95fd9e8098aa4a95dbbf /tensorflow/contrib/tpu/python | |
parent | bb0f1e9b415d5fd208b63cafb93636bdade2e985 (diff) |
Step 1 toward automatically assigning # of TPU cores.
Also fix the API as num_cores is not a good API. Changed to a debugging oriented argument using_single_core. Ideally we should not need that in future.
The updated lstm example cannot work with TF 1.10 anymore.
PiperOrigin-RevId: 210632592
Diffstat (limited to 'tensorflow/contrib/tpu/python')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_support.py | 76 |
1 files changed, 51 insertions, 25 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 87b900574c..dbf5c66c9e 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -61,6 +61,7 @@ from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.contrib.tpu.python.tpu import tpu_optimizer +from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as tf_session from tensorflow.python.data.ops import dataset_ops @@ -80,7 +81,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util import tf_inspect _SESSIONS = {} @@ -110,31 +110,52 @@ def reset_tpu_sessions(): _SESSIONS.clear() -# Work-around dependency cycle between DistributionStrategy and TPU lib. -def TPUDistributionStrategy(tpu_cluster_resolver=None, num_cores=None): # pylint: disable=invalid-name - """Construct a TPUDistributionStrategy.""" - from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top - # TODO(b/112705069): Remove this when TPUStrategy API is consistent. - # We are including this for (a) backwards compatibility for open sourced - # releases of TensorFlow and (b) to work around a circular dependency - # where keras_support and tpu_strategy depends on each other. Once we release - # a final version and remove support for the old API, this will be deleted. - # (See bug above for more details) - if tpu_cluster_resolver is None: - tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('') - - args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__) - if len(args) == 4: - logging.info('Detected new TPUStrategy API.') - return tpu_strategy.TPUStrategy(tpu_cluster_resolver, - steps_per_run=1, - num_cores=num_cores) - else: - logging.info('Detected old TPUStrategy API.') - strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8) - strategy._tpu_cluster_resolver = tpu_cluster_resolver +def get_tpu_system_metadata(tpu_cluster_resolver): + """Retrieves TPU system metadata given a TPUClusterResolver.""" + master = tpu_cluster_resolver.master() + + # pylint: disable=protected-access + cluster_spec = tpu_cluster_resolver.cluster_spec() + cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None + tpu_system_metadata = ( + tpu_system_metadata_lib._query_tpu_system_metadata( + master, + cluster_def=cluster_def, + query_topology=False)) + + return tpu_system_metadata + + +class TPUDistributionStrategy(object): + """The strategy to run Keras model on TPU.""" + + def __init__(self, tpu_cluster_resolver=None, using_single_core=False): + """Construct a TPUDistributionStrategy. + + Args: + tpu_cluster_resolver: Any instance of `TPUClusterResolver`. If None, will + create one with '' as master address. + using_single_core: Bool. This is the debugging option, which might be + removed in future once the model replication functionality is mature + enough. If `False` (default behavior), the system automatically finds + the best configuration, in terms of number of TPU cores, for the model + replication, typically using all avaiable TPU cores. If overwrites as + `True`, force the model replication using single core, i.e., no + replication. + """ - return strategy + if tpu_cluster_resolver is None: + tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('') + + num_cores = (1 if using_single_core else + get_tpu_system_metadata(tpu_cluster_resolver).num_cores) + + self._tpu_cluster_resolver = tpu_cluster_resolver + self._num_cores = num_cores + + @property + def num_towers(self): + return self._num_cores class TPUEmbedding(embeddings.Embedding): @@ -1212,5 +1233,10 @@ def tpu_model(model, strategy=None): if strategy is None: strategy = TPUDistributionStrategy() + else: + if not isinstance(strategy, TPUDistributionStrategy): + raise TypeError( + '`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. ' + 'Got: {}'.format(type(strategy))) return KerasTPUModel(cpu_model=model, strategy=strategy) |