aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2018-09-28 18:41:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 18:45:56 -0700
commitd37f771cc5a208cdc88a50a65f491b3c06c9f262 (patch)
tree1036470d10da26df9f5dcf897a74c78329fe57cc
parentabd5c32c0fa6451e73b491affdd86d852a74177f (diff)
Move TPU variables to the TPU device in TPUStrategy.
PiperOrigin-RevId: 215027511
-rw-r--r--tensorflow/contrib/distribute/python/BUILD1
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py175
-rw-r--r--tensorflow/contrib/distribute/python/values.py381
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py2
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py11
-rw-r--r--tensorflow/python/eager/backprop.py2
-rw-r--r--tensorflow/python/estimator/estimator.py4
-rw-r--r--tensorflow/python/estimator/util.py8
-rw-r--r--tensorflow/python/training/optimizer.py5
-rw-r--r--tensorflow/python/training/session_manager.py5
10 files changed, 565 insertions, 29 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 422983dbef..cfb9d42a6f 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -28,6 +28,7 @@ py_library(
"//tensorflow/python:device_util",
"//tensorflow/python:distribute",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index a6762e5e87..1b555482d3 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -37,9 +38,13 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest
+_TPU_INITIALIZE_SYSTEM_COLLECTION = "TPU_STRATEGY_INITIALIZE"
+
+
def get_tpu_system_metadata(tpu_cluster_resolver):
"""Retrieves TPU system metadata given a TPUClusterResolver."""
master = tpu_cluster_resolver.master()
@@ -56,6 +61,58 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
return tpu_system_metadata
+# TODO(jhseu): Deduplicate with MirroredStrategy?
+def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args,
+ **kwargs): # pylint: disable=g-missing-docstring
+ # Figure out what collections this variable should be added to.
+ # We'll add the TPUMirroredVariable to those collections instead.
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ # TODO(jhseu): Should we have different behavior for different
+ # synchronization settings?
+
+ # Get aggregation value
+ # TODO(jhseu): Support aggregation in a tower context.
+ aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
+ if aggregation not in [
+ vs.VariableAggregation.NONE,
+ vs.VariableAggregation.SUM,
+ vs.VariableAggregation.MEAN,
+ vs.VariableAggregation.ONLY_FIRST_TOWER,
+ ]:
+ raise ValueError("Invalid variable aggregation mode: {} for variable: {}"
+ .format(aggregation, kwargs["name"]))
+
+ # Ignore user-specified caching device, not needed for mirrored variables.
+ kwargs.pop("caching_device", None)
+
+ # TODO(josh11b,apassos): It would be better if variable initialization
+ # was never recorded on the tape instead of having to do this manually
+ # here.
+ with tape.stop_recording():
+ index = real_mirrored_creator(devices, *args, **kwargs)
+ result = values.TPUMirroredVariable(index, index[devices[0]], aggregation)
+
+ if not context.executing_eagerly():
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the member variables
+ # to the TRAINABLE_VARIABLES collection, so we manually remove
+ # them and replace with the MirroredVariable. We can't set
+ # "trainable" to False for next_creator() since that causes functions
+ # like implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ for v in index.values():
+ l.remove(v)
+ g.add_to_collections(collections, result)
+ return result
+
+
+# TODO(jhseu): Stop inheriting from OneDeviceStrategy.
class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""Experimental TPU distribution strategy implementation."""
@@ -82,6 +139,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
# TODO(sourabhbajaj): Change this from num_cores to metadata_override
self._num_cores_override = num_cores
+ # TODO(jhseu): Switch to DeviceAssignment to support pods and model
+ # parallelism.
+ device_map = {d.name: i for i, d in enumerate(self._tpu_metadata.devices)
+ if "device:TPU:" in d.name}
+ self._device_index = values.PerDevice(device_map)
+ self._tpu_devices = sorted(device_map.keys())
+ # Only create variables for the number of towers we're running.
+ self._tpu_devices = self._tpu_devices[:self.num_towers]
+
# TODO(sourabhbajaj): Remove this once performance of running one step
# at a time is comparable to multiple steps.
self.steps_per_run = steps_per_run
@@ -239,6 +305,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
return ctx
def _call_for_each_tower(self, fn, *args, **kwargs):
+ # TODO(jhseu): Consider making it so call_for_each_tower implies that we're
+ # in a tpu.rewrite(), and update TPUMirroredVariable accordingly.
kwargs.pop('run_concurrently', None)
with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access
return fn(*args, **kwargs)
@@ -248,7 +316,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
# TODO(priyag): Add appopriate call here when eager is supported for TPUs.
raise NotImplementedError('Eager mode not supported in TPUStrategy.')
else:
- return [tpu.initialize_system()]
+ # TODO(jhseu): We need this hack because DistributionStrategies must be
+ # pickleable for copy.deepcopy(). Remove when initialize_system goes away.
+ graph = ops.get_default_graph()
+ tpu_init = graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION)
+ if tpu_init:
+ return tpu_init
+ graph.add_to_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION,
+ tpu.initialize_system())
+ return graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION)
def finalize(self):
if context.executing_eagerly():
@@ -257,21 +333,53 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
else:
return [tpu.shutdown_system()]
+ def _get_devices_from(self, colocate_with=None):
+ # TODO(jhseu): Change this when we support model parallelism.
+ return self._tpu_devices
+
+ def _create_variable(self, next_creator, *args, **kwargs):
+ """Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
+ colocate_with = kwargs.pop("colocate_with", None)
+ devices = self._get_devices_from(colocate_with)
+
+ def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring
+ index = {}
+ for i, d in enumerate(devices):
+ with ops.device(d):
+ if i > 0:
+ # Give replicas meaningful distinct names:
+ var0name = index[devices[0]].name.split(":")[0]
+ # We append a / to variable names created on towers with id > 0 to
+ # ensure that we ignore the name scope and instead use the given
+ # name as the absolute name of the variable.
+ kwargs["name"] = "%s/replica_%d/" % (var0name, i)
+ # Initialize replicas with the same value:
+ if context.executing_eagerly():
+ kwargs["initial_value"] = array_ops.identity(
+ index[devices[0]].value())
+ else:
+ def initial_value_fn(device=d):
+ with ops.device(device):
+ return array_ops.identity(index[devices[0]].initial_value)
+ kwargs["initial_value"] = initial_value_fn
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ v = next_creator(*args, **kwargs)
+ assert not isinstance(v, values.TPUMirroredVariable)
+ index[d] = v
+ return index
+
+ return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args,
+ **kwargs)
+
def _reduce(self, aggregation, value, destinations):
- graph = ops.get_default_graph()
- cf_context = graph._get_control_flow_context() # pylint: disable=protected-access
- # If we're inside the ReplicateContext, reduction should be done using
- # CrossReplicaSum while outside we can directly use an add_n op.
- while cf_context:
- if isinstance(cf_context, tpu.TPUReplicateContext):
- if aggregation == vs.VariableAggregation.MEAN:
- # TODO(jhseu): Revisit once we support model-parallelism.
- value *= (1. / self.num_towers)
- elif aggregation != vs.VariableAggregation.SUM:
- raise NotImplementedError(
- 'Currently only support sum & mean in TPUStrategy.')
- return tpu_ops.cross_replica_sum(value)
- cf_context = cf_context.outer_context
+ if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
+ if aggregation == vs.VariableAggregation.MEAN:
+ # TODO(jhseu): Revisit once we support model-parallelism.
+ value *= (1. / self.num_towers)
+ elif aggregation != vs.VariableAggregation.SUM:
+ raise NotImplementedError(
+ "Currently only support sum & mean in TPUStrategy.")
+ return tpu_ops.cross_replica_sum(value)
# Validate that the destination is same as the host device
# Note we don't do this when in replicate context as the reduction is
@@ -290,6 +398,35 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
return output * (1. / len(value))
return output
+ def _update(self, var, fn, *args, **kwargs):
+ # TODO(jhseu): Consider supporting grouped==False.
+ assert isinstance(var, values.TPUMirroredVariable)
+ if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
+ return fn(var, *args, **kwargs)
+
+ # Otherwise, we revert to MirroredStrategy behavior and update each variable
+ # directly.
+ updates = {}
+ for d, v in var._index.items(): # pylint: disable=protected-access
+ name = "update_%d" % self._device_index.get(d)
+ with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
+ # If args and kwargs are not mirrored, the value is returned as is.
+ updates[d] = fn(v,
+ *values.select_device_mirrored(d, args),
+ **values.select_device_mirrored(d, kwargs))
+
+ # Make a single control dependency to keep the variables mirrored. If one
+ # assignment is fetched, then run all assignments.
+ sorted_keys = sorted(updates.keys())
+ update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys])
+ for i, d in enumerate(sorted_keys):
+ updates[d] = update_tuple[i]
+ return values.regroup(updates, values.Mirrored)
+
+ def read_var(self, var):
+ assert isinstance(var, values.TPUMirroredVariable)
+ return var.read_value()
+
def _unwrap(self, value):
if isinstance(value, list):
return value
@@ -323,6 +460,14 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def should_save_summary(self):
return True
+ @property
+ def worker_devices(self):
+ return self._tpu_devices
+
+ @property
+ def parameter_devices(self):
+ return self._tpu_devices
+
def get_host_cpu_device(self, host_id):
if self._tpu_cluster_resolver.get_master() in ('', 'local'):
return '/replica:0/task:0/device:CPU:0'
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 4955ded4d5..c18faeb67d 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -22,17 +22,20 @@ from __future__ import division
from __future__ import print_function
import collections
+import contextlib
import weakref
import six
from tensorflow.contrib.distribute.python import input_ops
from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
@@ -453,6 +456,384 @@ ops.register_tensor_conversion_function(MirroredVariable,
_tensor_conversion_mirrored)
+def _enclosing_tpu_context():
+ # pylint: disable=protected-access
+ tpu_context = ops.get_default_graph()._get_control_flow_context()
+ # pylint: enable=protected-access
+ while tpu_context is not None and not isinstance(
+ tpu_context, control_flow_ops.XLAControlFlowContext):
+ tpu_context = tpu_context.outer_context
+ return tpu_context
+
+
+# TODO(jhseu): Deduplicate code. We copy code because we don't want to
+# inherit from DistributedDelegate. DistributedDelegate will not work in a
+# tpu.replicate() because it assumes that you're in a device context where you
+# can operate on a single version of the variable, but a tpu.replicate()
+# operates on all variables and is replicated during a rewrite pass.
+class TPUMirroredVariable(checkpointable.CheckpointableBase):
+ """Holds a map from device to TPU variables whose values are kept in sync."""
+
+ def __init__(self, index, primary_var, aggregation):
+ # Use a weakref to make it easy to map from the contained values
+ # to the container without introducing a reference cycle.
+ for v in six.itervalues(index):
+ v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
+ self._index = {device_util.canonicalize(key): value
+ for key, value in six.iteritems(index)}
+ self._primary_var = primary_var
+ self._common_name = self._primary_var.name.split(":")[0]
+ self._aggregation = aggregation
+ # Needed for GradientTape
+ self._trainable = self._primary_var.trainable
+
+ def _get(self, device=None):
+ """Returns the value for the current device or raises a ValueError."""
+ if device is None:
+ tower_context = distribution_strategy_context.get_tower_context()
+ if tower_context:
+ device = tower_context.device
+ else:
+ device = distribute_lib.get_update_device()
+ if device is None:
+ return self._get_cross_tower()
+ device = device_util.canonicalize(device)
+ try:
+ return self._index[device]
+ except KeyError as e:
+ six.raise_from(
+ ValueError("Device %s not found in %s (current device %s)" %
+ (device, self._index.keys(), device_util.current())), e)
+
+ # pylint: disable=multiple-statements
+ def __add__(self, o): return self.read_value() + o
+ def __radd__(self, o): return o + self.read_value()
+ def __sub__(self, o): return self.read_value() - o
+ def __rsub__(self, o): return o - self.read_value()
+ def __mul__(self, o): return self.read_value() * o
+ def __rmul__(self, o): return o * self.read_value()
+ def __truediv__(self, o): return self.read_value() / o
+ def __rtruediv__(self, o): return o / self.read_value()
+ def __floordiv__(self, o): return self.read_value() // o
+ def __rfloordiv__(self, o): return o // self.read_value()
+ def __mod__(self, o): return self.read_value() % o
+ def __rmod__(self, o): return o % self.read_value()
+ def __lt__(self, o): return self.read_value() < o
+ def __le__(self, o): return self.read_value() <= o
+ def __gt__(self, o): return self.read_value() > o
+ def __ge__(self, o): return self.read_value() >= o
+ def __and__(self, o): return self.read_value() & o
+ def __rand__(self, o): return o & self.read_value()
+ def __or__(self, o): return self.read_value() | o
+ def __ror__(self, o): return o | self.read_value()
+ def __xor__(self, o): return self.read_value() ^ o
+ def __rxor__(self, o): return o ^ self.read_value()
+ def __getitem__(self, o): return self.read_value()[o]
+ def __pow__(self, o, modulo=None): return pow(self.read_value(), o, modulo)
+ def __rpow__(self, o): return pow(o, self.read_value())
+ def __invert__(self): return ~self.read_value()
+ def __neg__(self): return -self.read_value()
+ def __abs__(self): return abs(self.read_value())
+
+ def __div__(self, o):
+ try:
+ return self.read_value().__div__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rdiv__(self, o):
+ try:
+ return self.read_value().__rdiv__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __matmul__(self, o):
+ try:
+ return self.read_value().__matmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rmatmul__(self, o):
+ try:
+ return self.read_value().__rmatmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ @property
+ def handle(self):
+ # If we're in a tpu.rewrite(), return the replicated handle.
+ tpu_context = _enclosing_tpu_context()
+ if tpu_context is not None:
+ return tpu_context.get_replicated_var_handle(
+ self._common_name, nest.flatten(self._index))
+
+ device = distribute_lib.get_update_device()
+ if device is None:
+ return self._primary_var.handle
+ device = device_util.canonicalize(device)
+ try:
+ return self._index[device].handle
+ except KeyError as e:
+ six.raise_from(
+ ValueError("Device %s not found in %s (current device %s)" %
+ (device, self._index.keys(), device_util.current())), e)
+
+ # The arguments to update() are automatically unwrapped so the update()
+ # function would normally see regular variables, not MirroredVariables.
+ # However, the update function can still operate on wrapped MirroredVariables
+ # through object members, captured arguments, etc. This is more likely in an
+ # update_non_slot() function (like OptimizerV2._finish), which can
+ # update several non-slot variables in one call.
+ def _assign_func(self, *args, **kwargs):
+ if distribution_strategy_context.get_distribution_strategy().__class__.__name__ != "TPUStrategy":
+ raise ValueError("You may only assign to a TPUMirroredVariable within a "
+ "TPUStrategy.")
+ f = kwargs.pop("f")
+ if distribution_strategy_context.get_cross_tower_context():
+ if _enclosing_tpu_context() is not None:
+ return distribution_strategy_context.get_distribution_strategy().update(
+ self, f, *args, **kwargs)
+
+ update_device = distribute_lib.get_update_device()
+ # We are calling update on the mirrored variable in cross tower context.
+ if update_device is not None:
+ # We are calling an assign function on the mirrored variable in cross
+ # tower context.
+ v = self._get(device=update_device)
+ return f(v, *args, **kwargs)
+
+ return distribution_strategy_context.get_distribution_strategy().update(
+ self, f, *args, **kwargs)
+ else:
+ _assert_tower_context()
+ # We are calling an assign function on the mirrored variable in tower
+ # context.
+ # We reduce the value we want to assign/add/sub. More details about how we
+ # handle the different use cases can be found in the _reduce method.
+ # We call the function on each of the mirrored variables with the reduced
+ # value.
+ if self._aggregation == vs.VariableAggregation.NONE:
+ raise ValueError("You must specify an aggregation method to update a "
+ "TPUMirroredVariable in Tower Context.")
+
+ def merge_fn(strategy, value, *other_args, **other_kwargs):
+ return strategy.update(
+ self, f,
+ strategy.reduce(
+ aggregation=self._aggregation, value=value, destinations=self),
+ *other_args, **other_kwargs)
+
+ return distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, *args, **kwargs)
+
+ @contextlib.contextmanager
+ def _handle_graph(self, handle):
+ # Note: might have an eager tensor but not be executing eagerly when
+ # building functions.
+ if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor)
+ or ops.has_default_graph()):
+ yield
+ else:
+ with handle.graph.as_default():
+ yield
+
+ @property
+ def trainable(self):
+ return self._trainable
+
+ def _read_variable_op(self, parent_op=None):
+ if self.trainable:
+ tape.variable_accessed(self)
+ if parent_op is not None:
+ with ops.control_dependencies([parent_op]):
+ return gen_resource_variable_ops.read_variable_op(
+ self.handle, self.dtype)
+
+ return gen_resource_variable_ops.read_variable_op(
+ self.handle, self.dtype)
+
+ def read_value(self):
+ return self._read_variable_op()
+
+ def assign_sub(self, *args, **kwargs):
+ def assign_sub_fn(var, delta, **kw):
+ name = kw.pop("name", None)
+ read_value = kw.pop("read_value", True)
+ with self._handle_graph(var.handle):
+ op = gen_resource_variable_ops.assign_sub_variable_op(
+ var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op(parent_op=op)
+ return op
+
+ return self._assign_func(f=assign_sub_fn, *args, **kwargs)
+
+ def assign_add(self, *args, **kwargs):
+ def assign_add_fn(var, delta, **kw):
+ name = kw.pop("name", None)
+ read_value = kw.pop("read_value", True)
+ with self._handle_graph(var.handle):
+ op = gen_resource_variable_ops.assign_add_variable_op(
+ var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op(parent_op=op)
+ return op
+
+ return self._assign_func(f=assign_add_fn, *args, **kwargs)
+
+ def assign(self, *args, **kwargs):
+ def assign_fn(var, value, **kw):
+ name = kw.pop("name", None)
+ read_value = kw.pop("read_value", True)
+ with self._handle_graph(var.handle):
+ op = gen_resource_variable_ops.assign_variable_op(
+ var.handle, ops.convert_to_tensor(value, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op(parent_op=op)
+ return op
+
+ return self._assign_func(f=assign_fn, *args, **kwargs)
+
+ @property
+ def aggregation(self):
+ return self._aggregation
+
+ @property
+ def constraint(self):
+ return None
+
+ @property
+ def initializer(self):
+ return control_flow_ops.group(
+ [v.initializer for v in nest.flatten(self._index)])
+
+ @property
+ def graph(self):
+ return self._primary_var.graph
+
+ @property
+ def _shared_name(self):
+ return self._common_name
+
+ @property
+ def _unique_id(self):
+ return self._primary_var._unique_id # pylint: disable=protected-access
+
+ @property
+ def name(self):
+ return self._primary_var.name
+
+ @property
+ def dtype(self):
+ return self._primary_var.dtype
+
+ @property
+ def shape(self):
+ return self._primary_var.shape
+
+ def get_shape(self):
+ return self._primary_var.get_shape()
+
+ def to_proto(self, export_scope=None):
+ return self._primary_var.to_proto(export_scope=export_scope)
+
+ def _get_cross_tower(self):
+ device = device_util.canonicalize(device_util.current())
+ if device in self._index:
+ return self._index[device]
+ return self._primary_var
+
+ def _as_graph_element(self):
+ # pylint: disable=protected-access
+ if distribution_strategy_context.get_cross_tower_context():
+ return self._primary_var._as_graph_element()
+ return self._read_variable_op()
+
+ def _gather_saveables_for_checkpoint(self):
+ """Overrides CheckpointableBase method.
+
+ This allows both name-based and object-based save and restore of
+ MirroredVariables.
+
+ Returns:
+ A dictionary mapping attribute names to `SaveableObject` factories.
+ """
+ def _saveable_factory(name=self._common_name):
+ return _MirroredSaveable(self, self._primary_var, name)
+ return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
+
+ def _should_act_as_resource_variable(self):
+ """Pass resource_variable_ops.is_resource_variable check."""
+ pass
+
+ # Needed to pass ResourceVariable checks.
+ @property
+ def op(self):
+ return self._primary_var.op
+
+ @property
+ def _in_graph_mode(self):
+ return self._primary_var._in_graph_mode # pylint: disable=protected-access
+
+ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
+ """Converts a variable to a tensor."""
+ # pylint: disable=protected-access
+ if _enclosing_tpu_context() is None:
+ return self._get()._dense_var_to_tensor(dtype, name, as_ref)
+ # pylint: enable=protected-access
+ if dtype is not None and dtype != self.dtype:
+ raise NotImplementedError
+ if as_ref:
+ return self.handle
+ else:
+ return self.read_value()
+
+ def is_initialized(self, name=None):
+ """Identifies if all the component variables are initialized.
+
+ Args:
+ name: Name of the final `logical_and` op.
+
+ Returns:
+ The op that evaluates to True or False depending on if all the
+ component variables are initialized.
+ """
+ # TODO(jhseu): Do we need TPU context implementation?
+
+ # We have to cast the self._index.values() to a `list` because when we
+ # use `model_to_estimator` to run tf.keras models, self._index.values() is
+ # of type `dict_values` and not `list`.
+ values_list = nest.flatten(self._index)
+ result = values_list[0].is_initialized()
+ # We iterate through the list of values except the last one to allow us to
+ # name the final `logical_and` op the same name that is passed by the user
+ # to the `is_initialized` op. For distributed variables, the
+ # `is_initialized` op is a `logical_and` op.
+ for v in values_list[1:-1]:
+ result = math_ops.logical_and(result, v.is_initialized())
+ result = math_ops.logical_and(result, values_list[-1].is_initialized(),
+ name=name)
+ return result
+
+
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion_tpu_mirrored(var, dtype=None, name=None, as_ref=False):
+ return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
+
+
+ops.register_tensor_conversion_function(TPUMirroredVariable,
+ _tensor_conversion_tpu_mirrored)
+ops.register_dense_tensor_like_type(TPUMirroredVariable)
+
+
class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
"""Class for defining how to restore a TowerLocalVariable."""
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
index 598da7418e..004b1012e5 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
@@ -78,7 +78,7 @@ class ReplicatedVariable(object):
if tpu_context is None:
return self._primary_var.handle
- return tpu_context.get_replicated_var_handle(self)
+ return tpu_context.get_replicated_var_handle(self._name, self._vars)
@contextlib.contextmanager
def _assign_dependencies(self):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 883e08bf47..11aaa1c66a 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -155,19 +155,20 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._pivot = pivot
self._replicated_vars = {}
- def get_replicated_var_handle(self, var):
+ def get_replicated_var_handle(self, name, vars_):
"""Returns a variable handle for replicated TPU variable 'var'.
This is a method used by an experimental replicated variable implementation
and is not intended as a public API.
Args:
- var: The replicated TPU variable.
+ name: The common name of the variable.
+ vars_: The replicated TPU variables.
Returns:
The handle of the TPU replicated input node.
"""
- handle = self._replicated_vars.get(var)
+ handle = self._replicated_vars.get(name)
if handle is not None:
return handle
@@ -183,10 +184,10 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
saved_context = graph._get_control_flow_context()
graph._set_control_flow_context(self.outer_context)
handle = tpu_ops.tpu_replicated_input(
- [v.handle for v in var._vars], name=var.name + "/handle")
+ [v.handle for v in vars_], name=name + "/handle")
graph._set_control_flow_context(saved_context)
# pylint: enable=protected-access
- self._replicated_vars[var] = handle
+ self._replicated_vars[name] = handle
return handle
def report_unsupported_operations(self):
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 78f3198011..deac29111f 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -619,7 +619,7 @@ pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
def _handle_or_self(x):
"""If x is ResourceVariable, return its handle, else x."""
- if isinstance(x, resource_variable_ops.ResourceVariable):
+ if resource_variable_ops.is_resource_variable(x):
x = x.handle
return x
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 34faf03bb0..e6d82f0db7 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -468,6 +468,10 @@ class Estimator(object):
with ops.Graph().as_default():
if self._eval_distribution:
+ # We want to create the iterations variable outside the distribution
+ # scope as that is just stored on the host and mainly used to drive
+ # the loop and doesn't need to be a Mirrored/Device variable.
+ training.get_or_create_steps_per_run_variable()
with self._eval_distribution.scope():
return _evaluate()
else:
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py
index 31e4778e72..fb110c4b7b 100644
--- a/tensorflow/python/estimator/util.py
+++ b/tensorflow/python/estimator/util.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import os
import time
-from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training
@@ -144,14 +143,11 @@ class StrategyInitFinalizeHook(training.SessionRunHook):
self._finalize_fn = finalize_fn
def begin(self):
+ # We only create the init ops, but don't run it. We rely on SessionManager
+ # to run it for us.
self._init_ops = self._initialization_fn()
self._finalize_ops = self._finalize_fn()
- def after_create_session(self, session, coord):
- logging.info('Initialize system')
- session.run(self._init_ops,
- options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))
-
def end(self, session):
logging.info('Finalize system.')
session.run(self._finalize_ops)
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index f004f3944a..30b0ed20c8 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -471,7 +471,10 @@ class Optimizer(
if var_list is None:
var_list = tape.watched_variables()
- grads = tape.gradient(loss_value, var_list, grad_loss)
+ # TODO(jhseu): Figure out why GradientTape's gradients don't require loss
+ # to be executed.
+ with ops.control_dependencies([loss_value]):
+ grads = tape.gradient(loss_value, var_list, grad_loss)
return list(zip(grads, var_list))
# Non-callable/Tensor loss case
diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py
index a2e0645ba8..5e4749f306 100644
--- a/tensorflow/python/training/session_manager.py
+++ b/tensorflow/python/training/session_manager.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util.tf_export import tf_export
@@ -182,6 +183,10 @@ class SessionManager(object):
"""
self._target = master
sess = session.Session(self._target, graph=self._graph, config=config)
+ # TODO(jhseu): Delete once tpu.initialize_system() goes away.
+ sess.run(
+ distribution_strategy_context.get_distribution_strategy().initialize()
+ )
if checkpoint_dir and checkpoint_filename_with_path:
raise ValueError("Can not provide both checkpoint_dir and "