diff options
author | Jonathan Hseu <jhseu@google.com> | 2018-09-28 18:41:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 18:45:56 -0700 |
commit | d37f771cc5a208cdc88a50a65f491b3c06c9f262 (patch) | |
tree | 1036470d10da26df9f5dcf897a74c78329fe57cc | |
parent | abd5c32c0fa6451e73b491affdd86d852a74177f (diff) |
Move TPU variables to the TPU device in TPUStrategy.
PiperOrigin-RevId: 215027511
-rw-r--r-- | tensorflow/contrib/distribute/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/tpu_strategy.py | 175 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 381 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu.py | 11 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop.py | 2 | ||||
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 4 | ||||
-rw-r--r-- | tensorflow/python/estimator/util.py | 8 | ||||
-rw-r--r-- | tensorflow/python/training/optimizer.py | 5 | ||||
-rw-r--r-- | tensorflow/python/training/session_manager.py | 5 |
10 files changed, 565 insertions, 29 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 422983dbef..cfb9d42a6f 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -28,6 +28,7 @@ py_library( "//tensorflow/python:device_util", "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/eager:context", diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index a6762e5e87..1b555482d3 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -29,6 +29,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.python.eager import context +from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -37,9 +38,13 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest +_TPU_INITIALIZE_SYSTEM_COLLECTION = "TPU_STRATEGY_INITIALIZE" + + def get_tpu_system_metadata(tpu_cluster_resolver): """Retrieves TPU system metadata given a TPUClusterResolver.""" master = tpu_cluster_resolver.master() @@ -56,6 +61,58 @@ def get_tpu_system_metadata(tpu_cluster_resolver): return tpu_system_metadata +# TODO(jhseu): Deduplicate with MirroredStrategy? +def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, + **kwargs): # pylint: disable=g-missing-docstring + # Figure out what collections this variable should be added to. + # We'll add the TPUMirroredVariable to those collections instead. + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + # TODO(jhseu): Should we have different behavior for different + # synchronization settings? + + # Get aggregation value + # TODO(jhseu): Support aggregation in a tower context. + aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) + if aggregation not in [ + vs.VariableAggregation.NONE, + vs.VariableAggregation.SUM, + vs.VariableAggregation.MEAN, + vs.VariableAggregation.ONLY_FIRST_TOWER, + ]: + raise ValueError("Invalid variable aggregation mode: {} for variable: {}" + .format(aggregation, kwargs["name"])) + + # Ignore user-specified caching device, not needed for mirrored variables. + kwargs.pop("caching_device", None) + + # TODO(josh11b,apassos): It would be better if variable initialization + # was never recorded on the tape instead of having to do this manually + # here. + with tape.stop_recording(): + index = real_mirrored_creator(devices, *args, **kwargs) + result = values.TPUMirroredVariable(index, index[devices[0]], aggregation) + + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the member variables + # to the TRAINABLE_VARIABLES collection, so we manually remove + # them and replace with the MirroredVariable. We can't set + # "trainable" to False for next_creator() since that causes functions + # like implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + for v in index.values(): + l.remove(v) + g.add_to_collections(collections, result) + return result + + +# TODO(jhseu): Stop inheriting from OneDeviceStrategy. class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" @@ -82,6 +139,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): # TODO(sourabhbajaj): Change this from num_cores to metadata_override self._num_cores_override = num_cores + # TODO(jhseu): Switch to DeviceAssignment to support pods and model + # parallelism. + device_map = {d.name: i for i, d in enumerate(self._tpu_metadata.devices) + if "device:TPU:" in d.name} + self._device_index = values.PerDevice(device_map) + self._tpu_devices = sorted(device_map.keys()) + # Only create variables for the number of towers we're running. + self._tpu_devices = self._tpu_devices[:self.num_towers] + # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run @@ -239,6 +305,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): return ctx def _call_for_each_tower(self, fn, *args, **kwargs): + # TODO(jhseu): Consider making it so call_for_each_tower implies that we're + # in a tpu.rewrite(), and update TPUMirroredVariable accordingly. kwargs.pop('run_concurrently', None) with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access return fn(*args, **kwargs) @@ -248,7 +316,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): # TODO(priyag): Add appopriate call here when eager is supported for TPUs. raise NotImplementedError('Eager mode not supported in TPUStrategy.') else: - return [tpu.initialize_system()] + # TODO(jhseu): We need this hack because DistributionStrategies must be + # pickleable for copy.deepcopy(). Remove when initialize_system goes away. + graph = ops.get_default_graph() + tpu_init = graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) + if tpu_init: + return tpu_init + graph.add_to_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION, + tpu.initialize_system()) + return graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) def finalize(self): if context.executing_eagerly(): @@ -257,21 +333,53 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): else: return [tpu.shutdown_system()] + def _get_devices_from(self, colocate_with=None): + # TODO(jhseu): Change this when we support model parallelism. + return self._tpu_devices + + def _create_variable(self, next_creator, *args, **kwargs): + """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" + colocate_with = kwargs.pop("colocate_with", None) + devices = self._get_devices_from(colocate_with) + + def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring + index = {} + for i, d in enumerate(devices): + with ops.device(d): + if i > 0: + # Give replicas meaningful distinct names: + var0name = index[devices[0]].name.split(":")[0] + # We append a / to variable names created on towers with id > 0 to + # ensure that we ignore the name scope and instead use the given + # name as the absolute name of the variable. + kwargs["name"] = "%s/replica_%d/" % (var0name, i) + # Initialize replicas with the same value: + if context.executing_eagerly(): + kwargs["initial_value"] = array_ops.identity( + index[devices[0]].value()) + else: + def initial_value_fn(device=d): + with ops.device(device): + return array_ops.identity(index[devices[0]].initial_value) + kwargs["initial_value"] = initial_value_fn + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + v = next_creator(*args, **kwargs) + assert not isinstance(v, values.TPUMirroredVariable) + index[d] = v + return index + + return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args, + **kwargs) + def _reduce(self, aggregation, value, destinations): - graph = ops.get_default_graph() - cf_context = graph._get_control_flow_context() # pylint: disable=protected-access - # If we're inside the ReplicateContext, reduction should be done using - # CrossReplicaSum while outside we can directly use an add_n op. - while cf_context: - if isinstance(cf_context, tpu.TPUReplicateContext): - if aggregation == vs.VariableAggregation.MEAN: - # TODO(jhseu): Revisit once we support model-parallelism. - value *= (1. / self.num_towers) - elif aggregation != vs.VariableAggregation.SUM: - raise NotImplementedError( - 'Currently only support sum & mean in TPUStrategy.') - return tpu_ops.cross_replica_sum(value) - cf_context = cf_context.outer_context + if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access + if aggregation == vs.VariableAggregation.MEAN: + # TODO(jhseu): Revisit once we support model-parallelism. + value *= (1. / self.num_towers) + elif aggregation != vs.VariableAggregation.SUM: + raise NotImplementedError( + "Currently only support sum & mean in TPUStrategy.") + return tpu_ops.cross_replica_sum(value) # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is @@ -290,6 +398,35 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): return output * (1. / len(value)) return output + def _update(self, var, fn, *args, **kwargs): + # TODO(jhseu): Consider supporting grouped==False. + assert isinstance(var, values.TPUMirroredVariable) + if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access + return fn(var, *args, **kwargs) + + # Otherwise, we revert to MirroredStrategy behavior and update each variable + # directly. + updates = {} + for d, v in var._index.items(): # pylint: disable=protected-access + name = "update_%d" % self._device_index.get(d) + with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + # If args and kwargs are not mirrored, the value is returned as is. + updates[d] = fn(v, + *values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs)) + + # Make a single control dependency to keep the variables mirrored. If one + # assignment is fetched, then run all assignments. + sorted_keys = sorted(updates.keys()) + update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys]) + for i, d in enumerate(sorted_keys): + updates[d] = update_tuple[i] + return values.regroup(updates, values.Mirrored) + + def read_var(self, var): + assert isinstance(var, values.TPUMirroredVariable) + return var.read_value() + def _unwrap(self, value): if isinstance(value, list): return value @@ -323,6 +460,14 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): def should_save_summary(self): return True + @property + def worker_devices(self): + return self._tpu_devices + + @property + def parameter_devices(self): + return self._tpu_devices + def get_host_cpu_device(self, host_id): if self._tpu_cluster_resolver.get_master() in ('', 'local'): return '/replica:0/task:0/device:CPU:0' diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 4955ded4d5..c18faeb67d 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -22,17 +22,20 @@ from __future__ import division from __future__ import print_function import collections +import contextlib import weakref import six from tensorflow.contrib.distribute.python import input_ops from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context +from tensorflow.python.eager import tape from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as variables_lib @@ -453,6 +456,384 @@ ops.register_tensor_conversion_function(MirroredVariable, _tensor_conversion_mirrored) +def _enclosing_tpu_context(): + # pylint: disable=protected-access + tpu_context = ops.get_default_graph()._get_control_flow_context() + # pylint: enable=protected-access + while tpu_context is not None and not isinstance( + tpu_context, control_flow_ops.XLAControlFlowContext): + tpu_context = tpu_context.outer_context + return tpu_context + + +# TODO(jhseu): Deduplicate code. We copy code because we don't want to +# inherit from DistributedDelegate. DistributedDelegate will not work in a +# tpu.replicate() because it assumes that you're in a device context where you +# can operate on a single version of the variable, but a tpu.replicate() +# operates on all variables and is replicated during a rewrite pass. +class TPUMirroredVariable(checkpointable.CheckpointableBase): + """Holds a map from device to TPU variables whose values are kept in sync.""" + + def __init__(self, index, primary_var, aggregation): + # Use a weakref to make it easy to map from the contained values + # to the container without introducing a reference cycle. + for v in six.itervalues(index): + v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access + self._index = {device_util.canonicalize(key): value + for key, value in six.iteritems(index)} + self._primary_var = primary_var + self._common_name = self._primary_var.name.split(":")[0] + self._aggregation = aggregation + # Needed for GradientTape + self._trainable = self._primary_var.trainable + + def _get(self, device=None): + """Returns the value for the current device or raises a ValueError.""" + if device is None: + tower_context = distribution_strategy_context.get_tower_context() + if tower_context: + device = tower_context.device + else: + device = distribute_lib.get_update_device() + if device is None: + return self._get_cross_tower() + device = device_util.canonicalize(device) + try: + return self._index[device] + except KeyError as e: + six.raise_from( + ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())), e) + + # pylint: disable=multiple-statements + def __add__(self, o): return self.read_value() + o + def __radd__(self, o): return o + self.read_value() + def __sub__(self, o): return self.read_value() - o + def __rsub__(self, o): return o - self.read_value() + def __mul__(self, o): return self.read_value() * o + def __rmul__(self, o): return o * self.read_value() + def __truediv__(self, o): return self.read_value() / o + def __rtruediv__(self, o): return o / self.read_value() + def __floordiv__(self, o): return self.read_value() // o + def __rfloordiv__(self, o): return o // self.read_value() + def __mod__(self, o): return self.read_value() % o + def __rmod__(self, o): return o % self.read_value() + def __lt__(self, o): return self.read_value() < o + def __le__(self, o): return self.read_value() <= o + def __gt__(self, o): return self.read_value() > o + def __ge__(self, o): return self.read_value() >= o + def __and__(self, o): return self.read_value() & o + def __rand__(self, o): return o & self.read_value() + def __or__(self, o): return self.read_value() | o + def __ror__(self, o): return o | self.read_value() + def __xor__(self, o): return self.read_value() ^ o + def __rxor__(self, o): return o ^ self.read_value() + def __getitem__(self, o): return self.read_value()[o] + def __pow__(self, o, modulo=None): return pow(self.read_value(), o, modulo) + def __rpow__(self, o): return pow(o, self.read_value()) + def __invert__(self): return ~self.read_value() + def __neg__(self): return -self.read_value() + def __abs__(self): return abs(self.read_value()) + + def __div__(self, o): + try: + return self.read_value().__div__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rdiv__(self, o): + try: + return self.read_value().__rdiv__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __matmul__(self, o): + try: + return self.read_value().__matmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rmatmul__(self, o): + try: + return self.read_value().__rmatmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + @property + def handle(self): + # If we're in a tpu.rewrite(), return the replicated handle. + tpu_context = _enclosing_tpu_context() + if tpu_context is not None: + return tpu_context.get_replicated_var_handle( + self._common_name, nest.flatten(self._index)) + + device = distribute_lib.get_update_device() + if device is None: + return self._primary_var.handle + device = device_util.canonicalize(device) + try: + return self._index[device].handle + except KeyError as e: + six.raise_from( + ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())), e) + + # The arguments to update() are automatically unwrapped so the update() + # function would normally see regular variables, not MirroredVariables. + # However, the update function can still operate on wrapped MirroredVariables + # through object members, captured arguments, etc. This is more likely in an + # update_non_slot() function (like OptimizerV2._finish), which can + # update several non-slot variables in one call. + def _assign_func(self, *args, **kwargs): + if distribution_strategy_context.get_distribution_strategy().__class__.__name__ != "TPUStrategy": + raise ValueError("You may only assign to a TPUMirroredVariable within a " + "TPUStrategy.") + f = kwargs.pop("f") + if distribution_strategy_context.get_cross_tower_context(): + if _enclosing_tpu_context() is not None: + return distribution_strategy_context.get_distribution_strategy().update( + self, f, *args, **kwargs) + + update_device = distribute_lib.get_update_device() + # We are calling update on the mirrored variable in cross tower context. + if update_device is not None: + # We are calling an assign function on the mirrored variable in cross + # tower context. + v = self._get(device=update_device) + return f(v, *args, **kwargs) + + return distribution_strategy_context.get_distribution_strategy().update( + self, f, *args, **kwargs) + else: + _assert_tower_context() + # We are calling an assign function on the mirrored variable in tower + # context. + # We reduce the value we want to assign/add/sub. More details about how we + # handle the different use cases can be found in the _reduce method. + # We call the function on each of the mirrored variables with the reduced + # value. + if self._aggregation == vs.VariableAggregation.NONE: + raise ValueError("You must specify an aggregation method to update a " + "TPUMirroredVariable in Tower Context.") + + def merge_fn(strategy, value, *other_args, **other_kwargs): + return strategy.update( + self, f, + strategy.reduce( + aggregation=self._aggregation, value=value, destinations=self), + *other_args, **other_kwargs) + + return distribution_strategy_context.get_tower_context().merge_call( + merge_fn, *args, **kwargs) + + @contextlib.contextmanager + def _handle_graph(self, handle): + # Note: might have an eager tensor but not be executing eagerly when + # building functions. + if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) + or ops.has_default_graph()): + yield + else: + with handle.graph.as_default(): + yield + + @property + def trainable(self): + return self._trainable + + def _read_variable_op(self, parent_op=None): + if self.trainable: + tape.variable_accessed(self) + if parent_op is not None: + with ops.control_dependencies([parent_op]): + return gen_resource_variable_ops.read_variable_op( + self.handle, self.dtype) + + return gen_resource_variable_ops.read_variable_op( + self.handle, self.dtype) + + def read_value(self): + return self._read_variable_op() + + def assign_sub(self, *args, **kwargs): + def assign_sub_fn(var, delta, **kw): + name = kw.pop("name", None) + read_value = kw.pop("read_value", True) + with self._handle_graph(var.handle): + op = gen_resource_variable_ops.assign_sub_variable_op( + var.handle, ops.convert_to_tensor(delta, dtype=self.dtype), + name=name) + if read_value: + return self._read_variable_op(parent_op=op) + return op + + return self._assign_func(f=assign_sub_fn, *args, **kwargs) + + def assign_add(self, *args, **kwargs): + def assign_add_fn(var, delta, **kw): + name = kw.pop("name", None) + read_value = kw.pop("read_value", True) + with self._handle_graph(var.handle): + op = gen_resource_variable_ops.assign_add_variable_op( + var.handle, ops.convert_to_tensor(delta, dtype=self.dtype), + name=name) + if read_value: + return self._read_variable_op(parent_op=op) + return op + + return self._assign_func(f=assign_add_fn, *args, **kwargs) + + def assign(self, *args, **kwargs): + def assign_fn(var, value, **kw): + name = kw.pop("name", None) + read_value = kw.pop("read_value", True) + with self._handle_graph(var.handle): + op = gen_resource_variable_ops.assign_variable_op( + var.handle, ops.convert_to_tensor(value, dtype=self.dtype), + name=name) + if read_value: + return self._read_variable_op(parent_op=op) + return op + + return self._assign_func(f=assign_fn, *args, **kwargs) + + @property + def aggregation(self): + return self._aggregation + + @property + def constraint(self): + return None + + @property + def initializer(self): + return control_flow_ops.group( + [v.initializer for v in nest.flatten(self._index)]) + + @property + def graph(self): + return self._primary_var.graph + + @property + def _shared_name(self): + return self._common_name + + @property + def _unique_id(self): + return self._primary_var._unique_id # pylint: disable=protected-access + + @property + def name(self): + return self._primary_var.name + + @property + def dtype(self): + return self._primary_var.dtype + + @property + def shape(self): + return self._primary_var.shape + + def get_shape(self): + return self._primary_var.get_shape() + + def to_proto(self, export_scope=None): + return self._primary_var.to_proto(export_scope=export_scope) + + def _get_cross_tower(self): + device = device_util.canonicalize(device_util.current()) + if device in self._index: + return self._index[device] + return self._primary_var + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribution_strategy_context.get_cross_tower_context(): + return self._primary_var._as_graph_element() + return self._read_variable_op() + + def _gather_saveables_for_checkpoint(self): + """Overrides CheckpointableBase method. + + This allows both name-based and object-based save and restore of + MirroredVariables. + + Returns: + A dictionary mapping attribute names to `SaveableObject` factories. + """ + def _saveable_factory(name=self._common_name): + return _MirroredSaveable(self, self._primary_var, name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + def _should_act_as_resource_variable(self): + """Pass resource_variable_ops.is_resource_variable check.""" + pass + + # Needed to pass ResourceVariable checks. + @property + def op(self): + return self._primary_var.op + + @property + def _in_graph_mode(self): + return self._primary_var._in_graph_mode # pylint: disable=protected-access + + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): + """Converts a variable to a tensor.""" + # pylint: disable=protected-access + if _enclosing_tpu_context() is None: + return self._get()._dense_var_to_tensor(dtype, name, as_ref) + # pylint: enable=protected-access + if dtype is not None and dtype != self.dtype: + raise NotImplementedError + if as_ref: + return self.handle + else: + return self.read_value() + + def is_initialized(self, name=None): + """Identifies if all the component variables are initialized. + + Args: + name: Name of the final `logical_and` op. + + Returns: + The op that evaluates to True or False depending on if all the + component variables are initialized. + """ + # TODO(jhseu): Do we need TPU context implementation? + + # We have to cast the self._index.values() to a `list` because when we + # use `model_to_estimator` to run tf.keras models, self._index.values() is + # of type `dict_values` and not `list`. + values_list = nest.flatten(self._index) + result = values_list[0].is_initialized() + # We iterate through the list of values except the last one to allow us to + # name the final `logical_and` op the same name that is passed by the user + # to the `is_initialized` op. For distributed variables, the + # `is_initialized` op is a `logical_and` op. + for v in values_list[1:-1]: + result = math_ops.logical_and(result, v.is_initialized()) + result = math_ops.logical_and(result, values_list[-1].is_initialized(), + name=name) + return result + + +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion_tpu_mirrored(var, dtype=None, name=None, as_ref=False): + return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access + + +ops.register_tensor_conversion_function(TPUMirroredVariable, + _tensor_conversion_tpu_mirrored) +ops.register_dense_tensor_like_type(TPUMirroredVariable) + + class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): """Class for defining how to restore a TowerLocalVariable.""" diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py index 598da7418e..004b1012e5 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py @@ -78,7 +78,7 @@ class ReplicatedVariable(object): if tpu_context is None: return self._primary_var.handle - return tpu_context.get_replicated_var_handle(self) + return tpu_context.get_replicated_var_handle(self._name, self._vars) @contextlib.contextmanager def _assign_dependencies(self): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 883e08bf47..11aaa1c66a 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -155,19 +155,20 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._pivot = pivot self._replicated_vars = {} - def get_replicated_var_handle(self, var): + def get_replicated_var_handle(self, name, vars_): """Returns a variable handle for replicated TPU variable 'var'. This is a method used by an experimental replicated variable implementation and is not intended as a public API. Args: - var: The replicated TPU variable. + name: The common name of the variable. + vars_: The replicated TPU variables. Returns: The handle of the TPU replicated input node. """ - handle = self._replicated_vars.get(var) + handle = self._replicated_vars.get(name) if handle is not None: return handle @@ -183,10 +184,10 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): saved_context = graph._get_control_flow_context() graph._set_control_flow_context(self.outer_context) handle = tpu_ops.tpu_replicated_input( - [v.handle for v in var._vars], name=var.name + "/handle") + [v.handle for v in vars_], name=name + "/handle") graph._set_control_flow_context(saved_context) # pylint: enable=protected-access - self._replicated_vars[var] = handle + self._replicated_vars[name] = handle return handle def report_unsupported_operations(self): diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 78f3198011..deac29111f 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -619,7 +619,7 @@ pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace) def _handle_or_self(x): """If x is ResourceVariable, return its handle, else x.""" - if isinstance(x, resource_variable_ops.ResourceVariable): + if resource_variable_ops.is_resource_variable(x): x = x.handle return x diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 34faf03bb0..e6d82f0db7 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -468,6 +468,10 @@ class Estimator(object): with ops.Graph().as_default(): if self._eval_distribution: + # We want to create the iterations variable outside the distribution + # scope as that is just stored on the host and mainly used to drive + # the loop and doesn't need to be a Mirrored/Device variable. + training.get_or_create_steps_per_run_variable() with self._eval_distribution.scope(): return _evaluate() else: diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py index 31e4778e72..fb110c4b7b 100644 --- a/tensorflow/python/estimator/util.py +++ b/tensorflow/python/estimator/util.py @@ -22,7 +22,6 @@ from __future__ import print_function import os import time -from tensorflow.core.protobuf import config_pb2 from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training @@ -144,14 +143,11 @@ class StrategyInitFinalizeHook(training.SessionRunHook): self._finalize_fn = finalize_fn def begin(self): + # We only create the init ops, but don't run it. We rely on SessionManager + # to run it for us. self._init_ops = self._initialization_fn() self._finalize_ops = self._finalize_fn() - def after_create_session(self, session, coord): - logging.info('Initialize system') - session.run(self._init_ops, - options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) - def end(self, session): logging.info('Finalize system.') session.run(self._finalize_ops) diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index f004f3944a..30b0ed20c8 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -471,7 +471,10 @@ class Optimizer( if var_list is None: var_list = tape.watched_variables() - grads = tape.gradient(loss_value, var_list, grad_loss) + # TODO(jhseu): Figure out why GradientTape's gradients don't require loss + # to be executed. + with ops.control_dependencies([loss_value]): + grads = tape.gradient(loss_value, var_list, grad_loss) return list(zip(grads, var_list)) # Non-callable/Tensor loss case diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py index a2e0645ba8..5e4749f306 100644 --- a/tensorflow/python/training/session_manager.py +++ b/tensorflow/python/training/session_manager.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util.tf_export import tf_export @@ -182,6 +183,10 @@ class SessionManager(object): """ self._target = master sess = session.Session(self._target, graph=self._graph, config=config) + # TODO(jhseu): Delete once tpu.initialize_system() goes away. + sess.run( + distribution_strategy_context.get_distribution_strategy().initialize() + ) if checkpoint_dir and checkpoint_filename_with_path: raise ValueError("Can not provide both checkpoint_dir and " |