diff options
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/keras_support.py')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_support.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 956d0142a3..696656e840 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -959,7 +959,16 @@ class TPUFunction(object): # Compute our outfeed depending on the execution mode if is_training: - self._cloned_model._make_train_function() + if not isinstance(self._cloned_optimizer, keras_optimizers.TFOptimizer): + # For Keras optimizer, we try to place the variable weights on the TPU + # device. Keras creates optimizer variables (e.g. momentum values for + # the Momentum optimizer) when _make_train_function is invoked. + with keras_tpu_variables.replicated_variable_for_optimizer( + self._tpu_assignment.num_towers): + self._cloned_model._make_train_function() + else: + self._cloned_model._make_train_function() + self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) for tensor in self._cloned_model.train_function.outputs |