diff options
author | Jonathan Hseu <jhseu@google.com> | 2018-09-28 18:41:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 18:45:56 -0700 |
commit | d37f771cc5a208cdc88a50a65f491b3c06c9f262 (patch) | |
tree | 1036470d10da26df9f5dcf897a74c78329fe57cc /tensorflow/contrib/distribute/python/values.py | |
parent | abd5c32c0fa6451e73b491affdd86d852a74177f (diff) |
Move TPU variables to the TPU device in TPUStrategy.
PiperOrigin-RevId: 215027511
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 381 |
1 files changed, 381 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 4955ded4d5..c18faeb67d 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -22,17 +22,20 @@ from __future__ import division from __future__ import print_function import collections +import contextlib import weakref import six from tensorflow.contrib.distribute.python import input_ops from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context +from tensorflow.python.eager import tape from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as variables_lib @@ -453,6 +456,384 @@ ops.register_tensor_conversion_function(MirroredVariable, _tensor_conversion_mirrored) +def _enclosing_tpu_context(): + # pylint: disable=protected-access + tpu_context = ops.get_default_graph()._get_control_flow_context() + # pylint: enable=protected-access + while tpu_context is not None and not isinstance( + tpu_context, control_flow_ops.XLAControlFlowContext): + tpu_context = tpu_context.outer_context + return tpu_context + + +# TODO(jhseu): Deduplicate code. We copy code because we don't want to +# inherit from DistributedDelegate. DistributedDelegate will not work in a +# tpu.replicate() because it assumes that you're in a device context where you +# can operate on a single version of the variable, but a tpu.replicate() +# operates on all variables and is replicated during a rewrite pass. +class TPUMirroredVariable(checkpointable.CheckpointableBase): + """Holds a map from device to TPU variables whose values are kept in sync.""" + + def __init__(self, index, primary_var, aggregation): + # Use a weakref to make it easy to map from the contained values + # to the container without introducing a reference cycle. + for v in six.itervalues(index): + v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access + self._index = {device_util.canonicalize(key): value + for key, value in six.iteritems(index)} + self._primary_var = primary_var + self._common_name = self._primary_var.name.split(":")[0] + self._aggregation = aggregation + # Needed for GradientTape + self._trainable = self._primary_var.trainable + + def _get(self, device=None): + """Returns the value for the current device or raises a ValueError.""" + if device is None: + tower_context = distribution_strategy_context.get_tower_context() + if tower_context: + device = tower_context.device + else: + device = distribute_lib.get_update_device() + if device is None: + return self._get_cross_tower() + device = device_util.canonicalize(device) + try: + return self._index[device] + except KeyError as e: + six.raise_from( + ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())), e) + + # pylint: disable=multiple-statements + def __add__(self, o): return self.read_value() + o + def __radd__(self, o): return o + self.read_value() + def __sub__(self, o): return self.read_value() - o + def __rsub__(self, o): return o - self.read_value() + def __mul__(self, o): return self.read_value() * o + def __rmul__(self, o): return o * self.read_value() + def __truediv__(self, o): return self.read_value() / o + def __rtruediv__(self, o): return o / self.read_value() + def __floordiv__(self, o): return self.read_value() // o + def __rfloordiv__(self, o): return o // self.read_value() + def __mod__(self, o): return self.read_value() % o + def __rmod__(self, o): return o % self.read_value() + def __lt__(self, o): return self.read_value() < o + def __le__(self, o): return self.read_value() <= o + def __gt__(self, o): return self.read_value() > o + def __ge__(self, o): return self.read_value() >= o + def __and__(self, o): return self.read_value() & o + def __rand__(self, o): return o & self.read_value() + def __or__(self, o): return self.read_value() | o + def __ror__(self, o): return o | self.read_value() + def __xor__(self, o): return self.read_value() ^ o + def __rxor__(self, o): return o ^ self.read_value() + def __getitem__(self, o): return self.read_value()[o] + def __pow__(self, o, modulo=None): return pow(self.read_value(), o, modulo) + def __rpow__(self, o): return pow(o, self.read_value()) + def __invert__(self): return ~self.read_value() + def __neg__(self): return -self.read_value() + def __abs__(self): return abs(self.read_value()) + + def __div__(self, o): + try: + return self.read_value().__div__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rdiv__(self, o): + try: + return self.read_value().__rdiv__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __matmul__(self, o): + try: + return self.read_value().__matmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rmatmul__(self, o): + try: + return self.read_value().__rmatmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + @property + def handle(self): + # If we're in a tpu.rewrite(), return the replicated handle. + tpu_context = _enclosing_tpu_context() + if tpu_context is not None: + return tpu_context.get_replicated_var_handle( + self._common_name, nest.flatten(self._index)) + + device = distribute_lib.get_update_device() + if device is None: + return self._primary_var.handle + device = device_util.canonicalize(device) + try: + return self._index[device].handle + except KeyError as e: + six.raise_from( + ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())), e) + + # The arguments to update() are automatically unwrapped so the update() + # function would normally see regular variables, not MirroredVariables. + # However, the update function can still operate on wrapped MirroredVariables + # through object members, captured arguments, etc. This is more likely in an + # update_non_slot() function (like OptimizerV2._finish), which can + # update several non-slot variables in one call. + def _assign_func(self, *args, **kwargs): + if distribution_strategy_context.get_distribution_strategy().__class__.__name__ != "TPUStrategy": + raise ValueError("You may only assign to a TPUMirroredVariable within a " + "TPUStrategy.") + f = kwargs.pop("f") + if distribution_strategy_context.get_cross_tower_context(): + if _enclosing_tpu_context() is not None: + return distribution_strategy_context.get_distribution_strategy().update( + self, f, *args, **kwargs) + + update_device = distribute_lib.get_update_device() + # We are calling update on the mirrored variable in cross tower context. + if update_device is not None: + # We are calling an assign function on the mirrored variable in cross + # tower context. + v = self._get(device=update_device) + return f(v, *args, **kwargs) + + return distribution_strategy_context.get_distribution_strategy().update( + self, f, *args, **kwargs) + else: + _assert_tower_context() + # We are calling an assign function on the mirrored variable in tower + # context. + # We reduce the value we want to assign/add/sub. More details about how we + # handle the different use cases can be found in the _reduce method. + # We call the function on each of the mirrored variables with the reduced + # value. + if self._aggregation == vs.VariableAggregation.NONE: + raise ValueError("You must specify an aggregation method to update a " + "TPUMirroredVariable in Tower Context.") + + def merge_fn(strategy, value, *other_args, **other_kwargs): + return strategy.update( + self, f, + strategy.reduce( + aggregation=self._aggregation, value=value, destinations=self), + *other_args, **other_kwargs) + + return distribution_strategy_context.get_tower_context().merge_call( + merge_fn, *args, **kwargs) + + @contextlib.contextmanager + def _handle_graph(self, handle): + # Note: might have an eager tensor but not be executing eagerly when + # building functions. + if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) + or ops.has_default_graph()): + yield + else: + with handle.graph.as_default(): + yield + + @property + def trainable(self): + return self._trainable + + def _read_variable_op(self, parent_op=None): + if self.trainable: + tape.variable_accessed(self) + if parent_op is not None: + with ops.control_dependencies([parent_op]): + return gen_resource_variable_ops.read_variable_op( + self.handle, self.dtype) + + return gen_resource_variable_ops.read_variable_op( + self.handle, self.dtype) + + def read_value(self): + return self._read_variable_op() + + def assign_sub(self, *args, **kwargs): + def assign_sub_fn(var, delta, **kw): + name = kw.pop("name", None) + read_value = kw.pop("read_value", True) + with self._handle_graph(var.handle): + op = gen_resource_variable_ops.assign_sub_variable_op( + var.handle, ops.convert_to_tensor(delta, dtype=self.dtype), + name=name) + if read_value: + return self._read_variable_op(parent_op=op) + return op + + return self._assign_func(f=assign_sub_fn, *args, **kwargs) + + def assign_add(self, *args, **kwargs): + def assign_add_fn(var, delta, **kw): + name = kw.pop("name", None) + read_value = kw.pop("read_value", True) + with self._handle_graph(var.handle): + op = gen_resource_variable_ops.assign_add_variable_op( + var.handle, ops.convert_to_tensor(delta, dtype=self.dtype), + name=name) + if read_value: + return self._read_variable_op(parent_op=op) + return op + + return self._assign_func(f=assign_add_fn, *args, **kwargs) + + def assign(self, *args, **kwargs): + def assign_fn(var, value, **kw): + name = kw.pop("name", None) + read_value = kw.pop("read_value", True) + with self._handle_graph(var.handle): + op = gen_resource_variable_ops.assign_variable_op( + var.handle, ops.convert_to_tensor(value, dtype=self.dtype), + name=name) + if read_value: + return self._read_variable_op(parent_op=op) + return op + + return self._assign_func(f=assign_fn, *args, **kwargs) + + @property + def aggregation(self): + return self._aggregation + + @property + def constraint(self): + return None + + @property + def initializer(self): + return control_flow_ops.group( + [v.initializer for v in nest.flatten(self._index)]) + + @property + def graph(self): + return self._primary_var.graph + + @property + def _shared_name(self): + return self._common_name + + @property + def _unique_id(self): + return self._primary_var._unique_id # pylint: disable=protected-access + + @property + def name(self): + return self._primary_var.name + + @property + def dtype(self): + return self._primary_var.dtype + + @property + def shape(self): + return self._primary_var.shape + + def get_shape(self): + return self._primary_var.get_shape() + + def to_proto(self, export_scope=None): + return self._primary_var.to_proto(export_scope=export_scope) + + def _get_cross_tower(self): + device = device_util.canonicalize(device_util.current()) + if device in self._index: + return self._index[device] + return self._primary_var + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribution_strategy_context.get_cross_tower_context(): + return self._primary_var._as_graph_element() + return self._read_variable_op() + + def _gather_saveables_for_checkpoint(self): + """Overrides CheckpointableBase method. + + This allows both name-based and object-based save and restore of + MirroredVariables. + + Returns: + A dictionary mapping attribute names to `SaveableObject` factories. + """ + def _saveable_factory(name=self._common_name): + return _MirroredSaveable(self, self._primary_var, name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + def _should_act_as_resource_variable(self): + """Pass resource_variable_ops.is_resource_variable check.""" + pass + + # Needed to pass ResourceVariable checks. + @property + def op(self): + return self._primary_var.op + + @property + def _in_graph_mode(self): + return self._primary_var._in_graph_mode # pylint: disable=protected-access + + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): + """Converts a variable to a tensor.""" + # pylint: disable=protected-access + if _enclosing_tpu_context() is None: + return self._get()._dense_var_to_tensor(dtype, name, as_ref) + # pylint: enable=protected-access + if dtype is not None and dtype != self.dtype: + raise NotImplementedError + if as_ref: + return self.handle + else: + return self.read_value() + + def is_initialized(self, name=None): + """Identifies if all the component variables are initialized. + + Args: + name: Name of the final `logical_and` op. + + Returns: + The op that evaluates to True or False depending on if all the + component variables are initialized. + """ + # TODO(jhseu): Do we need TPU context implementation? + + # We have to cast the self._index.values() to a `list` because when we + # use `model_to_estimator` to run tf.keras models, self._index.values() is + # of type `dict_values` and not `list`. + values_list = nest.flatten(self._index) + result = values_list[0].is_initialized() + # We iterate through the list of values except the last one to allow us to + # name the final `logical_and` op the same name that is passed by the user + # to the `is_initialized` op. For distributed variables, the + # `is_initialized` op is a `logical_and` op. + for v in values_list[1:-1]: + result = math_ops.logical_and(result, v.is_initialized()) + result = math_ops.logical_and(result, values_list[-1].is_initialized(), + name=name) + return result + + +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion_tpu_mirrored(var, dtype=None, name=None, as_ref=False): + return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access + + +ops.register_tensor_conversion_function(TPUMirroredVariable, + _tensor_conversion_tpu_mirrored) +ops.register_dense_tensor_like_type(TPUMirroredVariable) + + class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): """Class for defining how to restore a TowerLocalVariable.""" |