aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2018-09-28 12:08:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 12:12:54 -0700
commite00954e8626c74b263b90527e0c020cfd64136b2 (patch)
treea413c022f55de48bf1a5b6fe82d84861054b45c0 /tensorflow/contrib/tpu
parent90aa10fcf5c80591b31988754e6221d6c2b8bbd0 (diff)
Puts the keras optimizer weights on device.
PiperOrigin-RevId: 214974535
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py11
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py53
2 files changed, 63 insertions, 1 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 956d0142a3..696656e840 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -959,7 +959,16 @@ class TPUFunction(object):
# Compute our outfeed depending on the execution mode
if is_training:
- self._cloned_model._make_train_function()
+ if not isinstance(self._cloned_optimizer, keras_optimizers.TFOptimizer):
+ # For Keras optimizer, we try to place the variable weights on the TPU
+ # device. Keras creates optimizer variables (e.g. momentum values for
+ # the Momentum optimizer) when _make_train_function is invoked.
+ with keras_tpu_variables.replicated_variable_for_optimizer(
+ self._tpu_assignment.num_towers):
+ self._cloned_model._make_train_function()
+ else:
+ self._cloned_model._make_train_function()
+
self._outfeed_spec = [
tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
for tensor in self._cloned_model.train_function.outputs
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
index 170977d8ab..598da7418e 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
@@ -25,10 +25,15 @@ from __future__ import print_function
import contextlib
+import numpy as np
+
from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
+from tensorflow.python.keras import backend
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
@@ -285,3 +290,51 @@ def replicated_scope(num_replicas):
return variable_scope.variable_scope(
"", custom_getter=_replicated_variable_getter)
+
+
+@contextlib.contextmanager
+def replicated_variable_for_optimizer(num_replicas):
+ """Context manager for optimizer weights. Overrides K.variable."""
+ if num_replicas == 1:
+ yield
+ return
+
+ try:
+ old_v = backend.variable
+
+ def opt_variable(value, dtype=None, name=None, constraint=None):
+ """Instantiates a variable and returns it."""
+ if dtype is None:
+ dtype = backend.floatx()
+
+ variables = []
+ for i in range(num_replicas):
+ # Keras holds the variables in optimizer class instance , so the name
+ # does not matter here. ResourceVariable constructor will find a unique
+ # name (including name=None) for each replica.
+ with ops.device("device:TPU:{}".format(i)):
+ v = resource_variable_ops.ResourceVariable(
+ value,
+ dtype=dtypes_module.as_dtype(dtype),
+ name=name,
+ constraint=constraint)
+ variables.append(v)
+ name = "replicate_{}_{}".format("variable" if name is None else name,
+ ops.uid())
+ v = ReplicatedVariable(name, variables)
+
+ # pylint: disable=protected-access
+
+ if isinstance(value, np.ndarray):
+ v._keras_shape = value.shape
+ elif hasattr(value, "shape"):
+ v._keras_shape = backend.int_shape(value)
+ v._uses_learning_phase = False
+ backend.track_variable(v)
+ return v
+
+ backend.variable = opt_variable
+ yield
+
+ finally:
+ backend.variable = old_v