aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2018-09-28 18:41:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 18:45:56 -0700
commitd37f771cc5a208cdc88a50a65f491b3c06c9f262 (patch)
tree1036470d10da26df9f5dcf897a74c78329fe57cc /tensorflow/contrib/distribute/python/values.py
parentabd5c32c0fa6451e73b491affdd86d852a74177f (diff)
Move TPU variables to the TPU device in TPUStrategy.
PiperOrigin-RevId: 215027511
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py381
1 files changed, 381 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 4955ded4d5..c18faeb67d 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -22,17 +22,20 @@ from __future__ import division
from __future__ import print_function
import collections
+import contextlib
import weakref
import six
from tensorflow.contrib.distribute.python import input_ops
from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
@@ -453,6 +456,384 @@ ops.register_tensor_conversion_function(MirroredVariable,
_tensor_conversion_mirrored)
+def _enclosing_tpu_context():
+ # pylint: disable=protected-access
+ tpu_context = ops.get_default_graph()._get_control_flow_context()
+ # pylint: enable=protected-access
+ while tpu_context is not None and not isinstance(
+ tpu_context, control_flow_ops.XLAControlFlowContext):
+ tpu_context = tpu_context.outer_context
+ return tpu_context
+
+
+# TODO(jhseu): Deduplicate code. We copy code because we don't want to
+# inherit from DistributedDelegate. DistributedDelegate will not work in a
+# tpu.replicate() because it assumes that you're in a device context where you
+# can operate on a single version of the variable, but a tpu.replicate()
+# operates on all variables and is replicated during a rewrite pass.
+class TPUMirroredVariable(checkpointable.CheckpointableBase):
+ """Holds a map from device to TPU variables whose values are kept in sync."""
+
+ def __init__(self, index, primary_var, aggregation):
+ # Use a weakref to make it easy to map from the contained values
+ # to the container without introducing a reference cycle.
+ for v in six.itervalues(index):
+ v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
+ self._index = {device_util.canonicalize(key): value
+ for key, value in six.iteritems(index)}
+ self._primary_var = primary_var
+ self._common_name = self._primary_var.name.split(":")[0]
+ self._aggregation = aggregation
+ # Needed for GradientTape
+ self._trainable = self._primary_var.trainable
+
+ def _get(self, device=None):
+ """Returns the value for the current device or raises a ValueError."""
+ if device is None:
+ tower_context = distribution_strategy_context.get_tower_context()
+ if tower_context:
+ device = tower_context.device
+ else:
+ device = distribute_lib.get_update_device()
+ if device is None:
+ return self._get_cross_tower()
+ device = device_util.canonicalize(device)
+ try:
+ return self._index[device]
+ except KeyError as e:
+ six.raise_from(
+ ValueError("Device %s not found in %s (current device %s)" %
+ (device, self._index.keys(), device_util.current())), e)
+
+ # pylint: disable=multiple-statements
+ def __add__(self, o): return self.read_value() + o
+ def __radd__(self, o): return o + self.read_value()
+ def __sub__(self, o): return self.read_value() - o
+ def __rsub__(self, o): return o - self.read_value()
+ def __mul__(self, o): return self.read_value() * o
+ def __rmul__(self, o): return o * self.read_value()
+ def __truediv__(self, o): return self.read_value() / o
+ def __rtruediv__(self, o): return o / self.read_value()
+ def __floordiv__(self, o): return self.read_value() // o
+ def __rfloordiv__(self, o): return o // self.read_value()
+ def __mod__(self, o): return self.read_value() % o
+ def __rmod__(self, o): return o % self.read_value()
+ def __lt__(self, o): return self.read_value() < o
+ def __le__(self, o): return self.read_value() <= o
+ def __gt__(self, o): return self.read_value() > o
+ def __ge__(self, o): return self.read_value() >= o
+ def __and__(self, o): return self.read_value() & o
+ def __rand__(self, o): return o & self.read_value()
+ def __or__(self, o): return self.read_value() | o
+ def __ror__(self, o): return o | self.read_value()
+ def __xor__(self, o): return self.read_value() ^ o
+ def __rxor__(self, o): return o ^ self.read_value()
+ def __getitem__(self, o): return self.read_value()[o]
+ def __pow__(self, o, modulo=None): return pow(self.read_value(), o, modulo)
+ def __rpow__(self, o): return pow(o, self.read_value())
+ def __invert__(self): return ~self.read_value()
+ def __neg__(self): return -self.read_value()
+ def __abs__(self): return abs(self.read_value())
+
+ def __div__(self, o):
+ try:
+ return self.read_value().__div__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rdiv__(self, o):
+ try:
+ return self.read_value().__rdiv__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __matmul__(self, o):
+ try:
+ return self.read_value().__matmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rmatmul__(self, o):
+ try:
+ return self.read_value().__rmatmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ @property
+ def handle(self):
+ # If we're in a tpu.rewrite(), return the replicated handle.
+ tpu_context = _enclosing_tpu_context()
+ if tpu_context is not None:
+ return tpu_context.get_replicated_var_handle(
+ self._common_name, nest.flatten(self._index))
+
+ device = distribute_lib.get_update_device()
+ if device is None:
+ return self._primary_var.handle
+ device = device_util.canonicalize(device)
+ try:
+ return self._index[device].handle
+ except KeyError as e:
+ six.raise_from(
+ ValueError("Device %s not found in %s (current device %s)" %
+ (device, self._index.keys(), device_util.current())), e)
+
+ # The arguments to update() are automatically unwrapped so the update()
+ # function would normally see regular variables, not MirroredVariables.
+ # However, the update function can still operate on wrapped MirroredVariables
+ # through object members, captured arguments, etc. This is more likely in an
+ # update_non_slot() function (like OptimizerV2._finish), which can
+ # update several non-slot variables in one call.
+ def _assign_func(self, *args, **kwargs):
+ if distribution_strategy_context.get_distribution_strategy().__class__.__name__ != "TPUStrategy":
+ raise ValueError("You may only assign to a TPUMirroredVariable within a "
+ "TPUStrategy.")
+ f = kwargs.pop("f")
+ if distribution_strategy_context.get_cross_tower_context():
+ if _enclosing_tpu_context() is not None:
+ return distribution_strategy_context.get_distribution_strategy().update(
+ self, f, *args, **kwargs)
+
+ update_device = distribute_lib.get_update_device()
+ # We are calling update on the mirrored variable in cross tower context.
+ if update_device is not None:
+ # We are calling an assign function on the mirrored variable in cross
+ # tower context.
+ v = self._get(device=update_device)
+ return f(v, *args, **kwargs)
+
+ return distribution_strategy_context.get_distribution_strategy().update(
+ self, f, *args, **kwargs)
+ else:
+ _assert_tower_context()
+ # We are calling an assign function on the mirrored variable in tower
+ # context.
+ # We reduce the value we want to assign/add/sub. More details about how we
+ # handle the different use cases can be found in the _reduce method.
+ # We call the function on each of the mirrored variables with the reduced
+ # value.
+ if self._aggregation == vs.VariableAggregation.NONE:
+ raise ValueError("You must specify an aggregation method to update a "
+ "TPUMirroredVariable in Tower Context.")
+
+ def merge_fn(strategy, value, *other_args, **other_kwargs):
+ return strategy.update(
+ self, f,
+ strategy.reduce(
+ aggregation=self._aggregation, value=value, destinations=self),
+ *other_args, **other_kwargs)
+
+ return distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, *args, **kwargs)
+
+ @contextlib.contextmanager
+ def _handle_graph(self, handle):
+ # Note: might have an eager tensor but not be executing eagerly when
+ # building functions.
+ if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor)
+ or ops.has_default_graph()):
+ yield
+ else:
+ with handle.graph.as_default():
+ yield
+
+ @property
+ def trainable(self):
+ return self._trainable
+
+ def _read_variable_op(self, parent_op=None):
+ if self.trainable:
+ tape.variable_accessed(self)
+ if parent_op is not None:
+ with ops.control_dependencies([parent_op]):
+ return gen_resource_variable_ops.read_variable_op(
+ self.handle, self.dtype)
+
+ return gen_resource_variable_ops.read_variable_op(
+ self.handle, self.dtype)
+
+ def read_value(self):
+ return self._read_variable_op()
+
+ def assign_sub(self, *args, **kwargs):
+ def assign_sub_fn(var, delta, **kw):
+ name = kw.pop("name", None)
+ read_value = kw.pop("read_value", True)
+ with self._handle_graph(var.handle):
+ op = gen_resource_variable_ops.assign_sub_variable_op(
+ var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op(parent_op=op)
+ return op
+
+ return self._assign_func(f=assign_sub_fn, *args, **kwargs)
+
+ def assign_add(self, *args, **kwargs):
+ def assign_add_fn(var, delta, **kw):
+ name = kw.pop("name", None)
+ read_value = kw.pop("read_value", True)
+ with self._handle_graph(var.handle):
+ op = gen_resource_variable_ops.assign_add_variable_op(
+ var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op(parent_op=op)
+ return op
+
+ return self._assign_func(f=assign_add_fn, *args, **kwargs)
+
+ def assign(self, *args, **kwargs):
+ def assign_fn(var, value, **kw):
+ name = kw.pop("name", None)
+ read_value = kw.pop("read_value", True)
+ with self._handle_graph(var.handle):
+ op = gen_resource_variable_ops.assign_variable_op(
+ var.handle, ops.convert_to_tensor(value, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op(parent_op=op)
+ return op
+
+ return self._assign_func(f=assign_fn, *args, **kwargs)
+
+ @property
+ def aggregation(self):
+ return self._aggregation
+
+ @property
+ def constraint(self):
+ return None
+
+ @property
+ def initializer(self):
+ return control_flow_ops.group(
+ [v.initializer for v in nest.flatten(self._index)])
+
+ @property
+ def graph(self):
+ return self._primary_var.graph
+
+ @property
+ def _shared_name(self):
+ return self._common_name
+
+ @property
+ def _unique_id(self):
+ return self._primary_var._unique_id # pylint: disable=protected-access
+
+ @property
+ def name(self):
+ return self._primary_var.name
+
+ @property
+ def dtype(self):
+ return self._primary_var.dtype
+
+ @property
+ def shape(self):
+ return self._primary_var.shape
+
+ def get_shape(self):
+ return self._primary_var.get_shape()
+
+ def to_proto(self, export_scope=None):
+ return self._primary_var.to_proto(export_scope=export_scope)
+
+ def _get_cross_tower(self):
+ device = device_util.canonicalize(device_util.current())
+ if device in self._index:
+ return self._index[device]
+ return self._primary_var
+
+ def _as_graph_element(self):
+ # pylint: disable=protected-access
+ if distribution_strategy_context.get_cross_tower_context():
+ return self._primary_var._as_graph_element()
+ return self._read_variable_op()
+
+ def _gather_saveables_for_checkpoint(self):
+ """Overrides CheckpointableBase method.
+
+ This allows both name-based and object-based save and restore of
+ MirroredVariables.
+
+ Returns:
+ A dictionary mapping attribute names to `SaveableObject` factories.
+ """
+ def _saveable_factory(name=self._common_name):
+ return _MirroredSaveable(self, self._primary_var, name)
+ return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
+
+ def _should_act_as_resource_variable(self):
+ """Pass resource_variable_ops.is_resource_variable check."""
+ pass
+
+ # Needed to pass ResourceVariable checks.
+ @property
+ def op(self):
+ return self._primary_var.op
+
+ @property
+ def _in_graph_mode(self):
+ return self._primary_var._in_graph_mode # pylint: disable=protected-access
+
+ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
+ """Converts a variable to a tensor."""
+ # pylint: disable=protected-access
+ if _enclosing_tpu_context() is None:
+ return self._get()._dense_var_to_tensor(dtype, name, as_ref)
+ # pylint: enable=protected-access
+ if dtype is not None and dtype != self.dtype:
+ raise NotImplementedError
+ if as_ref:
+ return self.handle
+ else:
+ return self.read_value()
+
+ def is_initialized(self, name=None):
+ """Identifies if all the component variables are initialized.
+
+ Args:
+ name: Name of the final `logical_and` op.
+
+ Returns:
+ The op that evaluates to True or False depending on if all the
+ component variables are initialized.
+ """
+ # TODO(jhseu): Do we need TPU context implementation?
+
+ # We have to cast the self._index.values() to a `list` because when we
+ # use `model_to_estimator` to run tf.keras models, self._index.values() is
+ # of type `dict_values` and not `list`.
+ values_list = nest.flatten(self._index)
+ result = values_list[0].is_initialized()
+ # We iterate through the list of values except the last one to allow us to
+ # name the final `logical_and` op the same name that is passed by the user
+ # to the `is_initialized` op. For distributed variables, the
+ # `is_initialized` op is a `logical_and` op.
+ for v in values_list[1:-1]:
+ result = math_ops.logical_and(result, v.is_initialized())
+ result = math_ops.logical_and(result, values_list[-1].is_initialized(),
+ name=name)
+ return result
+
+
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion_tpu_mirrored(var, dtype=None, name=None, as_ref=False):
+ return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
+
+
+ops.register_tensor_conversion_function(TPUMirroredVariable,
+ _tensor_conversion_tpu_mirrored)
+ops.register_dense_tensor_like_type(TPUMirroredVariable)
+
+
class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
"""Class for defining how to restore a TowerLocalVariable."""