aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py112
1 files changed, 82 insertions, 30 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 95390041f4..47dcf679c2 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -30,10 +30,11 @@ from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver
@@ -77,6 +78,13 @@ class DistributedValues(object):
def devices(self):
return list(self._index.keys())
+ @property
+ def is_tensor_like(self):
+ for v in self._index.values():
+ if not tensor_util.is_tensor(v):
+ return False
+ return True
+
def __str__(self):
return "%s:%s" % (self.__class__.__name__, self._index)
@@ -196,10 +204,43 @@ class DistributedVariable(DistributedDelegate):
# to the container without introducing a reference cycle.
for v in six.itervalues(index):
v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
+ # tf.keras keeps track of variables initialized using this attribute. When
+ # tf.keras gets the default session, it initializes all uninitialized vars.
+ # We need to make _keras_initialized a member of DistributedVariable because
+ # without this it will use `__getattr__` which will delegate to a component
+ # variable.
+ self._keras_initialized = False
super(DistributedVariable, self).__init__(index)
+ def is_initialized(self, name=None):
+ """Identifies if all the component variables are initialized.
+
+ Args:
+ name: Name of the final `logical_and` op.
+
+ Returns:
+ The op that evaluates to True or False depending on if all the
+ component variables are initialized.
+ """
+ # We have to cast the self._index.values() to a `list` because when we
+ # use `model_to_estimator` to run tf.keras models, self._index.values() is
+ # of type `dict_values` and not `list`.
+ values_list = list(self._index.values())
+ result = values_list[0].is_initialized()
+ # We iterate through the list of values except the last one to allow us to
+ # name the final `logical_and` op the same name that is passed by the user
+ # to the `is_initialized` op. For distributed variables, the
+ # `is_initialized` op is a `logical_and` op.
+ for v in values_list[1:-1]:
+ result = math_ops.logical_and(result, v.is_initialized())
+ result = math_ops.logical_and(result, values_list[-1].is_initialized(),
+ name=name)
+ return result
+
@property
def initializer(self):
+ # return grouped ops of all the var initializations of component values of
+ # the mirrored variable
return control_flow_ops.group([v.initializer for v in self._index.values()])
@property
@@ -290,13 +331,13 @@ class MirroredVariable(DistributedVariable, Mirrored,
checkpointable.CheckpointableBase):
"""Holds a map from device to variables whose values are kept in sync."""
- def __init__(self, index, primary_var, aggregation_method=None):
+ def __init__(self, index, primary_var, aggregation):
# Use a weakref to make it easy to map from the contained values
# to the container without introducing a reference cycle.
for v in six.itervalues(index):
v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._primary_var = primary_var
- self._aggregation_method = aggregation_method
+ self._aggregation = aggregation
super(MirroredVariable, self).__init__(index)
# The arguments to update() are automatically unwrapped so the update()
@@ -319,34 +360,42 @@ class MirroredVariable(DistributedVariable, Mirrored,
return distribute_lib.get_distribution_strategy().update(
self, f, *args, **kwargs)
else:
+ _assert_tower_context()
# We are calling an assign function on the mirrored variable in tower
# context.
# We reduce the value we want to assign/add/sub. More details about how we
# handle the different use cases can be found in the _reduce method.
# We call the function on each of the mirrored variables with the reduced
# value.
- if not self._aggregation_method:
+ if self._aggregation == vs.VariableAggregation.NONE:
raise ValueError("You must specify an aggregation method to update a "
"MirroredVariable in Tower Context.")
- def merge_fn(strategy, value):
- return strategy.update(self,
- f,
- strategy.reduce(
- method_string=self._aggregation_method,
- value=value,
- destinations=self))
+ def merge_fn(strategy, value, *other_args, **other_kwargs):
+ return strategy.update(
+ self, f,
+ strategy.reduce(
+ aggregation=self._aggregation, value=value, destinations=self),
+ *other_args, **other_kwargs)
+
return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
**kwargs)
def assign_sub(self, *args, **kwargs):
- return self._assign_func(f=state_ops.assign_sub, *args, **kwargs)
+ assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
+ return self._assign_func(f=assign_sub_fn, *args, **kwargs)
def assign_add(self, *args, **kwargs):
- return self._assign_func(f=state_ops.assign_add, *args, **kwargs)
+ assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
+ return self._assign_func(f=assign_add_fn, *args, **kwargs)
def assign(self, *args, **kwargs):
- return self._assign_func(f=state_ops.assign, *args, **kwargs)
+ assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
+ return self._assign_func(f=assign_fn, *args, **kwargs)
+
+ @property
+ def aggregation(self):
+ return self._aggregation
def _get_cross_tower(self):
device = device_util.canonicalize(device_util.current())
@@ -408,14 +457,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into all variables."""
tensor, = restored_tensors
- # To preserve the sum across save and restore, we have to divide the
- # total across all devices when restoring a variable that was summed
- # when saving.
- if self._tower_local_variable.reduce_method == "sum":
- tensor *= 1. / len(self._tower_local_variable.devices)
- return control_flow_ops.group([
- _assign_on_device(d, v, tensor)
- for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access
+ return self._tower_local_variable.assign(tensor)
def _assert_tower_context():
@@ -428,9 +470,9 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
checkpointable.CheckpointableBase):
"""Holds a map from device to variables whose values are reduced on save."""
- def __init__(self, index, primary_var, reduce_method):
+ def __init__(self, index, primary_var, aggregation):
self._primary_var = primary_var
- self._reduce_method = reduce_method
+ self._aggregation = aggregation
super(TowerLocalVariable, self).__init__(index)
def assign_sub(self, *args, **kwargs):
@@ -442,18 +484,29 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return self.get().assign_add(*args, **kwargs)
def assign(self, *args, **kwargs):
- _assert_tower_context()
- return self.get().assign(*args, **kwargs)
+ if distribute_lib.get_cross_tower_context():
+ # To preserve the sum across save and restore, we have to divide the
+ # total across all devices when restoring a variable that was summed
+ # when saving.
+ tensor = args[0]
+ if self._aggregation == vs.VariableAggregation.SUM:
+ tensor *= 1. / len(self.devices)
+ return control_flow_ops.group(
+ [_assign_on_device(d, v, tensor)
+ for d, v in six.iteritems(self._index)])
+ else:
+ _assert_tower_context()
+ return self.get().assign(*args, **kwargs)
@property
- def reduce_method(self):
- return self._reduce_method
+ def aggregation(self):
+ return self._aggregation
def _get_cross_tower(self):
all_components = tuple(self._index.values())
# TODO(josh11b): Use a strategy-specific method.
total = math_ops.add_n(all_components)
- if self._reduce_method == "mean":
+ if self._aggregation == vs.VariableAggregation.MEAN:
return total * (1./ len(all_components))
return total
@@ -929,4 +982,3 @@ class MultiStepContext(object):
assert o.dtype == i.dtype, (
"Dtype {} of left {} doesn't match dtype {} of right {}.".
format(o.dtype, o, i.dtype, i))
-