aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-08-14 11:22:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-14 11:31:36 -0700
commit77fabbeabb5b9061d8c606050c1ea79aec990c03 (patch)
tree1495d6acb396eebd40c703b891a4f2e7437a8532 /tensorflow/contrib/distribute/python/values.py
parentcea262e16a004d73295259c42f21e2655da3df13 (diff)
1. Move distribution strategy context utility methods to a separate file with few dependencies. This allows us to import this in some places without creating circular dependencies as the original file imported many things.
2. Move the stack used in distribution strategy context to the graph. This allows us to use different strategies in different graphs (for e.g. in train and eval). This fixes #21412 and #21180. PiperOrigin-RevId: 208680454
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py39
1 files changed, 22 insertions, 17 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 5fd4c9de69..8548a86421 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -38,6 +38,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import nest
@@ -56,7 +57,7 @@ class DistributedValues(object):
def get(self, device=None):
"""Returns the value for the current device or raises a ValueError."""
if device is None:
- tower_context = distribute_lib.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
if tower_context:
device = tower_context.device
else:
@@ -289,14 +290,15 @@ class DistributedVariable(DistributedDelegate):
# We want cross-tower code that does some var.op.X calls
# to work (even if the current device isn't in self.devices), but
# other uses of var.op in a cross-tower context to fail.
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
return DistributedVarOp(self._primary_var.op.name,
self._primary_var.op.graph,
self._primary_var.op.type)
return self.get().op
def read_value(self):
- return distribute_lib.get_distribution_strategy().read_var(self)
+ return distribution_strategy_context.get_distribution_strategy().read_var(
+ self)
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
@@ -362,7 +364,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
# update several non-slot variables in one call.
def _assign_func(self, *args, **kwargs):
f = kwargs.pop("f")
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
update_device = distribute_lib.get_update_device()
# We are calling update on the mirrored variable in cross tower context.
if update_device is not None:
@@ -371,7 +373,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
v = self.get(device=update_device)
return f(v, *args, **kwargs)
- return distribute_lib.get_distribution_strategy().update(
+ return distribution_strategy_context.get_distribution_strategy().update(
self, f, *args, **kwargs)
else:
_assert_tower_context()
@@ -392,8 +394,8 @@ class MirroredVariable(DistributedVariable, Mirrored,
aggregation=self._aggregation, value=value, destinations=self),
*other_args, **other_kwargs)
- return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
- **kwargs)
+ return distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, *args, **kwargs)
def assign_sub(self, *args, **kwargs):
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
@@ -419,7 +421,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
def _as_graph_element(self):
# pylint: disable=protected-access
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
return self._primary_var._as_graph_element()
return self.get()._as_graph_element()
@@ -459,7 +461,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
# We use a callable so that we don't have to evaluate this expression
# in the case where we are trying to restore instead of save.
def tensor():
- return distribute_lib.get_distribution_strategy().read_var(
+ return distribution_strategy_context.get_distribution_strategy().read_var(
tower_local_variable)
spec = saver.BaseSaverBuilder.SaveSpec(
tensor=tensor,
@@ -475,7 +477,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
def _assert_tower_context():
- if not distribute_lib.get_tower_context():
+ if not distribution_strategy_context.get_tower_context():
raise RuntimeError(
"Tower-local variables may only be assigned in a tower context.")
@@ -498,7 +500,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return self.get().assign_add(*args, **kwargs)
def assign(self, *args, **kwargs):
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
# To preserve the sum across save and restore, we have to divide the
# total across all devices when restoring a variable that was summed
# when saving.
@@ -526,7 +528,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
def _as_graph_element(self):
# pylint: disable=protected-access
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
return self._get_cross_tower()
return self.get()._as_graph_element()
@@ -994,12 +996,12 @@ class MultiStepContext(object):
outputs as already reduced or not.
"""
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
self._last_step_outputs_aggregations[name] = aggregation
if aggregation is variables_lib.VariableAggregation.NONE:
self._last_step_outputs[name] = output
else:
- distribution = distribute_lib.get_distribution_strategy()
+ distribution = distribution_strategy_context.get_distribution_strategy()
self._last_step_outputs[name] = distribution.reduce(
aggregation, output, destinations="/device:CPU:0")
else:
@@ -1011,7 +1013,9 @@ class MultiStepContext(object):
# context object, so it's more robust to set it only once (even if all
# the towers are trying to set the same value).
self._last_step_outputs_aggregations[name] = aggregation
- distribute_lib.get_tower_context().merge_call(merge_fn, output)
+
+ distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, output)
@property
def non_tensor_outputs(self):
@@ -1020,14 +1024,15 @@ class MultiStepContext(object):
def set_non_tensor_output(self, name, output):
"""Set `output` with `name` to be captured as a non tensor output."""
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
self._non_tensor_outputs[name] = output
else:
def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as aggregation doesn't make sense on non tensors.
self._non_tensor_outputs[name] = distribution.unwrap(value)
- distribute_lib.get_tower_context().merge_call(merge_fn, output)
+ distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, output)
def value_container(val):