aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py39
1 files changed, 22 insertions, 17 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 5fd4c9de69..8548a86421 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -38,6 +38,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import nest
@@ -56,7 +57,7 @@ class DistributedValues(object):
def get(self, device=None):
"""Returns the value for the current device or raises a ValueError."""
if device is None:
- tower_context = distribute_lib.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
if tower_context:
device = tower_context.device
else:
@@ -289,14 +290,15 @@ class DistributedVariable(DistributedDelegate):
# We want cross-tower code that does some var.op.X calls
# to work (even if the current device isn't in self.devices), but
# other uses of var.op in a cross-tower context to fail.
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
return DistributedVarOp(self._primary_var.op.name,
self._primary_var.op.graph,
self._primary_var.op.type)
return self.get().op
def read_value(self):
- return distribute_lib.get_distribution_strategy().read_var(self)
+ return distribution_strategy_context.get_distribution_strategy().read_var(
+ self)
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
@@ -362,7 +364,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
# update several non-slot variables in one call.
def _assign_func(self, *args, **kwargs):
f = kwargs.pop("f")
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
update_device = distribute_lib.get_update_device()
# We are calling update on the mirrored variable in cross tower context.
if update_device is not None:
@@ -371,7 +373,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
v = self.get(device=update_device)
return f(v, *args, **kwargs)
- return distribute_lib.get_distribution_strategy().update(
+ return distribution_strategy_context.get_distribution_strategy().update(
self, f, *args, **kwargs)
else:
_assert_tower_context()
@@ -392,8 +394,8 @@ class MirroredVariable(DistributedVariable, Mirrored,
aggregation=self._aggregation, value=value, destinations=self),
*other_args, **other_kwargs)
- return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
- **kwargs)
+ return distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, *args, **kwargs)
def assign_sub(self, *args, **kwargs):
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
@@ -419,7 +421,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
def _as_graph_element(self):
# pylint: disable=protected-access
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
return self._primary_var._as_graph_element()
return self.get()._as_graph_element()
@@ -459,7 +461,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
# We use a callable so that we don't have to evaluate this expression
# in the case where we are trying to restore instead of save.
def tensor():
- return distribute_lib.get_distribution_strategy().read_var(
+ return distribution_strategy_context.get_distribution_strategy().read_var(
tower_local_variable)
spec = saver.BaseSaverBuilder.SaveSpec(
tensor=tensor,
@@ -475,7 +477,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
def _assert_tower_context():
- if not distribute_lib.get_tower_context():
+ if not distribution_strategy_context.get_tower_context():
raise RuntimeError(
"Tower-local variables may only be assigned in a tower context.")
@@ -498,7 +500,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return self.get().assign_add(*args, **kwargs)
def assign(self, *args, **kwargs):
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
# To preserve the sum across save and restore, we have to divide the
# total across all devices when restoring a variable that was summed
# when saving.
@@ -526,7 +528,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
def _as_graph_element(self):
# pylint: disable=protected-access
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
return self._get_cross_tower()
return self.get()._as_graph_element()
@@ -994,12 +996,12 @@ class MultiStepContext(object):
outputs as already reduced or not.
"""
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
self._last_step_outputs_aggregations[name] = aggregation
if aggregation is variables_lib.VariableAggregation.NONE:
self._last_step_outputs[name] = output
else:
- distribution = distribute_lib.get_distribution_strategy()
+ distribution = distribution_strategy_context.get_distribution_strategy()
self._last_step_outputs[name] = distribution.reduce(
aggregation, output, destinations="/device:CPU:0")
else:
@@ -1011,7 +1013,9 @@ class MultiStepContext(object):
# context object, so it's more robust to set it only once (even if all
# the towers are trying to set the same value).
self._last_step_outputs_aggregations[name] = aggregation
- distribute_lib.get_tower_context().merge_call(merge_fn, output)
+
+ distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, output)
@property
def non_tensor_outputs(self):
@@ -1020,14 +1024,15 @@ class MultiStepContext(object):
def set_non_tensor_output(self, name, output):
"""Set `output` with `name` to be captured as a non tensor output."""
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
self._non_tensor_outputs[name] = output
else:
def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as aggregation doesn't make sense on non tensors.
self._non_tensor_outputs[name] = distribution.unwrap(value)
- distribute_lib.get_tower_context().merge_call(merge_fn, output)
+ distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, output)
def value_container(val):