diff options
author | 2018-08-31 12:52:42 -0700 | |
---|---|---|
committer | 2018-08-31 12:57:11 -0700 | |
commit | cdb0821239dda8bbf4eb98fc9d93244e2548e7a5 (patch) | |
tree | 5c1ece2c9cf152adae7bfc08f30724dfc042c76c /tensorflow/contrib/tpu | |
parent | 8c2939c76ff3c440b931c83e08ef9946c088c720 (diff) |
Add experimental replicated variable support for Keras/TPU.
PiperOrigin-RevId: 211129794
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_support.py | 36 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py | 289 |
3 files changed, 311 insertions, 15 deletions
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index a9e338ee59..298ffc1ded 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -167,6 +167,7 @@ py_library( name = "keras_support", srcs = [ "python/tpu/keras_support.py", + "python/tpu/keras_tpu_variables.py", ], srcs_version = "PY2AND3", visibility = [ diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index e22aeb2ac0..b8d29d1789 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -58,6 +58,7 @@ from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_reso from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import keras_tpu_variables from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.contrib.tpu.python.tpu import tpu_optimizer @@ -96,9 +97,9 @@ def tpu_session(cluster_resolver): if cluster_spec: config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + logging.info('Connecting to: %s', master) graph = ops.Graph() session = tf_session.Session(graph=graph, target=master, config=config) - with graph.as_default(): session.run(tpu.initialize_system()) @@ -147,11 +148,17 @@ class TPUDistributionStrategy(object): if tpu_cluster_resolver is None: tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('') - num_cores = (1 if using_single_core else - get_tpu_system_metadata(tpu_cluster_resolver).num_cores) - + metadata = get_tpu_system_metadata(tpu_cluster_resolver) + self._tpu_metadata = metadata self._tpu_cluster_resolver = tpu_cluster_resolver - self._num_cores = num_cores + self._num_cores = 1 if using_single_core else metadata.num_cores + + # Walk device list to identify TPU worker for enqueue/dequeue operations. + worker_re = re.compile('/job:([^/]+)') + for device in metadata.devices: + if 'TPU:0' in device.name: + self.worker_name = worker_re.search(device.name).group(1) + break @property def num_towers(self): @@ -514,7 +521,7 @@ class TPUNumpyInfeedManager(TPUInfeedManager): shard_infeed_tensors = [] for shard_id in range(self._strategy.num_towers): - with ops.device('/device:CPU:0'): + with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name): infeed_tensors = [] with ops.device('/device:TPU:%d' % shard_id): for spec in input_specs: @@ -659,7 +666,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager): assert len(shard_infeed_tensors) == self._strategy.num_towers infeed_ops = [] for shard_id in range(self._strategy.num_towers): - with ops.device('/device:CPU:0'): + with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name): infeed_ops.append( tpu_ops.infeed_enqueue_tuple( shard_infeed_tensors[shard_id], @@ -737,8 +744,7 @@ class TPUFunction(object): # Clone our CPU model, running within the TPU device context. with TPURewriteContext(tpu_input_map): with variable_scope.variable_scope('tpu_model_%s' % id(self.model)): - # TODO(power): Replicate variables. - with ops.device('/device:TPU:0'): + with keras_tpu_variables.replicated_scope(self._strategy.num_towers): self._cloned_model = models.clone_model(self.model) # Create a copy of the optimizer for this graph. @@ -817,7 +823,7 @@ class TPUFunction(object): # Build output ops. outfeed_op = [] for shard_id in range(self._strategy.num_towers): - with ops.device('/device:CPU:0'): + with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name): outfeed_op.extend( tpu_ops.outfeed_dequeue_tuple( dtypes=[spec.dtype for spec in self._outfeed_spec], @@ -835,7 +841,7 @@ class TPUFunction(object): def _test_model_compiles(self, tpu_model_ops): """Verifies that the given TPUModelOp can be compiled via XLA.""" logging.info('Started compiling') - start_time = time.clock() + start_time = time.time() result = K.get_session().run(tpu_model_ops.compile_op) proto = tpu_compilation_result.CompilationResultProto() @@ -844,7 +850,7 @@ class TPUFunction(object): raise RuntimeError('Compilation failed: {}'.format( proto.status_error_message)) - end_time = time.clock() + end_time = time.time() logging.info('Finished compiling. Time elapsed: %s secs', end_time - start_time) @@ -940,7 +946,6 @@ class KerasTPUModel(models.Model): self._tpu_weights_initialized = False self._session = tpu_session(cluster_resolver) - self._graph = self._session.graph # If the input CPU model has already been compiled, compile our TPU model # immediately. @@ -1015,7 +1020,8 @@ class KerasTPUModel(models.Model): 'https://github.com/tensorflow/tpu/tree/master/models/experimental' '/keras') if callable(x): - with self.tpu_session() as sess: + with self.tpu_session() as sess,\ + ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name): dataset = x() if steps_per_epoch is None: raise ValueError('When using tf.data as input to a model, you ' @@ -1189,7 +1195,7 @@ class KerasTPUModel(models.Model): @contextlib.contextmanager def tpu_session(self): """Yields a TPU session and sets it as the default Keras session.""" - with self._graph.as_default(): + with self._session.graph.as_default(): default_session = K.get_session() # N.B. We have to call `K.set_session()` AND set our session as the # TF default. `K.get_session()` surprisingly does not return the value diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py new file mode 100644 index 0000000000..a423aeace7 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py @@ -0,0 +1,289 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Distributed variable implementation for TPUs. + +N.B. This is an experimental feature that should only be used for Keras support. + +It is unsupported and will be removed in favor of Distribution Strategy soon. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib + +from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging + + +@contextlib.contextmanager +def _handle_graph(handle): + with handle.graph.as_default(): + yield + + +def _enclosing_tpu_context(): + # pylint: disable=protected-access + context = ops.get_default_graph()._get_control_flow_context() + # pylint: enable=protected-access + while context is not None and not isinstance( + context, control_flow_ops.XLAControlFlowContext): + context = context.outer_context + return context + + +class ReplicatedVariable(object): + """A replicated variable for use on TPUs. + + When accessed inside a tpu.replicate() context, this variable acts as if it + is a single variable whose handle is a replicated input to the computation. + + Outside a tpu.replicate() context currently this object has pretty murky + semantics, especially with respect to things such as + * initialization + * colocation. + """ + + def __init__(self, name, variables): + self._name = name + self._primary_var = variables[0] + self._vars = variables + self._cached_value = None + self._dtype = variables[0].dtype + + @property + def handle(self): + tpu_context = _enclosing_tpu_context() + if tpu_context is None: + return self._primary_var.handle + + return tpu_context.get_replicated_var_handle(self) + + @contextlib.contextmanager + def _assign_dependencies(self): + """Makes assignments depend on the cached value, if any. + + This prevents undefined behavior with reads not ordered wrt writes. + + Yields: + None. + """ + if self._cached_value is not None: + with ops.control_dependencies([self._cached_value]): + yield + else: + yield + + @property + def initializer(self): + return control_flow_ops.group([v.initializer for v in self._vars]) + + @property + def graph(self): + return self._primary_var.graph + + @property + def _shared_name(self): + return self._common_name + + @property + def _unique_id(self): + return self._primary_var._unique_id # pylint: disable=protected-access + + @property + def name(self): + return self._name + + @property + def dtype(self): + return self._primary_var.dtype + + @property + def shape(self): + return self._primary_var.shape + + def get_shape(self): + return self._primary_var.get_shape() + + def to_proto(self, export_scope=None): + return self._primary_var.to_proto(export_scope=export_scope) + + @property + def constraint(self): + return None + + @property + def op(self): + return self.get().op + + @property + def is_tensor_like(self): + return True + + def _read_variable_op(self): + if _enclosing_tpu_context() is None: + return self._primary_var.read_value() + v = gen_resource_variable_ops.read_variable_op(self.handle, self._dtype) + return v + + def read_value(self): + return self._read_variable_op() + + def is_initialized(self, name=None): + return self._vars[0].is_initialized(name=name) + + def __getitem__(self, *args): + return self.read_value().__getitem__(*args) + + def assign(self, value, use_locking=None, name=None, read_value=False): + """Assign `value` to all replicas. + + Outside of the tpu.rewrite context, assign explicitly to all replicas. + Inside of the tpu.rewrite context, assigns to the local replica. + + Arguments: + value: Tensor to assign + use_locking: ignored + name: ignored + read_value: return the value from the assignment + Returns: + Assignment operation, or new value of the variable if `read_value` is True + """ + del use_locking + if _enclosing_tpu_context() is None: + assign_ops = [] + with self._assign_dependencies(): + for var in self._vars: + assign_ops.append(var.assign(value, use_locking=None, name=name)) + + if read_value: + with ops.control_dependencies(assign_ops): + return self.read_value() + else: + return control_flow_ops.group(assign_ops) + + with _handle_graph(self.handle), self._assign_dependencies(): + value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) + assign_op = gen_resource_variable_ops.assign_variable_op( + self.handle, value_tensor, name=name) + if read_value: + return self._read_variable_op() + return assign_op + + def assign_add(self, delta, use_locking=None, name=None, read_value=True): + del use_locking + with _handle_graph(self.handle), self._assign_dependencies(): + assign_add_op = gen_resource_variable_ops.assign_add_variable_op( + self.handle, + ops.convert_to_tensor(delta, dtype=self.dtype), + name=name) + if read_value: + return self._read_variable_op() + return assign_add_op + + def assign_sub(self, delta, use_locking=None, name=None, read_value=True): + del use_locking + with _handle_graph(self.handle), self._assign_dependencies(): + assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op( + self.handle, + ops.convert_to_tensor(delta, dtype=self.dtype), + name=name) + if read_value: + return self._read_variable_op() + return assign_sub_op + + def get(self): + return self._primary_var + + def _should_act_as_resource_variable(self): + """Pass resource_variable_ops.is_resource_variable check.""" + pass + + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): + """Converts a variable to a tensor.""" + # pylint: disable=protected-access + if _enclosing_tpu_context() is None: + return self._primary_var._dense_var_to_tensor(dtype, name, as_ref) + # pylint: enable=protected-access + if dtype is not None and dtype != self.dtype: + return NotImplemented + if as_ref: + return self.handle + else: + return self.read_value() + + +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion(var, dtype=None, name=None, as_ref=False): + return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access + + +def replicated_fetch_function(var): + # pylint: disable=protected-access + return ([var._dense_var_to_tensor()], lambda v: v[0]) + # pylint: enable=protected-access + + +ops.register_tensor_conversion_function(ReplicatedVariable, _tensor_conversion) +ops.register_dense_tensor_like_type(ReplicatedVariable) +session_lib.register_session_run_conversion_functions( + ReplicatedVariable, replicated_fetch_function) + + +def replicated_scope(num_replicas): + """Variable scope for constructing replicated variables.""" + + def _replicated_variable_getter(getter, name, *args, **kwargs): + """Getter that constructs replicated variables.""" + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + logging.info("Constructing replicated variable %s", name) + variables = [] + index = {} + for i in range(num_replicas): + replica_name = "{}/{}".format(name, i) + with ops.device("device:TPU:{}".format(i)): + v = getter(*args, name=replica_name, **kwargs) + variables.append(v) + index[i] = v + result = ReplicatedVariable(name, variables) + + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the member variables + # to the TRAINABLE_VARIABLES collection, so we manually remove + # them and replace with the MirroredVariable. We can't set + # "trainable" to False for next_creator() since that causes functions + # like implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + for v in index.values(): + if v in l: + l.remove(v) + g.add_to_collections(collections, result) + + return result + + return variable_scope.variable_scope( + "", custom_getter=_replicated_variable_getter) |