diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 39 |
1 files changed, 22 insertions, 17 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 5fd4c9de69..8548a86421 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import saver from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest @@ -56,7 +57,7 @@ class DistributedValues(object): def get(self, device=None): """Returns the value for the current device or raises a ValueError.""" if device is None: - tower_context = distribute_lib.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() if tower_context: device = tower_context.device else: @@ -289,14 +290,15 @@ class DistributedVariable(DistributedDelegate): # We want cross-tower code that does some var.op.X calls # to work (even if the current device isn't in self.devices), but # other uses of var.op in a cross-tower context to fail. - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return DistributedVarOp(self._primary_var.op.name, self._primary_var.op.graph, self._primary_var.op.type) return self.get().op def read_value(self): - return distribute_lib.get_distribution_strategy().read_var(self) + return distribution_strategy_context.get_distribution_strategy().read_var( + self) def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" @@ -362,7 +364,7 @@ class MirroredVariable(DistributedVariable, Mirrored, # update several non-slot variables in one call. def _assign_func(self, *args, **kwargs): f = kwargs.pop("f") - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): update_device = distribute_lib.get_update_device() # We are calling update on the mirrored variable in cross tower context. if update_device is not None: @@ -371,7 +373,7 @@ class MirroredVariable(DistributedVariable, Mirrored, v = self.get(device=update_device) return f(v, *args, **kwargs) - return distribute_lib.get_distribution_strategy().update( + return distribution_strategy_context.get_distribution_strategy().update( self, f, *args, **kwargs) else: _assert_tower_context() @@ -392,8 +394,8 @@ class MirroredVariable(DistributedVariable, Mirrored, aggregation=self._aggregation, value=value, destinations=self), *other_args, **other_kwargs) - return distribute_lib.get_tower_context().merge_call(merge_fn, *args, - **kwargs) + return distribution_strategy_context.get_tower_context().merge_call( + merge_fn, *args, **kwargs) def assign_sub(self, *args, **kwargs): assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) @@ -419,7 +421,7 @@ class MirroredVariable(DistributedVariable, Mirrored, def _as_graph_element(self): # pylint: disable=protected-access - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return self._primary_var._as_graph_element() return self.get()._as_graph_element() @@ -459,7 +461,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): # We use a callable so that we don't have to evaluate this expression # in the case where we are trying to restore instead of save. def tensor(): - return distribute_lib.get_distribution_strategy().read_var( + return distribution_strategy_context.get_distribution_strategy().read_var( tower_local_variable) spec = saver.BaseSaverBuilder.SaveSpec( tensor=tensor, @@ -475,7 +477,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): def _assert_tower_context(): - if not distribute_lib.get_tower_context(): + if not distribution_strategy_context.get_tower_context(): raise RuntimeError( "Tower-local variables may only be assigned in a tower context.") @@ -498,7 +500,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice, return self.get().assign_add(*args, **kwargs) def assign(self, *args, **kwargs): - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): # To preserve the sum across save and restore, we have to divide the # total across all devices when restoring a variable that was summed # when saving. @@ -526,7 +528,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice, def _as_graph_element(self): # pylint: disable=protected-access - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return self._get_cross_tower() return self.get()._as_graph_element() @@ -994,12 +996,12 @@ class MultiStepContext(object): outputs as already reduced or not. """ - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): self._last_step_outputs_aggregations[name] = aggregation if aggregation is variables_lib.VariableAggregation.NONE: self._last_step_outputs[name] = output else: - distribution = distribute_lib.get_distribution_strategy() + distribution = distribution_strategy_context.get_distribution_strategy() self._last_step_outputs[name] = distribution.reduce( aggregation, output, destinations="/device:CPU:0") else: @@ -1011,7 +1013,9 @@ class MultiStepContext(object): # context object, so it's more robust to set it only once (even if all # the towers are trying to set the same value). self._last_step_outputs_aggregations[name] = aggregation - distribute_lib.get_tower_context().merge_call(merge_fn, output) + + distribution_strategy_context.get_tower_context().merge_call( + merge_fn, output) @property def non_tensor_outputs(self): @@ -1020,14 +1024,15 @@ class MultiStepContext(object): def set_non_tensor_output(self, name, output): """Set `output` with `name` to be captured as a non tensor output.""" - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): self._non_tensor_outputs[name] = output else: def merge_fn(distribution, value): # NOTE(priyag): For non tensor outputs, we simply return all the values # in a list as aggregation doesn't make sense on non tensors. self._non_tensor_outputs[name] = distribution.unwrap(value) - distribute_lib.get_tower_context().merge_call(merge_fn, output) + distribution_strategy_context.get_tower_context().merge_call( + merge_fn, output) def value_container(val): |