diff options
author | 2018-10-05 09:54:40 -0700 | |
---|---|---|
committer | 2018-10-05 09:59:41 -0700 | |
commit | 8b7c789e7401fe56b4f648a04f675a3cb69119e5 (patch) | |
tree | 774d7ef4eaf5169620ae26325cc819bf47d072f2 /tensorflow/contrib/tpu | |
parent | 5a43e01ef0f8cb86d836a4d1c08a246630e26f8c (diff) |
- Don't set tpu optimizer parameter variable during weight initialization if the optimizer isn't set, e.g. loading weights and then predict.
- Add load_weights for `KerasTpuModel`.
PiperOrigin-RevId: 215920993
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/keras_support.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index a3a7fd8bb0..af183b3232 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -1998,6 +1998,9 @@ class KerasTPUModel(models.Model): logging.info('Setting weights on TPU model.') cloned_model.set_weights(weights) + if self._tpu_model.optimizer is None: + # tpu_model may not be compiled, e.g., loading weights and then predict. + return for k, v in six.iteritems(cpu_optimizer_config): opt_var = getattr(self._tpu_model.optimizer, k) if isinstance(opt_var, variables.Variable): @@ -2052,6 +2055,10 @@ class KerasTPUModel(models.Model): self._cpu_model.set_weights(weights) self._tpu_weights_initialized = False + def load_weights(self, filepath, by_name=False): + self._cpu_model.load_weights(filepath, by_name) + self._tpu_weights_initialized = False + # pylint: disable=bad-continuation def _validate_shapes(model): |