aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2018-08-28 16:54:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 17:02:39 -0700
commit4dcd00066fba2bd7c504c1bc35738f804de9df67 (patch)
treeed074b01f355df668f6e95fd9e8098aa4a95dbbf /tensorflow/contrib/tpu/python
parentbb0f1e9b415d5fd208b63cafb93636bdade2e985 (diff)
Step 1 toward automatically assigning # of TPU cores.
Also fix the API as num_cores is not a good API. Changed to a debugging oriented argument using_single_core. Ideally we should not need that in future. The updated lstm example cannot work with TF 1.10 anymore. PiperOrigin-RevId: 210632592
Diffstat (limited to 'tensorflow/contrib/tpu/python')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py76
1 files changed, 51 insertions, 25 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 87b900574c..dbf5c66c9e 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -61,6 +61,7 @@ from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
+from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.data.ops import dataset_ops
@@ -80,7 +81,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.util import tf_inspect
_SESSIONS = {}
@@ -110,31 +110,52 @@ def reset_tpu_sessions():
_SESSIONS.clear()
-# Work-around dependency cycle between DistributionStrategy and TPU lib.
-def TPUDistributionStrategy(tpu_cluster_resolver=None, num_cores=None): # pylint: disable=invalid-name
- """Construct a TPUDistributionStrategy."""
- from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top
- # TODO(b/112705069): Remove this when TPUStrategy API is consistent.
- # We are including this for (a) backwards compatibility for open sourced
- # releases of TensorFlow and (b) to work around a circular dependency
- # where keras_support and tpu_strategy depends on each other. Once we release
- # a final version and remove support for the old API, this will be deleted.
- # (See bug above for more details)
- if tpu_cluster_resolver is None:
- tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
-
- args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__)
- if len(args) == 4:
- logging.info('Detected new TPUStrategy API.')
- return tpu_strategy.TPUStrategy(tpu_cluster_resolver,
- steps_per_run=1,
- num_cores=num_cores)
- else:
- logging.info('Detected old TPUStrategy API.')
- strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8)
- strategy._tpu_cluster_resolver = tpu_cluster_resolver
+def get_tpu_system_metadata(tpu_cluster_resolver):
+ """Retrieves TPU system metadata given a TPUClusterResolver."""
+ master = tpu_cluster_resolver.master()
+
+ # pylint: disable=protected-access
+ cluster_spec = tpu_cluster_resolver.cluster_spec()
+ cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
+ tpu_system_metadata = (
+ tpu_system_metadata_lib._query_tpu_system_metadata(
+ master,
+ cluster_def=cluster_def,
+ query_topology=False))
+
+ return tpu_system_metadata
+
+
+class TPUDistributionStrategy(object):
+ """The strategy to run Keras model on TPU."""
+
+ def __init__(self, tpu_cluster_resolver=None, using_single_core=False):
+ """Construct a TPUDistributionStrategy.
+
+ Args:
+ tpu_cluster_resolver: Any instance of `TPUClusterResolver`. If None, will
+ create one with '' as master address.
+ using_single_core: Bool. This is the debugging option, which might be
+ removed in future once the model replication functionality is mature
+ enough. If `False` (default behavior), the system automatically finds
+ the best configuration, in terms of number of TPU cores, for the model
+ replication, typically using all avaiable TPU cores. If overwrites as
+ `True`, force the model replication using single core, i.e., no
+ replication.
+ """
- return strategy
+ if tpu_cluster_resolver is None:
+ tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
+
+ num_cores = (1 if using_single_core else
+ get_tpu_system_metadata(tpu_cluster_resolver).num_cores)
+
+ self._tpu_cluster_resolver = tpu_cluster_resolver
+ self._num_cores = num_cores
+
+ @property
+ def num_towers(self):
+ return self._num_cores
class TPUEmbedding(embeddings.Embedding):
@@ -1212,5 +1233,10 @@ def tpu_model(model, strategy=None):
if strategy is None:
strategy = TPUDistributionStrategy()
+ else:
+ if not isinstance(strategy, TPUDistributionStrategy):
+ raise TypeError(
+ '`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. '
+ 'Got: {}'.format(type(strategy)))
return KerasTPUModel(cpu_model=model, strategy=strategy)