diff options
author | Jianwei Xie <xiejw@google.com> | 2018-09-28 12:08:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 12:12:54 -0700 |
commit | e00954e8626c74b263b90527e0c020cfd64136b2 (patch) | |
tree | a413c022f55de48bf1a5b6fe82d84861054b45c0 /tensorflow/contrib/tpu | |
parent | 90aa10fcf5c80591b31988754e6221d6c2b8bbd0 (diff) |
Puts the keras optimizer weights on device.
PiperOrigin-RevId: 214974535
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_support.py | 11 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py | 53 |
2 files changed, 63 insertions, 1 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 956d0142a3..696656e840 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -959,7 +959,16 @@ class TPUFunction(object): # Compute our outfeed depending on the execution mode if is_training: - self._cloned_model._make_train_function() + if not isinstance(self._cloned_optimizer, keras_optimizers.TFOptimizer): + # For Keras optimizer, we try to place the variable weights on the TPU + # device. Keras creates optimizer variables (e.g. momentum values for + # the Momentum optimizer) when _make_train_function is invoked. + with keras_tpu_variables.replicated_variable_for_optimizer( + self._tpu_assignment.num_towers): + self._cloned_model._make_train_function() + else: + self._cloned_model._make_train_function() + self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) for tensor in self._cloned_model.train_function.outputs diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py index 170977d8ab..598da7418e 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py @@ -25,10 +25,15 @@ from __future__ import print_function import contextlib +import numpy as np + from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops +from tensorflow.python.keras import backend from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_resource_variable_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -285,3 +290,51 @@ def replicated_scope(num_replicas): return variable_scope.variable_scope( "", custom_getter=_replicated_variable_getter) + + +@contextlib.contextmanager +def replicated_variable_for_optimizer(num_replicas): + """Context manager for optimizer weights. Overrides K.variable.""" + if num_replicas == 1: + yield + return + + try: + old_v = backend.variable + + def opt_variable(value, dtype=None, name=None, constraint=None): + """Instantiates a variable and returns it.""" + if dtype is None: + dtype = backend.floatx() + + variables = [] + for i in range(num_replicas): + # Keras holds the variables in optimizer class instance , so the name + # does not matter here. ResourceVariable constructor will find a unique + # name (including name=None) for each replica. + with ops.device("device:TPU:{}".format(i)): + v = resource_variable_ops.ResourceVariable( + value, + dtype=dtypes_module.as_dtype(dtype), + name=name, + constraint=constraint) + variables.append(v) + name = "replicate_{}_{}".format("variable" if name is None else name, + ops.uid()) + v = ReplicatedVariable(name, variables) + + # pylint: disable=protected-access + + if isinstance(value, np.ndarray): + v._keras_shape = value.shape + elif hasattr(value, "shape"): + v._keras_shape = backend.int_shape(value) + v._uses_learning_phase = False + backend.track_variable(v) + return v + + backend.variable = opt_variable + yield + + finally: + backend.variable = old_v |